feat: sqlite

This commit is contained in:
Teajey 2024-05-06 09:02:52 +12:00
parent cf3e84202b
commit 1de4893f8e
Signed by: Teajey
GPG Key ID: 970E790FE834A713
16 changed files with 228 additions and 109 deletions

2
.gitignore vendored
View File

@ -1 +1,3 @@
.DS_Store
gin-bin
lishwist.db

View File

@ -6,7 +6,6 @@ import (
"lishwist/db"
"lishwist/env"
"lishwist/types"
"github.com/gorilla/sessions"
)
@ -27,16 +26,16 @@ func (auth *AuthMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
func (auth *AuthMiddleware) ExpectUser(r *http.Request) *types.UserData {
func (auth *AuthMiddleware) ExpectUser(r *http.Request) *db.User {
session, _ := auth.Store.Get(r, "lishwist_user")
username, ok := session.Values["username"].(string)
if !ok {
log.Fatalln("Failed to get username")
}
user := db.GetUser(username)
if user == nil {
log.Fatalln("Failed to get user")
user, err := db.GetUser(username)
if err != nil {
log.Fatalf("Failed to get user: %s\n", err)
}
return user
}

View File

@ -2,7 +2,6 @@ package auth
import (
"lishwist/db"
"lishwist/types"
"log"
"net/http"
"time"
@ -19,14 +18,20 @@ func (auth *AuthMiddleware) LoginPost(w http.ResponseWriter, r *http.Request) {
username := r.Form.Get("username")
password := r.Form.Get("password")
user, ok := db.Get("user:" + username).(types.UserData)
if !ok {
time.Sleep(2 * time.Second)
user, err := db.GetUser(username)
if err != nil {
time.Sleep(time.Second)
http.Error(w, "Username or password invalid", http.StatusUnauthorized)
return
}
err := bcrypt.CompareHashAndPassword(user.PassHash, []byte(password))
passHash, err := user.GetPassHash()
if err != nil {
http.Error(w, "Something went wrong. Error code: Momo", http.StatusInternalServerError)
return
}
err = bcrypt.CompareHashAndPassword(passHash, []byte(password))
if err != nil {
http.Error(w, "Username or password invalid", http.StatusUnauthorized)
return

View File

@ -4,7 +4,6 @@ import (
"net/http"
"lishwist/db"
"lishwist/types"
"golang.org/x/crypto/bcrypt"
)
@ -19,7 +18,8 @@ func (auth *AuthMiddleware) RegisterPost(w http.ResponseWriter, r *http.Request)
newPassword := r.Form.Get("newPassword")
confirmPassword := r.Form.Get("confirmPassword")
if db.Exists("user:" + username) {
existingUser, _ := db.GetUser(username)
if existingUser != nil {
http.Error(w, "Username is taken", http.StatusBadRequest)
return
}
@ -35,10 +35,11 @@ func (auth *AuthMiddleware) RegisterPost(w http.ResponseWriter, r *http.Request)
return
}
db.Set("user:"+username, types.UserData{
Username: username,
PassHash: hashedPasswordBytes,
})
_, err = db.CreateUser(username, hashedPasswordBytes)
if err != nil {
http.Error(w, "Something went wrong. Error code: Ozai", http.StatusInternalServerError)
return
}
http.Redirect(w, r, "/", http.StatusSeeOther)
}

View File

@ -2,9 +2,8 @@ package context
import (
"lishwist/auth"
"lishwist/db"
"log"
"net/http"
"slices"
)
type Context struct {
@ -17,12 +16,13 @@ func (ctx *Context) WishlistAdd(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
items := db.GetUserItems(user.Username)
newItem := r.Form.Get("item")
if newItem != "" {
items = append(items, newItem)
newGiftName := r.Form.Get("gift_name")
err := user.AddGift(newGiftName)
if err != nil {
log.Printf("Failed to add gift: %s\n", err)
http.Error(w, "Failed to add gift.", http.StatusInternalServerError)
return
}
db.SetUserItems(user.Username, items)
http.Redirect(w, r, "/", http.StatusSeeOther)
}
@ -32,18 +32,16 @@ func (ctx *Context) WishlistDelete(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
items := db.GetUserItems(user.Username)
target := r.Form.Get("item")
target := r.Form.Get("gift_id")
if target == "" {
http.Error(w, "Item not provided"+target, http.StatusBadRequest)
http.Error(w, "Gift ID not provided"+target, http.StatusBadRequest)
return
}
idx := slices.Index(items, target)
if idx < 0 {
http.Error(w, "Couldn't find item: "+target, http.StatusBadRequest)
err := user.RemoveGift(target)
if err != nil {
log.Printf("Failed to remove gift: %s\n", err)
http.Error(w, "Failed to remove gift.", http.StatusInternalServerError)
return
}
items = append(items[:idx], items[idx+1:]...)
db.SetUserItems(user.Username, items)
http.Redirect(w, r, "/", http.StatusSeeOther)
}

View File

@ -3,25 +3,37 @@ package context
import (
"lishwist/db"
"lishwist/templates"
"log"
"net/http"
)
type ForeignWishlistProps struct {
Username string
Items []string
Gifts []db.Gift
}
func (ctx *Context) ViewForeignWishlist(w http.ResponseWriter, r *http.Request) {
otherUsername := r.PathValue("username")
user := ctx.Auth.ExpectUser(r)
if user.Username == otherUsername {
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
if user.Name == otherUsername {
http.Error(w, "You can't view your own list, silly ;)", http.StatusForbidden)
return
}
items := db.GetUserItems(otherUsername)
if items == nil {
http.Error(w, "User not found", http.StatusNotFound)
otherUser, err := db.GetUser(otherUsername)
if err != nil {
log.Printf("An error occurred while fetching a user: %s\n", err)
http.Error(w, "An error occurred while fetching this user :(", http.StatusInternalServerError)
return
}
p := ForeignWishlistProps{Username: otherUsername, Items: items}
if otherUser == nil {
http.Error(w, "User not found", http.StatusNotFound)
return
}
gifts, err := otherUser.GetGifts()
if err != nil {
http.Error(w, "An error occurred while fetching this user's wishlist :(", http.StatusInternalServerError)
return
}
p := ForeignWishlistProps{Username: otherUsername, Gifts: gifts}
templates.Execute(w, "foreign_wishlist.gotmpl", p)
}

View File

@ -8,12 +8,16 @@ import (
)
type HomeProps struct {
Items []string
Gifts []db.Gift
}
func (ctx *Context) Home(w http.ResponseWriter, r *http.Request) {
user := ctx.Auth.ExpectUser(r)
items := db.GetUserItems(user.Username)
p := HomeProps{Items: items}
gifts, err := user.GetGifts()
if err != nil {
http.Error(w, "An error occurred while fetching your wishlist :(", http.StatusInternalServerError)
return
}
p := HomeProps{Gifts: gifts}
templates.Execute(w, "home.gotmpl", p)
}

View File

@ -1,40 +1,31 @@
package db
import "fmt"
import (
"database/sql"
"os"
var database map[string]any = map[string]any{}
_ "github.com/mattn/go-sqlite3"
)
func Add(key string, value any) error {
_, existing := database[key]
if existing {
return fmt.Errorf("A value already exists under '%s'", key)
var database *sql.DB
func Open() error {
db, err := sql.Open("sqlite3", "./lishwist.db")
if err != nil {
return err
}
database[key] = value
database = db
return nil
}
func Set(key string, value any) {
database[key] = value
}
func Get(key string) any {
value, existing := database[key]
if !existing {
return fmt.Errorf("No value under '%s'", key)
func Init() error {
initStmt, err := os.ReadFile("./db/init.sql")
if err != nil {
return err
}
return value
}
func Remove(key string) any {
value, existing := database[key]
if !existing {
return fmt.Errorf("No value under '%s'", key)
_, err = database.Exec(string(initStmt))
if err != nil {
return err
}
delete(database, key)
return value
}
func Exists(key string) bool {
_, existing := database[key]
return existing
return nil
}

22
db/init.sql Normal file
View File

@ -0,0 +1,22 @@
BEGIN TRANSACTION;
CREATE TABLE IF NOT EXISTS "user" (
"id" INTEGER NOT NULL UNIQUE,
"name" TEXT NOT NULL UNIQUE,
"reference" TEXT NOT NULL UNIQUE,
"motto" TEXT NOT NULL,
"password_hash" TEXT NOT NULL,
PRIMARY KEY("id" AUTOINCREMENT)
);
CREATE TABLE IF NOT EXISTS "gift" (
"id" INTEGER NOT NULL UNIQUE,
"name" TEXT NOT NULL,
"recipient_id" INTEGER NOT NULL,
"claimant_id" INTEGER,
"creator_id" INTEGER NOT NULL,
"sent" INTEGER NOT NULL DEFAULT 0,
PRIMARY KEY("id" AUTOINCREMENT),
FOREIGN KEY("recipient_id") REFERENCES "user"("id"),
FOREIGN KEY("creator_id") REFERENCES "user"("id"),
FOREIGN KEY("claimant_id") REFERENCES "user"("id")
);
COMMIT;

View File

@ -1,39 +1,113 @@
package db
import (
"database/sql"
"fmt"
"lishwist/types"
"github.com/google/uuid"
)
func GetUser(username string) *types.UserData {
user, ok := Get("user:" + username).(types.UserData)
if !ok {
return nil
}
return &user
type User struct {
Id string
Name string
}
func GetUserItems(username string) []string {
user := GetUser(username)
if user == nil {
return nil
}
items, ok := Get("user_items:" + user.Username).([]string)
if !ok {
return nil
}
return items
type Gift struct {
Id string
Name string
}
func SetUserItems(username string, items []string) error {
user := GetUser(username)
if user == nil {
return fmt.Errorf("Didn't find user")
func GetUser(username string) (*User, error) {
stmt := "SELECT user.id, user.name FROM user WHERE user.name = ?"
var id string
var name string
err := database.QueryRow(stmt, username).Scan(&id, &name)
if err == sql.ErrNoRows {
return nil, nil
} else if err != nil {
return nil, err
}
user := User{
Id: id,
Name: name,
}
return &user, nil
}
Set("user_items:"+user.Username, items)
func CreateUser(username string, passHash []byte) (*User, error) {
stmt := "INSERT INTO user (name, motto, reference, password_hash) VALUES (?, '', ?, ?)"
reference, err := uuid.NewRandom()
if err != nil {
return nil, err
}
result, err := database.Exec(stmt, username, reference, passHash)
if err != nil {
return nil, err
}
id, err := result.LastInsertId()
if err != nil {
return nil, err
}
user := User{
Id: fmt.Sprintf("%d", id),
Name: username,
}
return &user, nil
}
func (u *User) GetPassHash() ([]byte, error) {
stmt := "SELECT user.password_hash FROM user WHERE user.id = ?"
var passHash string
err := database.QueryRow(stmt, u.Id).Scan(&passHash)
if err != nil {
return nil, err
}
return []byte(passHash), nil
}
func (u *User) GetGifts() ([]Gift, error) {
stmt := "SELECT gift.id, gift.name FROM gift JOIN user ON gift.recipient_id = user.id WHERE user.id = ?"
rows, err := database.Query(stmt, u.Id)
if err != nil {
return nil, err
}
defer rows.Close()
gifts := []Gift{}
for rows.Next() {
var id string
var name string
rows.Scan(&id, &name)
gift := Gift{
Id: id,
Name: name,
}
gifts = append(gifts, gift)
}
err = rows.Err()
if err != nil {
return nil, err
}
return gifts, nil
}
func (u *User) AddGift(name string) error {
stmt := "INSERT INTO gift (name, recipient_id, creator_id) VALUES (?, ?, ?)"
_, err := database.Exec(stmt, name, u.Id, u.Id)
if err != nil {
return err
}
return nil
}
func (u *User) RemoveGift(id string) error {
stmt := "DELETE FROM gift WHERE gift.creator_id = ? AND gift.id = ?"
result, err := database.Exec(stmt, u.Id, id)
if err != nil {
return err
}
affected, _ := result.RowsAffected()
if affected == 0 {
return fmt.Errorf("No gift match.")
}
return nil
}

2
go.mod
View File

@ -3,7 +3,9 @@ module lishwist
go 1.22.0
require (
github.com/google/uuid v1.6.0
github.com/gorilla/sessions v1.2.2
github.com/mattn/go-sqlite3 v1.14.22
golang.org/x/crypto v0.22.0
)

4
go.sum
View File

@ -1,8 +1,12 @@
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA=
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
github.com/gorilla/sessions v1.2.2 h1:lqzMYz6bOfvn2WriPUjNByzeXIlVzURcPmgMczkmTjY=
github.com/gorilla/sessions v1.2.2/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30=
golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=

11
main.go
View File

@ -1,14 +1,25 @@
package main
import (
"log"
"net/http"
"lishwist/auth"
"lishwist/context"
"lishwist/db"
"lishwist/templates"
)
func main() {
err := db.Open()
if err != nil {
log.Fatalf("Failed to open DB: %s\n", err)
}
err = db.Init()
if err != nil {
log.Fatalf("Failed to init DB: %s\n", err)
}
publicMux := http.NewServeMux()
protectedMux := http.NewServeMux()

View File

@ -1,4 +1,5 @@
{{define "body"}}
<h1>Lishwist</h1>
<nav>
<ul>
<li>
@ -6,11 +7,10 @@
</li>
</ul>
</nav>
<h1>Lishwist</h1>
<h2>{{.Username}}'s list</h2>
<ul>
{{range .Items}}
<li>{{.}}</li>
{{range .Gifts}}
<li>{{.Name}}</li>
{{end}}
</ul>
{{end}}

View File

@ -5,17 +5,17 @@
<h1>Lishwist</h1>
<h2>Your list</h2>
<ul>
{{range .Items}}
<li>{{.}}
{{range .Gifts}}
<li>{{.Name}}
<form method="post" action="wishlist/delete">
<input type="hidden" name="item" value="{{.}}">
<input type="hidden" name="gift_id" value="{{.Id}}">
<input type="submit" value="Delete">
</form>
</li>
{{end}}
</ul>
<form method="post" action="/wishlist/add">
<input name="item" required>
<input name="gift_name" required>
<input type="submit">
</form>
{{end}}

View File

@ -1,6 +0,0 @@
package types
type UserData struct {
Username string
PassHash []byte
}