diff --git a/main.go b/main.go index 5e4d14d..77d23aa 100644 --- a/main.go +++ b/main.go @@ -55,7 +55,6 @@ func main() { }() defer exec.DockerContainer.RemoveContainer() defer sqlite3.CloseDB() - sqlite3.InitGormDB() go func() { const reconnectDelay = 5 * time.Second for !stopRun { diff --git a/sqlite3/gorm.go b/sqlite3/gorm.go index 3ba3e23..279a9e5 100644 --- a/sqlite3/gorm.go +++ b/sqlite3/gorm.go @@ -1,12 +1,16 @@ package sqlite3 import ( + "sync" + "gorm.io/driver/sqlite" "gorm.io/gorm" ) var db *gorm.DB +var OnceInitGormDB sync.Once + func InitGormDB() { var err error db, err = gorm.Open(sqlite.Open("data.db"), &gorm.Config{}) @@ -16,27 +20,30 @@ func InitGormDB() { } func GetGormDB() *gorm.DB { + if db == nil { + OnceInitGormDB.Do(InitGormDB) + } return db } // TryCreateTable 使用GORM执行原始SQL创建表语句 func TryCreateTable(query string) error { - return db.Exec(query).Error + return GetGormDB().Exec(query).Error } // AutoMigrate 使用GORM的自动迁移功能 func AutoMigrate(models ...interface{}) error { - return db.AutoMigrate(models...) + return GetGormDB().AutoMigrate(models...) } // GetGormTx 获取GORM事务 func GetGormTx() *gorm.DB { - return db.Begin() + return GetGormDB().Begin() } // CloseDB 关闭数据库连接 func CloseDB() { - if sqlDB, err := db.DB(); err == nil { + if sqlDB, err := GetGormDB().DB(); err == nil { sqlDB.Close() } }