diff --git a/server/api/register.go b/server/api/register.go index 2c83c97..77a4df1 100644 --- a/server/api/register.go +++ b/server/api/register.go @@ -17,6 +17,8 @@ type RegisterProps struct { } func (p *RegisterProps) Validate() (valid bool) { + valid = true + if p.Password.Value != p.ConfirmPassword.Value { p.ConfirmPassword.Error = "Passwords didn't match" valid = false @@ -69,24 +71,27 @@ func Register(username, newPassword, confirmPassword string) *RegisterProps { props.Password.Value = "" props.ConfirmPassword.Value = "" if !valid { + log.Printf("Invalid props: %#v\n", props) return props } existingUser, _ := db.GetUserByName(username) if existingUser != nil { + log.Printf("Username is taken: %q\n", username) props.Username.Error = "Username is taken" return props } hashedPasswordBytes, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.MinCost) if err != nil { + log.Printf("Failed to hash password: %s\n", err) props.GeneralError = "Something went wrong. Error code: Aang" return props } _, err = db.CreateUser(username, hashedPasswordBytes) if err != nil { - log.Println("Registration error:", err) + log.Printf("Failed to create user: %s\n", err) props.GeneralError = "Something went wrong. Error code: Ozai" return props } diff --git a/server/db/init.sql b/server/db/init.sql index b9f5172..1d69a55 100644 --- a/server/db/init.sql +++ b/server/db/init.sql @@ -5,6 +5,7 @@ CREATE TABLE IF NOT EXISTS "user" ( "reference" TEXT NOT NULL UNIQUE, "motto" TEXT NOT NULL, "password_hash" TEXT NOT NULL, + "is_admin" INTEGER NOT NULL DEFAULT 0, PRIMARY KEY("id" AUTOINCREMENT) ); CREATE TABLE IF NOT EXISTS "gift" ( diff --git a/server/db/user.go b/server/db/user.go index 067c64b..74bbb21 100644 --- a/server/db/user.go +++ b/server/db/user.go @@ -11,6 +11,7 @@ type User struct { Id string Name string Reference string + IsAdmin bool } type Gift struct { @@ -27,30 +28,46 @@ type Gift struct { } func queryForUser(query string, args ...any) (*User, error) { - var id string - var name string - var reference string - err := database.QueryRow(query, args...).Scan(&id, &name, &reference) + var u User + err := database.QueryRow(query, args...).Scan(&u.Id, &u.Name, &u.Reference, &u.IsAdmin) if err == sql.ErrNoRows { return nil, nil } else if err != nil { return nil, err } - user := User{ - Id: id, - Name: name, - Reference: reference, + return &u, nil +} + +func GetAllUsers() ([]User, error) { + stmt := "SELECT user.id, user.name, user.reference, user.is_admin FROM user" + rows, err := database.Query(stmt) + if err != nil { + return nil, err } - return &user, nil + defer rows.Close() + users := []User{} + for rows.Next() { + var u User + err = rows.Scan(&u.Id, &u.Name, &u.Reference, &u.IsAdmin) + if err != nil { + return nil, err + } + users = append(users, u) + } + err = rows.Err() + if err != nil { + return nil, err + } + return users, nil } func GetUserByName(username string) (*User, error) { - stmt := "SELECT user.id, user.name, user.reference FROM user WHERE user.name = ?" + stmt := "SELECT user.id, user.name, user.reference, user.is_admin FROM user WHERE user.name = ?" return queryForUser(stmt, username) } func GetUserByReference(reference string) (*User, error) { - stmt := "SELECT user.id, user.name, user.reference FROM user WHERE user.reference = ?" + stmt := "SELECT user.id, user.name, user.reference, user.is_admin FROM user WHERE user.reference = ?" return queryForUser(stmt, reference) } diff --git a/server/main.go b/server/main.go index d316b4f..d5d280d 100644 --- a/server/main.go +++ b/server/main.go @@ -27,7 +27,7 @@ func main() { store, err := db.NewSessionStore() if err != nil { - log.Fatalf("Failed to ") + log.Fatalf("Failed to initialize session store: %s\n", err) } store.Options.MaxAge = 86_400 store.Options.Secure = !env.InDev @@ -44,14 +44,17 @@ func main() { r.Html.Public.HandleFunc("GET /list/{userReference}", route.PublicWishlist) r.Html.Public.HandleFunc("GET /group/{groupReference}", route.PublicGroupPage) - r.Html.Private.HandleFunc("GET /{$}", route.Home) - r.Html.Private.HandleFunc("POST /{$}", route.HomePost) - r.Html.Private.HandleFunc("GET /list/{userReference}", route.ForeignWishlist) - r.Html.Private.HandleFunc("POST /list/{userReference}", route.ForeignWishlistPost) - r.Html.Private.HandleFunc("GET /group/{groupReference}", route.GroupPage) + r.Html.Private.Handle("GET /{$}", route.ExpectUser(route.Home)) + r.Html.Private.Handle("POST /{$}", route.ExpectUser(route.HomePost)) + r.Html.Private.Handle("GET /list/{userReference}", route.ExpectUser(route.ForeignWishlist)) + r.Html.Private.Handle("POST /list/{userReference}", route.ExpectUser(route.ForeignWishlistPost)) + r.Html.Private.Handle("GET /group/{groupReference}", route.ExpectUser(route.GroupPage)) r.Html.Private.HandleFunc("POST /logout", route.LogoutPost) r.Json.Public.HandleFunc("POST /register", route.RegisterPostJson) + r.Json.Public.HandleFunc("GET /", routing.NotFoundJson) + + r.Json.Private.Handle("GET /users", route.ExpectUser(route.UsersJson)) http.Handle("/", r) diff --git a/server/routing/context.go b/server/routing/context.go index 6480b81..b5c20b2 100644 --- a/server/routing/context.go +++ b/server/routing/context.go @@ -18,16 +18,23 @@ func NewContext(store *sqlstore.Store) *Context { } } -func (auth *Context) ExpectUser(r *http.Request) *db.User { - session, _ := auth.store.Get(r, "lishwist_user") - username, ok := session.Values["username"].(string) - if !ok { - log.Fatalln("Failed to get username") - } +func (ctx *Context) ExpectUser(next func(*db.User, http.ResponseWriter, *http.Request)) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + session, _ := ctx.store.Get(r, "lishwist_user") + username, ok := session.Values["username"].(string) + if !ok { + log.Println("Failed to get username") + http.Error(w, "", http.StatusInternalServerError) + return + } - user, err := db.GetUserByName(username) - if err != nil { - log.Fatalf("Failed to get user: %s\n", err) - } - return user + user, err := db.GetUserByName(username) + if err != nil { + log.Printf("Failed to get user: %s\n", err) + http.Error(w, "", http.StatusInternalServerError) + return + } + + next(user, w, r) + }) } diff --git a/server/routing/error.go b/server/routing/error.go index 79d1f65..830a103 100644 --- a/server/routing/error.go +++ b/server/routing/error.go @@ -2,11 +2,13 @@ package routing import ( "fmt" + "log" "net/http" "strings" ) func writeGeneralError(w http.ResponseWriter, msg string, status int) { + log.Printf("General error: %s\n", msg) w.WriteHeader(status) escapedMsg := strings.ReplaceAll(msg, `"`, `\"`) _, _ = w.Write([]byte(fmt.Sprintf(`{"GeneralError":"%s"}`, escapedMsg))) diff --git a/server/routing/foreign_wishlist.go b/server/routing/foreign_wishlist.go index d974125..b552b4e 100644 --- a/server/routing/foreign_wishlist.go +++ b/server/routing/foreign_wishlist.go @@ -14,9 +14,8 @@ type foreignWishlistProps struct { Gifts []db.Gift } -func (ctx *Context) ForeignWishlist(w http.ResponseWriter, r *http.Request) { +func (ctx *Context) ForeignWishlist(user *db.User, w http.ResponseWriter, r *http.Request) { userReference := r.PathValue("userReference") - user := ctx.ExpectUser(r) if user.Reference == userReference { http.Redirect(w, r, "/", http.StatusSeeOther) return diff --git a/server/routing/groups.go b/server/routing/groups.go index 1573c7b..c2d3f17 100644 --- a/server/routing/groups.go +++ b/server/routing/groups.go @@ -14,8 +14,7 @@ type GroupProps struct { CurrentUsername string } -func (ctx *Context) GroupPage(w http.ResponseWriter, r *http.Request) { - user := ctx.ExpectUser(r) +func (ctx *Context) GroupPage(user *db.User, w http.ResponseWriter, r *http.Request) { groupReference := r.PathValue("groupReference") group, err := user.GetGroupByReference(groupReference) if err != nil { diff --git a/server/routing/home.go b/server/routing/home.go index 18ac4a4..46292c0 100644 --- a/server/routing/home.go +++ b/server/routing/home.go @@ -18,8 +18,7 @@ type HomeProps struct { Groups []db.Group } -func (ctx *Context) Home(w http.ResponseWriter, r *http.Request) { - user := ctx.ExpectUser(r) +func (ctx *Context) Home(user *db.User, w http.ResponseWriter, r *http.Request) { gifts, err := user.GetGifts() if err != nil { error.Page(w, "An error occurred while fetching your wishlist :(", http.StatusInternalServerError, err) @@ -39,20 +38,20 @@ func (ctx *Context) Home(w http.ResponseWriter, r *http.Request) { templates.Execute(w, "home.gotmpl", p) } -func (ctx *Context) HomePost(w http.ResponseWriter, r *http.Request) { +func (ctx *Context) HomePost(user *db.User, w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { http.Error(w, "Couldn't parse form", http.StatusBadRequest) return } switch r.Form.Get("intent") { case "add_idea": - ctx.WishlistAdd(w, r) + ctx.WishlistAdd(user, w, r) return case "delete_idea": - ctx.WishlistDelete(w, r) + ctx.WishlistDelete(user, w, r) return default: - ctx.TodoUpdate(w, r) + ctx.TodoUpdate(user, w, r) return } } diff --git a/server/routing/login.go b/server/routing/login.go index 1fd9135..3094a48 100644 --- a/server/routing/login.go +++ b/server/routing/login.go @@ -6,7 +6,6 @@ import ( "lishwist/templates" "log" "net/http" - "time" "golang.org/x/crypto/bcrypt" ) @@ -79,8 +78,14 @@ func (ctx *Context) LoginPost(w http.ResponseWriter, r *http.Request) { props.Username.Value = username user, err := db.GetUserByName(username) - if user == nil || err != nil { - time.Sleep(time.Second) + if err != nil { + log.Printf("Failed to fetch user: %s\n", err) + props.GeneralError = "Username or password invalid" + ctx.RedirectWithFlash(w, r, "/", "login_props", &props) + return + } + if user == nil { + log.Printf("User not found by name: %q\n", username) props.GeneralError = "Username or password invalid" ctx.RedirectWithFlash(w, r, "/", "login_props", &props) return @@ -88,6 +93,7 @@ func (ctx *Context) LoginPost(w http.ResponseWriter, r *http.Request) { passHash, err := user.GetPassHash() if err != nil { + log.Println("Failed to get password hash: " + err.Error()) props.GeneralError = "Something went wrong. Error code: Momo" ctx.RedirectWithFlash(w, r, "/", "login_props", &props) return @@ -95,6 +101,7 @@ func (ctx *Context) LoginPost(w http.ResponseWriter, r *http.Request) { err = bcrypt.CompareHashAndPassword(passHash, []byte(password)) if err != nil { + log.Println("Username or password invalid: " + err.Error()) props.GeneralError = "Username or password invalid" ctx.RedirectWithFlash(w, r, "/", "login_props", &props) return diff --git a/server/routing/not_found.go b/server/routing/not_found.go new file mode 100644 index 0000000..10cb0f9 --- /dev/null +++ b/server/routing/not_found.go @@ -0,0 +1,11 @@ +package routing + +import ( + "net/http" +) + +func NotFoundJson(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"GeneralError":"Not Found"}`)) + w.Header().Add("Content-Type", "application/json") +} diff --git a/server/routing/register.go b/server/routing/register.go index c043606..6eea8dd 100644 --- a/server/routing/register.go +++ b/server/routing/register.go @@ -13,6 +13,7 @@ func (ctx *Context) Register(w http.ResponseWriter, r *http.Request) { session, _ := ctx.store.Get(r, "lishwist_user") if flashes := session.Flashes("register_props"); len(flashes) > 0 { + log.Printf("Register found flashes: %#v\n", flashes) flashProps, _ := flashes[0].(*api.RegisterProps) props.Username.Value = flashProps.Username.Value diff --git a/server/routing/todo.go b/server/routing/todo.go index 98ac86a..554cbff 100644 --- a/server/routing/todo.go +++ b/server/routing/todo.go @@ -1,12 +1,12 @@ package routing import ( + "lishwist/db" "log" "net/http" ) -func (ctx *Context) TodoUpdate(w http.ResponseWriter, r *http.Request) { - user := ctx.ExpectUser(r) +func (ctx *Context) TodoUpdate(user *db.User, w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return diff --git a/server/routing/users.go b/server/routing/users.go new file mode 100644 index 0000000..f4dcd67 --- /dev/null +++ b/server/routing/users.go @@ -0,0 +1,22 @@ +package routing + +import ( + "encoding/json" + "lishwist/db" + "net/http" +) + +func (ctx *Context) UsersJson(user *db.User, w http.ResponseWriter, r *http.Request) { + if !user.IsAdmin { + NotFoundJson(w, r) + return + } + + users, err := db.GetAllUsers() + if err != nil { + writeGeneralError(w, "Failed to get users: "+err.Error(), http.StatusBadRequest) + return + } + + _ = json.NewEncoder(w).Encode(users) +} diff --git a/server/routing/wishlist.go b/server/routing/wishlist.go index d78fba7..08d9006 100644 --- a/server/routing/wishlist.go +++ b/server/routing/wishlist.go @@ -1,12 +1,12 @@ package routing import ( + "lishwist/db" "lishwist/error" "net/http" ) -func (ctx *Context) WishlistAdd(w http.ResponseWriter, r *http.Request) { - user := ctx.ExpectUser(r) +func (ctx *Context) WishlistAdd(user *db.User, w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return @@ -20,8 +20,7 @@ func (ctx *Context) WishlistAdd(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/", http.StatusSeeOther) } -func (ctx *Context) WishlistDelete(w http.ResponseWriter, r *http.Request) { - user := ctx.ExpectUser(r) +func (ctx *Context) WishlistDelete(user *db.User, w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return @@ -35,8 +34,7 @@ func (ctx *Context) WishlistDelete(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/", http.StatusSeeOther) } -func (ctx *Context) ForeignWishlistPost(w http.ResponseWriter, r *http.Request) { - user := ctx.ExpectUser(r) +func (ctx *Context) ForeignWishlistPost(user *db.User, w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { error.Page(w, "Failed to parse form...", http.StatusBadRequest, err) return