package db import ( "database/sql" "fmt" "github.com/google/uuid" ) type User struct { Id string Name string Reference string } type Gift struct { Id string Name string ClaimantId string ClaimantName string Sent bool } func queryForUser(query string, args ...any) (*User, error) { var id string var name string var reference string err := database.QueryRow(query, args...).Scan(&id, &name, &reference) if err == sql.ErrNoRows { return nil, nil } else if err != nil { return nil, err } user := User{ Id: id, Name: name, Reference: reference, } return &user, nil } func GetUserByName(username string) (*User, error) { stmt := "SELECT user.id, user.name, user.reference FROM user WHERE user.name = ?" return queryForUser(stmt, username) } func GetUserByReference(reference string) (*User, error) { stmt := "SELECT user.id, user.name, user.reference FROM user WHERE user.reference = ?" return queryForUser(stmt, reference) } func CreateUser(username string, passHash []byte) (*User, error) { stmt := "INSERT INTO user (name, motto, reference, password_hash) VALUES (?, '', ?, ?)" reference, err := uuid.NewRandom() if err != nil { return nil, err } result, err := database.Exec(stmt, username, reference, passHash) if err != nil { return nil, err } id, err := result.LastInsertId() if err != nil { return nil, err } user := User{ Id: fmt.Sprintf("%d", id), Name: username, } return &user, nil } func (u *User) GetPassHash() ([]byte, error) { stmt := "SELECT user.password_hash FROM user WHERE user.id = ?" var passHash string err := database.QueryRow(stmt, u.Id).Scan(&passHash) if err != nil { return nil, err } return []byte(passHash), nil } func (u *User) GetGifts() ([]Gift, error) { stmt := "SELECT gift.id, gift.name, claimant.id, claimant.name, gift.sent FROM gift JOIN user ON gift.recipient_id = user.id LEFT JOIN user AS claimant ON gift.claimant_id = claimant.id WHERE user.id = ? ORDER BY gift.name DESC" rows, err := database.Query(stmt, u.Id) if err != nil { return nil, err } defer rows.Close() gifts := []Gift{} for rows.Next() { var id string var name string var claimantId string var claimantName string var sent bool rows.Scan(&id, &name, &claimantId, &claimantName, &sent) gift := Gift{ Id: id, Name: name, ClaimantId: claimantId, ClaimantName: claimantName, Sent: sent, } gifts = append(gifts, gift) } err = rows.Err() if err != nil { return nil, err } return gifts, nil } func (u *User) AddGift(name string) error { stmt := "INSERT INTO gift (name, recipient_id, creator_id) VALUES (?, ?, ?)" _, err := database.Exec(stmt, name, u.Id, u.Id) if err != nil { return err } return nil } func (u *User) RemoveGift(id string) error { stmt := "DELETE FROM gift WHERE gift.creator_id = ? AND gift.id = ?" result, err := database.Exec(stmt, u.Id, id) if err != nil { return err } affected, _ := result.RowsAffected() if affected == 0 { return fmt.Errorf("No gift match.") } return nil } func (u *User) executeClaims(tx *sql.Tx, claims, unclaims []string) error { claimStmt := "UPDATE gift SET claimant_id = ? WHERE id = ?" unclaimStmt := "UPDATE gift SET claimant_id = NULL WHERE id = ?" for _, id := range claims { _, err := tx.Exec(claimStmt, u.Id, id) if err != nil { return err } } for _, id := range unclaims { _, err := tx.Exec(unclaimStmt, id) if err != nil { return err } } return nil } func (u *User) ClaimGifts(claims, unclaims []string) error { tx, err := database.Begin() if err != nil { return err } err = u.executeClaims(tx, claims, unclaims) if err != nil { err = tx.Rollback() return err } err = tx.Commit() return err } func (u *User) executeCompletions(tx *sql.Tx, claims []string) error { claimStmt := "UPDATE gift SET sent = 1 WHERE id = ?" for _, id := range claims { _, err := tx.Exec(claimStmt, id) if err != nil { return err } } return nil } func (u *User) CompleteGifts(claims []string) error { tx, err := database.Begin() if err != nil { return err } err = u.executeCompletions(tx, claims) if err != nil { err = tx.Rollback() return err } err = tx.Commit() return err }