ls 4 anos atrás
pai
commit
8ca334a961
1 arquivos alterados com 212 adições e 7 exclusões
  1. 212 7
      types/string.go

+ 212 - 7
types/string.go

@@ -1,7 +1,13 @@
 package types
 
 import (
+	"database/sql"
 	"database/sql/driver"
+	"errors"
+	"fmt"
+	"reflect"
+	"strconv"
+	"time"
 )
 
 // TextNull is an implementation of a string for the MySQL type char/varchar/text ....
@@ -16,17 +22,216 @@ func (s TextNull) Value() (driver.Value, error) {
 // Scan implements the sql.Scanner interface,
 // and turns the bytes incoming from MySQL into a string
 func (s *TextNull) Scan(src interface{}) error {
-	if src != nil {
-		v, ok := src.([]byte)
-		if !ok {
-			*s = TextNull("")
+	if src == nil {
+		*s = TextNull("")
+	}
+
+	convertAssign(&s, src)
+
+	return nil
+}
+
+var errNilPtr = errors.New("destination pointer is nil") // embedded in descriptive error
+// convertAssign is the same as convertAssignRows, but without the optional
+// rows argument.
+func convertAssign(dest, src interface{}) error {
+	return convertAssignRows(dest, src)
+}
+
+// convertAssignRows copies to dest the value in src, converting it if possible.
+// An error is returned if the copy would result in loss of information.
+// dest should be a pointer type. If rows is passed in, the rows will
+// be used as the parent for any cursor values converted from a
+// driver.Rows to a *Rows.
+func convertAssignRows(dest, src interface{}) error {
+	// Common cases, without reflect.
+	switch s := src.(type) {
+	case string:
+		switch d := dest.(type) {
+		case *string:
+			if d == nil {
+				return errNilPtr
+			}
+			*d = s
+			return nil
+		case *[]byte:
+			if d == nil {
+				return errNilPtr
+			}
+			*d = []byte(s)
+			return nil
+		case *sql.RawBytes:
+			if d == nil {
+				return errNilPtr
+			}
+			*d = append((*d)[:0], s...)
+			return nil
+		}
+	case []byte:
+		switch d := dest.(type) {
+		case *string:
+			if d == nil {
+				return errNilPtr
+			}
+			*d = string(s)
+			return nil
+		case *interface{}:
+			if d == nil {
+				return errNilPtr
+			}
+			*d = cloneBytes(s)
+			return nil
+		case *[]byte:
+			if d == nil {
+				return errNilPtr
+			}
+			*d = cloneBytes(s)
+			return nil
+		case *sql.RawBytes:
+			if d == nil {
+				return errNilPtr
+			}
+			*d = s
+			return nil
+		}
+	case time.Time:
+		switch d := dest.(type) {
+		case *time.Time:
+			*d = s
+			return nil
+		case *string:
+			*d = s.Format(time.RFC3339Nano)
+			return nil
+		case *[]byte:
+			if d == nil {
+				return errNilPtr
+			}
+			*d = []byte(s.Format(time.RFC3339Nano))
+			return nil
+		case *sql.RawBytes:
+			if d == nil {
+				return errNilPtr
+			}
+			*d = s.AppendFormat((*d)[:0], time.RFC3339Nano)
+			return nil
+		}
+	case nil:
+		switch d := dest.(type) {
+		case *interface{}:
+			if d == nil {
+				return errNilPtr
+			}
+			*d = nil
+			return nil
+		case *[]byte:
+			if d == nil {
+				return errNilPtr
+			}
+			*d = nil
+			return nil
+		case *sql.RawBytes:
+			if d == nil {
+				return errNilPtr
+			}
+			*d = nil
+			return nil
+		}
+	// The driver is returning a cursor the client may iterate over.
+	case driver.Rows:
+		switch d := dest.(type) {
+		case *sql.Rows:
+			if d == nil {
+				return errNilPtr
+			}
 			return nil
 		}
+	}
+
+	var sv reflect.Value
 
-		*s = TextNull(v)
+	switch d := dest.(type) {
+	case *string:
+		sv = reflect.ValueOf(src)
+		switch sv.Kind() {
+		case reflect.Bool,
+			reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
+			reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
+			reflect.Float32, reflect.Float64:
+			*d = asString(src)
+			return nil
+		}
+	case *[]byte:
+		sv = reflect.ValueOf(src)
+		if b, ok := asBytes(nil, sv); ok {
+			*d = b
+			return nil
+		}
+	case *sql.RawBytes:
+		sv = reflect.ValueOf(src)
+		if b, ok := asBytes([]byte(*d)[:0], sv); ok {
+			*d = sql.RawBytes(b)
+			return nil
+		}
+	case *bool:
+		bv, err := driver.Bool.ConvertValue(src)
+		if err == nil {
+			*d = bv.(bool)
+		}
+		return err
+	case *interface{}:
+		*d = src
 		return nil
 	}
+	return fmt.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, dest)
+}
 
-	*s = TextNull("")
-	return nil
+func cloneBytes(b []byte) []byte {
+	if b == nil {
+		return nil
+	}
+	c := make([]byte, len(b))
+	copy(c, b)
+	return c
+}
+
+func asString(src interface{}) string {
+	switch v := src.(type) {
+	case string:
+		return v
+	case []byte:
+		return string(v)
+	}
+	rv := reflect.ValueOf(src)
+	switch rv.Kind() {
+	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+		return strconv.FormatInt(rv.Int(), 10)
+	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+		return strconv.FormatUint(rv.Uint(), 10)
+	case reflect.Float64:
+		return strconv.FormatFloat(rv.Float(), 'g', -1, 64)
+	case reflect.Float32:
+		return strconv.FormatFloat(rv.Float(), 'g', -1, 32)
+	case reflect.Bool:
+		return strconv.FormatBool(rv.Bool())
+	}
+	return fmt.Sprintf("%v", src)
+}
+
+func asBytes(buf []byte, rv reflect.Value) (b []byte, ok bool) {
+	switch rv.Kind() {
+	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
+		return strconv.AppendInt(buf, rv.Int(), 10), true
+	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
+		return strconv.AppendUint(buf, rv.Uint(), 10), true
+	case reflect.Float32:
+		return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 32), true
+	case reflect.Float64:
+		return strconv.AppendFloat(buf, rv.Float(), 'g', -1, 64), true
+	case reflect.Bool:
+		return strconv.AppendBool(buf, rv.Bool()), true
+	case reflect.String:
+		s := rv.String()
+		return append(buf, s...), true
+	}
+	return
 }