package service import ( "backend/session" "crypto/md5" "database/sql" "encoding/json" "errors" "fmt" "github.com/gin-gonic/gin" "io/ioutil" "net/http" ) // User Management type Identity struct { UserName string `json:username` Password string `json:password` } type UserManage struct { Id Identity DB *sql.DB ss *session.SessionManager } func NewUserManage(db *sql.DB, ss *session.SessionManager) *UserManage { _, err := db.Exec(`create table if not exists user(username varchar(20) unique, password varchar(33))`) if err != nil { panic(err) } return &UserManage{ Id: Identity{}, DB: db, ss: ss, } } func (u *UserManage) loginImpl(c *gin.Context) error { epass := fmt.Sprintf("%x", md5.Sum([]byte(u.Id.Password))) rows, err := u.DB.Query("select count(username) from user") if err != nil { return err } count := -1 for rows.Next() { rows.Scan(&count) } // first time login if count == 0 { tx, err := u.DB.Begin() if err != nil { return err } _, err = tx.Exec("insert into user(username, password) values(?, ?)", u.Id.UserName, epass) if err != nil { tx.Rollback() } else { tx.Commit() } return err } // validate rows, err = u.DB.Query("select username from user where username = ? and password = ?", u.Id.UserName, epass) if err != nil { return err } count = 0 for rows.Next() { count += 1 } if count == 0 { return errors.New("invalid id or pass") } return nil } func (u *UserManage) login(c *gin.Context) { ss := u.ss.StartSession(c.Writer, c.Request) status := ss.Get(u.ss.SessionKey()) if status != u.ss.SessionVal() { err := c.BindJSON(&u.Id) if err != nil { newError(c, -1, err.Error()) return } err = u.loginImpl(c) if err == nil { ss.Set(u.ss.SessionKey(), u.ss.SessionVal()) ss.Set(ss.Id(), u.Id.UserName) newError(c, 0, "ok") } else { newError(c, -1, err.Error()) } } else { newError(c, 0, "ok") } } func (u *UserManage) changePassImpl(c *gin.Context) error { data, err := ioutil.ReadAll(c.Request.Body) if err != nil { return err } p := make(map[string]interface{}) err = json.Unmarshal(data, &p) if err != nil { return err } ss := u.ss.StartSession(c.Writer, c.Request) username := ss.Get(ss.Id()) password := p["password"].(string) epass := fmt.Sprintf("%x", md5.Sum([]byte(password))) stmt, err := u.DB.Prepare("update user set password = ? where username = ?") if err != nil { return err } tx, err := u.DB.Begin() if err != nil { return err } _, err = tx.Stmt(stmt).Exec(epass, username) if err != nil { err = tx.Rollback() } else { err = tx.Commit() } return err } func (u *UserManage) changePass(c *gin.Context) { err := u.changePassImpl(c) if err != nil { newError(c, -1, err.Error()) } else { u.ss.Clear() newError(c, 0, "ok") } } func (u *UserManage) logout(c *gin.Context) { u.ss.SessionDestroy(c.Writer, c.Request) newError(c, 0, "ok") } func (u *UserManage) Serve(c *gin.Context) { if c.Request.Method == http.MethodPost { u.login(c) } else if c.Request.Method == http.MethodPut { u.changePass(c) } else if c.Request.Method == http.MethodDelete { u.logout(c) } else { newError(c, -1, "404 not found") } } func (u *UserManage) Close() { _ = u.DB.Close() }