lishwist/core/user.go

203 lines
5.5 KiB
Go

package lishwist
import (
"fmt"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
"lishwist/core/internal/db"
"lishwist/core/internal/normalize"
)
type User struct {
Id string
NormalName string
Name string
Reference string
IsAdmin bool
IsLive bool
PasswordFromAdmin bool
}
func queryManyUsers(query string, args ...any) ([]User, error) {
rows, err := db.Connection.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
users := []User{}
for rows.Next() {
var u User
err = rows.Scan(&u.Id, &u.NormalName, &u.Name, &u.Reference, &u.IsAdmin, &u.IsLive, &u.PasswordFromAdmin)
if err != nil {
return nil, err
}
users = append(users, u)
}
err = rows.Err()
if err != nil {
return nil, err
}
return users, nil
}
func queryOneUser(query string, args ...any) (*User, error) {
users, err := queryManyUsers(query, args...)
if err != nil {
return nil, err
}
if len(users) < 1 {
return nil, nil
}
return &users[0], nil
}
func getUserByName(username string) (*User, error) {
username = normalize.Name(username)
stmt := "SELECT id, name, display_name, reference, is_admin, is_live, password_from_admin FROM v_user WHERE name = ?"
return queryOneUser(stmt, username)
}
func createUser(name string, passHash []byte, isAdmin bool) (*User, error) {
username := normalize.Name(name)
stmt := "INSERT INTO user (name, display_name, reference, password_hash, is_admin) VALUES (?, ?, ?, ?, ?)"
reference, err := uuid.NewRandom()
if err != nil {
return nil, fmt.Errorf("Failed to generate reference: %w", err)
}
result, err := db.Connection.Exec(stmt, username, name, reference, passHash, isAdmin)
if err != nil {
return nil, fmt.Errorf("Failed to execute query: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return nil, fmt.Errorf("Failed to get last insert id: %w", err)
}
user := User{
Id: fmt.Sprintf("%d", id),
Name: name,
}
recordEventCreateUser(user.Id, user.Id)
return &user, nil
}
func (u *User) getPassHash() ([]byte, error) {
stmt := "SELECT password_hash FROM v_user WHERE id = ?"
var passHash string
err := db.Connection.QueryRow(stmt, u.Id).Scan(&passHash)
if err != nil {
return nil, err
}
return []byte(passHash), nil
}
func getUserByReference(reference string) (*User, error) {
stmt := "SELECT id, name, display_name, reference, is_admin, is_live, password_from_admin FROM v_user WHERE reference = ?"
return queryOneUser(stmt, reference)
}
func getUserById(id string) (*User, error) {
stmt := "SELECT id, name, display_name, reference, is_admin, is_live, password_from_admin FROM v_user WHERE id = ?"
return queryOneUser(stmt, id)
}
func hasUsers() (bool, error) {
stmt := "SELECT COUNT(id) FROM v_user LIMIT 1"
var userCount uint
err := db.Connection.QueryRow(stmt).Scan(&userCount)
if err != nil {
return false, err
}
return userCount > 0, nil
}
func (*Admin) ListUsers() ([]User, error) {
stmt := "SELECT id, name, display_name, reference, is_admin, is_live, password_from_admin FROM user"
return queryManyUsers(stmt)
}
func (*Admin) GetUser(id string) (*User, error) {
return getUserById(id)
}
func GetUserByReference(reference string) (*User, error) {
return getUserByReference(reference)
}
func (u *User) GetTodo() ([]Wish, error) {
stmt := "SELECT wish.id, wish.name, wish.sent, recipient.display_name, recipient.reference FROM wish JOIN v_user AS user ON wish.claimant_id = user.id JOIN v_user AS recipient ON wish.recipient_id = recipient.id WHERE user.id = ? ORDER BY wish.sent ASC, wish.name"
rows, err := db.Connection.Query(stmt, u.Id)
if err != nil {
return nil, err
}
defer rows.Close()
wishes := []Wish{}
for rows.Next() {
var id string
var name string
var sent bool
var recipientName string
var recipientRef string
_ = rows.Scan(&id, &name, &sent, &recipientName, &recipientRef)
wish := Wish{
Id: id,
Name: name,
Sent: sent,
RecipientName: recipientName,
RecipientRef: recipientRef,
}
wishes = append(wishes, wish)
}
err = rows.Err()
if err != nil {
return nil, err
}
return wishes, nil
}
func (u *Admin) UserSetLive(userReference string, setting bool) error {
query := "UPDATE user SET is_live = ? WHERE reference = ?"
_, err := db.Connection.Exec(query, setting, userReference)
if err != nil {
return err
}
return err
}
func (u *Admin) RenameUser(userReference string, displayName string) error {
name := normalize.Name(displayName)
query := "UPDATE user SET name = ?, display_name = ? WHERE reference = ?"
_, err := db.Connection.Exec(query, name, displayName, userReference)
if err != nil {
return err
}
return err
}
func (u *Admin) SetUserPassword(userReference string, newPassword string) error {
hashedPasswordBytes, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.MinCost)
if err != nil {
return fmt.Errorf("Failed to hash password: %w", err)
}
query := "UPDATE user SET password_hash = ?, password_from_admin = 1 WHERE reference = ?"
_, err = db.Connection.Exec(query, hashedPasswordBytes, userReference)
if err != nil {
return err
}
return err
}
func (u *User) SetPassword(newPassword string) error {
hashedPasswordBytes, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.MinCost)
if err != nil {
return fmt.Errorf("Failed to hash password: %w", err)
}
query := "UPDATE user SET password_hash = ?, password_from_admin = 0 WHERE id = ?"
_, err = db.Connection.Exec(query, hashedPasswordBytes, u.Id)
if err != nil {
return err
}
return err
}