diff --git a/context/context.go b/context/context.go index 79a0c56..a576f94 100644 --- a/context/context.go +++ b/context/context.go @@ -46,20 +46,42 @@ func (ctx *Context) WishlistDelete(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/", http.StatusSeeOther) } -func (ctx *Context) UpdateClaim(w http.ResponseWriter, r *http.Request) { +func (ctx *Context) updateClaims(w http.ResponseWriter, r *http.Request) { user := ctx.Auth.ExpectUser(r) - if err := r.ParseForm(); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } userReference := r.PathValue("userReference") - claims := r.Form["claim"] - unclaims := r.Form["unclaim"] + claims := r.Form["unclaimed"] + unclaims := r.Form["claimed"] err := user.ClaimGifts(claims, unclaims) - println("err?", err) if err != nil { http.Error(w, "Failed to update claim...", http.StatusInternalServerError) return } http.Redirect(w, r, "/"+userReference, http.StatusSeeOther) } + +func (ctx *Context) completeGifts(w http.ResponseWriter, r *http.Request) { + user := ctx.Auth.ExpectUser(r) + userReference := r.PathValue("userReference") + claims := r.Form["claimed"] + err := user.CompleteGifts(claims) + if err != nil { + http.Error(w, "Failed to complete gifts...", http.StatusInternalServerError) + return + } + http.Redirect(w, r, "/"+userReference, http.StatusSeeOther) +} + +func (ctx *Context) UpdateForeignWishlist(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + switch r.Form.Get("mode") { + case "claim": + ctx.updateClaims(w, r) + case "complete": + ctx.completeGifts(w, r) + default: + http.Error(w, "Invalid mode", http.StatusBadRequest) + } +} diff --git a/db/user.go b/db/user.go index 48c27f2..06f7804 100644 --- a/db/user.go +++ b/db/user.go @@ -18,6 +18,7 @@ type Gift struct { Name string ClaimantId string ClaimantName string + Sent bool } func queryForUser(query string, args ...any) (*User, error) { @@ -80,7 +81,7 @@ func (u *User) GetPassHash() ([]byte, error) { } func (u *User) GetGifts() ([]Gift, error) { - stmt := "SELECT gift.id, gift.name, claimant.id, claimant.name FROM gift JOIN user ON gift.recipient_id = user.id LEFT JOIN user AS claimant ON gift.claimant_id = claimant.id WHERE user.id = ? ORDER BY gift.name DESC" + stmt := "SELECT gift.id, gift.name, claimant.id, claimant.name, gift.sent FROM gift JOIN user ON gift.recipient_id = user.id LEFT JOIN user AS claimant ON gift.claimant_id = claimant.id WHERE user.id = ? ORDER BY gift.name DESC" rows, err := database.Query(stmt, u.Id) if err != nil { return nil, err @@ -92,12 +93,14 @@ func (u *User) GetGifts() ([]Gift, error) { var name string var claimantId string var claimantName string - rows.Scan(&id, &name, &claimantId, &claimantName) + var sent bool + rows.Scan(&id, &name, &claimantId, &claimantName, &sent) gift := Gift{ Id: id, Name: name, ClaimantId: claimantId, ClaimantName: claimantName, + Sent: sent, } gifts = append(gifts, gift) } @@ -134,7 +137,6 @@ func (u *User) executeClaims(tx *sql.Tx, claims, unclaims []string) error { claimStmt := "UPDATE gift SET claimant_id = ? WHERE id = ?" unclaimStmt := "UPDATE gift SET claimant_id = NULL WHERE id = ?" for _, id := range claims { - println("exec claim:", claimStmt, u.Id, id) _, err := tx.Exec(claimStmt, u.Id, id) if err != nil { return err @@ -164,3 +166,30 @@ func (u *User) ClaimGifts(claims, unclaims []string) error { err = tx.Commit() return err } + +func (u *User) executeCompletions(tx *sql.Tx, claims []string) error { + claimStmt := "UPDATE gift SET sent = 1 WHERE id = ?" + for _, id := range claims { + _, err := tx.Exec(claimStmt, id) + if err != nil { + return err + } + } + return nil +} + +func (u *User) CompleteGifts(claims []string) error { + tx, err := database.Begin() + if err != nil { + return err + } + + err = u.executeCompletions(tx, claims) + if err != nil { + err = tx.Rollback() + return err + } + + err = tx.Commit() + return err +} diff --git a/main.go b/main.go index 18a2e28..2f13e14 100644 --- a/main.go +++ b/main.go @@ -36,7 +36,7 @@ func main() { protectedMux.HandleFunc("GET /", ctx.Home) protectedMux.HandleFunc("GET /{userReference}", ctx.ViewForeignWishlist) - protectedMux.HandleFunc("POST /{userReference}/update_claim", ctx.UpdateClaim) + protectedMux.HandleFunc("POST /{userReference}/update", ctx.UpdateForeignWishlist) protectedMux.HandleFunc("POST /wishlist/add", ctx.WishlistAdd) protectedMux.HandleFunc("POST /wishlist/delete", ctx.WishlistDelete) protectedMux.HandleFunc("POST /logout", authMiddleware.LogoutPost) diff --git a/templates/foreign_wishlist.gotmpl b/templates/foreign_wishlist.gotmpl index d1ff8e8..179d943 100644 --- a/templates/foreign_wishlist.gotmpl +++ b/templates/foreign_wishlist.gotmpl @@ -8,31 +8,35 @@