package db import ( "database/sql" "errors" "fmt" "sync" "time" // PostgreSQL _ "github.com/lib/pq" // MySQL //_ "github.com/go-sql-driver/mysql" "github.com/jmoiron/sqlx" ) var ( config Config db *sqlx.DB err error once sync.Once defaultDB *DB ) // DB define type DB struct { Driver string conn *sqlx.DB tx *sqlx.Tx } // SetConfig set func SetConfig(cfg Config) { config.Driver = cfg.Driver config.DNS = cfg.DNS config.MaxOpenConns = cfg.MaxOpenConns config.MaxIdle = cfg.MaxIdle config.MaxLifetime = cfg.MaxLifetime * time.Second defaultDB = &DB{Driver: config.Driver} } // New new DB object func New() *DB { return &DB{Driver: config.Driver} } // Release free db connect func Release() { if db != nil { db.Close() } } // NewConfig new DB dynamic object func NewConfig(config Config) (dbx *DB, err error) { //dbx = &DB{} dbx.Driver = config.Driver dbx.conn, err = sqlx.Connect(config.Driver, config.DNS) if err != nil { return } dbx.conn.SetMaxOpenConns(config.MaxOpenConns) dbx.conn.SetMaxIdleConns(config.MaxIdle) dbx.conn.SetConnMaxLifetime(config.MaxLifetime) dbx.conn.Ping() return } // ReleaseConfig free db connect func ReleaseConfig(dbx *DB) { if dbx.conn != nil { dbx.conn.Close() } } func connect() (dbx *sqlx.DB, err error) { once.Do(func() { db, err = sqlx.Connect(config.Driver, config.DNS) if err == nil { db.DB.SetMaxOpenConns(config.MaxOpenConns) db.DB.SetMaxIdleConns(config.MaxIdle) db.DB.SetConnMaxLifetime(config.MaxLifetime) db.Ping() } }) dbx = db return } // Stats Stats returns database statistics. func (d *DB) Stats() (s sql.DBStats) { s = d.conn.DB.Stats() return } // Connect connect to database func (d *DB) Connect() (err error) { if d.conn != nil { return } d.conn, err = connect() return } // Close close database connect func (d *DB) Close() { // use pool //d.conn.Close() } // BeginTrans begin trans func (d *DB) BeginTrans() (err error) { d.conn, err = connect() if err != nil { return } d.tx = d.conn.MustBegin() return } // Commit commit func (d *DB) Commit() error { return d.tx.Commit() } // Rollback rollback func (d *DB) Rollback() error { return d.tx.Rollback() } // TransExec trans execute func (d *DB) TransExec(query string, args interface{}) (LastInsertId, RowsAffected int64, err error) { if rs, err := d.tx.NamedExec(query, args); err == nil { RowsAffected, _ = rs.RowsAffected() LastInsertId, _ = rs.LastInsertId() } return } // TransUpdate trans update func (d *DB) TransUpdate(query string, args interface{}) (reply Reply) { var ( err error rs sql.Result ) if rs, err = d.tx.NamedExec(query, args); err == nil { a, _ := rs.RowsAffected() reply = ReplyOk(a, 0) } else { reply = ReplyFaild(ErrException, err, errors.New(`数据执行错误`)) } return } // TransRow trans get row func (d *DB) TransRow(dest interface{}, query string, args interface{}) error { nstmt, err := d.tx.PrepareNamed(query) if err != nil { return err } defer nstmt.Close() err = nstmt.Get(dest, args) //err = d.tx.Get(dest, query, args) return err } // Select select func (d *DB) Select(dest interface{}, query string, args ...interface{}) error { err := d.Connect() if err != nil { return err } defer d.Close() err = d.conn.Select(dest, query, args...) return err } // Rows get rows func (d *DB) Rows(dest interface{}, query string, args interface{}) error { err := d.Connect() if err != nil { return err } defer d.Close() nstmt, err := d.conn.PrepareNamed(query) if err != nil { return err } defer nstmt.Close() err = nstmt.Select(dest, args) return err } // Get get func (d *DB) Get(dest interface{}, query string, args ...interface{}) error { err := d.Connect() if err != nil { return err } defer d.Close() err = d.conn.Get(dest, query, args...) return err } // Row get row func (d *DB) Row(dest interface{}, query string, args interface{}) error { err := d.Connect() if err != nil { return err } defer d.Close() nstmt, err := d.conn.PrepareNamed(query) if err != nil { return err } defer nstmt.Close() err = nstmt.Get(dest, args) return err } // InsertReply insert and return DbReply func (d *DB) InsertReply(query string, args interface{}) (reply Reply) { var ( err error rs sql.Result ) err = d.Connect() if err != nil { reply = ReplyFaild(ErrNotConnect, err, errors.New(`数据库连接错误`)) return } defer d.Close() if rs, err = d.conn.NamedExec(query, args); err == nil { a, _ := rs.RowsAffected() n, _ := rs.LastInsertId() reply = ReplyOk(a, n) } else { reply = ReplyFaild(ErrException, err, errors.New(`数据执行错误`)) } return } // UpdateReply update/delete and return DbReply func (d *DB) UpdateReply(query string, args interface{}) (reply Reply) { var ( err error rs sql.Result ) err = d.Connect() if err != nil { reply = ReplyFaild(ErrNotConnect, err, errors.New(`数据库连接错误`)) return } defer d.Close() if rs, err = d.conn.NamedExec(query, args); err == nil { a, _ := rs.RowsAffected() reply = ReplyOk(a, 0) } else { reply = ReplyFaild(ErrException, err, errors.New(`数据执行错误`)) } return } // Insert insert into func (d *DB) Insert(query string, args interface{}) (LastInsertId, RowsAffected int64, err error) { err = d.Connect() if err != nil { return } defer d.Close() var rs sql.Result if rs, err = d.conn.NamedExec(query, args); err == nil { LastInsertId, _ = rs.LastInsertId() RowsAffected, _ = rs.RowsAffected() } return } // Update update/delete func (d *DB) Update(query string, args interface{}) (RowsAffected int64, err error) { err = d.Connect() if err != nil { return } defer d.Close() var rs sql.Result if rs, err = d.conn.NamedExec(query, args); err == nil { RowsAffected, _ = rs.RowsAffected() } return } // Exec exec func (d *DB) Exec(query string, args ...interface{}) (LastInsertId, RowsAffected int64, err error) { err = d.Connect() if err != nil { return } defer d.Close() var rs sql.Result if rs, err = d.conn.Exec(query, args...); err == nil { LastInsertId, _ = rs.LastInsertId() RowsAffected, _ = rs.RowsAffected() } return } // Limit MySQL limit func (d *DB) Limit(page, pagesize int) string { // MySQL limit 0, size if d.Driver == `mysql` { return fmt.Sprintf(" limit %d, %d", (page-1)*pagesize, pagesize) } // // PostgreSQL limit size offset 0 return fmt.Sprintf(" limit %d offset %d", pagesize, (page-1)*pagesize) } // Select select func Select(dest interface{}, query string, args ...interface{}) (err error) { defaultDB.conn, err = connect() if err != nil { return err } err = defaultDB.conn.Select(dest, query, args...) return } // Rows get rows func Rows(dest interface{}, query string, args interface{}) (err error) { defaultDB.conn, err = connect() if err != nil { return err } nstmt, err := defaultDB.conn.PrepareNamed(query) if err != nil { return } defer nstmt.Close() err = nstmt.Select(dest, args) return } // Get get func Get(dest interface{}, query string, args ...interface{}) (err error) { defaultDB.conn, err = connect() if err != nil { return } err = defaultDB.conn.Get(dest, query, args...) return } // Row get row func Row(dest interface{}, query string, args interface{}) (err error) { defaultDB.conn, err = connect() if err != nil { return } nstmt, err := defaultDB.conn.PrepareNamed(query) if err != nil { return err } defer nstmt.Close() err = nstmt.Get(dest, args) return } func Exec(query string, args ...interface{}) (LastInsertId, RowsAffected int64, err error) { defaultDB.conn, err = connect() if err != nil { return } LastInsertId, RowsAffected, err = defaultDB.conn.Exec(query, args...) return }