Browse Source

fix update trans

ls 2 years ago
parent
commit
f43c9126d7
1 changed files with 77 additions and 38 deletions
  1. 77 38
      db/sqlx.go

+ 77 - 38
db/sqlx.go

@@ -20,6 +20,10 @@ type DB struct {
 	tx     *sqlx.Tx
 }
 
+type Tx struct {
+	tx *sqlx.Tx
+}
+
 var (
 	defaultConfig Option
 	defaultDb     *DB
@@ -131,6 +135,17 @@ func (d *DB) BeginTrans() (err error) {
 	return
 }
 
+// Trans begin trans
+func (d *DB) Trans() (tx *Tx, err error) {
+	d.c, err = connect()
+	if err != nil {
+		return
+	}
+
+	tx.tx = d.c.MustBegin()
+	return
+}
+
 // Commit commit
 func (d *DB) Commit() error {
 	return d.tx.Commit()
@@ -167,7 +182,7 @@ func (d *DB) TransNamedExec(query string, args interface{}) (LastInsertId, RowsA
 
 // TransGet trans get row
 func (d *DB) TransGet(dest interface{}, query string, args ...interface{}) (err error) {
-	d.tx.Get(dest, query, args...)
+	err = d.tx.Get(dest, query, args...)
 	return
 }
 
@@ -298,6 +313,65 @@ func (d *DB) Limit(page, pagesize int) string {
 	return fmt.Sprintf(" LIMIT %d OFFSET %d", pagesize, (page-1)*pagesize)
 }
 
+// Commit commit
+func (t *Tx) Commit() error {
+	return t.tx.Commit()
+}
+
+// Rollback rollback
+func (t *Tx) Rollback() error {
+	return t.tx.Rollback()
+}
+
+// TransNamedExec trans execute
+func (t *Tx) TransExec(query string, args ...interface{}) (LastInsertId, RowsAffected int64, err error) {
+	rs, err := t.tx.Exec(query, args...)
+	if err != nil {
+		return
+	}
+
+	RowsAffected, _ = rs.RowsAffected()
+	LastInsertId, _ = rs.LastInsertId()
+	return
+}
+
+// TransNamedExec trans execute, named bindvars
+func (t *Tx) TransNamedExec(query string, args interface{}) (LastInsertId, RowsAffected int64, err error) {
+	rs, err := t.tx.NamedExec(query, args)
+	if err != nil {
+		return
+	}
+
+	RowsAffected, _ = rs.RowsAffected()
+	LastInsertId, _ = rs.LastInsertId()
+	return
+}
+
+// TransGet trans get row
+func (t *Tx) TransGet(dest interface{}, query string, args ...interface{}) (err error) {
+	err = t.tx.Get(dest, query, args...)
+	return
+}
+
+// TransNamedGet trans get row, named bindvars
+func (t *Tx) TransNamedGet(dest interface{}, query string, args interface{}) (err error) {
+	var nstmt *sqlx.NamedStmt
+	nstmt, err = t.tx.PrepareNamed(query)
+	if err != nil {
+		return
+	}
+	defer nstmt.Close()
+
+	err = nstmt.Get(dest, args)
+	return
+}
+
+// TransSelect trans get rows
+func (t *Tx) TransSelect(dest interface{}, query string, args ...interface{}) (err error) {
+	err = t.tx.Select(dest, query, args...)
+	return
+}
+
 // Ping verifies a connection to the database is still alive, establishing a connection if necessary.
 func Ping() error {
 	return defaultDb.Ping()
@@ -321,43 +395,8 @@ func Stats() sql.DBStats {
 }
 
 // BeginTrans begin trans
-func BeginTrans() (err error) {
-	return defaultDb.BeginTrans()
-}
-
-// Commit commit
-func Commit() error {
-	return defaultDb.Commit()
-}
-
-// Rollback rollback
-func Rollback() error {
-	return defaultDb.Rollback()
-}
-
-// TransNamedExec trans execute
-func TransExec(query string, args ...interface{}) (LastInsertId, RowsAffected int64, err error) {
-	return defaultDb.TransExec(query, args...)
-}
-
-// TransNamedExec trans execute, named bindvars
-func TransNamedExec(query string, args interface{}) (LastInsertId, RowsAffected int64, err error) {
-	return defaultDb.TransNamedExec(query, args)
-}
-
-// TransGet trans get row
-func TransGet(dest interface{}, query string, args interface{}) (err error) {
-	return defaultDb.TransGet(dest, query, args)
-}
-
-// TransNamedGet trans get row, named bindvars
-func TransNamedGet(dest interface{}, query string, args interface{}) (err error) {
-	return defaultDb.TransNamedGet(dest, query, args)
-}
-
-// TransSelect trans get rows
-func TransSelect(dest interface{}, query string, args ...interface{}) (err error) {
-	return defaultDb.TransSelect(dest, query, args...)
+func BeginTrans() (tx *Tx, err error) {
+	return defaultDb.Trans()
 }
 
 // Get get one