104 lines
2.2 KiB
Go
104 lines
2.2 KiB
Go
package lishwist
|
|
|
|
import (
|
|
"fmt"
|
|
|
|
"github.com/google/uuid"
|
|
|
|
"lishwist/core/internal/db"
|
|
"lishwist/core/internal/normalize"
|
|
)
|
|
|
|
type User struct {
|
|
Id string
|
|
NormalName string
|
|
Name string
|
|
Reference string
|
|
IsAdmin bool
|
|
IsLive bool
|
|
}
|
|
|
|
func queryManyUsers(query string, args ...any) ([]User, error) {
|
|
rows, err := db.Connection.Query(query, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
users := []User{}
|
|
for rows.Next() {
|
|
var u User
|
|
err = rows.Scan(&u.Id, &u.NormalName, &u.Name, &u.Reference, &u.IsAdmin, &u.IsLive)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
users = append(users, u)
|
|
}
|
|
err = rows.Err()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return users, nil
|
|
}
|
|
|
|
func queryOneUser(query string, args ...any) (*User, error) {
|
|
users, err := queryManyUsers(query, args...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(users) < 1 {
|
|
return nil, nil
|
|
}
|
|
return &users[0], nil
|
|
}
|
|
|
|
func (u *User) GetAdmin() *Admin {
|
|
if u.IsAdmin {
|
|
return &Admin{u}
|
|
} else {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func getUserByName(username string) (*User, error) {
|
|
username = normalize.Name(username)
|
|
stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM v_user WHERE name = ?"
|
|
return queryOneUser(stmt, username)
|
|
}
|
|
|
|
func createUser(name string, passHash []byte) (*User, error) {
|
|
username := normalize.Name(name)
|
|
stmt := "INSERT INTO user (name, display_name, reference, password_hash) VALUES (?, ?, ?, ?)"
|
|
reference, err := uuid.NewRandom()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
result, err := db.Connection.Exec(stmt, username, name, reference, passHash)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
id, err := result.LastInsertId()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
user := User{
|
|
Id: fmt.Sprintf("%d", id),
|
|
Name: name,
|
|
}
|
|
return &user, nil
|
|
}
|
|
|
|
func (u *User) getPassHash() ([]byte, error) {
|
|
stmt := "SELECT password_hash FROM v_user WHERE id = ?"
|
|
var passHash string
|
|
err := db.Connection.QueryRow(stmt, u.Id).Scan(&passHash)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return []byte(passHash), nil
|
|
}
|
|
|
|
func getUserByReference(reference string) (*User, error) {
|
|
stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM v_user WHERE reference = ?"
|
|
return queryOneUser(stmt, reference)
|
|
}
|