5 changed files with 465 additions and 474 deletions
-
2Dockerfile
-
2go.mod
-
464main.go
-
108option.go
-
363parser.go
@ -1,108 +0,0 @@ |
|||
package main |
|||
|
|||
type NullStyle int |
|||
|
|||
const ( |
|||
NullDisable NullStyle = iota |
|||
NullInSql |
|||
NullInPointer |
|||
) |
|||
|
|||
type Option func(*options) |
|||
|
|||
type options struct { |
|||
Charset string |
|||
Collation string |
|||
JsonTag bool |
|||
ZhTag bool |
|||
TablePrefix string |
|||
ColumnPrefix string |
|||
NoNullType bool |
|||
NullStyle NullStyle |
|||
Package string |
|||
GormType bool |
|||
ForceTableName bool |
|||
} |
|||
|
|||
var defaultOptions = options{ |
|||
NullStyle: NullInSql, |
|||
Package: "model", |
|||
} |
|||
|
|||
func WithCharset(charset string) Option { |
|||
return func(o *options) { |
|||
o.Charset = charset |
|||
} |
|||
} |
|||
|
|||
func WithCollation(collation string) Option { |
|||
return func(o *options) { |
|||
o.Collation = collation |
|||
} |
|||
} |
|||
|
|||
func WithTablePrefix(p string) Option { |
|||
return func(o *options) { |
|||
o.TablePrefix = p |
|||
} |
|||
} |
|||
|
|||
func WithColumnPrefix(p string) Option { |
|||
return func(o *options) { |
|||
o.ColumnPrefix = p |
|||
} |
|||
} |
|||
|
|||
func WithJsonTag() Option { |
|||
return func(o *options) { |
|||
o.JsonTag = true |
|||
} |
|||
} |
|||
|
|||
func WithZhTag() Option { |
|||
return func(o *options) { |
|||
o.ZhTag = true |
|||
} |
|||
} |
|||
|
|||
func WithNoNullType() Option { |
|||
return func(o *options) { |
|||
o.NoNullType = true |
|||
} |
|||
} |
|||
|
|||
func WithNullStyle(s NullStyle) Option { |
|||
return func(o *options) { |
|||
o.NullStyle = s |
|||
} |
|||
} |
|||
|
|||
func WithPackage(pkg string) Option { |
|||
return func(o *options) { |
|||
o.Package = pkg |
|||
} |
|||
} |
|||
|
|||
// WithGormType will write type in gorm tag
|
|||
func WithGormType() Option { |
|||
return func(o *options) { |
|||
o.GormType = true |
|||
} |
|||
} |
|||
|
|||
func WithForceTableName() Option { |
|||
return func(o *options) { |
|||
o.ForceTableName = true |
|||
} |
|||
} |
|||
|
|||
func parseOption(options []Option) options { |
|||
o := defaultOptions |
|||
for _, f := range options { |
|||
f(&o) |
|||
} |
|||
if o.NoNullType { |
|||
o.NullStyle = NullDisable |
|||
} |
|||
return o |
|||
} |
|||
@ -1,363 +0,0 @@ |
|||
package main |
|||
|
|||
// restruct from https://github.com/miaogaolin/gotl/blob/main/common/sql2gorm/parser/parser.go
|
|||
import ( |
|||
"fmt" |
|||
"go/format" |
|||
"io" |
|||
"sort" |
|||
"strings" |
|||
"sync" |
|||
"text/template" |
|||
|
|||
"github.com/iancoleman/strcase" |
|||
"github.com/jinzhu/inflection" |
|||
"github.com/knocknote/vitess-sqlparser/tidbparser/ast" |
|||
"github.com/knocknote/vitess-sqlparser/tidbparser/dependency/mysql" |
|||
"github.com/knocknote/vitess-sqlparser/tidbparser/dependency/types" |
|||
"github.com/knocknote/vitess-sqlparser/tidbparser/parser" |
|||
"github.com/pkg/errors" |
|||
) |
|||
|
|||
var ( |
|||
structTmplRaw string |
|||
fileTmplRaw string |
|||
structTmpl *template.Template |
|||
fileTmpl *template.Template |
|||
tmplParseOnce sync.Once |
|||
) |
|||
|
|||
type ModelCodes struct { |
|||
Package string |
|||
ImportPath []string |
|||
StructCode []string |
|||
} |
|||
|
|||
func ParseSql(sql string, options ...Option) (*ModelCodes, error) { |
|||
initTemplate() |
|||
opt := parseOption(options) |
|||
|
|||
stmts, err := parser.New().Parse(sql, opt.Charset, opt.Collation) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
tableStr := make([]string, 0, len(stmts)) |
|||
importPath := make(map[string]struct{}) |
|||
for _, stmt := range stmts { |
|||
if ct, ok := stmt.(*ast.CreateTableStmt); ok { |
|||
s, ipt, err := makeCode(ct, opt) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
tableStr = append(tableStr, s) |
|||
for _, s := range ipt { |
|||
importPath[s] = struct{}{} |
|||
} |
|||
} |
|||
} |
|||
importPathArr := make([]string, 0, len(importPath)) |
|||
for s := range importPath { |
|||
importPathArr = append(importPathArr, s) |
|||
} |
|||
sort.Strings(importPathArr) |
|||
return &ModelCodes{ |
|||
Package: opt.Package, |
|||
ImportPath: importPathArr, |
|||
StructCode: tableStr, |
|||
}, nil |
|||
} |
|||
|
|||
func ParseSqlToWrite(sql string, writer io.Writer, options ...Option) error { |
|||
data, err := ParseSql(sql, options...) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
err = fileTmpl.Execute(writer, data) |
|||
if err != nil { |
|||
return err |
|||
} |
|||
|
|||
return nil |
|||
} |
|||
|
|||
func ParseSqlFormat(sql string, options ...Option) ([]byte, error) { |
|||
w := strings.Builder{} |
|||
err := ParseSqlToWrite(sql, &w, options...) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
return format.Source([]byte(w.String())) |
|||
} |
|||
|
|||
type tmplData struct { |
|||
TableName string |
|||
NameFunc bool |
|||
RawTableName string |
|||
Fields []tmplField |
|||
Comment string |
|||
} |
|||
|
|||
type tmplField struct { |
|||
Name string |
|||
GoType string |
|||
Tag string |
|||
Comment string |
|||
} |
|||
|
|||
func makeCode(stmt *ast.CreateTableStmt, opt options) (string, []string, error) { |
|||
importPath := make([]string, 0, 1) |
|||
data := tmplData{ |
|||
TableName: stmt.Table.Name.String(), |
|||
RawTableName: stmt.Table.Name.String(), |
|||
Fields: make([]tmplField, 0, 1), |
|||
} |
|||
tablePrefix := opt.TablePrefix |
|||
if tablePrefix != "" && strings.HasPrefix(data.TableName, tablePrefix) { |
|||
data.NameFunc = true |
|||
data.TableName = data.TableName[len(tablePrefix):] |
|||
} |
|||
if opt.ForceTableName || data.RawTableName != inflection.Plural(data.RawTableName) { |
|||
data.NameFunc = true |
|||
} |
|||
|
|||
data.TableName = strcase.ToCamel(data.TableName) |
|||
|
|||
// find table comment
|
|||
for _, opt := range stmt.Options { |
|||
if opt.Tp == ast.TableOptionComment { |
|||
data.Comment = opt.StrValue |
|||
break |
|||
} |
|||
} |
|||
|
|||
isPrimaryKey := make(map[string]bool) |
|||
for _, con := range stmt.Constraints { |
|||
if con.Tp == ast.ConstraintPrimaryKey { |
|||
isPrimaryKey[con.Keys[0].Column.String()] = true |
|||
} |
|||
} |
|||
|
|||
columnPrefix := opt.ColumnPrefix |
|||
for _, col := range stmt.Cols { |
|||
colName := col.Name.Name.String() |
|||
goFieldName := colName |
|||
if columnPrefix != "" && strings.HasPrefix(goFieldName, columnPrefix) { |
|||
goFieldName = goFieldName[len(columnPrefix):] |
|||
} |
|||
|
|||
field := tmplField{ |
|||
Name: strcase.ToCamel(goFieldName), |
|||
} |
|||
|
|||
tags := make([]string, 0, 4) |
|||
// make GORM's tag
|
|||
gormTag := strings.Builder{} |
|||
gormTag.WriteString("column:") |
|||
gormTag.WriteString(colName) |
|||
if opt.GormType { |
|||
gormTag.WriteString(";type:") |
|||
gormTag.WriteString(col.Tp.InfoSchemaStr()) |
|||
} |
|||
if isPrimaryKey[colName] { |
|||
gormTag.WriteString(";primary_key") |
|||
} |
|||
isNotNull := false |
|||
canNull := false |
|||
for _, o := range col.Options { |
|||
switch o.Tp { |
|||
case ast.ColumnOptionPrimaryKey: |
|||
if !isPrimaryKey[colName] { |
|||
gormTag.WriteString(";primary_key") |
|||
isPrimaryKey[colName] = true |
|||
} |
|||
case ast.ColumnOptionNotNull: |
|||
isNotNull = true |
|||
case ast.ColumnOptionAutoIncrement: |
|||
gormTag.WriteString(";AUTO_INCREMENT") |
|||
case ast.ColumnOptionDefaultValue: |
|||
if value := getDefaultValue(o.Expr); value != "" { |
|||
gormTag.WriteString(";default:") |
|||
gormTag.WriteString(value) |
|||
} |
|||
case ast.ColumnOptionUniqKey: |
|||
gormTag.WriteString(";unique") |
|||
case ast.ColumnOptionNull: |
|||
//gormTag.WriteString(";NULL")
|
|||
canNull = true |
|||
case ast.ColumnOptionOnUpdate: // For Timestamp and Datetime only.
|
|||
case ast.ColumnOptionFulltext: |
|||
case ast.ColumnOptionComment: |
|||
field.Comment = o.Expr.GetDatum().GetString() |
|||
default: |
|||
//return "", nil, errors.Errorf(" unsupport option %d\n", o.Tp)
|
|||
} |
|||
} |
|||
if !isPrimaryKey[colName] && isNotNull { |
|||
gormTag.WriteString(";NOT NULL") |
|||
} |
|||
tags = append(tags, "gorm", gormTag.String()) |
|||
|
|||
if opt.JsonTag { |
|||
tags = append(tags, "json", colName) |
|||
} |
|||
|
|||
if opt.ZhTag { |
|||
tags = append(tags, "zh-cn", field.Comment) |
|||
} |
|||
field.Tag = makeTagStr(tags) |
|||
|
|||
// get type in golang
|
|||
nullStyle := opt.NullStyle |
|||
if !canNull { |
|||
nullStyle = NullDisable |
|||
} |
|||
goType, pkg := mysqlToGoType(col.Tp, nullStyle) |
|||
if pkg != "" { |
|||
importPath = append(importPath, pkg) |
|||
} |
|||
field.GoType = goType |
|||
|
|||
data.Fields = append(data.Fields, field) |
|||
} |
|||
|
|||
builder := strings.Builder{} |
|||
err := structTmpl.Execute(&builder, data) |
|||
if err != nil { |
|||
return "", nil, err |
|||
} |
|||
code, err := format.Source([]byte(builder.String())) |
|||
if err != nil { |
|||
return string(code), importPath, errors.WithMessage(err, "format golang code error") |
|||
} |
|||
return string(code), importPath, nil |
|||
} |
|||
|
|||
func mysqlToGoType(colTp *types.FieldType, style NullStyle) (name string, path string) { |
|||
if style == NullInSql { |
|||
path = "database/sql" |
|||
switch colTp.Tp { |
|||
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong: |
|||
name = "int64" |
|||
case mysql.TypeLonglong: |
|||
name = "int64" |
|||
case mysql.TypeFloat, mysql.TypeDouble: |
|||
name = "float64" |
|||
case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, |
|||
mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: |
|||
name = "string" |
|||
case mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDate: |
|||
name = "int64" |
|||
case mysql.TypeDecimal, mysql.TypeNewDecimal: |
|||
name = "float64" |
|||
case mysql.TypeJSON, mysql.TypeEnum: |
|||
name = "string" |
|||
default: |
|||
return "UnSupport", "" |
|||
} |
|||
} else { |
|||
switch colTp.Tp { |
|||
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong: |
|||
if mysql.HasUnsignedFlag(colTp.Flag) { |
|||
name = "int64" |
|||
} else { |
|||
name = "int64" |
|||
} |
|||
case mysql.TypeLonglong: |
|||
if mysql.HasUnsignedFlag(colTp.Flag) { |
|||
name = "int64" |
|||
} else { |
|||
name = "int64" |
|||
} |
|||
case mysql.TypeFloat, mysql.TypeDouble: |
|||
name = "float64" |
|||
case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, |
|||
mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: |
|||
name = "string" |
|||
case mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDate: |
|||
name = "int64" |
|||
case mysql.TypeDecimal, mysql.TypeNewDecimal: |
|||
name = "float64" |
|||
case mysql.TypeJSON: |
|||
name = "string" |
|||
case mysql.TypeEnum: |
|||
name = "string" |
|||
default: |
|||
return "UnSupport", "" |
|||
} |
|||
if style == NullInPointer { |
|||
name = "*" + name |
|||
} |
|||
} |
|||
return |
|||
} |
|||
|
|||
func makeTagStr(tags []string) string { |
|||
builder := strings.Builder{} |
|||
for i := 0; i < len(tags)/2; i++ { |
|||
builder.WriteString(tags[i*2]) |
|||
builder.WriteString(`:"`) |
|||
builder.WriteString(tags[i*2+1]) |
|||
builder.WriteString(`" `) |
|||
} |
|||
if builder.Len() > 0 { |
|||
return builder.String()[:builder.Len()-1] |
|||
} |
|||
return builder.String() |
|||
} |
|||
|
|||
func getDefaultValue(expr ast.ExprNode) (value string) { |
|||
if expr.GetDatum().Kind() != types.KindNull { |
|||
value = fmt.Sprintf("%v", expr.GetDatum().GetValue()) |
|||
} else if expr.GetFlag() != ast.FlagConstant { |
|||
if expr.GetFlag() == ast.FlagHasFunc { |
|||
if funcExpr, ok := expr.(*ast.FuncCallExpr); ok { |
|||
value = funcExpr.FnName.O |
|||
} |
|||
} |
|||
} |
|||
return |
|||
} |
|||
|
|||
func initTemplate() { |
|||
tmplParseOnce.Do(func() { |
|||
var err error |
|||
structTmpl, err = template.New("goStruct").Parse(structTmplRaw) |
|||
if err != nil { |
|||
panic(err) |
|||
} |
|||
fileTmpl, err = template.New("goFile").Parse(fileTmplRaw) |
|||
if err != nil { |
|||
panic(err) |
|||
} |
|||
}) |
|||
} |
|||
|
|||
func init() { |
|||
structTmplRaw = ` |
|||
{{- if .Comment -}} |
|||
// {{.Comment}}
|
|||
{{end -}} |
|||
type {{.TableName}} struct { |
|||
{{- range .Fields}} |
|||
{{.Name}} {{.GoType}} {{if .Tag}}` + "`{{.Tag}}`" + `{{end}}{{if .Comment}} // {{.Comment}}{{end}}
|
|||
{{- end}} |
|||
} |
|||
{{if .NameFunc}} |
|||
func (m *{{.TableName}}) TableName() string { |
|||
return "{{.RawTableName}}" |
|||
} |
|||
{{end}}` |
|||
fileTmplRaw = ` |
|||
package {{.Package}} |
|||
{{if .ImportPath}} |
|||
import ( |
|||
{{- range .ImportPath}} |
|||
"{{.}}" |
|||
{{- end}} |
|||
) |
|||
{{- end}} |
|||
{{range .StructCode}} |
|||
{{.}} |
|||
{{end}} |
|||
` |
|||
} |
|||
Write
Preview
Loading…
Cancel
Save
Reference in new issue