diff --git a/.gitignore b/.gitignore index 4fd0696..96baa7e 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,3 @@ +.DS_Store gin-bin +lishwist.db diff --git a/auth/auth.go b/auth/auth.go index 3367812..5915bf7 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -6,7 +6,6 @@ import ( "lishwist/db" "lishwist/env" - "lishwist/types" "github.com/gorilla/sessions" ) @@ -27,16 +26,16 @@ func (auth *AuthMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -func (auth *AuthMiddleware) ExpectUser(r *http.Request) *types.UserData { +func (auth *AuthMiddleware) ExpectUser(r *http.Request) *db.User { session, _ := auth.Store.Get(r, "lishwist_user") username, ok := session.Values["username"].(string) if !ok { log.Fatalln("Failed to get username") } - user := db.GetUser(username) - if user == nil { - log.Fatalln("Failed to get user") + user, err := db.GetUser(username) + if err != nil { + log.Fatalf("Failed to get user: %s\n", err) } return user } diff --git a/auth/login.go b/auth/login.go index 211466f..3c10749 100644 --- a/auth/login.go +++ b/auth/login.go @@ -2,7 +2,6 @@ package auth import ( "lishwist/db" - "lishwist/types" "log" "net/http" "time" @@ -19,14 +18,20 @@ func (auth *AuthMiddleware) LoginPost(w http.ResponseWriter, r *http.Request) { username := r.Form.Get("username") password := r.Form.Get("password") - user, ok := db.Get("user:" + username).(types.UserData) - if !ok { - time.Sleep(2 * time.Second) + user, err := db.GetUser(username) + if err != nil { + time.Sleep(time.Second) http.Error(w, "Username or password invalid", http.StatusUnauthorized) return } - err := bcrypt.CompareHashAndPassword(user.PassHash, []byte(password)) + passHash, err := user.GetPassHash() + if err != nil { + http.Error(w, "Something went wrong. Error code: Momo", http.StatusInternalServerError) + return + } + + err = bcrypt.CompareHashAndPassword(passHash, []byte(password)) if err != nil { http.Error(w, "Username or password invalid", http.StatusUnauthorized) return diff --git a/auth/register.go b/auth/register.go index 30a5101..b745fc7 100644 --- a/auth/register.go +++ b/auth/register.go @@ -4,7 +4,6 @@ import ( "net/http" "lishwist/db" - "lishwist/types" "golang.org/x/crypto/bcrypt" ) @@ -19,7 +18,8 @@ func (auth *AuthMiddleware) RegisterPost(w http.ResponseWriter, r *http.Request) newPassword := r.Form.Get("newPassword") confirmPassword := r.Form.Get("confirmPassword") - if db.Exists("user:" + username) { + existingUser, _ := db.GetUser(username) + if existingUser != nil { http.Error(w, "Username is taken", http.StatusBadRequest) return } @@ -35,10 +35,11 @@ func (auth *AuthMiddleware) RegisterPost(w http.ResponseWriter, r *http.Request) return } - db.Set("user:"+username, types.UserData{ - Username: username, - PassHash: hashedPasswordBytes, - }) + _, err = db.CreateUser(username, hashedPasswordBytes) + if err != nil { + http.Error(w, "Something went wrong. Error code: Ozai", http.StatusInternalServerError) + return + } http.Redirect(w, r, "/", http.StatusSeeOther) } diff --git a/context/context.go b/context/context.go index 93f44d5..ff31d78 100644 --- a/context/context.go +++ b/context/context.go @@ -2,9 +2,8 @@ package context import ( "lishwist/auth" - "lishwist/db" + "log" "net/http" - "slices" ) type Context struct { @@ -17,12 +16,13 @@ func (ctx *Context) WishlistAdd(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusBadRequest) return } - items := db.GetUserItems(user.Username) - newItem := r.Form.Get("item") - if newItem != "" { - items = append(items, newItem) + newGiftName := r.Form.Get("gift_name") + err := user.AddGift(newGiftName) + if err != nil { + log.Printf("Failed to add gift: %s\n", err) + http.Error(w, "Failed to add gift.", http.StatusInternalServerError) + return } - db.SetUserItems(user.Username, items) http.Redirect(w, r, "/", http.StatusSeeOther) } @@ -32,18 +32,16 @@ func (ctx *Context) WishlistDelete(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusBadRequest) return } - items := db.GetUserItems(user.Username) - target := r.Form.Get("item") + target := r.Form.Get("gift_id") if target == "" { - http.Error(w, "Item not provided"+target, http.StatusBadRequest) + http.Error(w, "Gift ID not provided"+target, http.StatusBadRequest) return } - idx := slices.Index(items, target) - if idx < 0 { - http.Error(w, "Couldn't find item: "+target, http.StatusBadRequest) + err := user.RemoveGift(target) + if err != nil { + log.Printf("Failed to remove gift: %s\n", err) + http.Error(w, "Failed to remove gift.", http.StatusInternalServerError) return } - items = append(items[:idx], items[idx+1:]...) - db.SetUserItems(user.Username, items) http.Redirect(w, r, "/", http.StatusSeeOther) } diff --git a/context/foreign_wishlist.go b/context/foreign_wishlist.go index b11ad46..6ad508b 100644 --- a/context/foreign_wishlist.go +++ b/context/foreign_wishlist.go @@ -3,25 +3,37 @@ package context import ( "lishwist/db" "lishwist/templates" + "log" "net/http" ) type ForeignWishlistProps struct { Username string - Items []string + Gifts []db.Gift } func (ctx *Context) ViewForeignWishlist(w http.ResponseWriter, r *http.Request) { otherUsername := r.PathValue("username") user := ctx.Auth.ExpectUser(r) - if user.Username == otherUsername { - http.Redirect(w, r, "/", http.StatusTemporaryRedirect) + if user.Name == otherUsername { + http.Error(w, "You can't view your own list, silly ;)", http.StatusForbidden) return } - items := db.GetUserItems(otherUsername) - if items == nil { - http.Error(w, "User not found", http.StatusNotFound) + otherUser, err := db.GetUser(otherUsername) + if err != nil { + log.Printf("An error occurred while fetching a user: %s\n", err) + http.Error(w, "An error occurred while fetching this user :(", http.StatusInternalServerError) + return } - p := ForeignWishlistProps{Username: otherUsername, Items: items} + if otherUser == nil { + http.Error(w, "User not found", http.StatusNotFound) + return + } + gifts, err := otherUser.GetGifts() + if err != nil { + http.Error(w, "An error occurred while fetching this user's wishlist :(", http.StatusInternalServerError) + return + } + p := ForeignWishlistProps{Username: otherUsername, Gifts: gifts} templates.Execute(w, "foreign_wishlist.gotmpl", p) } diff --git a/context/home.go b/context/home.go index a9e4eed..ac3dd07 100644 --- a/context/home.go +++ b/context/home.go @@ -8,12 +8,16 @@ import ( ) type HomeProps struct { - Items []string + Gifts []db.Gift } func (ctx *Context) Home(w http.ResponseWriter, r *http.Request) { user := ctx.Auth.ExpectUser(r) - items := db.GetUserItems(user.Username) - p := HomeProps{Items: items} + gifts, err := user.GetGifts() + if err != nil { + http.Error(w, "An error occurred while fetching your wishlist :(", http.StatusInternalServerError) + return + } + p := HomeProps{Gifts: gifts} templates.Execute(w, "home.gotmpl", p) } diff --git a/db/db.go b/db/db.go index f285e40..db1082b 100644 --- a/db/db.go +++ b/db/db.go @@ -1,40 +1,31 @@ package db -import "fmt" +import ( + "database/sql" + "os" -var database map[string]any = map[string]any{} + _ "github.com/mattn/go-sqlite3" +) -func Add(key string, value any) error { - _, existing := database[key] - if existing { - return fmt.Errorf("A value already exists under '%s'", key) +var database *sql.DB + +func Open() error { + db, err := sql.Open("sqlite3", "./lishwist.db") + if err != nil { + return err } - database[key] = value + database = db return nil } -func Set(key string, value any) { - database[key] = value -} - -func Get(key string) any { - value, existing := database[key] - if !existing { - return fmt.Errorf("No value under '%s'", key) +func Init() error { + initStmt, err := os.ReadFile("./db/init.sql") + if err != nil { + return err } - return value -} - -func Remove(key string) any { - value, existing := database[key] - if !existing { - return fmt.Errorf("No value under '%s'", key) + _, err = database.Exec(string(initStmt)) + if err != nil { + return err } - delete(database, key) - return value -} - -func Exists(key string) bool { - _, existing := database[key] - return existing + return nil } diff --git a/db/init.sql b/db/init.sql new file mode 100644 index 0000000..b5581eb --- /dev/null +++ b/db/init.sql @@ -0,0 +1,22 @@ +BEGIN TRANSACTION; +CREATE TABLE IF NOT EXISTS "user" ( + "id" INTEGER NOT NULL UNIQUE, + "name" TEXT NOT NULL UNIQUE, + "reference" TEXT NOT NULL UNIQUE, + "motto" TEXT NOT NULL, + "password_hash" TEXT NOT NULL, + PRIMARY KEY("id" AUTOINCREMENT) +); +CREATE TABLE IF NOT EXISTS "gift" ( + "id" INTEGER NOT NULL UNIQUE, + "name" TEXT NOT NULL, + "recipient_id" INTEGER NOT NULL, + "claimant_id" INTEGER, + "creator_id" INTEGER NOT NULL, + "sent" INTEGER NOT NULL DEFAULT 0, + PRIMARY KEY("id" AUTOINCREMENT), + FOREIGN KEY("recipient_id") REFERENCES "user"("id"), + FOREIGN KEY("creator_id") REFERENCES "user"("id"), + FOREIGN KEY("claimant_id") REFERENCES "user"("id") +); +COMMIT; diff --git a/db/user.go b/db/user.go index 0e05ea2..9479285 100644 --- a/db/user.go +++ b/db/user.go @@ -1,39 +1,113 @@ package db import ( + "database/sql" "fmt" - "lishwist/types" + + "github.com/google/uuid" ) -func GetUser(username string) *types.UserData { - user, ok := Get("user:" + username).(types.UserData) - if !ok { - return nil - } - return &user +type User struct { + Id string + Name string } -func GetUserItems(username string) []string { - user := GetUser(username) - if user == nil { - return nil - } - - items, ok := Get("user_items:" + user.Username).([]string) - if !ok { - return nil - } - - return items +type Gift struct { + Id string + Name string } -func SetUserItems(username string, items []string) error { - user := GetUser(username) - if user == nil { - return fmt.Errorf("Didn't find user") +func GetUser(username string) (*User, error) { + stmt := "SELECT user.id, user.name FROM user WHERE user.name = ?" + var id string + var name string + err := database.QueryRow(stmt, username).Scan(&id, &name) + if err == sql.ErrNoRows { + return nil, nil + } else if err != nil { + return nil, err } + user := User{ + Id: id, + Name: name, + } + return &user, nil +} - Set("user_items:"+user.Username, items) +func CreateUser(username string, passHash []byte) (*User, error) { + stmt := "INSERT INTO user (name, motto, reference, password_hash) VALUES (?, '', ?, ?)" + reference, err := uuid.NewRandom() + if err != nil { + return nil, err + } + result, err := database.Exec(stmt, username, reference, passHash) + if err != nil { + return nil, err + } + id, err := result.LastInsertId() + if err != nil { + return nil, err + } + user := User{ + Id: fmt.Sprintf("%d", id), + Name: username, + } + return &user, nil +} +func (u *User) GetPassHash() ([]byte, error) { + stmt := "SELECT user.password_hash FROM user WHERE user.id = ?" + var passHash string + err := database.QueryRow(stmt, u.Id).Scan(&passHash) + if err != nil { + return nil, err + } + return []byte(passHash), nil +} + +func (u *User) GetGifts() ([]Gift, error) { + stmt := "SELECT gift.id, gift.name FROM gift JOIN user ON gift.recipient_id = user.id WHERE user.id = ?" + rows, err := database.Query(stmt, u.Id) + if err != nil { + return nil, err + } + defer rows.Close() + gifts := []Gift{} + for rows.Next() { + var id string + var name string + rows.Scan(&id, &name) + gift := Gift{ + Id: id, + Name: name, + } + gifts = append(gifts, gift) + } + err = rows.Err() + if err != nil { + return nil, err + } + return gifts, nil +} + +func (u *User) AddGift(name string) error { + stmt := "INSERT INTO gift (name, recipient_id, creator_id) VALUES (?, ?, ?)" + _, err := database.Exec(stmt, name, u.Id, u.Id) + if err != nil { + return err + } + return nil +} + +func (u *User) RemoveGift(id string) error { + stmt := "DELETE FROM gift WHERE gift.creator_id = ? AND gift.id = ?" + result, err := database.Exec(stmt, u.Id, id) + if err != nil { + return err + } + affected, _ := result.RowsAffected() + if affected == 0 { + return fmt.Errorf("No gift match.") + } return nil } diff --git a/go.mod b/go.mod index 1ab14a5..cf5fd50 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,9 @@ module lishwist go 1.22.0 require ( + github.com/google/uuid v1.6.0 github.com/gorilla/sessions v1.2.2 + github.com/mattn/go-sqlite3 v1.14.22 golang.org/x/crypto v0.22.0 ) diff --git a/go.sum b/go.sum index 22712ff..cc81296 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,12 @@ github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/gorilla/sessions v1.2.2 h1:lqzMYz6bOfvn2WriPUjNByzeXIlVzURcPmgMczkmTjY= github.com/gorilla/sessions v1.2.2/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= diff --git a/main.go b/main.go index 1fd55c0..c912f39 100644 --- a/main.go +++ b/main.go @@ -1,14 +1,25 @@ package main import ( + "log" "net/http" "lishwist/auth" "lishwist/context" + "lishwist/db" "lishwist/templates" ) func main() { + err := db.Open() + if err != nil { + log.Fatalf("Failed to open DB: %s\n", err) + } + err = db.Init() + if err != nil { + log.Fatalf("Failed to init DB: %s\n", err) + } + publicMux := http.NewServeMux() protectedMux := http.NewServeMux() diff --git a/templates/foreign_wishlist.gotmpl b/templates/foreign_wishlist.gotmpl index 66d1020..c30ba96 100644 --- a/templates/foreign_wishlist.gotmpl +++ b/templates/foreign_wishlist.gotmpl @@ -1,4 +1,5 @@ {{define "body"}} +

Lishwist

-

Lishwist

{{.Username}}'s list

-{{end}} \ No newline at end of file +{{end}} diff --git a/templates/home.gotmpl b/templates/home.gotmpl index 1aeed4d..67bdc4d 100644 --- a/templates/home.gotmpl +++ b/templates/home.gotmpl @@ -5,17 +5,17 @@

Lishwist

Your list

- +
-{{end}} \ No newline at end of file +{{end}} diff --git a/types/user.go b/types/user.go deleted file mode 100644 index a1ea732..0000000 --- a/types/user.go +++ /dev/null @@ -1,6 +0,0 @@ -package types - -type UserData struct { - Username string - PassHash []byte -}