lishwist/server/db/user.go

431 lines
10 KiB
Go

package db
import (
"database/sql"
"fmt"
"github.com/google/uuid"
)
type User struct {
Id string
Name string
Reference string
}
type Gift struct {
Id string
Name string
ClaimantId string
ClaimantName string
Sent bool
RecipientId string
RecipientName string
RecipientRef string
CreatorId string
CreatorName string
}
func queryForUser(query string, args ...any) (*User, error) {
var id string
var name string
var reference string
err := database.QueryRow(query, args...).Scan(&id, &name, &reference)
if err == sql.ErrNoRows {
return nil, nil
} else if err != nil {
return nil, err
}
user := User{
Id: id,
Name: name,
Reference: reference,
}
return &user, nil
}
func GetUserByName(username string) (*User, error) {
stmt := "SELECT user.id, user.name, user.reference FROM user WHERE user.name = ?"
return queryForUser(stmt, username)
}
func GetUserByReference(reference string) (*User, error) {
stmt := "SELECT user.id, user.name, user.reference FROM user WHERE user.reference = ?"
return queryForUser(stmt, reference)
}
func CreateUser(username string, passHash []byte) (*User, error) {
stmt := "INSERT INTO user (name, motto, reference, password_hash) VALUES (?, '', ?, ?)"
reference, err := uuid.NewRandom()
if err != nil {
return nil, err
}
result, err := database.Exec(stmt, username, reference, passHash)
if err != nil {
return nil, err
}
id, err := result.LastInsertId()
if err != nil {
return nil, err
}
user := User{
Id: fmt.Sprintf("%d", id),
Name: username,
}
return &user, nil
}
func (u *User) GetPassHash() ([]byte, error) {
stmt := "SELECT user.password_hash FROM user WHERE user.id = ?"
var passHash string
err := database.QueryRow(stmt, u.Id).Scan(&passHash)
if err != nil {
return nil, err
}
return []byte(passHash), nil
}
func (u *User) CountGifts() (int, error) {
stmt := "SELECT COUNT(gift.id) AS gift_count FROM gift JOIN user ON gift.recipient_id = user.id LEFT JOIN user AS claimant ON gift.claimant_id = claimant.id WHERE gift.creator_id = user.id AND user.id = ?"
var giftCount int
err := database.QueryRow(stmt, u.Id).Scan(&giftCount)
if err != nil {
return 0, err
}
return giftCount, nil
}
func (u *User) GetGifts() ([]Gift, error) {
stmt := "SELECT gift.id, gift.name, claimant.id, claimant.name, gift.sent FROM gift JOIN user ON gift.recipient_id = user.id LEFT JOIN user AS claimant ON gift.claimant_id = claimant.id WHERE gift.creator_id = user.id AND user.id = ?"
rows, err := database.Query(stmt, u.Id)
if err != nil {
return nil, err
}
defer rows.Close()
gifts := []Gift{}
for rows.Next() {
var id string
var name string
var claimantId sql.NullString
var claimantName sql.NullString
var sent bool
err = rows.Scan(&id, &name, &claimantId, &claimantName, &sent)
if err != nil {
return nil, err
}
gift := Gift{
Id: id,
Name: name,
ClaimantId: claimantId.String,
ClaimantName: claimantName.String,
Sent: sent,
}
gifts = append(gifts, gift)
}
err = rows.Err()
if err != nil {
return nil, err
}
return gifts, nil
}
func (u *User) GetOtherUserGifts(otherUserReference string) ([]Gift, error) {
otherUser, err := GetUserByReference(otherUserReference)
if err != nil {
return nil, err
}
if otherUser.Id == u.Id {
return nil, fmt.Errorf("Not allowed to view own foreign wishlist")
}
stmt := "SELECT gift.id, gift.name, claimant.id, claimant.name, gift.sent, gift.creator_id, creator.name, gift.recipient_id FROM gift JOIN user ON gift.recipient_id = user.id LEFT JOIN user AS claimant ON gift.claimant_id = claimant.id LEFT JOIN user AS creator ON gift.creator_id = creator.id WHERE user.id = ?"
rows, err := database.Query(stmt, otherUser.Id)
if err != nil {
return nil, err
}
defer rows.Close()
gifts := []Gift{}
for rows.Next() {
var id string
var name string
var claimantId sql.NullString
var claimantName sql.NullString
var sent bool
var creatorId string
var creatorName string
var recipientId string
err = rows.Scan(&id, &name, &claimantId, &claimantName, &sent, &creatorId, &creatorName, &recipientId)
if err != nil {
return nil, err
}
gift := Gift{
Id: id,
Name: name,
ClaimantId: claimantId.String,
ClaimantName: claimantName.String,
Sent: sent,
CreatorId: creatorId,
CreatorName: creatorName,
RecipientId: recipientId,
}
gifts = append(gifts, gift)
}
err = rows.Err()
if err != nil {
return nil, err
}
return gifts, nil
}
func (u *User) GetTodo() ([]Gift, error) {
stmt := "SELECT gift.id, gift.name, gift.sent, recipient.name, recipient.reference FROM gift JOIN user ON gift.claimant_id = user.id JOIN user AS recipient ON gift.recipient_id = recipient.id WHERE user.id = ? ORDER BY gift.sent ASC, gift.name"
rows, err := database.Query(stmt, u.Id)
if err != nil {
return nil, err
}
defer rows.Close()
gifts := []Gift{}
for rows.Next() {
var id string
var name string
var sent bool
var recipientName string
var recipientRef string
rows.Scan(&id, &name, &sent, &recipientName, &recipientRef)
gift := Gift{
Id: id,
Name: name,
Sent: sent,
RecipientName: recipientName,
RecipientRef: recipientRef,
}
gifts = append(gifts, gift)
}
err = rows.Err()
if err != nil {
return nil, err
}
return gifts, nil
}
func (u *User) AddGift(name string) error {
stmt := "INSERT INTO gift (name, recipient_id, creator_id) VALUES (?, ?, ?)"
_, err := database.Exec(stmt, name, u.Id, u.Id)
if err != nil {
return err
}
return nil
}
func (u *User) deleteGifts(tx *sql.Tx, ids []string) error {
stmt := "DELETE FROM gift WHERE gift.creator_id = ? AND gift.id = ?"
for _, id := range ids {
r, err := tx.Exec(stmt, u.Id, id)
if err != nil {
return err
}
rE, err := r.RowsAffected()
if err != nil {
return err
}
if rE < 1 {
return fmt.Errorf("Gift deletion failed for '%s'", id)
}
}
return nil
}
func (u *User) RemoveGifts(ids ...string) error {
if len(ids) < 1 {
return fmt.Errorf("Attempt to remove zero gifts")
}
tx, err := database.Begin()
if err != nil {
return err
}
err = u.deleteGifts(tx, ids)
if err != nil {
rollBackErr := tx.Rollback()
if rollBackErr != nil {
return err
}
return err
}
err = tx.Commit()
return err
}
func (u *User) executeClaims(tx *sql.Tx, claims, unclaims []string) error {
claimStmt := "UPDATE gift SET claimant_id = ? WHERE id = ?"
unclaimStmt := "UPDATE gift SET claimant_id = NULL WHERE id = ?"
for _, id := range claims {
r, err := tx.Exec(claimStmt, u.Id, id)
if err != nil {
return err
}
rE, err := r.RowsAffected()
if err != nil {
return err
}
if rE < 1 {
return fmt.Errorf("Gift claim failed for '%s'", id)
}
}
for _, id := range unclaims {
r, err := tx.Exec(unclaimStmt, id)
if err != nil {
return err
}
rE, err := r.RowsAffected()
if err != nil {
return err
}
if rE < 1 {
return fmt.Errorf("Gift unclaim failed for '%s'", id)
}
}
return nil
}
func (u *User) ClaimGifts(claims, unclaims []string) error {
if len(claims) < 1 && len(unclaims) < 1 {
return fmt.Errorf("Attempt to claim/unclaim zero gifts")
}
tx, err := database.Begin()
if err != nil {
return err
}
err = u.executeClaims(tx, claims, unclaims)
if err != nil {
rollBackErr := tx.Rollback()
if rollBackErr != nil {
return err
}
return err
}
err = tx.Commit()
return err
}
func (u *User) executeCompletions(tx *sql.Tx, claims []string) error {
claimStmt := "UPDATE gift SET sent = 1 WHERE id = ?"
for _, id := range claims {
r, err := tx.Exec(claimStmt, id)
if err != nil {
return err
}
rE, err := r.RowsAffected()
if err != nil {
return err
}
if rE < 1 {
return fmt.Errorf("Gift completion failed for '%s'", id)
}
}
return nil
}
func (u *User) CompleteGifts(claims []string) error {
if len(claims) < 1 {
return fmt.Errorf("Attempt to complete zero gifts")
}
tx, err := database.Begin()
if err != nil {
return err
}
err = u.executeCompletions(tx, claims)
if err != nil {
rollBackErr := tx.Rollback()
if rollBackErr != nil {
return err
}
return err
}
err = tx.Commit()
return err
}
func (u *User) AddGiftToUser(otherUserReference string, giftName string) error {
otherUser, err := GetUserByReference(otherUserReference)
if err != nil {
return err
}
stmt := "INSERT INTO gift (name, recipient_id, creator_id) VALUES (?, ?, ?)"
_, err = database.Exec(stmt, giftName, otherUser.Id, u.Id)
if err != nil {
return err
}
return nil
}
func (u *User) GetGroups() ([]Group, error) {
stmt := "SELECT [group].id, [group].name, [group].reference FROM [group] JOIN group_member ON group_member.group_id = [group].id JOIN user ON user.id = group_member.user_id WHERE user.id = ?"
rows, err := database.Query(stmt, u.Id)
if err != nil {
return nil, err
}
defer rows.Close()
groups := []Group{}
for rows.Next() {
var id string
var name string
var reference string
err := rows.Scan(&id, &name, &reference)
if err != nil {
return nil, err
}
groups = append(groups, Group{
Id: id,
Name: name,
Reference: reference,
})
}
err = rows.Err()
if err != nil {
return nil, err
}
return groups, nil
}
func (u *User) GetPeers(groupId string) ([]User, error) {
stmt := "SELECT user.id, user.name, user.reference FROM user JOIN group_member ON group_member.user_id = user.id JOIN [group] ON [group].id = group_member.group_id WHERE [group].id = ? AND user.id != ?"
rows, err := database.Query(stmt, groupId, u.Id)
if err != nil {
return nil, err
}
defer rows.Close()
users := []User{}
for rows.Next() {
var id string
var name string
var reference string
err := rows.Scan(&id, &name, &reference)
if err != nil {
return nil, err
}
users = append(users, User{
Id: id,
Name: name,
Reference: reference,
})
}
err = rows.Err()
if err != nil {
return nil, err
}
return users, nil
}
func (u *User) GetGroupByReference(reference string) (*Group, error) {
stmt := "SELECT [group].id, [group].name, [group].reference FROM [group] JOIN group_member ON [group].id == group_member.group_id WHERE [group].reference = ? AND group_member.user_id = ?"
return queryForGroup(stmt, reference, u.Id)
}