ls 5 years ago
parent
commit
08abaa93dd
10 changed files with 293 additions and 188 deletions
  1. 9 48
      binding/binding.go
  2. 16 20
      binding/default_validator.go
  3. 33 2
      binding/form.go
  4. 197 90
      binding/form_mapping.go
  5. 1 0
      binding/json.go
  6. 16 6
      binding/msgpack.go
  7. 18 17
      binding/protobuf.go
  8. 1 3
      binding/query.go
  9. 2 2
      binding/uri.go
  10. 0 0
      binding/xml.go

+ 9 - 48
binding/binding.go

@@ -4,11 +4,7 @@
 
 package binding
 
-import (
-	"net/http"
-	//validator "gopkg.in/go-playground/validator.v8"
-	//validator "gopkg.in/go-playground/validator.v8"
-)
+import "net/http"
 
 // Content-Type MIME of the most common data formats.
 const (
@@ -40,14 +36,13 @@ type BindingBody interface {
 	BindBody([]byte, interface{}) error
 }
 
-// BindingURI adds BindUri method to Binding. BindUri is similar with Bind,
+// BindingUri adds BindUri method to Binding. BindUri is similar with Bind,
 // but it read the Params.
-type BindingURI interface {
+type BindingUri interface {
 	Name() string
-	BindURI(map[string][]string, interface{}) error
+	BindUri(map[string][]string, interface{}) error
 }
 
-/*
 // StructValidator is the minimal interface which needs to be implemented in
 // order for it to be used as the validator engine for ensuring the correctness
 // of the request. Gin provides a default implementation for this using
@@ -68,8 +63,7 @@ type StructValidator interface {
 // Validator is the default validator which implements the StructValidator
 // interface. It uses https://github.com/go-playground/validator/tree/v8.18.2
 // under the hood.
-var Validator StructValidator = &defaultValidator{}
-// */
+//var Validator StructValidator = &defaultValidator{}
 
 // These implement the Binding interface and can be used to bind the data
 // present in the request to struct instances.
@@ -83,7 +77,7 @@ var (
 	ProtoBuf      = protobufBinding{}
 	MsgPack       = msgpackBinding{}
 	YAML          = yamlBinding{}
-	URI           = uriBinding{}
+	Uri           = uriBinding{}
 )
 
 // Default returns the appropriate Binding instance based on the HTTP method
@@ -104,7 +98,9 @@ func Default(method, contentType string) Binding {
 		return MsgPack
 	case MIMEYAML:
 		return YAML
-	default: //case MIMEPOSTForm, MIMEMultipartPOSTForm:
+	case MIMEMultipartPOSTForm:
+		return FormMultipart
+	default: // case MIMEPOSTForm:
 		return Form
 	}
 }
@@ -115,41 +111,6 @@ func validate(obj interface{}) error {
 		if Validator == nil {
 			return nil
 		}
-
 		return Validator.ValidateStruct(obj)
 		// */
 }
-
-// Bind checks the Content-Type to select a binding engine automatically,
-// Depending the "Content-Type" header different bindings are used:
-//     "application/json" --> JSON binding
-//     "application/xml"  --> XML binding
-// otherwise --> returns an error.
-// It parses the request's body as JSON if Content-Type == "application/json" using JSON or XML as a JSON input.
-// It decodes the json payload into the struct specified as a pointer.
-// It writes a 400 error and sets Content-Type header "text/plain" in the response if input is not valid.
-func Bind(req *http.Request, obj interface{}) error {
-	b := Default(req.Method, ContentType(req))
-	return MustBindWith(req, obj, b)
-}
-
-// MustBindWith binds the passed struct pointer using the specified binding engine.
-// It will abort the request with HTTP 400 if any error occurs.
-// See the binding package.
-func MustBindWith(req *http.Request, obj interface{}, b Binding) (err error) {
-	return b.Bind(req, obj)
-}
-
-// ContentType returns the Content-Type header of the request.
-func ContentType(req *http.Request) string {
-	return filterFlags(req.Header.Get("Content-Type"))
-}
-
-func filterFlags(content string) string {
-	for i, char := range content {
-		if char == ' ' || char == ';' {
-			return content[:i]
-		}
-	}
-	return content
-}

+ 16 - 20
binding/default_validator.go

@@ -14,39 +14,35 @@ type defaultValidator struct {
 
 var _ StructValidator = &defaultValidator{}
 
+// ValidateStruct receives any kind of type, but only performed struct or pointer to struct type.
 func (v *defaultValidator) ValidateStruct(obj interface{}) error {
-	if kindOfData(obj) == reflect.Struct {
+	value := reflect.ValueOf(obj)
+	valueType := value.Kind()
+	if valueType == reflect.Ptr {
+		valueType = value.Elem().Kind()
+	}
+	if valueType == reflect.Struct {
 		v.lazyinit()
 		if err := v.validate.Struct(obj); err != nil {
-			buf := bytes.NewBufferString("")
-			for _, v := range err.(validator.ValidationErrors) {
-				buf.WriteString(fmt.Sprintf("%s %s %s", v.Name, v.Tag, v.Param))
-				buf.WriteString(";")
-			}
-			return errors.New(buf.String()) //error(err)
+			return err
 		}
 	}
 	return nil
 }
 
-func (v *defaultValidator) RegisterValidation(key string, fn validator.Func) error {
+// Engine returns the underlying validator engine which powers the default
+// Validator instance. This is useful if you want to register custom validations
+// or struct level validations. See validator GoDoc for more info -
+// https://godoc.org/gopkg.in/go-playground/validator.v8
+func (v *defaultValidator) Engine() interface{} {
 	v.lazyinit()
-	return v.validate.RegisterValidation(key, fn)
+	return v.validate
 }
 
 func (v *defaultValidator) lazyinit() {
 	v.once.Do(func() {
-		config := &validator.Config{TagName: "validate"}
+		config := &validator.Config{TagName: "binding"}
 		v.validate = validator.New(config)
 	})
 }
-
-func kindOfData(data interface{}) reflect.Kind {
-	value := reflect.ValueOf(data)
-	valueType := value.Kind()
-	if valueType == reflect.Ptr {
-		valueType = value.Elem().Kind()
-	}
-	return valueType
-}
-//*/
+// */

+ 33 - 2
binding/form.go

@@ -5,7 +5,9 @@
 package binding
 
 import (
+	"mime/multipart"
 	"net/http"
+	"reflect"
 )
 
 const defaultMemory = 32 * 1024 * 1024
@@ -22,7 +24,11 @@ func (formBinding) Bind(req *http.Request, obj interface{}) error {
 	if err := req.ParseForm(); err != nil {
 		return err
 	}
-	req.ParseMultipartForm(defaultMemory)
+	if err := req.ParseMultipartForm(defaultMemory); err != nil {
+		if err != http.ErrNotMultipart {
+			return err
+		}
+	}
 	if err := mapForm(obj, req.Form); err != nil {
 		return err
 	}
@@ -51,8 +57,33 @@ func (formMultipartBinding) Bind(req *http.Request, obj interface{}) error {
 	if err := req.ParseMultipartForm(defaultMemory); err != nil {
 		return err
 	}
-	if err := mapForm(obj, req.MultipartForm.Value); err != nil {
+	if err := mappingByPtr(obj, (*multipartRequest)(req), "form"); err != nil {
 		return err
 	}
+
 	return validate(obj)
 }
+
+type multipartRequest http.Request
+
+var _ setter = (*multipartRequest)(nil)
+
+var (
+	multipartFileHeaderStructType = reflect.TypeOf(multipart.FileHeader{})
+)
+
+// TrySet tries to set a value by the multipart request with the binding a form file
+func (r *multipartRequest) TrySet(value reflect.Value, field reflect.StructField, key string, opt setOptions) (isSetted bool, err error) {
+	if value.Type() == multipartFileHeaderStructType {
+		_, file, err := (*http.Request)(r).FormFile(key)
+		if err != nil {
+			return false, err
+		}
+		if file != nil {
+			value.Set(reflect.ValueOf(*file))
+			return true, nil
+		}
+	}
+
+	return setByForm(value, field, r.MultipartForm.Value, key, opt)
+}

+ 197 - 90
binding/form_mapping.go

@@ -5,20 +5,19 @@
 package binding
 
 import (
+	"encoding/json"
 	"errors"
-	"net/url"
+	"fmt"
 	"reflect"
 	"strconv"
 	"strings"
 	"time"
+	//"github.com/gin-gonic/gin/internal/json"
 )
 
-// MapForm form values map to struct
-func MapForm(ptr interface{}, values url.Values) error {
-	return mapFormByTag(ptr, values, "form")
-}
+var errUnknownType = errors.New("Unknown type")
 
-func mapURI(ptr interface{}, m map[string][]string) error {
+func mapUri(ptr interface{}, m map[string][]string) error {
 	return mapFormByTag(ptr, m, "uri")
 }
 
@@ -26,121 +25,192 @@ func mapForm(ptr interface{}, form map[string][]string) error {
 	return mapFormByTag(ptr, form, "form")
 }
 
+var emptyField = reflect.StructField{}
+
 func mapFormByTag(ptr interface{}, form map[string][]string, tag string) error {
-	typ := reflect.TypeOf(ptr).Elem()
-	val := reflect.ValueOf(ptr).Elem()
-	for i := 0; i < typ.NumField(); i++ {
-		typeField := typ.Field(i)
-		structField := val.Field(i)
-		if !structField.CanSet() {
-			continue
-		}
+	return mappingByPtr(ptr, formSource(form), tag)
+}
 
-		structFieldKind := structField.Kind()
-		inputFieldName := typeField.Tag.Get(tag)
-		inputFieldNameList := strings.Split(inputFieldName, ",")
-		inputFieldName = inputFieldNameList[0]
-		var defaultValue string
-		if len(inputFieldNameList) > 1 {
-			defaultList := strings.SplitN(inputFieldNameList[1], "=", 2)
-			if defaultList[0] == "default" {
-				defaultValue = defaultList[1]
-			}
+// setter tries to set value on a walking by fields of a struct
+type setter interface {
+	TrySet(value reflect.Value, field reflect.StructField, key string, opt setOptions) (isSetted bool, err error)
+}
+
+type formSource map[string][]string
+
+var _ setter = formSource(nil)
+
+// TrySet tries to set a value by request's form source (like map[string][]string)
+func (form formSource) TrySet(value reflect.Value, field reflect.StructField, tagValue string, opt setOptions) (isSetted bool, err error) {
+	return setByForm(value, field, form, tagValue, opt)
+}
+
+func mappingByPtr(ptr interface{}, setter setter, tag string) error {
+	_, err := mapping(reflect.ValueOf(ptr), emptyField, setter, tag)
+	return err
+}
+
+func mapping(value reflect.Value, field reflect.StructField, setter setter, tag string) (bool, error) {
+	var vKind = value.Kind()
+
+	if vKind == reflect.Ptr {
+		var isNew bool
+		vPtr := value
+		if value.IsNil() {
+			isNew = true
+			vPtr = reflect.New(value.Type().Elem())
 		}
-		if inputFieldName == "" {
-			inputFieldName = typeField.Name
-
-			// if "form" tag is nil, we inspect if the field is a struct or struct pointer.
-			// this would not make sense for JSON parsing but it does for a form
-			// since data is flatten
-			if structFieldKind == reflect.Ptr {
-				if !structField.Elem().IsValid() {
-					structField.Set(reflect.New(structField.Type().Elem()))
-				}
-				structField = structField.Elem()
-				structFieldKind = structField.Kind()
-			}
-			if structFieldKind == reflect.Struct {
-				err := mapFormByTag(structField.Addr().Interface(), form, tag)
-				if err != nil {
-					return err
-				}
-				continue
-			}
+		isSetted, err := mapping(vPtr.Elem(), field, setter, tag)
+		if err != nil {
+			return false, err
 		}
-		inputValue, exists := form[inputFieldName]
+		if isNew && isSetted {
+			value.Set(vPtr)
+		}
+		return isSetted, nil
+	}
+
+	ok, err := tryToSetValue(value, field, setter, tag)
+	if err != nil {
+		return false, err
+	}
+	if ok {
+		return true, nil
+	}
+
+	if vKind == reflect.Struct {
+		tValue := value.Type()
 
-		if !exists {
-			if defaultValue == "" {
+		var isSetted bool
+		for i := 0; i < value.NumField(); i++ {
+			if !value.Field(i).CanSet() {
 				continue
 			}
-			inputValue = make([]string, 1)
-			inputValue[0] = defaultValue
+			ok, err := mapping(value.Field(i), tValue.Field(i), setter, tag)
+			if err != nil {
+				return false, err
+			}
+			isSetted = isSetted || ok
 		}
+		return isSetted, nil
+	}
+	return false, nil
+}
 
-		numElems := len(inputValue)
-		if structFieldKind == reflect.Slice && numElems > 0 {
-			sliceOf := structField.Type().Elem().Kind()
-			slice := reflect.MakeSlice(structField.Type(), numElems, numElems)
-			for i := 0; i < numElems; i++ {
-				if err := setWithProperType(sliceOf, inputValue[i], slice.Index(i)); err != nil {
-					return err
-				}
-			}
-			val.Field(i).Set(slice)
-			continue
+type setOptions struct {
+	isDefaultExists bool
+	defaultValue    string
+}
+
+func tryToSetValue(value reflect.Value, field reflect.StructField, setter setter, tag string) (bool, error) {
+	var tagValue string
+	var setOpt setOptions
+
+	tagValue = field.Tag.Get(tag)
+	tagValue, opts := head(tagValue, ",")
+
+	if tagValue == "-" { // just ignoring this field
+		return false, nil
+	}
+	if tagValue == "" { // default value is FieldName
+		tagValue = field.Name
+	}
+	if tagValue == "" { // when field is "emptyField" variable
+		return false, nil
+	}
+
+	var opt string
+	for len(opts) > 0 {
+		opt, opts = head(opts, ",")
+
+		k, v := head(opt, "=")
+		switch k {
+		case "default":
+			setOpt.isDefaultExists = true
+			setOpt.defaultValue = v
 		}
-		if _, isTime := structField.Interface().(time.Time); isTime {
-			if err := setTimeField(inputValue[0], typeField, structField); err != nil {
-				return err
-			}
-			continue
+	}
+
+	return setter.TrySet(value, field, tagValue, setOpt)
+}
+
+func setByForm(value reflect.Value, field reflect.StructField, form map[string][]string, tagValue string, opt setOptions) (isSetted bool, err error) {
+	vs, ok := form[tagValue]
+	if !ok && !opt.isDefaultExists {
+		return false, nil
+	}
+
+	switch value.Kind() {
+	case reflect.Slice:
+		if !ok {
+			vs = []string{opt.defaultValue}
 		}
-		if err := setWithProperType(typeField.Type.Kind(), inputValue[0], structField); err != nil {
-			return err
+		return true, setSlice(vs, value, field)
+	case reflect.Array:
+		if !ok {
+			vs = []string{opt.defaultValue}
+		}
+		if len(vs) != value.Len() {
+			return false, fmt.Errorf("%q is not valid value for %s", vs, value.Type().String())
 		}
+		return true, setArray(vs, value, field)
+	default:
+		var val string
+		if !ok {
+			val = opt.defaultValue
+		}
+
+		if len(vs) > 0 {
+			val = vs[0]
+		}
+		return true, setWithProperType(val, value, field)
 	}
-	return nil
 }
 
-func setWithProperType(valueKind reflect.Kind, val string, structField reflect.Value) error {
-	switch valueKind {
+func setWithProperType(val string, value reflect.Value, field reflect.StructField) error {
+	switch value.Kind() {
 	case reflect.Int:
-		return setIntField(val, 0, structField)
+		return setIntField(val, 0, value)
 	case reflect.Int8:
-		return setIntField(val, 8, structField)
+		return setIntField(val, 8, value)
 	case reflect.Int16:
-		return setIntField(val, 16, structField)
+		return setIntField(val, 16, value)
 	case reflect.Int32:
-		return setIntField(val, 32, structField)
+		return setIntField(val, 32, value)
 	case reflect.Int64:
-		return setIntField(val, 64, structField)
+		switch value.Interface().(type) {
+		case time.Duration:
+			return setTimeDuration(val, value, field)
+		}
+		return setIntField(val, 64, value)
 	case reflect.Uint:
-		return setUintField(val, 0, structField)
+		return setUintField(val, 0, value)
 	case reflect.Uint8:
-		return setUintField(val, 8, structField)
+		return setUintField(val, 8, value)
 	case reflect.Uint16:
-		return setUintField(val, 16, structField)
+		return setUintField(val, 16, value)
 	case reflect.Uint32:
-		return setUintField(val, 32, structField)
+		return setUintField(val, 32, value)
 	case reflect.Uint64:
-		return setUintField(val, 64, structField)
+		return setUintField(val, 64, value)
 	case reflect.Bool:
-		return setBoolField(val, structField)
+		return setBoolField(val, value)
 	case reflect.Float32:
-		return setFloatField(val, 32, structField)
+		return setFloatField(val, 32, value)
 	case reflect.Float64:
-		return setFloatField(val, 64, structField)
+		return setFloatField(val, 64, value)
 	case reflect.String:
-		structField.SetString(val)
-	case reflect.Ptr:
-		if !structField.Elem().IsValid() {
-			structField.Set(reflect.New(structField.Type().Elem()))
+		value.SetString(val)
+	case reflect.Struct:
+		switch value.Interface().(type) {
+		case time.Time:
+			return setTimeField(val, field, value)
 		}
-		structFieldElem := structField.Elem()
-		return setWithProperType(structFieldElem.Kind(), val, structFieldElem)
+		return json.Unmarshal([]byte(val), value.Addr().Interface())
+	case reflect.Map:
+		return json.Unmarshal([]byte(val), value.Addr().Interface())
 	default:
-		return errors.New("Unknown type")
+		return errUnknownType
 	}
 	return nil
 }
@@ -221,3 +291,40 @@ func setTimeField(val string, structField reflect.StructField, value reflect.Val
 	value.Set(reflect.ValueOf(t))
 	return nil
 }
+
+func setArray(vals []string, value reflect.Value, field reflect.StructField) error {
+	for i, s := range vals {
+		err := setWithProperType(s, value.Index(i), field)
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func setSlice(vals []string, value reflect.Value, field reflect.StructField) error {
+	slice := reflect.MakeSlice(value.Type(), len(vals), len(vals))
+	err := setArray(vals, slice, field)
+	if err != nil {
+		return err
+	}
+	value.Set(slice)
+	return nil
+}
+
+func setTimeDuration(val string, value reflect.Value, field reflect.StructField) error {
+	d, err := time.ParseDuration(val)
+	if err != nil {
+		return err
+	}
+	value.Set(reflect.ValueOf(d))
+	return nil
+}
+
+func head(str, sep string) (head string, tail string) {
+	idx := strings.Index(str, sep)
+	if idx < 0 {
+		return str, ""
+	}
+	return str[:idx], str[idx+len(sep):]
+}

+ 1 - 0
binding/json.go

@@ -11,6 +11,7 @@ import (
 	"net/http"
 
 	"encoding/json"
+	//"github.com/gin-gonic/gin/internal/json"
 )
 
 // EnableDecoderUseNumber is used to call the UseNumber method on the JSON

+ 16 - 6
binding/msgpack.go

@@ -5,8 +5,11 @@
 package binding
 
 import (
+	"bytes"
+	"io"
 	"net/http"
-	//"github.com/ugorji/go/codec"
+
+	"github.com/ugorji/go/codec"
 )
 
 type msgpackBinding struct{}
@@ -16,10 +19,17 @@ func (msgpackBinding) Name() string {
 }
 
 func (msgpackBinding) Bind(req *http.Request, obj interface{}) error {
-	/*
-		if err := codec.NewDecoder(req.Body, new(codec.MsgpackHandle)).Decode(&obj); err != nil {
-			return err
-		}
-		// */
+	return decodeMsgPack(req.Body, obj)
+}
+
+func (msgpackBinding) BindBody(body []byte, obj interface{}) error {
+	return decodeMsgPack(bytes.NewReader(body), obj)
+}
+
+func decodeMsgPack(r io.Reader, obj interface{}) error {
+	cdc := new(codec.MsgpackHandle)
+	if err := codec.NewDecoder(r, cdc).Decode(&obj); err != nil {
+		return err
+	}
 	return validate(obj)
 }

+ 18 - 17
binding/protobuf.go

@@ -5,8 +5,10 @@
 package binding
 
 import (
+	"io/ioutil"
 	"net/http"
-	//"github.com/golang/protobuf/proto"
+
+	"github.com/golang/protobuf/proto"
 )
 
 type protobufBinding struct{}
@@ -15,21 +17,20 @@ func (protobufBinding) Name() string {
 	return "protobuf"
 }
 
-func (protobufBinding) Bind(req *http.Request, obj interface{}) error {
-	/*
-		buf, err := ioutil.ReadAll(req.Body)
-		if err != nil {
-			return err
-		}
-
-		if err = proto.Unmarshal(buf, obj.(proto.Message)); err != nil {
-			return err
-		}
+func (b protobufBinding) Bind(req *http.Request, obj interface{}) error {
+	buf, err := ioutil.ReadAll(req.Body)
+	if err != nil {
+		return err
+	}
+	return b.BindBody(buf, obj)
+}
 
-		//Here it's same to return validate(obj), but util now we cann't add `binding:""` to the struct
-		//which automatically generate by gen-proto
-		//return nil
-		return validate(obj)
-		// */
-	return validate(obj)
+func (protobufBinding) BindBody(body []byte, obj interface{}) error {
+	if err := proto.Unmarshal(body, obj.(proto.Message)); err != nil {
+		return err
+	}
+	// Here it's same to return validate(obj), but util now we can't add
+	// `binding:""` to the struct which automatically generate by gen-proto
+	return nil
+	// return validate(obj)
 }

+ 1 - 3
binding/query.go

@@ -4,9 +4,7 @@
 
 package binding
 
-import (
-	"net/http"
-)
+import "net/http"
 
 type queryBinding struct{}
 

+ 2 - 2
binding/uri.go

@@ -10,8 +10,8 @@ func (uriBinding) Name() string {
 	return "uri"
 }
 
-func (uriBinding) BindURI(m map[string][]string, obj interface{}) error {
-	if err := mapURI(obj, m); err != nil {
+func (uriBinding) BindUri(m map[string][]string, obj interface{}) error {
+	if err := mapUri(obj, m); err != nil {
 		return err
 	}
 	return validate(obj)

+ 0 - 0
binding/xml.go