From 271163a889b4ae4ba40ffc9c14748c74aac80597 Mon Sep 17 00:00:00 2001 From: Teajey <21069848+Teajey@users.noreply.github.com> Date: Thu, 21 Nov 2024 13:01:34 +0900 Subject: [PATCH] feat: groups json interface --- server/db/group.go | 123 ++++++++++++++++++++++++++++++++------- server/db/user.go | 41 +++++-------- server/main.go | 3 + server/routing/error.go | 3 +- server/routing/groups.go | 108 ++++++++++++++++++++++++++++++++++ server/routing/users.go | 2 +- 6 files changed, 228 insertions(+), 52 deletions(-) diff --git a/server/db/group.go b/server/db/group.go index fc17647..de4d2c3 100644 --- a/server/db/group.go +++ b/server/db/group.go @@ -1,38 +1,58 @@ package db -import "database/sql" +import ( + "database/sql" + "fmt" + "strconv" +) type Group struct { Id string Name string Reference string + Users []User +} + +func (g *Group) MemberIndex(userId string) int { + for i, u := range g.Users { + if u.Id == userId { + return i + } + } + return -1 } func queryForGroup(query string, args ...any) (*Group, error) { - var id string - var name string - var reference string - err := database.QueryRow(query, args...).Scan(&id, &name, &reference) + var group Group + err := database.QueryRow(query, args...).Scan(&group.Id, &group.Name, &group.Reference) if err == sql.ErrNoRows { return nil, nil } else if err != nil { return nil, err } - group := Group{ - Id: id, - Name: name, - Reference: reference, - } return &group, nil } func GetGroupByReference(reference string) (*Group, error) { - stmt := "SELECT [group].id, [group].name, [group].reference FROM [group] WHERE [group].reference = ?" - return queryForGroup(stmt, reference) + query := "SELECT [group].id, [group].name, [group].reference FROM [group] WHERE [group].reference = ?" + return queryForGroup(query, reference) +} + +func GetGroupByReferenceWithMembers(reference string) (*Group, error) { + group, err := GetGroupByReference(reference) + if err != nil { + return group, err + } + members, err := group.GetMembers() + if err != nil { + return group, fmt.Errorf("Failed to get members: %w\n", err) + } + group.Users = members + return group, err } func (g *Group) GetMembers() ([]User, error) { - stmt := "SELECT user.id, user.name, user.reference FROM user JOIN group_member ON group_member.user_id = user.id JOIN [group] ON [group].id = group_member.group_id WHERE [group].id = ?" + stmt := "SELECT user.id, user.name, user.reference, user.is_admin FROM user JOIN group_member ON group_member.user_id = user.id JOIN [group] ON [group].id = group_member.group_id WHERE [group].id = ?" rows, err := database.Query(stmt, g.Id) users := []User{} if err != nil { @@ -40,22 +60,81 @@ func (g *Group) GetMembers() ([]User, error) { } defer rows.Close() for rows.Next() { - var id string - var name string - var reference string - err := rows.Scan(&id, &name, &reference) + var user User + err := rows.Scan(&user.Id, &user.Name, &user.Reference, &user.IsAdmin) if err != nil { return users, err } - users = append(users, User{ - Id: id, - Name: name, - Reference: reference, - }) + users = append(users, user) } err = rows.Err() if err != nil { return users, err } + g.Users = users return users, nil } + +func GetAllGroups() ([]Group, error) { + query := "SELECT id, name, reference FROM [group];" + groups := []Group{} + rows, err := database.Query(query) + if err != nil { + return groups, err + } + defer rows.Close() + for rows.Next() { + var group Group + err := rows.Scan(&group.Id, &group.Name, &group.Reference) + if err != nil { + return groups, err + } + users, err := group.GetMembers() + if err != nil { + return groups, fmt.Errorf("Failed to get a member: %w", err) + } + group.Users = users + groups = append(groups, group) + } + err = rows.Err() + if err != nil { + return groups, err + } + return groups, nil +} + +func CreateGroup(name string, reference string) (*Group, error) { + stmt := "INSERT INTO [group] (name, reference) VALUES (?, ?)" + result, err := database.Exec(stmt, name, reference) + if err != nil { + return nil, err + } + id, err := result.LastInsertId() + if err != nil { + return nil, err + } + group := Group{ + Id: strconv.FormatInt(id, 10), + Name: name, + Reference: reference, + } + return &group, nil +} + +func (g *Group) AddUser(userId string) error { + stmt := "INSERT INTO group_member (group_id, user_id) VALUES (?, ?)" + _, err := database.Exec(stmt, g.Id, userId) + if err != nil { + return err + } + return nil +} + +func (g *Group) RemoveUser(userId string) error { + stmt := "DELETE FROM group_member WHERE group_id = ? AND user_id = ?" + _, err := database.Exec(stmt, g.Id, userId) + if err != nil { + return err + } + return nil +} diff --git a/server/db/user.go b/server/db/user.go index 74bbb21..18525ee 100644 --- a/server/db/user.go +++ b/server/db/user.go @@ -38,9 +38,8 @@ func queryForUser(query string, args ...any) (*User, error) { return &u, nil } -func GetAllUsers() ([]User, error) { - stmt := "SELECT user.id, user.name, user.reference, user.is_admin FROM user" - rows, err := database.Query(stmt) +func queryForUsers(query string, args ...any) ([]User, error) { + rows, err := database.Query(query, args...) if err != nil { return nil, err } @@ -61,6 +60,16 @@ func GetAllUsers() ([]User, error) { return users, nil } +func GetAllUsers() ([]User, error) { + stmt := "SELECT user.id, user.name, user.reference, user.is_admin FROM user" + return queryForUsers(stmt) +} + +func GetUser(id string) (*User, error) { + stmt := "SELECT user.id, user.name, user.reference, user.is_admin FROM user WHERE user.id = ?" + return queryForUser(stmt, id) +} + func GetUserByName(username string) (*User, error) { stmt := "SELECT user.id, user.name, user.reference, user.is_admin FROM user WHERE user.name = ?" return queryForUser(stmt, username) @@ -414,31 +423,7 @@ func (u *User) GetGroups() ([]Group, error) { func (u *User) GetPeers(groupId string) ([]User, error) { stmt := "SELECT user.id, user.name, user.reference FROM user JOIN group_member ON group_member.user_id = user.id JOIN [group] ON [group].id = group_member.group_id WHERE [group].id = ? AND user.id != ?" - rows, err := database.Query(stmt, groupId, u.Id) - if err != nil { - return nil, err - } - defer rows.Close() - users := []User{} - for rows.Next() { - var id string - var name string - var reference string - err := rows.Scan(&id, &name, &reference) - if err != nil { - return nil, err - } - users = append(users, User{ - Id: id, - Name: name, - Reference: reference, - }) - } - err = rows.Err() - if err != nil { - return nil, err - } - return users, nil + return queryForUsers(stmt, groupId, u.Id) } func (u *User) GetGroupByReference(reference string) (*Group, error) { diff --git a/server/main.go b/server/main.go index 4319a7b..3a5f343 100644 --- a/server/main.go +++ b/server/main.go @@ -54,6 +54,9 @@ func main() { r.Json.Public.HandleFunc("GET /", routing.NotFoundJson) r.Json.Private.Handle("GET /users", route.ExpectUser(route.UsersJson)) + r.Json.Private.Handle("GET /groups", route.ExpectUser(route.GroupsJson)) + r.Json.Private.Handle("POST /groups/{groupReference}", route.ExpectUser(route.GroupPost)) + r.Json.Private.Handle("GET /groups/{groupReference}", route.ExpectUser(route.Group)) r.Json.Private.HandleFunc("GET /", routing.NotFoundJson) http.Handle("/", r) diff --git a/server/routing/error.go b/server/routing/error.go index 830a103..cc0235e 100644 --- a/server/routing/error.go +++ b/server/routing/error.go @@ -7,7 +7,8 @@ import ( "strings" ) -func writeGeneralError(w http.ResponseWriter, msg string, status int) { +func writeGeneralErrorJson(w http.ResponseWriter, status int, format string, a ...any) { + msg := fmt.Sprintf(format, a...) log.Printf("General error: %s\n", msg) w.WriteHeader(status) escapedMsg := strings.ReplaceAll(msg, `"`, `\"`) diff --git a/server/routing/groups.go b/server/routing/groups.go index c2d3f17..8c9b04f 100644 --- a/server/routing/groups.go +++ b/server/routing/groups.go @@ -1,7 +1,9 @@ package routing import ( + "encoding/json" "net/http" + "slices" "lishwist/db" "lishwist/error" @@ -56,3 +58,109 @@ func (ctx *Context) PublicGroupPage(w http.ResponseWriter, r *http.Request) { } templates.Execute(w, "public_group_page.gotmpl", p) } + +func (ctx *Context) GroupPost(currentUser *db.User, w http.ResponseWriter, r *http.Request) { + if !currentUser.IsAdmin { + NotFoundJson(w, r) + return + } + if err := r.ParseForm(); err != nil { + writeGeneralErrorJson(w, http.StatusBadRequest, "Failed to parse form: "+err.Error()) + return + } + + var group *db.Group + + reference := r.PathValue("groupReference") + name := r.Form.Get("name") + addUsers := r.Form["addUser"] + removeUsers := r.Form["removeUser"] + + if name != "" { + createdGroup, err := db.CreateGroup(name, reference) + if err != nil { + writeGeneralErrorJson(w, http.StatusBadRequest, "Failed to create group: "+err.Error()) + return + } + group = createdGroup + } else { + existingGroup, err := db.GetGroupByReferenceWithMembers(reference) + if err != nil { + writeGeneralErrorJson(w, http.StatusBadRequest, "Failed to get group: "+err.Error()) + return + } + if existingGroup == nil { + writeGeneralErrorJson(w, http.StatusNotFound, "Group not found") + return + } + group = existingGroup + + for _, userId := range removeUsers { + index := group.MemberIndex(userId) + if index == -1 { + writeGeneralErrorJson(w, http.StatusBadRequest, "Group %q does not contain a user with id %s", reference, userId) + return + } + err = group.RemoveUser(userId) + if err != nil { + writeGeneralErrorJson(w, http.StatusBadRequest, "On group %q failed to remove user with id %s: %s", reference, userId, err) + return + } + group.Users = slices.Delete(group.Users, index, index) + } + } + + for _, userId := range addUsers { + user, err := db.GetUser(userId) + if err != nil { + writeGeneralErrorJson(w, http.StatusBadRequest, "Groups exists, but a user with id %s could not be fetched: %s", userId, err) + return + } + if user == nil { + writeGeneralErrorJson(w, http.StatusBadRequest, "Groups exists, but a user with id %s does not exist", userId) + return + } + err = group.AddUser(user.Id) + if err != nil { + writeGeneralErrorJson(w, http.StatusBadRequest, "Groups exists, but failed to add user with id %s: %s", userId, err) + return + } + group.Users = append(group.Users, *currentUser) + } + + _ = json.NewEncoder(w).Encode(group) +} + +func (ctx *Context) GroupsJson(user *db.User, w http.ResponseWriter, r *http.Request) { + if !user.IsAdmin { + NotFoundJson(w, r) + return + } + + groups, err := db.GetAllGroups() + if err != nil { + writeGeneralErrorJson(w, http.StatusBadRequest, "Failed to get groups: "+err.Error()) + return + } + + _ = json.NewEncoder(w).Encode(groups) +} + +func (ctx *Context) Group(user *db.User, w http.ResponseWriter, r *http.Request) { + if !user.IsAdmin { + NotFoundJson(w, r) + return + } + groupReference := r.PathValue("groupReference") + group, err := db.GetGroupByReferenceWithMembers(groupReference) + if err != nil { + writeGeneralErrorJson(w, http.StatusBadRequest, "Couldn't get group: %s", err) + return + } + if group == nil { + writeGeneralErrorJson(w, http.StatusNotFound, "Group not found.") + return + } + + _ = json.NewEncoder(w).Encode(group) +} diff --git a/server/routing/users.go b/server/routing/users.go index f4dcd67..83f42c6 100644 --- a/server/routing/users.go +++ b/server/routing/users.go @@ -14,7 +14,7 @@ func (ctx *Context) UsersJson(user *db.User, w http.ResponseWriter, r *http.Requ users, err := db.GetAllUsers() if err != nil { - writeGeneralError(w, "Failed to get users: "+err.Error(), http.StatusBadRequest) + writeGeneralErrorJson(w, http.StatusBadRequest, "Failed to get users: "+err.Error()) return }