diff --git a/server/api/db/db.go b/server/api/db/db.go index c3df638..b564228 100644 --- a/server/api/db/db.go +++ b/server/api/db/db.go @@ -23,7 +23,7 @@ func Open() error { } func Init() error { - _, err := database.Exec(InitQuery) + _, err := database.Exec(initQuery) if err != nil { return err } diff --git a/server/api/db/gen_init_sql.go b/server/api/db/gen_init_sql.go index 2e9d7e4..0e3da1b 100644 --- a/server/api/db/gen_init_sql.go +++ b/server/api/db/gen_init_sql.go @@ -11,7 +11,7 @@ import ( var initTemplate = template.Must(template.New("").Parse("// Code generated DO NOT EDIT.\n" + "package db\n" + "\n" + - "const InitQuery = `{{.}}`\n", + "const initQuery = `{{.}}`\n", )) func main() { diff --git a/server/api/db/group.go b/server/api/db/group.go index 9ba69d3..615fa4c 100644 --- a/server/api/db/group.go +++ b/server/api/db/group.go @@ -1,7 +1,6 @@ package db import ( - "database/sql" "fmt" "lishwist/normalize" "strconv" @@ -23,52 +22,47 @@ func (g *Group) MemberIndex(userId string) int { return -1 } -func queryForGroup(query string, args ...any) (*Group, error) { - var group Group - err := database.QueryRow(query, args...).Scan(&group.Id, &group.Name, &group.Reference) - if err == sql.ErrNoRows { - return nil, nil - } else if err != nil { - return nil, err - } - members, err := queryForGroupMembers(group.Id) - if err != nil { - return nil, err - } - group.Members = members - return &group, nil -} - -func queryForGroups(query string, args ...any) ([]Group, error) { +func queryManyGroups(query string, args ...any) ([]Group, error) { groups := []Group{} rows, err := database.Query(query, args...) if err != nil { - return groups, fmt.Errorf("Query failed: %w", err) + return nil, fmt.Errorf("Query failed: %w", err) } defer rows.Close() for rows.Next() { var group Group err := rows.Scan(&group.Id, &group.Name, &group.Reference) if err != nil { - return groups, fmt.Errorf("Failed to scan row: %w", err) + return nil, fmt.Errorf("Failed to scan row: %w", err) } - members, err := queryForGroupMembers(group.Id) + members, err := queryManyGroupMembers(group.Id) if err != nil { - return groups, fmt.Errorf("Failed to query for group members: %w", err) + return nil, fmt.Errorf("Failed to query for group members: %w", err) } group.Members = members groups = append(groups, group) } err = rows.Err() if err != nil { - return groups, fmt.Errorf("Rows error: %w", err) + return nil, fmt.Errorf("Rows error: %w", err) } return groups, nil } -func queryForGroupMembers(groupId string) ([]User, error) { +func queryOneGroup(query string, args ...any) (*Group, error) { + groups, err := queryManyGroups(query, args...) + if err != nil { + return nil, err + } + if len(groups) < 1 { + return nil, nil + } + return &groups[0], nil +} + +func queryManyGroupMembers(groupId string) ([]User, error) { query := "SELECT user.id, user.name, user.display_name, user.reference, user.is_admin, user.is_live FROM v_user AS user JOIN group_member ON group_member.user_id = user.id JOIN [group] ON [group].id = group_member.group_id WHERE [group].id = ? ORDER BY group_member.user_id" - members, err := queryForUsers(query, groupId) + members, err := queryManyUsers(query, groupId) if err != nil { return members, fmt.Errorf("Failed to get members: %w", err) } @@ -77,12 +71,12 @@ func queryForGroupMembers(groupId string) ([]User, error) { func GetGroupByReference(reference string) (*Group, error) { query := "SELECT [group].id, [group].name, [group].reference FROM [group] WHERE [group].reference = ?" - return queryForGroup(query, reference) + return queryOneGroup(query, reference) } func GetAllGroups() ([]Group, error) { query := "SELECT id, name, reference FROM [group];" - return queryForGroups(query) + return queryManyGroups(query) } func CreateGroup(name string, reference string) (*Group, error) { diff --git a/server/api/db/user.go b/server/api/db/user.go index ea40c88..34d2ba8 100644 --- a/server/api/db/user.go +++ b/server/api/db/user.go @@ -30,18 +30,7 @@ type Gift struct { CreatorName string `json:",omitempty"` } -func queryForUser(query string, args ...any) (*User, error) { - var u User - err := database.QueryRow(query, args...).Scan(&u.Id, &u.NormalName, &u.Name, &u.Reference, &u.IsAdmin, &u.IsLive) - if err == sql.ErrNoRows { - return nil, nil - } else if err != nil { - return nil, err - } - return &u, nil -} - -func queryForUsers(query string, args ...any) ([]User, error) { +func queryManyUsers(query string, args ...any) ([]User, error) { rows, err := database.Query(query, args...) if err != nil { return nil, err @@ -63,30 +52,41 @@ func queryForUsers(query string, args ...any) ([]User, error) { return users, nil } +func queryOneUser(query string, args ...any) (*User, error) { + users, err := queryManyUsers(query, args...) + if err != nil { + return nil, err + } + if len(users) < 1 { + return nil, nil + } + return &users[0], nil +} + func GetAllUsers() ([]User, error) { stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM user" - return queryForUsers(stmt) + return queryManyUsers(stmt) } func GetUser(id string) (*User, error) { stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM v_user WHERE id = ?" - return queryForUser(stmt, id) + return queryOneUser(stmt, id) } func GetUserByName(username string) (*User, error) { username = normalize.Name(username) stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM v_user WHERE name = ?" - return queryForUser(stmt, username) + return queryOneUser(stmt, username) } func GetUserByReference(reference string) (*User, error) { stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM v_user WHERE reference = ?" - return queryForUser(stmt, reference) + return queryOneUser(stmt, reference) } func GetAnyUserByReference(reference string) (*User, error) { stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM user WHERE reference = ?" - return queryForUser(stmt, reference) + return queryOneUser(stmt, reference) } func (u *User) SetLive(setting bool) error { @@ -410,10 +410,10 @@ func (u *User) AddGiftToUser(otherUserReference string, giftName string) error { func (u *User) GetGroups() ([]Group, error) { stmt := "SELECT [group].id, [group].name, [group].reference FROM [group] JOIN group_member ON group_member.group_id = [group].id JOIN v_user AS user ON user.id = group_member.user_id WHERE user.id = ?" - return queryForGroups(stmt, u.Id) + return queryManyGroups(stmt, u.Id) } func (u *User) GetGroupByReference(reference string) (*Group, error) { stmt := "SELECT [group].id, [group].name, [group].reference FROM [group] JOIN group_member ON [group].id == group_member.group_id WHERE [group].reference = ? AND group_member.user_id = ?" - return queryForGroup(stmt, reference, u.Id) + return queryOneGroup(stmt, reference, u.Id) }