You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

584 lines
13 KiB

package main
import (
"fmt"
"go/format"
"io"
"io/fs"
"net/http"
"os/exec"
"runtime"
"sort"
"strings"
"sync"
"text/template"
"time"
"unicode"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"github.com/iancoleman/strcase"
"github.com/jinzhu/inflection"
"github.com/knocknote/vitess-sqlparser/tidbparser/ast"
"github.com/knocknote/vitess-sqlparser/tidbparser/dependency/mysql"
"github.com/knocknote/vitess-sqlparser/tidbparser/dependency/types"
"github.com/knocknote/vitess-sqlparser/tidbparser/parser"
"github.com/pkg/errors"
)
type NullStyle int
const (
NullDisable NullStyle = iota
NullInSql
NullInPointer
)
type Option func(*options)
type options struct {
Charset string
Collation string
JsonTag bool
ZhTag bool
TablePrefix string
ColumnPrefix string
NoNullType bool
NullStyle NullStyle
Package string
GormType bool
ForceTableName bool
Camel bool // 是否json字段驼峰
}
var defaultOptions = options{
NullStyle: NullInSql,
Package: "model",
}
var Open = ""
func OpenBrowser(url string) {
var err error
switch runtime.GOOS {
case "linux":
err = exec.Command("xdg-open", url).Start()
case "windows":
err = exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
case "darwin":
err = exec.Command("open", url).Start()
}
if err != nil {
fmt.Println(err)
}
}
func WithCharset(charset string) Option {
return func(o *options) {
o.Charset = charset
}
}
func WithCollation(collation string) Option {
return func(o *options) {
o.Collation = collation
}
}
func WithTablePrefix(p string) Option {
return func(o *options) {
o.TablePrefix = p
}
}
func WithColumnPrefix(p string) Option {
return func(o *options) {
o.ColumnPrefix = p
}
}
func WithJsonTag() Option {
return func(o *options) {
o.JsonTag = true
}
}
func WithZhTag() Option {
return func(o *options) {
o.ZhTag = true
}
}
func WithNoNullType() Option {
return func(o *options) {
o.NoNullType = true
}
}
func WithNullStyle(s NullStyle) Option {
return func(o *options) {
o.NullStyle = s
}
}
func WithPackage(pkg string) Option {
return func(o *options) {
o.Package = pkg
}
}
func WithCamel() Option {
return func(o *options) {
o.Camel = true
}
}
// WithGormType will write type in gorm tag
func WithGormType() Option {
return func(o *options) {
o.GormType = true
}
}
func WithForceTableName() Option {
return func(o *options) {
o.ForceTableName = true
}
}
func parseOption(options []Option) options {
o := defaultOptions
for _, f := range options {
f(&o)
}
if o.NoNullType {
o.NullStyle = NullDisable
}
return o
}
var (
structTmplRaw string
fileTmplRaw string
structTmpl *template.Template
fileTmpl *template.Template
tmplParseOnce sync.Once
)
type ModelCodes struct {
Package string
ImportPath []string
StructCode []string
}
func ParseSql(sql string, options ...Option) (*ModelCodes, error) {
initTemplate()
opt := parseOption(options)
stmts, err := parser.New().Parse(sql, opt.Charset, opt.Collation)
if err != nil {
return nil, err
}
tableStr := make([]string, 0, len(stmts))
importPath := make(map[string]struct{})
for _, stmt := range stmts {
if ct, ok := stmt.(*ast.CreateTableStmt); ok {
s, ipt, err := makeCode(ct, opt)
if err != nil {
return nil, err
}
tableStr = append(tableStr, s)
for _, s := range ipt {
importPath[s] = struct{}{}
}
}
}
importPathArr := make([]string, 0, len(importPath))
for s := range importPath {
importPathArr = append(importPathArr, s)
}
sort.Strings(importPathArr)
return &ModelCodes{
Package: opt.Package,
ImportPath: importPathArr,
StructCode: tableStr,
}, nil
}
func ParseSqlToWrite(sql string, writer io.Writer, options ...Option) error {
data, err := ParseSql(sql, options...)
if err != nil {
return err
}
err = fileTmpl.Execute(writer, data)
if err != nil {
return err
}
return nil
}
func ParseSqlFormat(sql string, options ...Option) ([]byte, error) {
w := strings.Builder{}
err := ParseSqlToWrite(sql, &w, options...)
if err != nil {
return nil, err
}
return format.Source([]byte(w.String()))
}
type tmplData struct {
TableName string
NameFunc bool
RawTableName string
Fields []tmplField
Comment string
}
type tmplField struct {
Name string
GoType string
Tag string
Comment string
}
// 下划线写法转为小驼峰写法
func Case2Camel(name string) string {
name = strings.Replace(name, "_", " ", -1)
name = strings.Title(name)
return Lcfirst(strings.Replace(name, " ", "", -1))
}
// 首字母大写
func Ucfirst(str string) string {
for i, v := range str {
return string(unicode.ToUpper(v)) + str[i+1:]
}
return ""
}
// 首字母小写
func Lcfirst(str string) string {
for i, v := range str {
return string(unicode.ToLower(v)) + str[i+1:]
}
return ""
}
func makeCode(stmt *ast.CreateTableStmt, opt options) (string, []string, error) {
importPath := make([]string, 0, 1)
data := tmplData{
TableName: stmt.Table.Name.String(),
RawTableName: stmt.Table.Name.String(),
Fields: make([]tmplField, 0, 1),
}
tablePrefix := opt.TablePrefix
if tablePrefix != "" && strings.HasPrefix(data.TableName, tablePrefix) {
data.NameFunc = true
data.TableName = data.TableName[len(tablePrefix):]
}
if opt.ForceTableName || data.RawTableName != inflection.Plural(data.RawTableName) {
data.NameFunc = true
}
data.TableName = strcase.ToCamel(data.TableName)
// find table comment
for _, opt := range stmt.Options {
if opt.Tp == ast.TableOptionComment {
data.Comment = opt.StrValue
break
}
}
isPrimaryKey := make(map[string]bool)
for _, con := range stmt.Constraints {
if con.Tp == ast.ConstraintPrimaryKey {
isPrimaryKey[con.Keys[0].Column.String()] = true
}
}
columnPrefix := opt.ColumnPrefix
for _, col := range stmt.Cols {
colName := col.Name.Name.String()
goFieldName := colName
if columnPrefix != "" && strings.HasPrefix(goFieldName, columnPrefix) {
goFieldName = goFieldName[len(columnPrefix):]
}
field := tmplField{
Name: strcase.ToCamel(goFieldName),
}
tags := make([]string, 0, 4)
// make GORM's tag
gormTag := strings.Builder{}
gormTag.WriteString("column:")
gormTag.WriteString(colName)
if opt.GormType {
gormTag.WriteString(";type:")
gormTag.WriteString(col.Tp.InfoSchemaStr())
}
if isPrimaryKey[colName] {
gormTag.WriteString(";primary_key")
}
isNotNull := false
canNull := false
for _, o := range col.Options {
switch o.Tp {
case ast.ColumnOptionPrimaryKey:
if !isPrimaryKey[colName] {
gormTag.WriteString(";primary_key")
isPrimaryKey[colName] = true
}
case ast.ColumnOptionNotNull:
isNotNull = true
case ast.ColumnOptionAutoIncrement:
gormTag.WriteString(";AUTO_INCREMENT")
case ast.ColumnOptionDefaultValue:
if value := getDefaultValue(o.Expr); value != "" {
gormTag.WriteString(";default:")
gormTag.WriteString(value)
}
case ast.ColumnOptionUniqKey:
gormTag.WriteString(";unique")
case ast.ColumnOptionNull:
//gormTag.WriteString(";NULL")
canNull = true
case ast.ColumnOptionOnUpdate: // For Timestamp and Datetime only.
case ast.ColumnOptionFulltext:
case ast.ColumnOptionComment:
field.Comment = o.Expr.GetDatum().GetString()
default:
//return "", nil, errors.Errorf(" unsupport option %d\n", o.Tp)
}
}
if !isPrimaryKey[colName] && isNotNull {
gormTag.WriteString(";NOT NULL")
}
tags = append(tags, "gorm", gormTag.String())
if opt.JsonTag {
if opt.Camel {
tags = append(tags, "json", Case2Camel(colName))
} else {
tags = append(tags, "json", colName)
}
}
if opt.ZhTag {
tags = append(tags, "zh-cn", field.Comment)
}
field.Tag = makeTagStr(tags)
// get type in golang
nullStyle := opt.NullStyle
if !canNull {
nullStyle = NullDisable
}
goType, pkg := mysqlToGoType(col.Tp, nullStyle)
if pkg != "" {
importPath = append(importPath, pkg)
}
field.GoType = goType
data.Fields = append(data.Fields, field)
}
builder := strings.Builder{}
err := structTmpl.Execute(&builder, data)
if err != nil {
return "", nil, err
}
code, err := format.Source([]byte(builder.String()))
if err != nil {
return string(code), importPath, errors.WithMessage(err, "format golang code error")
}
return string(code), importPath, nil
}
func mysqlToGoType(colTp *types.FieldType, style NullStyle) (name string, path string) {
if style == NullInSql {
path = "database/sql"
switch colTp.Tp {
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong:
name = "int64"
case mysql.TypeLonglong:
name = "int64"
case mysql.TypeFloat, mysql.TypeDouble:
name = "float64"
case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString,
mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob:
name = "string"
case mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDate:
name = "int64"
case mysql.TypeDecimal, mysql.TypeNewDecimal:
name = "float64"
case mysql.TypeJSON, mysql.TypeEnum:
name = "string"
default:
return "UnSupport", ""
}
} else {
switch colTp.Tp {
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong:
if mysql.HasUnsignedFlag(colTp.Flag) {
name = "int64"
} else {
name = "int64"
}
case mysql.TypeLonglong:
if mysql.HasUnsignedFlag(colTp.Flag) {
name = "int64"
} else {
name = "int64"
}
case mysql.TypeFloat, mysql.TypeDouble:
name = "float64"
case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString,
mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob:
name = "string"
case mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDate:
name = "int64"
case mysql.TypeDecimal, mysql.TypeNewDecimal:
name = "float64"
case mysql.TypeJSON:
name = "string"
case mysql.TypeEnum:
name = "string"
default:
return "UnSupport", ""
}
if style == NullInPointer {
name = "*" + name
}
}
return
}
func makeTagStr(tags []string) string {
builder := strings.Builder{}
for i := 0; i < len(tags)/2; i++ {
builder.WriteString(tags[i*2])
builder.WriteString(`:"`)
builder.WriteString(tags[i*2+1])
builder.WriteString(`" `)
}
if builder.Len() > 0 {
return builder.String()[:builder.Len()-1]
}
return builder.String()
}
func getDefaultValue(expr ast.ExprNode) (value string) {
if expr.GetDatum().Kind() != types.KindNull {
value = fmt.Sprintf("%v", expr.GetDatum().GetValue())
} else if expr.GetFlag() != ast.FlagConstant {
if expr.GetFlag() == ast.FlagHasFunc {
if funcExpr, ok := expr.(*ast.FuncCallExpr); ok {
value = funcExpr.FnName.O
}
}
}
return
}
func initTemplate() {
tmplParseOnce.Do(func() {
var err error
structTmpl, err = template.New("goStruct").Parse(structTmplRaw)
if err != nil {
panic(err)
}
fileTmpl, err = template.New("goFile").Parse(fileTmplRaw)
if err != nil {
panic(err)
}
})
}
func init() {
structTmplRaw = `
{{- if .Comment -}}
// {{.Comment}}
{{end -}}
type {{.TableName}} struct {
{{- range .Fields}}
{{.Name}} {{.GoType}} {{if .Tag}}` + "`{{.Tag}}`" + `{{end}}{{if .Comment}} // {{.Comment}}{{end}}
{{- end}}
}
{{if .NameFunc}}
func (m *{{.TableName}}) TableName() string {
return "{{.RawTableName}}"
}
{{end}}`
fileTmplRaw = `
package {{.Package}}
{{if .ImportPath}}
import (
{{- range .ImportPath}}
"{{.}}"
{{- end}}
)
{{- end}}
{{range .StructCode}}
{{.}}
{{end}}
`
}
func main() {
r := gin.Default()
r.Use(Cors())
r.POST("/sql", func(c *gin.Context) {
var req struct {
Content string `json:"content"`
Camel bool `json:"camel"`
TablePrefix string `json:"table_prefix"` // 去除表前缀
}
err := c.BindJSON(&req)
if err != nil {
c.String(http.StatusBadRequest, err.Error())
return
}
opts := []Option{}
opts = append(opts, WithGormType())
opts = append(opts, WithJsonTag())
opts = append(opts, WithZhTag())
opts = append(opts, WithTablePrefix(req.TablePrefix))
if req.Camel {
opts = append(opts, WithCamel())
}
res, err := ParseSqlFormat(req.Content,
opts...,
)
if err != nil {
c.String(http.StatusInternalServerError, err.Error())
return
}
c.String(http.StatusOK, string(res))
})
// static 嵌入
r.GET("/", func(c *gin.Context) {
indexHTML, _ := dist.ReadFile("dist/index.html")
c.Writer.Write(indexHTML)
})
static, _ := fs.Sub(dist, "dist/static")
r.StaticFS("/static", http.FS(static))
if len(Open) == 0 {
OpenBrowser("http://localhost:8333")
}
r.Run(":8333")
}
func Cors() gin.HandlerFunc {
return cors.New(cors.Config{
AllowAllOrigins: true,
AllowCredentials: true,
AllowHeaders: []string{"*"},
MaxAge: time.Second * time.Duration(7200),
})
}