package main import ( "fmt" "go/format" "io" "io/fs" "net/http" "os/exec" "runtime" "sort" "strings" "sync" "text/template" "time" "unicode" "github.com/gin-contrib/cors" "github.com/gin-gonic/gin" "github.com/iancoleman/strcase" "github.com/jinzhu/inflection" "github.com/knocknote/vitess-sqlparser/tidbparser/ast" "github.com/knocknote/vitess-sqlparser/tidbparser/dependency/mysql" "github.com/knocknote/vitess-sqlparser/tidbparser/dependency/types" "github.com/knocknote/vitess-sqlparser/tidbparser/parser" "github.com/pkg/errors" ) type NullStyle int const ( NullDisable NullStyle = iota NullInSql NullInPointer ) type Option func(*options) type options struct { Charset string Collation string JsonTag bool ZhTag bool TablePrefix string ColumnPrefix string NoNullType bool NullStyle NullStyle Package string GormType bool ForceTableName bool Camel bool // 是否json字段驼峰 } var defaultOptions = options{ NullStyle: NullInSql, Package: "model", } var Open = "" func OpenBrowser(url string) { var err error switch runtime.GOOS { case "linux": err = exec.Command("xdg-open", url).Start() case "windows": err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start() case "darwin": err = exec.Command("open", url).Start() } if err != nil { fmt.Println(err) } } func WithCharset(charset string) Option { return func(o *options) { o.Charset = charset } } func WithCollation(collation string) Option { return func(o *options) { o.Collation = collation } } func WithTablePrefix(p string) Option { return func(o *options) { o.TablePrefix = p } } func WithColumnPrefix(p string) Option { return func(o *options) { o.ColumnPrefix = p } } func WithJsonTag() Option { return func(o *options) { o.JsonTag = true } } func WithZhTag() Option { return func(o *options) { o.ZhTag = true } } func WithNoNullType() Option { return func(o *options) { o.NoNullType = true } } func WithNullStyle(s NullStyle) Option { return func(o *options) { o.NullStyle = s } } func WithPackage(pkg string) Option { return func(o *options) { o.Package = pkg } } func WithCamel() Option { return func(o *options) { o.Camel = true } } // WithGormType will write type in gorm tag func WithGormType() Option { return func(o *options) { o.GormType = true } } func WithForceTableName() Option { return func(o *options) { o.ForceTableName = true } } func parseOption(options []Option) options { o := defaultOptions for _, f := range options { f(&o) } if o.NoNullType { o.NullStyle = NullDisable } return o } var ( structTmplRaw string fileTmplRaw string structTmpl *template.Template fileTmpl *template.Template tmplParseOnce sync.Once ) type ModelCodes struct { Package string ImportPath []string StructCode []string } func ParseSql(sql string, options ...Option) (*ModelCodes, error) { initTemplate() opt := parseOption(options) stmts, err := parser.New().Parse(sql, opt.Charset, opt.Collation) if err != nil { return nil, err } tableStr := make([]string, 0, len(stmts)) importPath := make(map[string]struct{}) for _, stmt := range stmts { if ct, ok := stmt.(*ast.CreateTableStmt); ok { s, ipt, err := makeCode(ct, opt) if err != nil { return nil, err } tableStr = append(tableStr, s) for _, s := range ipt { importPath[s] = struct{}{} } } } importPathArr := make([]string, 0, len(importPath)) for s := range importPath { importPathArr = append(importPathArr, s) } sort.Strings(importPathArr) return &ModelCodes{ Package: opt.Package, ImportPath: importPathArr, StructCode: tableStr, }, nil } func ParseSqlToWrite(sql string, writer io.Writer, options ...Option) error { data, err := ParseSql(sql, options...) if err != nil { return err } err = fileTmpl.Execute(writer, data) if err != nil { return err } return nil } func ParseSqlFormat(sql string, options ...Option) ([]byte, error) { w := strings.Builder{} err := ParseSqlToWrite(sql, &w, options...) if err != nil { return nil, err } return format.Source([]byte(w.String())) } type tmplData struct { TableName string NameFunc bool RawTableName string Fields []tmplField Comment string } type tmplField struct { Name string GoType string Tag string Comment string } // 下划线写法转为小驼峰写法 func Case2Camel(name string) string { name = strings.Replace(name, "_", " ", -1) name = strings.Title(name) return Lcfirst(strings.Replace(name, " ", "", -1)) } // 首字母大写 func Ucfirst(str string) string { for i, v := range str { return string(unicode.ToUpper(v)) + str[i+1:] } return "" } // 首字母小写 func Lcfirst(str string) string { for i, v := range str { return string(unicode.ToLower(v)) + str[i+1:] } return "" } func makeCode(stmt *ast.CreateTableStmt, opt options) (string, []string, error) { importPath := make([]string, 0, 1) data := tmplData{ TableName: stmt.Table.Name.String(), RawTableName: stmt.Table.Name.String(), Fields: make([]tmplField, 0, 1), } tablePrefix := opt.TablePrefix if tablePrefix != "" && strings.HasPrefix(data.TableName, tablePrefix) { data.NameFunc = true data.TableName = data.TableName[len(tablePrefix):] } if opt.ForceTableName || data.RawTableName != inflection.Plural(data.RawTableName) { data.NameFunc = true } data.TableName = strcase.ToCamel(data.TableName) // find table comment for _, opt := range stmt.Options { if opt.Tp == ast.TableOptionComment { data.Comment = opt.StrValue break } } isPrimaryKey := make(map[string]bool) for _, con := range stmt.Constraints { if con.Tp == ast.ConstraintPrimaryKey { isPrimaryKey[con.Keys[0].Column.String()] = true } } columnPrefix := opt.ColumnPrefix for _, col := range stmt.Cols { colName := col.Name.Name.String() goFieldName := colName if columnPrefix != "" && strings.HasPrefix(goFieldName, columnPrefix) { goFieldName = goFieldName[len(columnPrefix):] } field := tmplField{ Name: strcase.ToCamel(goFieldName), } tags := make([]string, 0, 4) // make GORM's tag gormTag := strings.Builder{} gormTag.WriteString("column:") gormTag.WriteString(colName) if opt.GormType { gormTag.WriteString(";type:") gormTag.WriteString(col.Tp.InfoSchemaStr()) } if isPrimaryKey[colName] { gormTag.WriteString(";primary_key") } isNotNull := false canNull := false for _, o := range col.Options { switch o.Tp { case ast.ColumnOptionPrimaryKey: if !isPrimaryKey[colName] { gormTag.WriteString(";primary_key") isPrimaryKey[colName] = true } case ast.ColumnOptionNotNull: isNotNull = true case ast.ColumnOptionAutoIncrement: gormTag.WriteString(";AUTO_INCREMENT") case ast.ColumnOptionDefaultValue: if value := getDefaultValue(o.Expr); value != "" { gormTag.WriteString(";default:") gormTag.WriteString(value) } case ast.ColumnOptionUniqKey: gormTag.WriteString(";unique") case ast.ColumnOptionNull: //gormTag.WriteString(";NULL") canNull = true case ast.ColumnOptionOnUpdate: // For Timestamp and Datetime only. case ast.ColumnOptionFulltext: case ast.ColumnOptionComment: field.Comment = o.Expr.GetDatum().GetString() default: //return "", nil, errors.Errorf(" unsupport option %d\n", o.Tp) } } if !isPrimaryKey[colName] && isNotNull { gormTag.WriteString(";NOT NULL") } tags = append(tags, "gorm", gormTag.String()) if opt.JsonTag { if opt.Camel { tags = append(tags, "json", Case2Camel(colName)) } else { tags = append(tags, "json", colName) } } if opt.ZhTag { tags = append(tags, "zh-cn", field.Comment) } field.Tag = makeTagStr(tags) // get type in golang nullStyle := opt.NullStyle if !canNull { nullStyle = NullDisable } goType, pkg := mysqlToGoType(col.Tp, nullStyle) if pkg != "" { importPath = append(importPath, pkg) } field.GoType = goType data.Fields = append(data.Fields, field) } builder := strings.Builder{} err := structTmpl.Execute(&builder, data) if err != nil { return "", nil, err } code, err := format.Source([]byte(builder.String())) if err != nil { return string(code), importPath, errors.WithMessage(err, "format golang code error") } return string(code), importPath, nil } func mysqlToGoType(colTp *types.FieldType, style NullStyle) (name string, path string) { if style == NullInSql { path = "database/sql" switch colTp.Tp { case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong: name = "int64" case mysql.TypeLonglong: name = "int64" case mysql.TypeFloat, mysql.TypeDouble: name = "float64" case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: name = "string" case mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDate: name = "int64" case mysql.TypeDecimal, mysql.TypeNewDecimal: name = "float64" case mysql.TypeJSON, mysql.TypeEnum: name = "string" default: return "UnSupport", "" } } else { switch colTp.Tp { case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong: if mysql.HasUnsignedFlag(colTp.Flag) { name = "int64" } else { name = "int64" } case mysql.TypeLonglong: if mysql.HasUnsignedFlag(colTp.Flag) { name = "int64" } else { name = "int64" } case mysql.TypeFloat, mysql.TypeDouble: name = "float64" case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: name = "string" case mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDate: name = "int64" case mysql.TypeDecimal, mysql.TypeNewDecimal: name = "float64" case mysql.TypeJSON: name = "string" case mysql.TypeEnum: name = "string" default: return "UnSupport", "" } if style == NullInPointer { name = "*" + name } } return } func makeTagStr(tags []string) string { builder := strings.Builder{} for i := 0; i < len(tags)/2; i++ { builder.WriteString(tags[i*2]) builder.WriteString(`:"`) builder.WriteString(tags[i*2+1]) builder.WriteString(`" `) } if builder.Len() > 0 { return builder.String()[:builder.Len()-1] } return builder.String() } func getDefaultValue(expr ast.ExprNode) (value string) { if expr.GetDatum().Kind() != types.KindNull { value = fmt.Sprintf("%v", expr.GetDatum().GetValue()) } else if expr.GetFlag() != ast.FlagConstant { if expr.GetFlag() == ast.FlagHasFunc { if funcExpr, ok := expr.(*ast.FuncCallExpr); ok { value = funcExpr.FnName.O } } } return } func initTemplate() { tmplParseOnce.Do(func() { var err error structTmpl, err = template.New("goStruct").Parse(structTmplRaw) if err != nil { panic(err) } fileTmpl, err = template.New("goFile").Parse(fileTmplRaw) if err != nil { panic(err) } }) } func init() { structTmplRaw = ` {{- if .Comment -}} // {{.Comment}} {{end -}} type {{.TableName}} struct { {{- range .Fields}} {{.Name}} {{.GoType}} {{if .Tag}}` + "`{{.Tag}}`" + `{{end}}{{if .Comment}} // {{.Comment}}{{end}} {{- end}} } {{if .NameFunc}} func (m *{{.TableName}}) TableName() string { return "{{.RawTableName}}" } {{end}}` fileTmplRaw = ` package {{.Package}} {{if .ImportPath}} import ( {{- range .ImportPath}} "{{.}}" {{- end}} ) {{- end}} {{range .StructCode}} {{.}} {{end}} ` } func main() { r := gin.Default() r.Use(Cors()) r.POST("/sql", func(c *gin.Context) { var req struct { Content string `json:"content"` Camel bool `json:"camel"` TablePrefix string `json:"table_prefix"` // 去除表前缀 } err := c.BindJSON(&req) if err != nil { c.String(http.StatusBadRequest, err.Error()) return } opts := []Option{} opts = append(opts, WithGormType()) opts = append(opts, WithJsonTag()) opts = append(opts, WithZhTag()) opts = append(opts, WithTablePrefix(req.TablePrefix)) if req.Camel { opts = append(opts, WithCamel()) } res, err := ParseSqlFormat(req.Content, opts..., ) if err != nil { c.String(http.StatusInternalServerError, err.Error()) return } c.String(http.StatusOK, string(res)) }) // static 嵌入 r.GET("/", func(c *gin.Context) { indexHTML, _ := dist.ReadFile("dist/index.html") c.Writer.Write(indexHTML) }) static, _ := fs.Sub(dist, "dist/static") r.StaticFS("/static", http.FS(static)) if len(Open) == 0 { OpenBrowser("http://localhost:8333") } r.Run(":8333") } func Cors() gin.HandlerFunc { return cors.New(cors.Config{ AllowAllOrigins: true, AllowCredentials: true, AllowHeaders: []string{"*"}, MaxAge: time.Second * time.Duration(7200), }) }