183 lines
4.8 KiB
Go
183 lines
4.8 KiB
Go
package main
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
"os"
|
|
"strings"
|
|
|
|
"github.com/spf13/cobra"
|
|
"gorm.io/driver/mysql"
|
|
"gorm.io/driver/postgres"
|
|
"gorm.io/driver/sqlite"
|
|
"gorm.io/gen"
|
|
"gorm.io/gen/field"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
var (
|
|
dbType string
|
|
dsn string
|
|
)
|
|
|
|
// genModels is gorm/gen generated models
|
|
func genModels(g *gen.Generator, db *gorm.DB, tables []string) (err error) {
|
|
if len(tables) == 0 {
|
|
// Execute tasks for all tables in the database
|
|
tables, err = db.Migrator().GetTables()
|
|
if err != nil {
|
|
return fmt.Errorf("GORM migrator get all tables fail: %w", err)
|
|
}
|
|
}
|
|
|
|
// Custom ModelOpt to remove default tag for nullable fields
|
|
modelOpt := gen.FieldGORMTag("*", func(tag field.GormTag) field.GormTag {
|
|
if vals, ok := tag["default"]; ok {
|
|
if len(vals) > 0 {
|
|
val := strings.Trim(strings.TrimSpace(vals[0]), `"'`)
|
|
if strings.ToUpper(val) == "NULL" || val == "0" || val == "" {
|
|
tag.Remove("default")
|
|
}
|
|
}
|
|
}
|
|
return tag
|
|
})
|
|
|
|
// Execute some data table tasks
|
|
for _, tableName := range tables {
|
|
if tableName == "Type" {
|
|
// Special handling for Type table to set TypeKind as int32
|
|
g.GenerateModel(tableName, gen.FieldType("type_kind", "int32"), modelOpt)
|
|
} else {
|
|
g.GenerateModel(tableName, modelOpt)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// getDialector returns the appropriate GORM dialector based on database type and DSN
|
|
func getDialector(dbType, dsn string) (gorm.Dialector, error) {
|
|
switch dbType {
|
|
case "mysql":
|
|
return mysql.Open(dsn), nil
|
|
case "postgres", "postgresql":
|
|
return postgres.Open(dsn), nil
|
|
case "sqlite":
|
|
return sqlite.Open(dsn), nil
|
|
default:
|
|
return nil, fmt.Errorf("unsupported database type: %s. Supported types: mysql, postgres, sqlite, sqlserver", dbType)
|
|
}
|
|
}
|
|
|
|
// rootCmd represents the base command
|
|
var rootCmd = &cobra.Command{
|
|
Use: "gorm-gen",
|
|
Short: "GORM code generator for model-registry database schemas",
|
|
Long: `GORM code generator for model-registry database schemas.
|
|
|
|
This tool generates GORM model structs from database tables for the model-registry project.
|
|
It supports multiple database types including MySQL, PostgreSQL, SQLite, and SQL Server.
|
|
|
|
The generated models are placed in the ../internal/db/schema directory.`,
|
|
RunE: func(cmd *cobra.Command, args []string) error {
|
|
return runGenerate()
|
|
},
|
|
}
|
|
|
|
func runGenerate() error {
|
|
// Allow environment variable overrides
|
|
if envDBType := os.Getenv("GORM_GEN_DB_TYPE"); envDBType != "" {
|
|
dbType = envDBType
|
|
}
|
|
|
|
if envDSN := os.Getenv("GORM_GEN_DSN"); envDSN != "" {
|
|
dsn = envDSN
|
|
}
|
|
|
|
// Use default DSN if not provided
|
|
if dsn == "" {
|
|
return fmt.Errorf("Please provide a DSN using --dsn flag or GORM_GEN_DSN environment variable for %s database", dbType)
|
|
}
|
|
|
|
fmt.Printf("Connecting to %s database...\n", dbType)
|
|
|
|
// Get the appropriate dialector
|
|
dialector, err := getDialector(dbType, dsn)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get database dialector: %w", err)
|
|
}
|
|
|
|
// Connect to database
|
|
db, err := gorm.Open(dialector, &gorm.Config{})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to connect to database: %w", err)
|
|
}
|
|
|
|
fmt.Println("Database connection successful!")
|
|
|
|
// Initialize the generator with configuration for models only
|
|
config := gen.Config{
|
|
OutPath: "../internal/db/schema",
|
|
ModelPkgPath: "schema",
|
|
Mode: 0,
|
|
FieldNullable: true,
|
|
FieldCoverable: true,
|
|
FieldSignable: true,
|
|
FieldWithIndexTag: false,
|
|
FieldWithTypeTag: false,
|
|
}
|
|
|
|
if dbType == "postgres" || dbType == "postgresql" {
|
|
dataTypeMap := map[string]func(gorm.ColumnType) string{
|
|
"bytea": func(columnType gorm.ColumnType) string {
|
|
return "[]byte"
|
|
},
|
|
}
|
|
config.WithDataTypeMap(dataTypeMap)
|
|
}
|
|
g := gen.NewGenerator(config)
|
|
|
|
// Use the database connection
|
|
g.UseDB(db)
|
|
|
|
// Generate models for all tables using custom function
|
|
err = genModels(g, db, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to generate models: %w", err)
|
|
}
|
|
|
|
// Generate the code
|
|
fmt.Printf("Generating GORM models for %s database...\n", dbType)
|
|
g.Execute()
|
|
fmt.Println("GORM models generated successfully!")
|
|
|
|
return nil
|
|
}
|
|
|
|
func init() {
|
|
// Define flags
|
|
rootCmd.Flags().StringVar(&dbType, "db-type", "mysql", "Database type (mysql, postgres, sqlite, sqlserver)")
|
|
rootCmd.Flags().StringVar(&dsn, "dsn", "", "Database connection string (DSN). If not provided, uses default for the database type")
|
|
|
|
// Add examples to the help
|
|
rootCmd.Example = ` # Generate models for MySQL (default)
|
|
gorm-gen --db-type=mysql --dsn="user:pass@tcp(localhost:3306)/dbname"
|
|
|
|
# Generate models for PostgreSQL
|
|
gorm-gen --db-type=postgres --dsn="host=localhost user=postgres dbname=mydb"
|
|
|
|
# Generate models for SQLite
|
|
gorm-gen --db-type=sqlite --dsn="./database.db"
|
|
|
|
# Use environment variables
|
|
export GORM_GEN_DB_TYPE=postgres
|
|
export GORM_GEN_DSN="host=localhost user=postgres dbname=mydb"
|
|
gorm-gen`
|
|
}
|
|
|
|
func main() {
|
|
if err := rootCmd.Execute(); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|