package db

import (
	"database/sql"
	"errors"
	"fmt"
	"sync"
	"time"

	// PostgreSQL
	_ "github.com/lib/pq"
	// MySQL
	_ "github.com/go-sql-driver/mysql"
	"github.com/jmoiron/sqlx"
)

var (
	config Config
	db     *sqlx.DB
	err    error
	once   sync.Once
)

// DB define
type DB struct {
	Driver string
	conn   *sqlx.DB
	tx     *sqlx.Tx
}

// SetConfig set
func SetConfig(cfg Config) {
	config.Driver = cfg.Driver
	config.DNS = cfg.DNS
	config.MaxOpenConns = cfg.MaxOpenConns
	config.MaxIdle = cfg.MaxIdle
	config.MaxLifetime = cfg.MaxLifetime * time.Second
}

// New new DB object
func New() *DB {
	return &DB{Driver: config.Driver}
}

// Release free db connect
func Release() {
	if db != nil {
		db.Close()
	}
}

// NewConfig new DB dynamic object
func NewConfig(config Config) (dbx *DB, err error) {
	dbx = &DB{}

	dbx.Driver = config.Driver
	dbx.conn, err = sqlx.Connect(config.Driver, config.DNS)
	if err == nil {
		dbx.conn.SetMaxOpenConns(config.MaxOpenConns)
		dbx.conn.SetMaxIdleConns(config.MaxIdle)
		dbx.conn.SetConnMaxLifetime(config.MaxLifetime)
		dbx.conn.Ping()
	}

	return
}

// ReleaseConfig free db connect
func ReleaseConfig(dbx *DB) {
	if dbx.conn != nil {
		dbx.conn.Close()
	}
}

func connect() (dbx *sqlx.DB, err error) {
	once.Do(func() {
		db, err = sqlx.Connect(config.Driver, config.DNS)
		if err == nil {
			db.DB.SetMaxOpenConns(config.MaxOpenConns)
			db.DB.SetMaxIdleConns(config.MaxIdle)
			db.DB.SetConnMaxLifetime(config.MaxLifetime)
			db.Ping()
		}
	})
	dbx = db
	return
}

// Connect connect to database
func (d *DB) Connect() (err error) {
	if d.conn != nil {
		return
	}

	d.conn, err = connect()

	return
}

// Close close database connect
func (d *DB) Close() {
	//d.conn.Close()
}

// BeginTrans begin trans
func (d *DB) BeginTrans() (err error) {
	d.conn, err = connect()
	if err != nil {
		return
	}
	d.tx = d.conn.MustBegin()
	return
}

// Commit commit
func (d *DB) Commit() error {
	return d.tx.Commit()
}

// Rollback rollback
func (d *DB) Rollback() error {
	return d.tx.Rollback()
}

// TransExec trans execute
func (d *DB) TransExec(query string, args interface{}) (LastInsertId, RowsAffected int64, err error) {
	if rs, err := d.tx.NamedExec(query, args); err == nil {
		RowsAffected, _ = rs.RowsAffected()
		LastInsertId, _ = rs.LastInsertId()
	}
	return
}

// TransUpdate trans update
func (d *DB) TransUpdate(query string, args interface{}) (reply Reply) {
	var (
		err error
		rs  sql.Result
	)

	if rs, err = d.tx.NamedExec(query, args); err == nil {
		a, _ := rs.RowsAffected()
		reply = ReplyOk(a, 0)
	} else {
		reply = ReplyFaild(ErrException, err, errors.New(`数据执行错误`))
	}
	return
}

// Rows get rows
func (d *DB) Rows(dest interface{}, query string, args interface{}) error {
	err := d.Connect()
	if err != nil {
		return err
	}
	defer d.Close()

	nstmt, err := d.conn.PrepareNamed(query)
	if err != nil {
		return err
	}
	defer nstmt.Close()

	err = nstmt.Select(dest, args)

	return err
}

// Row get row
func (d *DB) Row(dest interface{}, query string, args interface{}) error {
	err := d.Connect()
	if err != nil {
		return err
	}
	defer d.Close()

	nstmt, err := d.conn.PrepareNamed(query)
	if err != nil {
		return err
	}
	defer nstmt.Close()

	err = nstmt.Get(dest, args)

	return err
}

// InsertReply insert and return DbReply
func (d *DB) InsertReply(query string, args interface{}) (reply Reply) {
	var (
		err error
		rs  sql.Result
	)
	err = d.Connect()
	if err != nil {
		reply = ReplyFaild(ErrNotConnect, err, errors.New(`数据库连接错误`))
		return
	}
	defer d.Close()

	if rs, err = d.conn.NamedExec(query, args); err == nil {
		a, _ := rs.RowsAffected()
		n, _ := rs.LastInsertId()
		reply = ReplyOk(a, n)
	} else {
		reply = ReplyFaild(ErrException, err, errors.New(`数据执行错误`))
	}
	return
}

// UpdateReply update/delete and return DbReply
func (d *DB) UpdateReply(query string, args interface{}) (reply Reply) {
	var (
		err error
		rs  sql.Result
	)
	err = d.Connect()
	if err != nil {
		reply = ReplyFaild(ErrNotConnect, err, errors.New(`数据库连接错误`))
		return
	}
	defer d.Close()

	if rs, err = d.conn.NamedExec(query, args); err == nil {
		a, _ := rs.RowsAffected()
		reply = ReplyOk(a, 0)
	} else {
		reply = ReplyFaild(ErrException, err, errors.New(`数据执行错误`))
	}
	return
}

// Insert insert into
func (d *DB) Insert(query string, args interface{}) (LastInsertId, RowsAffected int64, err error) {
	err = d.Connect()
	if err != nil {
		return
	}
	defer d.Close()
	var rs sql.Result
	if rs, err = d.conn.NamedExec(query, args); err == nil {
		LastInsertId, _ = rs.LastInsertId()
		RowsAffected, _ = rs.RowsAffected()
	}
	return
}

// Update update/delete
func (d *DB) Update(query string, args interface{}) (RowsAffected int64, err error) {
	err = d.Connect()
	if err != nil {
		return
	}
	defer d.Close()

	var rs sql.Result
	if rs, err = d.conn.NamedExec(query, args); err == nil {
		RowsAffected, _ = rs.RowsAffected()
	}
	return
}

// Limit MySQL limit
func (d *DB) Limit(page, pagesize int) string {
	// MySQL limit 0, size
	if d.Driver == `mysql` {
		return fmt.Sprintf(" limit %d, %d", (page-1)*pagesize, pagesize)
	}
	// // PostgreSQL limit size offset 0
	return fmt.Sprintf(" limit %d offset %d", pagesize, (page-1)*pagesize)
}