5 changed files with 465 additions and 474 deletions
-
2Dockerfile
-
2go.mod
-
464main.go
-
108option.go
-
363parser.go
@ -1,108 +0,0 @@ |
|||||
package main |
|
||||
|
|
||||
type NullStyle int |
|
||||
|
|
||||
const ( |
|
||||
NullDisable NullStyle = iota |
|
||||
NullInSql |
|
||||
NullInPointer |
|
||||
) |
|
||||
|
|
||||
type Option func(*options) |
|
||||
|
|
||||
type options struct { |
|
||||
Charset string |
|
||||
Collation string |
|
||||
JsonTag bool |
|
||||
ZhTag bool |
|
||||
TablePrefix string |
|
||||
ColumnPrefix string |
|
||||
NoNullType bool |
|
||||
NullStyle NullStyle |
|
||||
Package string |
|
||||
GormType bool |
|
||||
ForceTableName bool |
|
||||
} |
|
||||
|
|
||||
var defaultOptions = options{ |
|
||||
NullStyle: NullInSql, |
|
||||
Package: "model", |
|
||||
} |
|
||||
|
|
||||
func WithCharset(charset string) Option { |
|
||||
return func(o *options) { |
|
||||
o.Charset = charset |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
func WithCollation(collation string) Option { |
|
||||
return func(o *options) { |
|
||||
o.Collation = collation |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
func WithTablePrefix(p string) Option { |
|
||||
return func(o *options) { |
|
||||
o.TablePrefix = p |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
func WithColumnPrefix(p string) Option { |
|
||||
return func(o *options) { |
|
||||
o.ColumnPrefix = p |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
func WithJsonTag() Option { |
|
||||
return func(o *options) { |
|
||||
o.JsonTag = true |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
func WithZhTag() Option { |
|
||||
return func(o *options) { |
|
||||
o.ZhTag = true |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
func WithNoNullType() Option { |
|
||||
return func(o *options) { |
|
||||
o.NoNullType = true |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
func WithNullStyle(s NullStyle) Option { |
|
||||
return func(o *options) { |
|
||||
o.NullStyle = s |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
func WithPackage(pkg string) Option { |
|
||||
return func(o *options) { |
|
||||
o.Package = pkg |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
// WithGormType will write type in gorm tag
|
|
||||
func WithGormType() Option { |
|
||||
return func(o *options) { |
|
||||
o.GormType = true |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
func WithForceTableName() Option { |
|
||||
return func(o *options) { |
|
||||
o.ForceTableName = true |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
func parseOption(options []Option) options { |
|
||||
o := defaultOptions |
|
||||
for _, f := range options { |
|
||||
f(&o) |
|
||||
} |
|
||||
if o.NoNullType { |
|
||||
o.NullStyle = NullDisable |
|
||||
} |
|
||||
return o |
|
||||
} |
|
||||
@ -1,363 +0,0 @@ |
|||||
package main |
|
||||
|
|
||||
// restruct from https://github.com/miaogaolin/gotl/blob/main/common/sql2gorm/parser/parser.go
|
|
||||
import ( |
|
||||
"fmt" |
|
||||
"go/format" |
|
||||
"io" |
|
||||
"sort" |
|
||||
"strings" |
|
||||
"sync" |
|
||||
"text/template" |
|
||||
|
|
||||
"github.com/iancoleman/strcase" |
|
||||
"github.com/jinzhu/inflection" |
|
||||
"github.com/knocknote/vitess-sqlparser/tidbparser/ast" |
|
||||
"github.com/knocknote/vitess-sqlparser/tidbparser/dependency/mysql" |
|
||||
"github.com/knocknote/vitess-sqlparser/tidbparser/dependency/types" |
|
||||
"github.com/knocknote/vitess-sqlparser/tidbparser/parser" |
|
||||
"github.com/pkg/errors" |
|
||||
) |
|
||||
|
|
||||
var ( |
|
||||
structTmplRaw string |
|
||||
fileTmplRaw string |
|
||||
structTmpl *template.Template |
|
||||
fileTmpl *template.Template |
|
||||
tmplParseOnce sync.Once |
|
||||
) |
|
||||
|
|
||||
type ModelCodes struct { |
|
||||
Package string |
|
||||
ImportPath []string |
|
||||
StructCode []string |
|
||||
} |
|
||||
|
|
||||
func ParseSql(sql string, options ...Option) (*ModelCodes, error) { |
|
||||
initTemplate() |
|
||||
opt := parseOption(options) |
|
||||
|
|
||||
stmts, err := parser.New().Parse(sql, opt.Charset, opt.Collation) |
|
||||
if err != nil { |
|
||||
return nil, err |
|
||||
} |
|
||||
tableStr := make([]string, 0, len(stmts)) |
|
||||
importPath := make(map[string]struct{}) |
|
||||
for _, stmt := range stmts { |
|
||||
if ct, ok := stmt.(*ast.CreateTableStmt); ok { |
|
||||
s, ipt, err := makeCode(ct, opt) |
|
||||
if err != nil { |
|
||||
return nil, err |
|
||||
} |
|
||||
tableStr = append(tableStr, s) |
|
||||
for _, s := range ipt { |
|
||||
importPath[s] = struct{}{} |
|
||||
} |
|
||||
} |
|
||||
} |
|
||||
importPathArr := make([]string, 0, len(importPath)) |
|
||||
for s := range importPath { |
|
||||
importPathArr = append(importPathArr, s) |
|
||||
} |
|
||||
sort.Strings(importPathArr) |
|
||||
return &ModelCodes{ |
|
||||
Package: opt.Package, |
|
||||
ImportPath: importPathArr, |
|
||||
StructCode: tableStr, |
|
||||
}, nil |
|
||||
} |
|
||||
|
|
||||
func ParseSqlToWrite(sql string, writer io.Writer, options ...Option) error { |
|
||||
data, err := ParseSql(sql, options...) |
|
||||
if err != nil { |
|
||||
return err |
|
||||
} |
|
||||
err = fileTmpl.Execute(writer, data) |
|
||||
if err != nil { |
|
||||
return err |
|
||||
} |
|
||||
|
|
||||
return nil |
|
||||
} |
|
||||
|
|
||||
func ParseSqlFormat(sql string, options ...Option) ([]byte, error) { |
|
||||
w := strings.Builder{} |
|
||||
err := ParseSqlToWrite(sql, &w, options...) |
|
||||
if err != nil { |
|
||||
return nil, err |
|
||||
} |
|
||||
return format.Source([]byte(w.String())) |
|
||||
} |
|
||||
|
|
||||
type tmplData struct { |
|
||||
TableName string |
|
||||
NameFunc bool |
|
||||
RawTableName string |
|
||||
Fields []tmplField |
|
||||
Comment string |
|
||||
} |
|
||||
|
|
||||
type tmplField struct { |
|
||||
Name string |
|
||||
GoType string |
|
||||
Tag string |
|
||||
Comment string |
|
||||
} |
|
||||
|
|
||||
func makeCode(stmt *ast.CreateTableStmt, opt options) (string, []string, error) { |
|
||||
importPath := make([]string, 0, 1) |
|
||||
data := tmplData{ |
|
||||
TableName: stmt.Table.Name.String(), |
|
||||
RawTableName: stmt.Table.Name.String(), |
|
||||
Fields: make([]tmplField, 0, 1), |
|
||||
} |
|
||||
tablePrefix := opt.TablePrefix |
|
||||
if tablePrefix != "" && strings.HasPrefix(data.TableName, tablePrefix) { |
|
||||
data.NameFunc = true |
|
||||
data.TableName = data.TableName[len(tablePrefix):] |
|
||||
} |
|
||||
if opt.ForceTableName || data.RawTableName != inflection.Plural(data.RawTableName) { |
|
||||
data.NameFunc = true |
|
||||
} |
|
||||
|
|
||||
data.TableName = strcase.ToCamel(data.TableName) |
|
||||
|
|
||||
// find table comment
|
|
||||
for _, opt := range stmt.Options { |
|
||||
if opt.Tp == ast.TableOptionComment { |
|
||||
data.Comment = opt.StrValue |
|
||||
break |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
isPrimaryKey := make(map[string]bool) |
|
||||
for _, con := range stmt.Constraints { |
|
||||
if con.Tp == ast.ConstraintPrimaryKey { |
|
||||
isPrimaryKey[con.Keys[0].Column.String()] = true |
|
||||
} |
|
||||
} |
|
||||
|
|
||||
columnPrefix := opt.ColumnPrefix |
|
||||
for _, col := range stmt.Cols { |
|
||||
colName := col.Name.Name.String() |
|
||||
goFieldName := colName |
|
||||
if columnPrefix != "" && strings.HasPrefix(goFieldName, columnPrefix) { |
|
||||
goFieldName = goFieldName[len(columnPrefix):] |
|
||||
} |
|
||||
|
|
||||
field := tmplField{ |
|
||||
Name: strcase.ToCamel(goFieldName), |
|
||||
} |
|
||||
|
|
||||
tags := make([]string, 0, 4) |
|
||||
// make GORM's tag
|
|
||||
gormTag := strings.Builder{} |
|
||||
gormTag.WriteString("column:") |
|
||||
gormTag.WriteString(colName) |
|
||||
if opt.GormType { |
|
||||
gormTag.WriteString(";type:") |
|
||||
gormTag.WriteString(col.Tp.InfoSchemaStr()) |
|
||||
} |
|
||||
if isPrimaryKey[colName] { |
|
||||
gormTag.WriteString(";primary_key") |
|
||||
} |
|
||||
isNotNull := false |
|
||||
canNull := false |
|
||||
for _, o := range col.Options { |
|
||||
switch o.Tp { |
|
||||
case ast.ColumnOptionPrimaryKey: |
|
||||
if !isPrimaryKey[colName] { |
|
||||
gormTag.WriteString(";primary_key") |
|
||||
isPrimaryKey[colName] = true |
|
||||
} |
|
||||
case ast.ColumnOptionNotNull: |
|
||||
isNotNull = true |
|
||||
case ast.ColumnOptionAutoIncrement: |
|
||||
gormTag.WriteString(";AUTO_INCREMENT") |
|
||||
case ast.ColumnOptionDefaultValue: |
|
||||
if value := getDefaultValue(o.Expr); value != "" { |
|
||||
gormTag.WriteString(";default:") |
|
||||
gormTag.WriteString(value) |
|
||||
} |
|
||||
case ast.ColumnOptionUniqKey: |
|
||||
gormTag.WriteString(";unique") |
|
||||
case ast.ColumnOptionNull: |
|
||||
//gormTag.WriteString(";NULL")
|
|
||||
canNull = true |
|
||||
case ast.ColumnOptionOnUpdate: // For Timestamp and Datetime only.
|
|
||||
case ast.ColumnOptionFulltext: |
|
||||
case ast.ColumnOptionComment: |
|
||||
field.Comment = o.Expr.GetDatum().GetString() |
|
||||
default: |
|
||||
//return "", nil, errors.Errorf(" unsupport option %d\n", o.Tp)
|
|
||||
} |
|
||||
} |
|
||||
if !isPrimaryKey[colName] && isNotNull { |
|
||||
gormTag.WriteString(";NOT NULL") |
|
||||
} |
|
||||
tags = append(tags, "gorm", gormTag.String()) |
|
||||
|
|
||||
if opt.JsonTag { |
|
||||
tags = append(tags, "json", colName) |
|
||||
} |
|
||||
|
|
||||
if opt.ZhTag { |
|
||||
tags = append(tags, "zh-cn", field.Comment) |
|
||||
} |
|
||||
field.Tag = makeTagStr(tags) |
|
||||
|
|
||||
// get type in golang
|
|
||||
nullStyle := opt.NullStyle |
|
||||
if !canNull { |
|
||||
nullStyle = NullDisable |
|
||||
} |
|
||||
goType, pkg := mysqlToGoType(col.Tp, nullStyle) |
|
||||
if pkg != "" { |
|
||||
importPath = append(importPath, pkg) |
|
||||
} |
|
||||
field.GoType = goType |
|
||||
|
|
||||
data.Fields = append(data.Fields, field) |
|
||||
} |
|
||||
|
|
||||
builder := strings.Builder{} |
|
||||
err := structTmpl.Execute(&builder, data) |
|
||||
if err != nil { |
|
||||
return "", nil, err |
|
||||
} |
|
||||
code, err := format.Source([]byte(builder.String())) |
|
||||
if err != nil { |
|
||||
return string(code), importPath, errors.WithMessage(err, "format golang code error") |
|
||||
} |
|
||||
return string(code), importPath, nil |
|
||||
} |
|
||||
|
|
||||
func mysqlToGoType(colTp *types.FieldType, style NullStyle) (name string, path string) { |
|
||||
if style == NullInSql { |
|
||||
path = "database/sql" |
|
||||
switch colTp.Tp { |
|
||||
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong: |
|
||||
name = "int64" |
|
||||
case mysql.TypeLonglong: |
|
||||
name = "int64" |
|
||||
case mysql.TypeFloat, mysql.TypeDouble: |
|
||||
name = "float64" |
|
||||
case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, |
|
||||
mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: |
|
||||
name = "string" |
|
||||
case mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDate: |
|
||||
name = "int64" |
|
||||
case mysql.TypeDecimal, mysql.TypeNewDecimal: |
|
||||
name = "float64" |
|
||||
case mysql.TypeJSON, mysql.TypeEnum: |
|
||||
name = "string" |
|
||||
default: |
|
||||
return "UnSupport", "" |
|
||||
} |
|
||||
} else { |
|
||||
switch colTp.Tp { |
|
||||
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong: |
|
||||
if mysql.HasUnsignedFlag(colTp.Flag) { |
|
||||
name = "int64" |
|
||||
} else { |
|
||||
name = "int64" |
|
||||
} |
|
||||
case mysql.TypeLonglong: |
|
||||
if mysql.HasUnsignedFlag(colTp.Flag) { |
|
||||
name = "int64" |
|
||||
} else { |
|
||||
name = "int64" |
|
||||
} |
|
||||
case mysql.TypeFloat, mysql.TypeDouble: |
|
||||
name = "float64" |
|
||||
case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, |
|
||||
mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: |
|
||||
name = "string" |
|
||||
case mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDate: |
|
||||
name = "int64" |
|
||||
case mysql.TypeDecimal, mysql.TypeNewDecimal: |
|
||||
name = "float64" |
|
||||
case mysql.TypeJSON: |
|
||||
name = "string" |
|
||||
case mysql.TypeEnum: |
|
||||
name = "string" |
|
||||
default: |
|
||||
return "UnSupport", "" |
|
||||
} |
|
||||
if style == NullInPointer { |
|
||||
name = "*" + name |
|
||||
} |
|
||||
} |
|
||||
return |
|
||||
} |
|
||||
|
|
||||
func makeTagStr(tags []string) string { |
|
||||
builder := strings.Builder{} |
|
||||
for i := 0; i < len(tags)/2; i++ { |
|
||||
builder.WriteString(tags[i*2]) |
|
||||
builder.WriteString(`:"`) |
|
||||
builder.WriteString(tags[i*2+1]) |
|
||||
builder.WriteString(`" `) |
|
||||
} |
|
||||
if builder.Len() > 0 { |
|
||||
return builder.String()[:builder.Len()-1] |
|
||||
} |
|
||||
return builder.String() |
|
||||
} |
|
||||
|
|
||||
func getDefaultValue(expr ast.ExprNode) (value string) { |
|
||||
if expr.GetDatum().Kind() != types.KindNull { |
|
||||
value = fmt.Sprintf("%v", expr.GetDatum().GetValue()) |
|
||||
} else if expr.GetFlag() != ast.FlagConstant { |
|
||||
if expr.GetFlag() == ast.FlagHasFunc { |
|
||||
if funcExpr, ok := expr.(*ast.FuncCallExpr); ok { |
|
||||
value = funcExpr.FnName.O |
|
||||
} |
|
||||
} |
|
||||
} |
|
||||
return |
|
||||
} |
|
||||
|
|
||||
func initTemplate() { |
|
||||
tmplParseOnce.Do(func() { |
|
||||
var err error |
|
||||
structTmpl, err = template.New("goStruct").Parse(structTmplRaw) |
|
||||
if err != nil { |
|
||||
panic(err) |
|
||||
} |
|
||||
fileTmpl, err = template.New("goFile").Parse(fileTmplRaw) |
|
||||
if err != nil { |
|
||||
panic(err) |
|
||||
} |
|
||||
}) |
|
||||
} |
|
||||
|
|
||||
func init() { |
|
||||
structTmplRaw = ` |
|
||||
{{- if .Comment -}} |
|
||||
// {{.Comment}}
|
|
||||
{{end -}} |
|
||||
type {{.TableName}} struct { |
|
||||
{{- range .Fields}} |
|
||||
{{.Name}} {{.GoType}} {{if .Tag}}` + "`{{.Tag}}`" + `{{end}}{{if .Comment}} // {{.Comment}}{{end}}
|
|
||||
{{- end}} |
|
||||
} |
|
||||
{{if .NameFunc}} |
|
||||
func (m *{{.TableName}}) TableName() string { |
|
||||
return "{{.RawTableName}}" |
|
||||
} |
|
||||
{{end}}` |
|
||||
fileTmplRaw = ` |
|
||||
package {{.Package}} |
|
||||
{{if .ImportPath}} |
|
||||
import ( |
|
||||
{{- range .ImportPath}} |
|
||||
"{{.}}" |
|
||||
{{- end}} |
|
||||
) |
|
||||
{{- end}} |
|
||||
{{range .StructCode}} |
|
||||
{{.}} |
|
||||
{{end}} |
|
||||
` |
|
||||
} |
|
||||
Write
Preview
Loading…
Cancel
Save
Reference in new issue