package db import ( "context" "database/sql" "fmt" "sync" // PostgreSQL _ "github.com/lib/pq" // MySQL _ "github.com/go-sql-driver/mysql" "github.com/jmoiron/sqlx" ) // DB db type DB struct { Driver string c *sqlx.DB tx *sqlx.Tx } type Tx struct { tx *sqlx.Tx } var ( defaultConfig Option defaultDb *DB db *sqlx.DB once sync.Once ) // SetDefaultOption default connect option func SetDefaultOption(opt Option) (err error) { defaultConfig = opt defaultDb = &DB{Driver: opt.Driver} defaultDb.c, err = connect() return } // ReleaseDefault release default connect func ReleaseDefault() error { if defaultDb != nil { if defaultDb.c != nil { defaultDb.c.Close() } } return nil } // New new DB dynamic object func New(opt Option) (dbx *DB, err error) { //dbx = &DB{} dbx.Driver = opt.Driver dbx.c, err = sqlx.Connect(opt.Driver, opt.DNS) if err != nil { return } dbx.c.SetMaxOpenConns(opt.MaxOpenConns) dbx.c.SetMaxIdleConns(opt.MaxIdle) dbx.c.SetConnMaxLifetime(opt.MaxLifetime) err = dbx.c.Ping() return } // Release release connect func Release(dbx *DB) (err error) { if dbx.c != nil { err = dbx.c.Close() } return } func connect() (dbx *sqlx.DB, err error) { once.Do(func() { db, err = sqlx.Connect(defaultConfig.Driver, defaultConfig.DNS) if err == nil { db.DB.SetMaxOpenConns(defaultConfig.MaxOpenConns) db.DB.SetMaxIdleConns(defaultConfig.MaxIdle) db.DB.SetConnMaxLifetime(defaultConfig.MaxLifetime) err = db.Ping() } }) dbx = db return } // Connect connect to database func (d *DB) Connect() (err error) { if d.c != nil { return } d.c, err = connect() return } /* // Close close database connect func (d *DB) Close() error { // not use pool if d.c != nil { return d.c.Close() } return nil } // */ // Ping verifies a connection to the database is still alive, establishing a connection if necessary. func (d *DB) Ping() error { return d.c.Ping() } // PingContext verifies a connection to the database is still alive, establishing a connection if necessary. func (d *DB) PingContext(ctx context.Context) error { return d.c.PingContext(ctx) } // Stats returns database statistics. func (d *DB) Stats() sql.DBStats { return d.c.Stats() } // BeginTrans begin trans func (d *DB) BeginTrans() (err error) { d.c, err = connect() if err != nil { return } d.tx = d.c.MustBegin() return } // Trans begin trans func (d *DB) Trans() (tx *Tx, err error) { d.c, err = connect() if err != nil { return } tx.tx = d.c.MustBegin() return } // Commit commit func (d *DB) Commit() error { return d.tx.Commit() } // Rollback rollback func (d *DB) Rollback() error { return d.tx.Rollback() } // TransNamedExec trans execute func (d *DB) TransExec(query string, args ...interface{}) (LastInsertId, RowsAffected int64, err error) { rs, err := d.tx.Exec(query, args...) if err != nil { return } RowsAffected, _ = rs.RowsAffected() LastInsertId, _ = rs.LastInsertId() return } // TransNamedExec trans execute, named bindvars func (d *DB) TransNamedExec(query string, args interface{}) (LastInsertId, RowsAffected int64, err error) { rs, err := d.tx.NamedExec(query, args) if err != nil { return } RowsAffected, _ = rs.RowsAffected() LastInsertId, _ = rs.LastInsertId() return } // TransGet trans get row func (d *DB) TransGet(dest interface{}, query string, args ...interface{}) (err error) { err = d.tx.Get(dest, query, args...) return } // TransNamedGet trans get row, named bindvars func (d *DB) TransNamedGet(dest interface{}, query string, args interface{}) (err error) { var nstmt *sqlx.NamedStmt nstmt, err = d.tx.PrepareNamed(query) if err != nil { return } defer nstmt.Close() err = nstmt.Get(dest, args) return } // TransSelect trans get rows func (d *DB) TransSelect(dest interface{}, query string, args ...interface{}) (err error) { err = d.tx.Select(dest, query, args...) return } // Get get one func (d *DB) Get(dest interface{}, query string, args ...interface{}) (err error) { err = d.Connect() if err != nil { return } //defer d.Close() err = d.c.Get(dest, query, args...) return } // Get get one, named bindvars func (d *DB) NamedGet(dest interface{}, query string, args interface{}) (err error) { err = d.Connect() if err != nil { return } //defer d.Close() nstmt, err := d.c.PrepareNamed(query) if err != nil { return } defer nstmt.Close() err = nstmt.Get(dest, args) return } // Select select rows func (d *DB) Select(dest interface{}, query string, args ...interface{}) error { err := d.Connect() if err != nil { return err } //defer d.Close() err = d.c.Select(dest, query, args...) return err } // NamedSelect select rows, named bindvars func (d *DB) NamedSelect(dest interface{}, query string, args interface{}) (err error) { err = d.Connect() if err != nil { return err } //defer d.Close() nstmt, err := d.c.PrepareNamed(query) if err != nil { return err } defer nstmt.Close() err = nstmt.Select(dest, args) return err } // Exec exec func (d *DB) Exec(query string, args ...interface{}) (LastInsertId, RowsAffected int64, err error) { err = d.Connect() if err != nil { return } //defer d.Close() var rs sql.Result rs, err = d.c.Exec(query, args...) if err != nil { return } LastInsertId, _ = rs.LastInsertId() RowsAffected, _ = rs.RowsAffected() return } // NamedExec exec, named bindvars func (d *DB) NamedExec(query string, args interface{}) (LastInsertId, RowsAffected int64, err error) { err = d.Connect() if err != nil { return } //defer d.Close() var rs sql.Result rs, err = d.c.NamedExec(query, args) if err != nil { return } LastInsertId, _ = rs.LastInsertId() RowsAffected, _ = rs.RowsAffected() return } // Limit MySQL/PostgreSQL limit func (d *DB) Limit(page, pagesize int) string { // MySQL limit n, size if d.Driver == DriverMySQL { return fmt.Sprintf(" LIMIT %d, %d", (page-1)*pagesize, pagesize) } // // PostgreSQL limit size offset n return fmt.Sprintf(" LIMIT %d OFFSET %d", pagesize, (page-1)*pagesize) } // Commit commit func (t *Tx) Commit() error { return t.tx.Commit() } // Rollback rollback func (t *Tx) Rollback() error { return t.tx.Rollback() } // TransNamedExec trans execute func (t *Tx) TransExec(query string, args ...interface{}) (LastInsertId, RowsAffected int64, err error) { rs, err := t.tx.Exec(query, args...) if err != nil { return } RowsAffected, _ = rs.RowsAffected() LastInsertId, _ = rs.LastInsertId() return } // TransNamedExec trans execute, named bindvars func (t *Tx) TransNamedExec(query string, args interface{}) (LastInsertId, RowsAffected int64, err error) { rs, err := t.tx.NamedExec(query, args) if err != nil { return } RowsAffected, _ = rs.RowsAffected() LastInsertId, _ = rs.LastInsertId() return } // TransGet trans get row func (t *Tx) TransGet(dest interface{}, query string, args ...interface{}) (err error) { err = t.tx.Get(dest, query, args...) return } // TransNamedGet trans get row, named bindvars func (t *Tx) TransNamedGet(dest interface{}, query string, args interface{}) (err error) { var nstmt *sqlx.NamedStmt nstmt, err = t.tx.PrepareNamed(query) if err != nil { return } defer nstmt.Close() err = nstmt.Get(dest, args) return } // TransSelect trans get rows func (t *Tx) TransSelect(dest interface{}, query string, args ...interface{}) (err error) { err = t.tx.Select(dest, query, args...) return } // Ping verifies a connection to the database is still alive, establishing a connection if necessary. func Ping() error { return defaultDb.Ping() } // PingContext verifies a connection to the database is still alive, establishing a connection if necessary. func PingContext(ctx context.Context) error { return defaultDb.PingContext(ctx) } /* // Close close database connect func Close() error { return defaultDb.Close() } // */ // Stats returns database statistics. func Stats() sql.DBStats { return defaultDb.Stats() } // BeginTrans begin trans func BeginTrans() (tx *Tx, err error) { return defaultDb.Trans() } // Get get one func Get(dest interface{}, query string, args ...interface{}) error { return defaultDb.Get(dest, query, args...) } // Get get one, named bindvars func NamedGet(dest interface{}, query string, args interface{}) (err error) { return defaultDb.NamedGet(dest, query, args) } // Select select rows func Select(dest interface{}, query string, args ...interface{}) error { return defaultDb.Select(dest, query, args...) } // NamedSelect select rows, named bindvars func NamedSelect(dest interface{}, query string, args interface{}) (err error) { return defaultDb.NamedSelect(dest, query, args) } // Exec execute func Exec(query string, args ...interface{}) (LastInsertId, RowsAffected int64, err error) { return defaultDb.Exec(query, args...) } // NamedExec exec, named bindvars func NamedExec(query string, args interface{}) (LastInsertId, RowsAffected int64, err error) { return defaultDb.NamedExec(query, args) } // Limit MySQL/PostgreSQL limit func Limit(page, pagesize int) string { return defaultDb.Limit(page, pagesize) }