lishwist/db/user.go

259 lines
5.3 KiB
Go

package db
import (
"database/sql"
"fmt"
"github.com/google/uuid"
)
type User struct {
Id string
Name string
Reference string
}
type Gift struct {
Id string
Name string
ClaimantId string
ClaimantName string
Sent bool
}
func queryForUser(query string, args ...any) (*User, error) {
var id string
var name string
var reference string
err := database.QueryRow(query, args...).Scan(&id, &name, &reference)
if err == sql.ErrNoRows {
return nil, nil
} else if err != nil {
return nil, err
}
user := User{
Id: id,
Name: name,
Reference: reference,
}
return &user, nil
}
func GetUserByName(username string) (*User, error) {
stmt := "SELECT user.id, user.name, user.reference FROM user WHERE user.name = ?"
return queryForUser(stmt, username)
}
func GetUserByReference(reference string) (*User, error) {
stmt := "SELECT user.id, user.name, user.reference FROM user WHERE user.reference = ?"
return queryForUser(stmt, reference)
}
func CreateUser(username string, passHash []byte) (*User, error) {
stmt := "INSERT INTO user (name, motto, reference, password_hash) VALUES (?, '', ?, ?)"
reference, err := uuid.NewRandom()
if err != nil {
return nil, err
}
result, err := database.Exec(stmt, username, reference, passHash)
if err != nil {
return nil, err
}
id, err := result.LastInsertId()
if err != nil {
return nil, err
}
user := User{
Id: fmt.Sprintf("%d", id),
Name: username,
}
return &user, nil
}
func (u *User) GetPassHash() ([]byte, error) {
stmt := "SELECT user.password_hash FROM user WHERE user.id = ?"
var passHash string
err := database.QueryRow(stmt, u.Id).Scan(&passHash)
if err != nil {
return nil, err
}
return []byte(passHash), nil
}
func (u *User) GetGifts() ([]Gift, error) {
stmt := "SELECT gift.id, gift.name, claimant.id, claimant.name, gift.sent FROM gift JOIN user ON gift.recipient_id = user.id LEFT JOIN user AS claimant ON gift.claimant_id = claimant.id WHERE user.id = ? ORDER BY gift.name DESC"
rows, err := database.Query(stmt, u.Id)
if err != nil {
return nil, err
}
defer rows.Close()
gifts := []Gift{}
for rows.Next() {
var id string
var name string
var claimantId string
var claimantName string
var sent bool
rows.Scan(&id, &name, &claimantId, &claimantName, &sent)
gift := Gift{
Id: id,
Name: name,
ClaimantId: claimantId,
ClaimantName: claimantName,
Sent: sent,
}
gifts = append(gifts, gift)
}
err = rows.Err()
if err != nil {
return nil, err
}
return gifts, nil
}
func (u *User) AddGift(name string) error {
stmt := "INSERT INTO gift (name, recipient_id, creator_id) VALUES (?, ?, ?)"
_, err := database.Exec(stmt, name, u.Id, u.Id)
if err != nil {
return err
}
return nil
}
func (u *User) deleteGifts(tx *sql.Tx, ids []string) error {
stmt := "DELETE FROM gift WHERE gift.creator_id = ? AND gift.id = ?"
for _, id := range ids {
r, err := tx.Exec(stmt, u.Id, id)
if err != nil {
return err
}
rE, err := r.RowsAffected()
if err != nil {
return err
}
if rE < 1 {
return fmt.Errorf("Gift deletion failed for '%s'", id)
}
}
return nil
}
func (u *User) RemoveGifts(ids ...string) error {
if len(ids) < 1 {
return fmt.Errorf("Attempt to remove zero gifts")
}
tx, err := database.Begin()
if err != nil {
return err
}
err = u.deleteGifts(tx, ids)
if err != nil {
rollBackErr := tx.Rollback()
if rollBackErr != nil {
return err
}
return err
}
err = tx.Commit()
return err
}
func (u *User) executeClaims(tx *sql.Tx, claims, unclaims []string) error {
claimStmt := "UPDATE gift SET claimant_id = ? WHERE id = ?"
unclaimStmt := "UPDATE gift SET claimant_id = NULL WHERE id = ?"
for _, id := range claims {
r, err := tx.Exec(claimStmt, u.Id, id)
if err != nil {
return err
}
rE, err := r.RowsAffected()
if err != nil {
return err
}
if rE < 1 {
return fmt.Errorf("Gift claim failed for '%s'", id)
}
}
for _, id := range unclaims {
r, err := tx.Exec(unclaimStmt, id)
if err != nil {
return err
}
rE, err := r.RowsAffected()
if err != nil {
return err
}
if rE < 1 {
return fmt.Errorf("Gift unclaim failed for '%s'", id)
}
}
return nil
}
func (u *User) ClaimGifts(claims, unclaims []string) error {
if len(claims) < 1 && len(unclaims) < 1 {
return fmt.Errorf("Attempt to claim/unclaim zero gifts")
}
tx, err := database.Begin()
if err != nil {
return err
}
err = u.executeClaims(tx, claims, unclaims)
if err != nil {
rollBackErr := tx.Rollback()
if rollBackErr != nil {
return err
}
return err
}
err = tx.Commit()
return err
}
func (u *User) executeCompletions(tx *sql.Tx, claims []string) error {
claimStmt := "UPDATE gift SET sent = 1 WHERE id = ?"
for _, id := range claims {
r, err := tx.Exec(claimStmt, id)
if err != nil {
return err
}
rE, err := r.RowsAffected()
if err != nil {
return err
}
if rE < 1 {
return fmt.Errorf("Gift completion failed for '%s'", id)
}
}
return nil
}
func (u *User) CompleteGifts(claims []string) error {
if len(claims) < 1 {
return fmt.Errorf("Attempt to complete zero gifts")
}
tx, err := database.Begin()
if err != nil {
return err
}
err = u.executeCompletions(tx, claims)
if err != nil {
rollBackErr := tx.Rollback()
if rollBackErr != nil {
return err
}
return err
}
err = tx.Commit()
return err
}