浏览代码

update db

ls 8 月之前
父节点
当前提交
471186dfe2
共有 5 个文件被更改,包括 70 次插入23 次删除
  1. 2 2
      binding/protobuf.go
  2. 51 5
      db/db.go
  3. 5 0
      db/struct.go
  4. 10 14
      request.go
  5. 2 2
      types/gzip.go

+ 2 - 2
binding/protobuf.go

@@ -5,7 +5,7 @@
 package binding
 
 import (
-	"io/ioutil"
+	"io"
 	"net/http"
 
 	"github.com/golang/protobuf/proto"
@@ -18,7 +18,7 @@ func (protobufBinding) Name() string {
 }
 
 func (b protobufBinding) Bind(req *http.Request, obj interface{}) error {
-	buf, err := ioutil.ReadAll(req.Body)
+	buf, err := io.ReadAll(req.Body)
 	if err != nil {
 		return err
 	}

+ 51 - 5
db/db.go

@@ -64,9 +64,8 @@ func ReleaseConfig(dbx *DB) {
 
 func connect() (dbx *sqlx.DB, err error) {
 	once.Do(func() {
-		db, err = sqlx.Connect(config.Driver, config.DNS)
+		db, err = sqlx.Open(config.Driver, config.DNS) // sqlx.Connect(config.Driver, config.DNS)
 		if err != nil {
-			fmt.Println("Connect ERR", err)
 			return
 		}
 		db.DB.SetMaxOpenConns(config.MaxOpenConns)
@@ -76,7 +75,6 @@ func connect() (dbx *sqlx.DB, err error) {
 		/*
 			err = db.Ping()
 			if err != nil {
-				fmt.Println("Connect Ping", err)
 				return
 			}
 			// */
@@ -93,7 +91,6 @@ func connectContext(ctx context.Context) (dbx *sqlx.DB, err error) {
 	once.Do(func() {
 		db, err = sqlx.ConnectContext(ctx, config.Driver, config.DNS)
 		if err != nil {
-			fmt.Println("Connect ERR", err)
 			return
 		}
 		db.DB.SetMaxOpenConns(config.MaxOpenConns)
@@ -267,6 +264,30 @@ func (d *DB) TransRow(dest interface{}, query string, args interface{}) (err err
 	return err
 }
 
+// Preparex a statement within a transaction.
+func (d *DB) Preparex(query string) (stmt *Stmt, err error) {
+	stmt, err = d.conn.Preparex(query)
+	return
+}
+
+// PreparexContext returns an sqlx.Stmt instead of a sql.Stmt.
+func (d *DB) PreparexContext(ctx context.Context, query string) (stmt *Stmt, err error) {
+	stmt, err = d.conn.PreparexContext(ctx, query)
+	return
+}
+
+// PrepareNamed returns an sqlx.NamedStmt
+func (d *DB) PrepareNamed(query string) (stmt *NamedStmt, err error) {
+	stmt, err = d.conn.PrepareNamed(query)
+	return
+}
+
+// PrepareNamedContext returns an sqlx.NamedStmt
+func (d *DB) PrepareNamedContext(ctx context.Context, query string) (stmt *NamedStmt, err error) {
+	stmt, err = d.conn.PrepareNamedContext(ctx, query)
+	return
+}
+
 // Select select
 func (d *DB) Select(dest interface{}, query string, args ...interface{}) (err error) {
 	err = d.Connect()
@@ -453,7 +474,7 @@ func (d *DB) RowContext(ctx context.Context, dest interface{}, query string, arg
 	return
 }
 
-/*
+// *
 // In expands slice values in args, returning the modified query string and a new arg list that can be executed by a database. The `query` should use the `?` bindVar. The return value uses the `?` bindVar.
 func (d *DB) In(query string, args ...interface{}) (q string, params []interface{}, err error) {
 	err = d.Connect()
@@ -466,6 +487,7 @@ func (d *DB) In(query string, args ...interface{}) (q string, params []interface
 	q = d.conn.Rebind(s)
 	return
 }
+
 //*/
 
 // InsertReply insert and return DbReply
@@ -699,6 +721,30 @@ func (t *Tx) ExecContext(ctx context.Context, query string, args ...interface{})
 	return
 }
 
+// Preparex a statement within a transaction.
+func (t *Tx) Preparex(query string) (stmt *Stmt, err error) {
+	stmt, err = t.tx.Preparex(query)
+	return
+}
+
+// PreparexContext a statement within a transaction.
+func (t *Tx) PreparexContext(ctx context.Context, query string) (stmt *Stmt, err error) {
+	stmt, err = t.tx.PreparexContext(ctx, query)
+	return
+}
+
+// PrepareNamed returns an sqlx.NamedStmt
+func (t *Tx) PrepareNamed(query string) (stmt *NamedStmt, err error) {
+	stmt, err = t.tx.PrepareNamed(query)
+	return
+}
+
+// PrepareNamedContext returns an sqlx.NamedStmt
+func (t *Tx) PrepareNamedContext(ctx context.Context, query string) (stmt *NamedStmt, err error) {
+	stmt, err = t.tx.PrepareNamedContext(ctx, query)
+	return
+}
+
 // Query executes a query that returns rows, typically a SELECT. with named args
 func (t *Tx) Query(dest interface{}, query string, args interface{}) (err error) {
 	nstmt := &sqlx.NamedStmt{}

+ 5 - 0
db/struct.go

@@ -1 +1,6 @@
 package db
+
+import "github.com/jmoiron/sqlx"
+
+type Stmt = sqlx.Stmt
+type NamedStmt = sqlx.NamedStmt

+ 10 - 14
request.go

@@ -8,7 +8,6 @@ import (
 	"encoding/xml"
 	"errors"
 	"io"
-	"io/ioutil"
 	"net"
 	"net/http"
 	"time"
@@ -85,12 +84,13 @@ func (m HTTPMessage) JSONQuery() (jq *JSONQuery, err error) {
 func newRequest(method, uri, certPath, keyPath string, header map[string]string, body io.Reader) (res *http.Response, err error) {
 	t := &http.Transport{
 		Dial: func(netw, addr string) (net.Conn, error) {
-			conn, err := net.DialTimeout(netw, addr, time.Second*RequestTimeOut)
+			var c net.Conn
+			c, err = net.DialTimeout(netw, addr, time.Second*RequestTimeOut)
 			if err != nil {
 				return nil, err
 			}
-			conn.SetDeadline(time.Now().Add(time.Second * reqTimeOut))
-			return conn, nil
+			c.SetDeadline(time.Now().Add(time.Second * reqTimeOut))
+			return c, nil
 		},
 		ResponseHeaderTimeout: time.Second * reqTimeOut,
 	}
@@ -104,17 +104,13 @@ func newRequest(method, uri, certPath, keyPath string, header map[string]string,
 			t.TLSClientConfig = &tls.Config{InsecureSkipVerify: true, Certificates: []tls.Certificate{cert}, RootCAs: pool}
 		}
 	}
-	client := &http.Client{Transport: t}
+
 	var (
-		req *http.Request
+		req    *http.Request
+		client = &http.Client{Transport: t}
 	)
 
-	if body != nil {
-		req, err = http.NewRequest(method, uri, body)
-	} else {
-		req, err = http.NewRequest(method, uri, nil)
-	}
-
+	req, err = http.NewRequest(method, uri, body)
 	if err != nil {
 		return
 	}
@@ -137,10 +133,10 @@ func readBody(res *http.Response) (msg HTTPMessage, err error) {
 	case "gzip":
 		reader, err = gzip.NewReader(res.Body)
 		if err == nil {
-			body, err = ioutil.ReadAll(reader)
+			body, err = io.ReadAll(reader)
 		}
 	default:
-		body, err = ioutil.ReadAll(res.Body)
+		body, err = io.ReadAll(res.Body)
 	}
 	if err != nil {
 		return

+ 2 - 2
types/gzip.go

@@ -5,7 +5,7 @@ import (
 	"compress/gzip"
 	"database/sql/driver"
 	"errors"
-	"io/ioutil"
+	"io"
 )
 
 // GzippedText is a []byte which transparently gzips data being submitted to
@@ -41,7 +41,7 @@ func (g *GzippedText) Scan(src interface{}) error {
 		return err
 	}
 	defer reader.Close()
-	b, err := ioutil.ReadAll(reader)
+	b, err := io.ReadAll(reader)
 	if err != nil {
 		return err
 	}