Changes made while flying Tokyo -> Auckland #11

Merged
Teajey merged 5 commits from tokyo-wellington-changes into main 2024-12-27 22:40:30 +13:00
23 changed files with 121 additions and 88 deletions

4
.gitignore vendored
View File

@ -1,6 +1,6 @@
.DS_Store .DS_Store
gin-bin gin-bin
lishwist.db *lishwist.db
.env*.local .env*.local
server/db/init_sql.go server/api/db/init_sql.go
.ignored/ .ignored/

View File

@ -14,7 +14,7 @@ import (
var database *sql.DB var database *sql.DB
func Open() error { func Open() error {
db, err := sql.Open("sqlite", "./lishwist.db") db, err := sql.Open("sqlite", env.DatabaseFile)
if err != nil { if err != nil {
return err return err
} }
@ -23,7 +23,7 @@ func Open() error {
} }
func Init() error { func Init() error {
_, err := database.Exec(InitQuery) _, err := database.Exec(initQuery)
if err != nil { if err != nil {
return err return err
} }

View File

@ -11,7 +11,7 @@ import (
var initTemplate = template.Must(template.New("").Parse("// Code generated DO NOT EDIT.\n" + var initTemplate = template.Must(template.New("").Parse("// Code generated DO NOT EDIT.\n" +
"package db\n" + "package db\n" +
"\n" + "\n" +
"const InitQuery = `{{.}}`\n", "const initQuery = `{{.}}`\n",
)) ))
func main() { func main() {

View File

@ -1,8 +1,8 @@
package db package db
import ( import (
"database/sql"
"fmt" "fmt"
"lishwist/normalize"
"strconv" "strconv"
) )
@ -22,52 +22,47 @@ func (g *Group) MemberIndex(userId string) int {
return -1 return -1
} }
func queryForGroup(query string, args ...any) (*Group, error) { func queryManyGroups(query string, args ...any) ([]Group, error) {
var group Group
err := database.QueryRow(query, args...).Scan(&group.Id, &group.Name, &group.Reference)
if err == sql.ErrNoRows {
return nil, nil
} else if err != nil {
return nil, err
}
members, err := queryForGroupMembers(group.Id)
if err != nil {
return nil, err
}
group.Members = members
return &group, nil
}
func queryForGroups(query string, args ...any) ([]Group, error) {
groups := []Group{} groups := []Group{}
rows, err := database.Query(query, args...) rows, err := database.Query(query, args...)
if err != nil { if err != nil {
return groups, fmt.Errorf("Query failed: %w", err) return nil, fmt.Errorf("Query failed: %w", err)
} }
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
var group Group var group Group
err := rows.Scan(&group.Id, &group.Name, &group.Reference) err := rows.Scan(&group.Id, &group.Name, &group.Reference)
if err != nil { if err != nil {
return groups, fmt.Errorf("Failed to scan row: %w", err) return nil, fmt.Errorf("Failed to scan row: %w", err)
} }
members, err := queryForGroupMembers(group.Id) members, err := queryManyGroupMembers(group.Id)
if err != nil { if err != nil {
return groups, fmt.Errorf("Failed to query for group members: %w", err) return nil, fmt.Errorf("Failed to query for group members: %w", err)
} }
group.Members = members group.Members = members
groups = append(groups, group) groups = append(groups, group)
} }
err = rows.Err() err = rows.Err()
if err != nil { if err != nil {
return groups, fmt.Errorf("Rows error: %w", err) return nil, fmt.Errorf("Rows error: %w", err)
} }
return groups, nil return groups, nil
} }
func queryForGroupMembers(groupId string) ([]User, error) { func queryOneGroup(query string, args ...any) (*Group, error) {
query := "SELECT user.id, user.name, user.reference, user.is_admin, user.is_live FROM v_user AS user JOIN group_member ON group_member.user_id = user.id JOIN [group] ON [group].id = group_member.group_id WHERE [group].id = ? ORDER BY group_member.user_id" groups, err := queryManyGroups(query, args...)
members, err := queryForUsers(query, groupId) if err != nil {
return nil, err
}
if len(groups) < 1 {
return nil, nil
}
return &groups[0], nil
}
func queryManyGroupMembers(groupId string) ([]User, error) {
query := "SELECT user.id, user.name, user.display_name, user.reference, user.is_admin, user.is_live FROM v_user AS user JOIN group_member ON group_member.user_id = user.id JOIN [group] ON [group].id = group_member.group_id WHERE [group].id = ? ORDER BY group_member.user_id"
members, err := queryManyUsers(query, groupId)
if err != nil { if err != nil {
return members, fmt.Errorf("Failed to get members: %w", err) return members, fmt.Errorf("Failed to get members: %w", err)
} }
@ -76,15 +71,16 @@ func queryForGroupMembers(groupId string) ([]User, error) {
func GetGroupByReference(reference string) (*Group, error) { func GetGroupByReference(reference string) (*Group, error) {
query := "SELECT [group].id, [group].name, [group].reference FROM [group] WHERE [group].reference = ?" query := "SELECT [group].id, [group].name, [group].reference FROM [group] WHERE [group].reference = ?"
return queryForGroup(query, reference) return queryOneGroup(query, reference)
} }
func GetAllGroups() ([]Group, error) { func GetAllGroups() ([]Group, error) {
query := "SELECT id, name, reference FROM [group];" query := "SELECT id, name, reference FROM [group];"
return queryForGroups(query) return queryManyGroups(query)
} }
func CreateGroup(name string, reference string) (*Group, error) { func CreateGroup(name string, reference string) (*Group, error) {
name = normalize.Name(name)
stmt := "INSERT INTO [group] (name, reference) VALUES (?, ?)" stmt := "INSERT INTO [group] (name, reference) VALUES (?, ?)"
result, err := database.Exec(stmt, name, reference) result, err := database.Exec(stmt, name, reference)
if err != nil { if err != nil {

View File

@ -2,6 +2,7 @@ BEGIN TRANSACTION;
CREATE TABLE IF NOT EXISTS "user" ( CREATE TABLE IF NOT EXISTS "user" (
"id" INTEGER NOT NULL UNIQUE, "id" INTEGER NOT NULL UNIQUE,
"name" TEXT NOT NULL UNIQUE, "name" TEXT NOT NULL UNIQUE,
"display_name" TEXT NOT NULL UNIQUE,
"reference" TEXT NOT NULL UNIQUE, "reference" TEXT NOT NULL UNIQUE,
"motto" TEXT NOT NULL DEFAULT "", "motto" TEXT NOT NULL DEFAULT "",
"password_hash" TEXT NOT NULL, "password_hash" TEXT NOT NULL,

View File

@ -3,16 +3,18 @@ package db
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"lishwist/normalize"
"github.com/google/uuid" "github.com/google/uuid"
) )
type User struct { type User struct {
Id string Id string
Name string NormalName string
Reference string Name string
IsAdmin bool Reference string
IsLive bool IsAdmin bool
IsLive bool
} }
type Gift struct { type Gift struct {
@ -28,18 +30,7 @@ type Gift struct {
CreatorName string `json:",omitempty"` CreatorName string `json:",omitempty"`
} }
func queryForUser(query string, args ...any) (*User, error) { func queryManyUsers(query string, args ...any) ([]User, error) {
var u User
err := database.QueryRow(query, args...).Scan(&u.Id, &u.Name, &u.Reference, &u.IsAdmin, &u.IsLive)
if err == sql.ErrNoRows {
return nil, nil
} else if err != nil {
return nil, err
}
return &u, nil
}
func queryForUsers(query string, args ...any) ([]User, error) {
rows, err := database.Query(query, args...) rows, err := database.Query(query, args...)
if err != nil { if err != nil {
return nil, err return nil, err
@ -48,7 +39,7 @@ func queryForUsers(query string, args ...any) ([]User, error) {
users := []User{} users := []User{}
for rows.Next() { for rows.Next() {
var u User var u User
err = rows.Scan(&u.Id, &u.Name, &u.Reference, &u.IsAdmin, &u.IsLive) err = rows.Scan(&u.Id, &u.NormalName, &u.Name, &u.Reference, &u.IsAdmin, &u.IsLive)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -61,29 +52,41 @@ func queryForUsers(query string, args ...any) ([]User, error) {
return users, nil return users, nil
} }
func queryOneUser(query string, args ...any) (*User, error) {
users, err := queryManyUsers(query, args...)
if err != nil {
return nil, err
}
if len(users) < 1 {
return nil, nil
}
return &users[0], nil
}
func GetAllUsers() ([]User, error) { func GetAllUsers() ([]User, error) {
stmt := "SELECT id, name, reference, is_admin, is_live FROM user" stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM user"
return queryForUsers(stmt) return queryManyUsers(stmt)
} }
func GetUser(id string) (*User, error) { func GetUser(id string) (*User, error) {
stmt := "SELECT id, name, reference, is_admin, is_live FROM v_user WHERE id = ?" stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM v_user WHERE id = ?"
return queryForUser(stmt, id) return queryOneUser(stmt, id)
} }
func GetUserByName(username string) (*User, error) { func GetUserByName(username string) (*User, error) {
stmt := "SELECT id, name, reference, is_admin, is_live FROM v_user WHERE name = ?" username = normalize.Name(username)
return queryForUser(stmt, username) stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM v_user WHERE name = ?"
return queryOneUser(stmt, username)
} }
func GetUserByReference(reference string) (*User, error) { func GetUserByReference(reference string) (*User, error) {
stmt := "SELECT id, name, reference, is_admin, is_live FROM v_user WHERE reference = ?" stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM v_user WHERE reference = ?"
return queryForUser(stmt, reference) return queryOneUser(stmt, reference)
} }
func GetAnyUserByReference(reference string) (*User, error) { func GetAnyUserByReference(reference string) (*User, error) {
stmt := "SELECT id, name, reference, is_admin, is_live FROM user WHERE reference = ?" stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM user WHERE reference = ?"
return queryForUser(stmt, reference) return queryOneUser(stmt, reference)
} }
func (u *User) SetLive(setting bool) error { func (u *User) SetLive(setting bool) error {
@ -96,13 +99,14 @@ func (u *User) SetLive(setting bool) error {
return err return err
} }
func CreateUser(username string, passHash []byte) (*User, error) { func CreateUser(name string, passHash []byte) (*User, error) {
stmt := "INSERT INTO user (name, reference, password_hash) VALUES (?, ?, ?)" username := normalize.Name(name)
stmt := "INSERT INTO user (name, display_name, reference, password_hash) VALUES (?, ?, ?, ?)"
reference, err := uuid.NewRandom() reference, err := uuid.NewRandom()
if err != nil { if err != nil {
return nil, err return nil, err
} }
result, err := database.Exec(stmt, username, reference, passHash) result, err := database.Exec(stmt, username, name, reference, passHash)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -112,7 +116,7 @@ func CreateUser(username string, passHash []byte) (*User, error) {
} }
user := User{ user := User{
Id: fmt.Sprintf("%d", id), Id: fmt.Sprintf("%d", id),
Name: username, Name: name,
} }
return &user, nil return &user, nil
} }
@ -406,10 +410,10 @@ func (u *User) AddGiftToUser(otherUserReference string, giftName string) error {
func (u *User) GetGroups() ([]Group, error) { func (u *User) GetGroups() ([]Group, error) {
stmt := "SELECT [group].id, [group].name, [group].reference FROM [group] JOIN group_member ON group_member.group_id = [group].id JOIN v_user AS user ON user.id = group_member.user_id WHERE user.id = ?" stmt := "SELECT [group].id, [group].name, [group].reference FROM [group] JOIN group_member ON group_member.group_id = [group].id JOIN v_user AS user ON user.id = group_member.user_id WHERE user.id = ?"
return queryForGroups(stmt, u.Id) return queryManyGroups(stmt, u.Id)
} }
func (u *User) GetGroupByReference(reference string) (*Group, error) { func (u *User) GetGroupByReference(reference string) (*Group, error) {
stmt := "SELECT [group].id, [group].name, [group].reference FROM [group] JOIN group_member ON [group].id == group_member.group_id WHERE [group].reference = ? AND group_member.user_id = ?" stmt := "SELECT [group].id, [group].name, [group].reference FROM [group] JOIN group_member ON [group].id == group_member.group_id WHERE [group].reference = ? AND group_member.user_id = ?"
return queryForGroup(stmt, reference, u.Id) return queryOneGroup(stmt, reference, u.Id)
} }

View File

@ -1,7 +1,7 @@
package api package api
import ( import (
"lishwist/db" "lishwist/api/db"
"lishwist/templates" "lishwist/templates"
"log" "log"

View File

@ -3,7 +3,7 @@ package api
import ( import (
"log" "log"
"lishwist/db" "lishwist/api/db"
"lishwist/templates" "lishwist/templates"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
@ -77,7 +77,7 @@ func Register(username, newPassword, confirmPassword string) *RegisterProps {
existingUser, _ := db.GetUserByName(username) existingUser, _ := db.GetUserByName(username)
if existingUser != nil { if existingUser != nil {
log.Printf("Username is taken: %q\n", username) log.Printf("Username is taken: %q\n", existingUser.NormalName)
props.Username.Error = "Username is taken" props.Username.Error = "Username is taken"
return props return props
} }

1
server/env/env.go vendored
View File

@ -14,6 +14,7 @@ func GuaranteeEnv(key string) (variable string) {
return return
} }
var DatabaseFile = GuaranteeEnv("LISHWIST_DATABASE_FILE")
var SessionSecret = GuaranteeEnv("LISHWIST_SESSION_SECRET") var SessionSecret = GuaranteeEnv("LISHWIST_SESSION_SECRET")
var HostRootUrl = GuaranteeEnv("LISHWIST_HOST_ROOT_URL") var HostRootUrl = GuaranteeEnv("LISHWIST_HOST_ROOT_URL")
var HostPort = os.Getenv("LISHWIST_HOST_PORT") var HostPort = os.Getenv("LISHWIST_HOST_PORT")

View File

@ -6,7 +6,8 @@ import (
"net/http" "net/http"
"lishwist/api" "lishwist/api"
"lishwist/db" // TODO: lishwist/api/db ought not to be used outside lishwist/api
"lishwist/api/db"
"lishwist/env" "lishwist/env"
"lishwist/router" "lishwist/router"
"lishwist/routing" "lishwist/routing"

14
server/normalize/name.go Normal file
View File

@ -0,0 +1,14 @@
package normalize
import (
"strings"
)
func Trim(s string) string {
return strings.Trim(s, " \t")
}
func Name(name string) string {
name = Trim(name)
return strings.ToLower(name)
}

View File

@ -1,7 +1,7 @@
package routing package routing
import ( import (
"lishwist/db" "lishwist/api/db"
"lishwist/rsvp" "lishwist/rsvp"
"net/http" "net/http"
) )

View File

@ -1,7 +1,7 @@
package routing package routing
import ( import (
"lishwist/db" "lishwist/api/db"
"lishwist/rsvp" "lishwist/rsvp"
"net/http" "net/http"
) )

View File

@ -4,7 +4,7 @@ import (
"net/http" "net/http"
"slices" "slices"
"lishwist/db" "lishwist/api/db"
"lishwist/rsvp" "lishwist/rsvp"
) )

View File

@ -3,7 +3,7 @@ package routing
import ( import (
"net/http" "net/http"
"lishwist/db" "lishwist/api/db"
"lishwist/env" "lishwist/env"
"lishwist/rsvp" "lishwist/rsvp"
) )

View File

@ -11,7 +11,7 @@ func Login(h http.Header, r *rsvp.Request) rsvp.Response {
props := api.NewLoginProps("", "") props := api.NewLoginProps("", "")
flash := session.FlashGet("login_props") flash := session.FlashGet()
flashProps, ok := flash.(*api.LoginProps) flashProps, ok := flash.(*api.LoginProps)
if ok { if ok {
props.Username.Value = flashProps.Username.Value props.Username.Value = flashProps.Username.Value
@ -21,7 +21,7 @@ func Login(h http.Header, r *rsvp.Request) rsvp.Response {
props.Password.Error = flashProps.Password.Error props.Password.Error = flashProps.Password.Error
} }
flash = session.FlashGet("successful_registration") flash = session.FlashGet()
successfulReg, _ := flash.(bool) successfulReg, _ := flash.(bool)
if successfulReg { if successfulReg {
props.SuccessfulRegistration = true props.SuccessfulRegistration = true
@ -39,7 +39,7 @@ func LoginPost(h http.Header, r *rsvp.Request) rsvp.Response {
props := api.Login(username, password) props := api.Login(username, password)
if props != nil { if props != nil {
session.FlashSet(&props, "login_props") session.FlashSet(&props)
return rsvp.SeeOther("/").SaveSession(session) return rsvp.SeeOther("/").SaveSession(session)
} }

View File

@ -10,7 +10,7 @@ func Register(h http.Header, r *rsvp.Request) rsvp.Response {
props := api.NewRegisterProps("", "", "") props := api.NewRegisterProps("", "", "")
session := r.GetSession() session := r.GetSession()
flash := session.FlashGet("register_props") flash := session.FlashGet()
flashProps, _ := flash.(*api.RegisterProps) flashProps, _ := flash.(*api.RegisterProps)
if flashProps != nil { if flashProps != nil {
@ -36,10 +36,10 @@ func RegisterPost(h http.Header, r *rsvp.Request) rsvp.Response {
s := r.GetSession() s := r.GetSession()
if props != nil { if props != nil {
s.FlashSet(&props, "register_props") s.FlashSet(&props)
return rsvp.SeeOther("/register").SaveSession(s) return rsvp.SeeOther("/register").SaveSession(s)
} }
s.FlashSet(true, "successful_registration") s.FlashSet(true)
return rsvp.SeeOther("/").SaveSession(s) return rsvp.SeeOther("/").SaveSession(s)
} }

View File

@ -1,7 +1,7 @@
package routing package routing
import ( import (
"lishwist/db" "lishwist/api/db"
"lishwist/rsvp" "lishwist/rsvp"
"net/http" "net/http"
) )

View File

@ -1,7 +1,7 @@
package routing package routing
import ( import (
"lishwist/db" "lishwist/api/db"
"lishwist/rsvp" "lishwist/rsvp"
"net/http" "net/http"
) )

View File

@ -1,7 +1,7 @@
package routing package routing
import ( import (
"lishwist/db" "lishwist/api/db"
"lishwist/rsvp" "lishwist/rsvp"
"net/http" "net/http"
) )

View File

@ -29,6 +29,13 @@ func (res *Response) Write(w http.ResponseWriter, r *http.Request) error {
if res.SeeOther != "" { if res.SeeOther != "" {
http.Redirect(w, r, res.SeeOther, http.StatusSeeOther) http.Redirect(w, r, res.SeeOther, http.StatusSeeOther)
flash := res.Session.FlashPeek()
if flash != nil {
err := json.NewEncoder(w).Encode(flash)
if err != nil {
return err
}
}
return nil return nil
} }

View File

@ -8,8 +8,8 @@ type Session struct {
inner *sessions.Session inner *sessions.Session
} }
func (s *Session) FlashGet(key ...string) any { func (s *Session) FlashGet() any {
list := s.inner.Flashes(key...) list := s.inner.Flashes()
if len(list) < 1 { if len(list) < 1 {
return nil return nil
} else { } else {
@ -17,8 +17,17 @@ func (s *Session) FlashGet(key ...string) any {
} }
} }
func (s *Session) FlashSet(value any, key ...string) { func (s *Session) FlashPeek() any {
s.inner.AddFlash(value, key...) list, _ := s.inner.Values["_flash"].([]any)
if len(list) < 1 {
return nil
} else {
return list[0]
}
}
func (s *Session) FlashSet(value any) {
s.inner.AddFlash(value)
} }
func (s *Session) SetID(value string) { func (s *Session) SetID(value string) {