package myth

import (
	"compress/gzip"
	"crypto/tls"
	"crypto/x509"
	"encoding/json"
	"encoding/xml"
	"errors"
	"io"
	"io/ioutil"
	"net"
	"net/http"
	"time"
)

const (
	// ReqContentTypeURL application/x-www-form-urlencoded
	ReqContentTypeURL = `application/x-www-form-urlencoded; charset=utf-8`
	// ReqContentTypeJSON application/json
	ReqContentTypeJSON = `application/json; charset=utf-8`
	// ReqContentTypeXML application/xml
	ReqContentTypeXML = `application/xml; charset=utf-8`
	// ReqContentTypeMultipart multipart/form-data
	ReqContentTypeMultipart = `multipart/form-data`

	// RequestTimeOut http request timeout (second)
	RequestTimeOut = 30
)

var (
	reqTimeOut = time.Duration(RequestTimeOut)
)

// GetRealIP get real IP from Request
func GetRealIP(req *http.Request) (ip string) {
	if ips := req.Header["X-Real-Ip"]; ips != nil {
		ip = ips[0]
	}
	return
}

// HTTPMessage HTTP response
type HTTPMessage struct {
	StatusCode int
	Body       []byte
	Header     http.Header
}

// JSON Body to JSON
func (m HTTPMessage) JSON(dest interface{}) (err error) {
	if m.StatusCode != http.StatusOK {
		err = errors.New(`StatusCode not 200`)
		return
	}
	err = json.Unmarshal(m.Body, &dest)
	return
}

// XML Body to XML
func (m HTTPMessage) XML(dest interface{}) (err error) {
	if m.StatusCode != http.StatusOK {
		err = errors.New(`StatusCode not 200`)
		return
	}
	err = xml.Unmarshal(m.Body, &dest)
	return
}

// JSONQuery Body to JSONQuery
func (m HTTPMessage) JSONQuery() (jq *JSONQuery, err error) {
	if m.StatusCode != http.StatusOK {
		err = errors.New(`StatusCode not 200`)
		return
	}
	jq, err = NewJSONQuery(m.Body)
	return
}

func newRequest(method, uri, certPath, keyPath string, header map[string]string, body io.Reader) (res *http.Response, err error) {
	t := &http.Transport{
		Dial: func(netw, addr string) (net.Conn, error) {
			conn, err := net.DialTimeout(netw, addr, time.Second*RequestTimeOut)
			if err != nil {
				return nil, err
			}
			conn.SetDeadline(time.Now().Add(time.Second * reqTimeOut))
			return conn, nil
		},
		ResponseHeaderTimeout: time.Second * reqTimeOut,
	}

	if certPath != "" {
		cert, e := tls.LoadX509KeyPair(certPath, keyPath)
		if e != nil {
			t.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
		} else {
			pool := x509.NewCertPool()
			t.TLSClientConfig = &tls.Config{InsecureSkipVerify: true, Certificates: []tls.Certificate{cert}, RootCAs: pool}
		}
	}
	client := &http.Client{Transport: t}
	var (
		req *http.Request
	)

	if body != nil {
		req, err = http.NewRequest(method, uri, body)
	} else {
		req, err = http.NewRequest(method, uri, nil)
	}

	if err != nil {
		return
	}

	for k := range header {
		req.Header.Add(k, header[k])
	}

	res, err = client.Do(req)
	return
}

func readBody(res *http.Response) (msg HTTPMessage, err error) {
	var (
		body   []byte
		reader io.Reader
	)
	encoding := res.Header.Get("Content-Encoding")
	switch encoding {
	case "gzip":
		reader, err = gzip.NewReader(res.Body)
		if err == nil {
			body, err = ioutil.ReadAll(reader)
		}
	default:
		body, err = ioutil.ReadAll(res.Body)
	}
	if err != nil {
		return
	}

	msg.StatusCode = res.StatusCode
	msg.Header = res.Header
	msg.Body = body
	return
}

// Post HTTP request POST
func Post(uri, certPath, keyPath string, header map[string]string, data io.Reader) (msg HTTPMessage, err error) {
	var res *http.Response
	if res, err = newRequest("POST", uri, certPath, keyPath, header, data); err != nil {
		return
	}
	defer res.Body.Close()
	msg, err = readBody(res)
	return
}

// Get HTTP request GET
func Get(uri, certPath, keyPath string, header map[string]string) (msg HTTPMessage, err error) {
	var res *http.Response
	if res, err = newRequest("GET", uri, certPath, keyPath, header, nil); err != nil {
		return
	}

	defer res.Body.Close()
	msg, err = readBody(res)
	return
}