diff --git a/core/admin.go b/core/admin.go index 8f87b2a..0667389 100644 --- a/core/admin.go +++ b/core/admin.go @@ -5,7 +5,7 @@ type Admin struct { } func (s *Session) Admin() *Admin { - if s.User.IsAdmin { + if s.User().IsAdmin { return &Admin{s} } else { return nil diff --git a/core/api.snap.txt b/core/api.snap.txt new file mode 100644 index 0000000..bb64f85 --- /dev/null +++ b/core/api.snap.txt @@ -0,0 +1,109 @@ +package lishwist // import "lishwist/core" + + +VARIABLES + +var ErrorUsernameTaken = errors.New("Username is taken") + +FUNCTIONS + +func Init(dataSourceName string) error +func PrintTables(d *sql.DB) +func PrintViews(d *sql.DB) + +TYPES + +type Admin struct { + // Has unexported fields. +} + +func (a *Admin) AddUserToGroup(userId, groupId string) error + +func (a *Admin) CreateGroup(name string, reference string) (*Group, error) + +func (*Admin) GetUser(id string) (*User, error) + +func (a *Admin) ListGroups() ([]Group, error) + +func (*Admin) ListUsers() ([]User, error) + +func (a *Admin) RemoveUserFromGroup(userId, groupId string) error + +func (u *Admin) UserSetLive(userReference string, setting bool) error + +type ErrorInvalidCredentials error + +type Group struct { + Id string + Name string + Reference string + Members []User +} + +func GetGroupByReference(reference string) (*Group, error) + +func (g *Group) MemberIndex(userId string) int + +type Session struct { + Key string + Expiry time.Time + // Has unexported fields. +} + +func Login(username, password string, sessionMaxAge time.Duration) (*Session, error) + +func SessionFromKey(key string) (*Session, error) + +func (s *Session) Admin() *Admin + +func (s *Session) ClaimWishes(claims, unclaims []string) error + +func (s *Session) CompleteWishes(claims []string) error + +func (s *Session) GetGroupByReference(reference string) (*Group, error) + +func (u *Session) GetGroups() ([]Group, error) + +func (s *Session) GetOthersWishes(userReference string) ([]Wish, error) + +func (s *Session) GetWishes() ([]Wish, error) + +func (s *Session) MakeWish(name string) error + +func (s *Session) RecindWishesForUser(ids ...string) error + +func (s *Session) RevokeWishes(ids ...string) error + +func (u *Session) SuggestWishForUser(otherUserReference string, wishName string) error + +func (s *Session) User() User + +type User struct { + Id string + NormalName string + Name string + Reference string + IsAdmin bool + IsLive bool +} + +func GetUserByReference(reference string) (*User, error) + +func Register(username, newPassword string) (*User, error) + +func (u *User) GetTodo() ([]Wish, error) + +func (u *User) WishCount() (int, error) + +type Wish struct { + Id string + Name string + ClaimantId string `json:",omitempty"` + ClaimantName string `json:",omitempty"` + Sent bool + RecipientId string `json:",omitempty"` + RecipientName string `json:",omitempty"` + RecipientRef string `json:",omitempty"` + CreatorId string `json:",omitempty"` + CreatorName string `json:",omitempty"` +} diff --git a/core/group.go b/core/group.go index 230c475..2b09c7f 100644 --- a/core/group.go +++ b/core/group.go @@ -72,7 +72,7 @@ func queryManyGroupMembers(groupId string) ([]User, error) { func (s *Session) GetGroupByReference(reference string) (*Group, error) { stmt := "SELECT [group].id, [group].name, [group].reference FROM [group] JOIN group_member ON [group].id == group_member.group_id WHERE [group].reference = ? AND group_member.user_id = ?;" - return queryOneGroup(stmt, reference, s.User.Id) + return queryOneGroup(stmt, reference, s.User().Id) } func GetGroupByReference(reference string) (*Group, error) { @@ -126,5 +126,5 @@ func (a *Admin) RemoveUserFromGroup(userId, groupId string) error { // Get the groups the session user belongs to func (u *Session) GetGroups() ([]Group, error) { stmt := "SELECT [group].id, [group].name, [group].reference FROM [group] JOIN group_member ON group_member.group_id = [group].id JOIN v_user AS user ON user.id = group_member.user_id WHERE user.id = ?" - return queryManyGroups(stmt, u.Id) + return queryManyGroups(stmt, u.User().Id) } diff --git a/core/group_test.go b/core/group_test.go index de96b63..e2b7ba6 100644 --- a/core/group_test.go +++ b/core/group_test.go @@ -26,7 +26,7 @@ func TestCantSeeSelfInGroup(t *testing.T) { group, err := s.Admin().CreateGroup(" My Friends ", " my-friends ") fixtures.FailIfErr(t, err, "Failed to create group") - err = s.Admin().AddUserToGroup(s.User.Id, group.Id) + err = s.Admin().AddUserToGroup(s.User().Id, group.Id) fixtures.FailIfErr(t, err, "Failed to add self to group") err = s.Admin().AddUserToGroup(caleb.Id, group.Id) diff --git a/core/internal/db/db.go b/core/internal/db/db.go index 938dc12..7c3bf16 100644 --- a/core/internal/db/db.go +++ b/core/internal/db/db.go @@ -15,12 +15,12 @@ var Connection *sql.DB func Init(dataSourceName string) error { db, err := sql.Open("sqlite3", dataSourceName) if err != nil { - return fmt.Errorf("Failed to open db connection: %w", err) + return fmt.Errorf("failed to open db connection: %w", err) } _, err = db.Exec(initQuery) if err != nil { - return fmt.Errorf("Failed to initialize db: %w", err) + return fmt.Errorf("failed to initialize db: %w", err) } Connection = db diff --git a/core/internal/db/init.sql b/core/internal/db/init.sql index e6d94b6..4a687e2 100644 --- a/core/internal/db/init.sql +++ b/core/internal/db/init.sql @@ -37,8 +37,11 @@ CREATE TABLE IF NOT EXISTS "group_member" ( ); CREATE TABLE IF NOT EXISTS "session" ( "id" INTEGER NOT NULL UNIQUE, - "value" TEXT NOT NULL, - PRIMARY KEY("id" AUTOINCREMENT) + "key" TEXT NOT NULL UNIQUE, + "user_id" INTEGER NOT NULL, + "expiry" TEXT NOT NULL, + PRIMARY KEY("id" AUTOINCREMENT), + FOREIGN KEY("user_id") REFERENCES "user"("id") ); DROP VIEW IF EXISTS "v_user"; diff --git a/core/internal/fixtures/login.go b/core/internal/fixtures/login.go index c730027..5b94868 100644 --- a/core/internal/fixtures/login.go +++ b/core/internal/fixtures/login.go @@ -3,6 +3,7 @@ package fixtures import ( "log" "testing" + "time" lishwist "lishwist/core" @@ -26,7 +27,7 @@ func Login(t *testing.T, username, password string) *lishwist.Session { log.Fatalf("Failed to register on login fixture: %s\n", err) } - session, err := lishwist.Login(username, password) + session, err := lishwist.Login(username, password, time.Hour*24) if err != nil { log.Fatalf("Failed to login on fixture: %s\n", err) } diff --git a/core/internal/id/generate.go b/core/internal/id/generate.go new file mode 100644 index 0000000..21efd68 --- /dev/null +++ b/core/internal/id/generate.go @@ -0,0 +1,14 @@ +package id + +import ( + "crypto/rand" + "encoding/hex" +) + +func Generate() string { + bytes := make([]byte, 16) + if _, err := rand.Read(bytes); err != nil { + panic(err) + } + return hex.EncodeToString(bytes) +} diff --git a/core/login.go b/core/login.go index 47659a8..48bca3f 100644 --- a/core/login.go +++ b/core/login.go @@ -2,13 +2,14 @@ package lishwist import ( "fmt" + "time" "golang.org/x/crypto/bcrypt" ) type ErrorInvalidCredentials error -func Login(username, password string) (*Session, error) { +func Login(username, password string, sessionMaxAge time.Duration) (*Session, error) { user, err := getUserByName(username) if err != nil { return nil, ErrorInvalidCredentials(fmt.Errorf("Failed to fetch user: %w", err)) @@ -27,5 +28,10 @@ func Login(username, password string) (*Session, error) { return nil, ErrorInvalidCredentials(fmt.Errorf("Password compare failed: %w", err)) } - return &Session{*user}, nil + session, err := insertSession(*user, sessionMaxAge) + if err != nil { + return nil, fmt.Errorf("failed to insert session: %w", err) + } + + return session, nil } diff --git a/core/login_test.go b/core/login_test.go index 9603a09..1d3ad00 100644 --- a/core/login_test.go +++ b/core/login_test.go @@ -2,6 +2,7 @@ package lishwist_test import ( "testing" + "time" lishwist "lishwist/core" "lishwist/core/internal/fixtures" @@ -18,7 +19,7 @@ func TestLogin(t *testing.T) { t.Fatalf("Failed to register: %s\n", err) } - _, err = lishwist.Login("thomas", "123") + _, err = lishwist.Login("thomas", "123", time.Hour*24) if err != nil { t.Fatalf("Failed to login: %s\n", err) } diff --git a/core/session.go b/core/session.go index 1a7445b..188d328 100644 --- a/core/session.go +++ b/core/session.go @@ -1,15 +1,66 @@ package lishwist -import "fmt" +import ( + "database/sql" + "errors" + "fmt" + "time" + + "lishwist/core/internal/db" + "lishwist/core/internal/id" +) type Session struct { - User + user User + Key string + Expiry time.Time } -func SessionFromUsername(username string) (*Session, error) { - user, err := getUserByName(username) - if err != nil { - return nil, fmt.Errorf("Failed to get user: %w", err) - } - return &Session{*user}, nil +// Returns a copy of the user associated with this session +func (s *Session) User() User { + return s.user +} + +func SessionFromKey(key string) (*Session, error) { + s := Session{} + query := "SELECT user.id, user.name, user.display_name, user.reference, user.is_admin, user.is_live, session.key, session.expiry FROM v_user as user JOIN session ON user.id = session.user_id WHERE session.key = ?" + var expiry string + err := db.Connection.QueryRow(query, key).Scan( + &s.user.Id, + &s.user.Name, + &s.user.NormalName, + &s.user.Reference, + &s.user.IsAdmin, + &s.user.IsLive, + &s.Key, + &expiry, + ) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("failed to fetch session: %w", err) + } + s.Expiry, err = time.Parse(time.RFC3339Nano, expiry) + if err != nil { + return nil, fmt.Errorf("failed to parse session expiry: %w", err) + } + if time.Now().After(s.Expiry) { + return nil, nil + } + return &s, err +} + +func insertSession(user User, maxAge time.Duration) (*Session, error) { + s := Session{ + user: user, + Key: id.Generate(), + Expiry: time.Now().Add(maxAge), + } + stmt := "INSERT INTO session (key, user_id, expiry) VALUES (?, ?, ?)" + _, err := db.Connection.Exec(stmt, &s.Key, &user.Id, &s.Expiry) + if err != nil { + return nil, fmt.Errorf("failed to execute query: %w", err) + } + return &s, nil } diff --git a/core/session/store.go b/core/session/store.go deleted file mode 100644 index 3045fc3..0000000 --- a/core/session/store.go +++ /dev/null @@ -1,39 +0,0 @@ -package session - -import ( - "fmt" - "lishwist/core/internal/db" - - "github.com/Teajey/sqlstore" -) - -func NewStore(keyPairs ...[]byte) (*sqlstore.Store, error) { - deleteStmt, err := db.Connection.Prepare("DELETE FROM session WHERE id = ?;") - if err != nil { - return nil, fmt.Errorf("Failed to prepare delete statement: %w", err) - } - - insertStmt, err := db.Connection.Prepare("INSERT INTO session (value) VALUES (?);") - if err != nil { - return nil, fmt.Errorf("Failed to prepare insert statement: %w", err) - } - - selectStmt, err := db.Connection.Prepare("SELECT value FROM session WHERE id = ?;") - if err != nil { - return nil, fmt.Errorf("Failed to prepare select statement: %w", err) - } - - updateStmt, err := db.Connection.Prepare("UPDATE session SET value = ?2 WHERE id = ?1;") - if err != nil { - return nil, fmt.Errorf("Failed to prepare update statement: %w", err) - } - - s := sqlstore.NewSqlStore(db.Connection, sqlstore.Statements{ - Delete: deleteStmt, - Insert: insertStmt, - Select: selectStmt, - Update: updateStmt, - }, keyPairs...) - - return s, nil -} diff --git a/core/wish.go b/core/wish.go index 3b7dffb..7cf705f 100644 --- a/core/wish.go +++ b/core/wish.go @@ -23,8 +23,8 @@ type Wish struct { } func (s *Session) GetWishes() ([]Wish, error) { - stmt := "SELECT wish.id, wish.name, wish.sent FROM wish WHERE wish.creator_id = ?1 AND wish.recipient_id = ?1" - rows, err := db.Connection.Query(stmt, s.User.Id) + stmt := "SELECT wish.id, wish.name, wish.sent FROM wish WHERE wish.creator_id = ?1 AND wish.recipient_id = ?1 ORDER BY wish.sent" + rows, err := db.Connection.Query(stmt, s.User().Id) if err != nil { return nil, fmt.Errorf("Query execution failed: %w", err) } @@ -54,17 +54,17 @@ func (s *Session) GetWishes() ([]Wish, error) { func (s *Session) MakeWish(name string) error { stmt := "INSERT INTO wish (name, recipient_id, creator_id) VALUES (?, ?, ?)" - _, err := db.Connection.Exec(stmt, strings.TrimSpace(name), s.User.Id, s.User.Id) + _, err := db.Connection.Exec(stmt, strings.TrimSpace(name), s.User().Id, s.User().Id) if err != nil { return fmt.Errorf("Query execution failed: %w", err) } return nil } -func (u *Session) deleteWishes(tx *sql.Tx, ids []string) error { +func (s *Session) deleteWishes(tx *sql.Tx, ids []string) error { stmt := "DELETE FROM wish WHERE wish.creator_id = ? AND wish.id = ?" for _, id := range ids { - r, err := tx.Exec(stmt, u.Id, id) + r, err := tx.Exec(stmt, s.User().Id, id) if err != nil { return err } @@ -107,10 +107,10 @@ func (s *Session) GetOthersWishes(userReference string) ([]Wish, error) { if err != nil { return nil, fmt.Errorf("Failed to get other user: %w", err) } - if otherUser.Id == s.User.Id { + if otherUser.Id == s.User().Id { return nil, errors.New("Use (s *Session) GetWishes() to view your own wishes") } - stmt := "SELECT wish.id, wish.name, claimant.id, claimant.name, wish.sent, wish.creator_id, creator.name, wish.recipient_id FROM wish JOIN v_user AS user ON wish.recipient_id = user.id LEFT JOIN v_user AS claimant ON wish.claimant_id = claimant.id LEFT JOIN v_user AS creator ON wish.creator_id = creator.id WHERE user.id = ?" + stmt := "SELECT wish.id, wish.name, claimant.id, claimant.name, wish.sent, wish.creator_id, creator.name, wish.recipient_id FROM wish JOIN v_user AS user ON wish.recipient_id = user.id LEFT JOIN v_user AS claimant ON wish.claimant_id = claimant.id LEFT JOIN v_user AS creator ON wish.creator_id = creator.id WHERE user.id = ? ORDER BY wish.sent" rows, err := db.Connection.Query(stmt, otherUser.Id) if err != nil { return nil, fmt.Errorf("Failed to execute query: %w", err) @@ -164,7 +164,7 @@ func (s *Session) executeClaims(tx *sql.Tx, claims, unclaims []string) error { claimStmt := "UPDATE wish SET claimant_id = ? WHERE id = ?" unclaimStmt := "UPDATE wish SET claimant_id = NULL WHERE id = ?" for _, id := range claims { - r, err := tx.Exec(claimStmt, s.Id, id) + r, err := tx.Exec(claimStmt, s.User().Id, id) if err != nil { return err } @@ -264,7 +264,7 @@ func (u *Session) SuggestWishForUser(otherUserReference string, wishName string) return err } stmt := "INSERT INTO wish (name, recipient_id, creator_id) VALUES (?, ?, ?)" - _, err = db.Connection.Exec(stmt, wishName, otherUser.Id, u.Id) + _, err = db.Connection.Exec(stmt, wishName, otherUser.Id, u.User().Id) if err != nil { return err } diff --git a/go.work b/go.work index 7328ab8..7afb314 100644 --- a/go.work +++ b/go.work @@ -1,4 +1,4 @@ -go 1.23.3 +go 1.24.5 use ( ./core diff --git a/http/env/env.go b/http/env/env.go index ee9a036..10a998b 100644 --- a/http/env/env.go +++ b/http/env/env.go @@ -6,12 +6,12 @@ import ( "os" ) -func GuaranteeEnv(key string) (variable string) { +func GuaranteeEnv(key string) string { variable, ok := os.LookupEnv(key) if !ok || variable == "" { log.Fatalln("Missing environment variable:", key) } - return + return variable } var DatabaseFile = GuaranteeEnv("LISHWIST_DATABASE_FILE") diff --git a/http/go.mod b/http/go.mod index ca1fb6a..debf689 100644 --- a/http/go.mod +++ b/http/go.mod @@ -1,25 +1,18 @@ module lishwist/http -go 1.23 +go 1.23.3 -toolchain go1.23.3 +toolchain go1.24.5 require ( - github.com/Teajey/sqlstore v0.0.6 - github.com/glebarez/go-sqlite v1.22.0 - github.com/google/uuid v1.6.0 + github.com/Teajey/rsvp v0.13.1 github.com/gorilla/sessions v1.4.0 - golang.org/x/crypto v0.22.0 + golang.org/x/crypto v0.39.0 ) +require github.com/gorilla/securecookie v1.1.2 + require ( - github.com/dustin/go-humanize v1.0.1 // indirect - github.com/gorilla/securecookie v1.1.2 // indirect - github.com/mattn/go-isatty v0.0.20 // indirect - github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect - golang.org/x/sys v0.19.0 // indirect - modernc.org/libc v1.37.6 // indirect - modernc.org/mathutil v1.6.0 // indirect - modernc.org/memory v1.7.2 // indirect - modernc.org/sqlite v1.28.0 // indirect + github.com/vmihailenco/msgpack/v5 v5.4.1 // indirect + github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect ) diff --git a/http/go.sum b/http/go.sum index 8b650a2..b36c5c0 100644 --- a/http/go.sum +++ b/http/go.sum @@ -1,39 +1,22 @@ -github.com/Teajey/sqlstore v0.0.3 h1:6Y1jz9/yw1cj/Z/jrii0s87RAomKWr/07B1auDgw8pg= -github.com/Teajey/sqlstore v0.0.3/go.mod h1:hjk0S593/2Q4QxkEXCgpThj9w5KWGTQi9JtgfziHXXk= -github.com/Teajey/sqlstore v0.0.4 h1:ATe25BD8cV0FUw4w2qlccx5m0c5kQI0K4ksl/LnSHsc= -github.com/Teajey/sqlstore v0.0.4/go.mod h1:hjk0S593/2Q4QxkEXCgpThj9w5KWGTQi9JtgfziHXXk= -github.com/Teajey/sqlstore v0.0.5 h1:WZvu54baa8+9n1sKQe9GuxBVwSISw+xCkw4VFSwwIs8= -github.com/Teajey/sqlstore v0.0.5/go.mod h1:hjk0S593/2Q4QxkEXCgpThj9w5KWGTQi9JtgfziHXXk= -github.com/Teajey/sqlstore v0.0.6 h1:kUEpA+3BKFHZl128MuMeYY6zVcmq1QmOlNyofcFEJOA= -github.com/Teajey/sqlstore v0.0.6/go.mod h1:hjk0S593/2Q4QxkEXCgpThj9w5KWGTQi9JtgfziHXXk= -github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= -github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= -github.com/glebarez/go-sqlite v1.22.0 h1:uAcMJhaA6r3LHMTFgP0SifzgXg46yJkgxqyuyec+ruQ= -github.com/glebarez/go-sqlite v1.22.0/go.mod h1:PlBIdHe0+aUEFn+r2/uthrWq4FxbzugL0L8Li6yQJbc= +github.com/Teajey/rsvp v0.13.1 h1:0lw+JosaWmdjSmXoKQYBRS9nptSZPInm60Y5GQ3llEU= +github.com/Teajey/rsvp v0.13.1/go.mod h1:z0L20VphVg+Ec2+hnpLFTG2MZTrWYFprav1kpxDba0Q= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= -github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= -github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= -github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ= github.com/gorilla/sessions v1.4.0/go.mod h1:FLWm50oby91+hl7p/wRxDth9bWSuk0qVL2emc7lT5ik= -github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= -github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= -github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= -golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= -golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= -golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= -golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -modernc.org/libc v1.37.6 h1:orZH3c5wmhIQFTXF+Nt+eeauyd+ZIt2BX6ARe+kD+aw= -modernc.org/libc v1.37.6/go.mod h1:YAXkAZ8ktnkCKaN9sw/UDeUVkGYJ/YquGO4FTi5nmHE= -modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4= -modernc.org/mathutil v1.6.0/go.mod h1:Ui5Q9q1TR2gFm0AQRqQUaBWFLAhQpCwNcuhBOSedWPo= -modernc.org/memory v1.7.2 h1:Klh90S215mmH8c9gO98QxQFsY+W451E8AnzjoE2ee1E= -modernc.org/memory v1.7.2/go.mod h1:NO4NVCQy0N7ln+T9ngWqOQfi7ley4vpwvARR+Hjw95E= -modernc.org/sqlite v1.28.0 h1:Zx+LyDDmXczNnEQdvPuEfcFVA2ZPyaD7UCZDjef3BHQ= -modernc.org/sqlite v1.28.0/go.mod h1:Qxpazz0zH8Z1xCFyi5GSL3FzbtZ3fvbjmywNogldEW0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= +github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= +golang.org/x/crypto v0.39.0 h1:SHs+kF4LP+f+p14esP5jAoDpHU8Gu/v9lFRK6IT5imM= +golang.org/x/crypto v0.39.0/go.mod h1:L+Xg3Wf6HoL4Bn4238Z6ft6KfEpN0tJGo53AAPC632U= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/http/internal/id/generate.go b/http/internal/id/generate.go new file mode 100644 index 0000000..21efd68 --- /dev/null +++ b/http/internal/id/generate.go @@ -0,0 +1,14 @@ +package id + +import ( + "crypto/rand" + "encoding/hex" +) + +func Generate() string { + bytes := make([]byte, 16) + if _, err := rand.Read(bytes); err != nil { + panic(err) + } + return hex.EncodeToString(bytes) +} diff --git a/http/main.go b/http/main.go index ad53076..a4ce39e 100644 --- a/http/main.go +++ b/http/main.go @@ -1,67 +1,25 @@ package main import ( - "encoding/gob" "log" "net/http" lishwist "lishwist/core" - "lishwist/core/session" - "lishwist/http/api" "lishwist/http/env" - "lishwist/http/router" - "lishwist/http/routing" + "lishwist/http/server" ) func main() { - gob.Register(&api.RegisterProps{}) - gob.Register(&api.LoginProps{}) - err := lishwist.Init(env.DatabaseFile) if err != nil { log.Fatalf("Failed to init Lishwist: %s\n", err) } - store, err := session.NewStore([]byte(env.SessionSecret)) - if err != nil { - log.Fatalf("Failed to initialize session store: %s\n", err) - } - store.Options.MaxAge = 86_400 - store.Options.Secure = !env.InDev - store.Options.HttpOnly = true - - r := router.New(store) - - r.Public.HandleFunc("GET /", routing.Login) - r.Public.HandleFunc("GET /groups/{groupReference}", routing.PublicGroup) - r.Public.HandleFunc("GET /lists/{userReference}", routing.PublicWishlist) - r.Public.HandleFunc("GET /register", routing.Register) - r.Public.HandleFunc("POST /", routing.LoginPost) - r.Public.HandleFunc("POST /register", routing.RegisterPost) - - r.Private.HandleFunc("GET /", routing.NotFound) - r.Private.HandleFunc("GET /groups", routing.ExpectAppSession(routing.Groups)) - r.Private.HandleFunc("GET /groups/{groupReference}", routing.ExpectAppSession(routing.Group)) - r.Private.HandleFunc("GET /lists/{userReference}", routing.ExpectAppSession(routing.ForeignWishlist)) - r.Private.HandleFunc("GET /users", routing.ExpectAppSession(routing.Users)) - r.Private.HandleFunc("GET /users/{userReference}", routing.ExpectAppSession(routing.User)) - r.Private.HandleFunc("GET /{$}", routing.ExpectAppSession(routing.Home)) - r.Private.HandleFunc("POST /groups/{groupReference}", routing.ExpectAppSession(routing.GroupPost)) - r.Private.HandleFunc("POST /list/{userReference}", routing.ExpectAppSession(routing.ForeignWishlistPost)) - r.Private.HandleFunc("POST /logout", routing.LogoutPost) - r.Private.HandleFunc("POST /users/{userReference}", routing.ExpectAppSession(routing.UserPost)) - r.Private.HandleFunc("POST /{$}", routing.ExpectAppSession(routing.HomePost)) - - // Deprecated - r.Private.HandleFunc("GET /group/{groupReference}", routing.ExpectAppSession(routing.Group)) - r.Private.HandleFunc("GET /list/{userReference}", routing.ExpectAppSession(routing.ForeignWishlist)) - r.Public.HandleFunc("GET /group/{groupReference}", routing.PublicGroup) - r.Public.HandleFunc("GET /list/{userReference}", routing.PublicWishlist) - - http.Handle("/", r) + useSecureCookies := !env.InDev + r := server.Create(useSecureCookies) log.Printf("Running at http://127.0.0.1:%s\n", env.ServePort) - err = http.ListenAndServe(":"+env.ServePort, nil) + err = http.ListenAndServe(":"+env.ServePort, r) if err != nil { log.Fatalln("Failed to listen and server:", err) } diff --git a/http/response/handler.go b/http/response/handler.go new file mode 100644 index 0000000..3962387 --- /dev/null +++ b/http/response/handler.go @@ -0,0 +1,56 @@ +package response + +import ( + "log" + "net/http" + + "lishwist/http/session" + "lishwist/http/templates" + + "github.com/Teajey/rsvp" +) + +type ServeMux struct { + inner *rsvp.ServeMux + store *session.Store +} + +func NewServeMux(store *session.Store) *ServeMux { + mux := rsvp.NewServeMux() + mux.Config.HtmlTemplate = templates.Template + return &ServeMux{ + inner: mux, + store: store, + } +} + +func (m *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { + m.inner.ServeHTTP(w, r) +} + +type Handler interface { + ServeHTTP(*Session, http.Header, *http.Request) rsvp.Response +} + +type HandlerFunc func(*Session, http.Header, *http.Request) rsvp.Response + +func (m *ServeMux) HandleFunc(pattern string, handler HandlerFunc) { + m.inner.MiddleHandleFunc(pattern, func(w http.ResponseWriter, r *http.Request) rsvp.Response { + session := m.GetSession(r) + + response := handler(session, w.Header(), r) + + if session.written { + err := session.inner.Save(r, w) + if err != nil { + log.Printf("Failed to save session: %s\n", err) + } + } + + return response + }) +} + +func (m *ServeMux) Handle(pattern string, handler Handler) { + m.HandleFunc(pattern, handler.ServeHTTP) +} diff --git a/http/response/request.go b/http/response/request.go new file mode 100644 index 0000000..a14e51f --- /dev/null +++ b/http/response/request.go @@ -0,0 +1,10 @@ +package response + +import ( + "net/http" +) + +func (m *ServeMux) GetSession(r *http.Request) *Session { + session, _ := m.store.Get(r, "lishwist_user") + return &Session{inner: session} +} diff --git a/http/response/response.go b/http/response/response.go new file mode 100644 index 0000000..b62703e --- /dev/null +++ b/http/response/response.go @@ -0,0 +1,26 @@ +package response + +import ( + "fmt" + "net/http" + + "github.com/Teajey/rsvp" +) + +func NotFound() rsvp.Response { + return Error(http.StatusNotFound, "Page not found") +} + +func Error(status int, format string, a ...any) rsvp.Response { + return rsvp.Response{ + Body: fmt.Sprintf(format, a...), + Status: status, + } +} + +func Data(templateName string, body any) rsvp.Response { + return rsvp.Response{ + Body: body, + TemplateName: templateName, + } +} diff --git a/http/rsvp/session.go b/http/response/session.go similarity index 84% rename from http/rsvp/session.go rename to http/response/session.go index 4da6ed0..d5e70fc 100644 --- a/http/rsvp/session.go +++ b/http/response/session.go @@ -1,11 +1,12 @@ -package rsvp +package response import ( "github.com/gorilla/sessions" ) type Session struct { - inner *sessions.Session + inner *sessions.Session + written bool } func (s *Session) FlashGet() any { @@ -13,6 +14,7 @@ func (s *Session) FlashGet() any { if len(list) < 1 { return nil } else { + s.written = true return list[0] } } @@ -32,14 +34,17 @@ func (s *Session) FlashPeek() any { func (s *Session) FlashSet(value any) { s.inner.AddFlash(value) + s.written = true } func (s *Session) SetID(value string) { s.inner.ID = value + s.written = true } func (s *Session) SetValue(key any, value any) { s.inner.Values[key] = value + s.written = true } func (s *Session) GetValue(key any) any { @@ -48,6 +53,7 @@ func (s *Session) GetValue(key any) any { func (s *Session) ClearValues() { s.inner.Values = nil + s.written = true } func (s *Session) Options() *sessions.Options { diff --git a/http/router/router.go b/http/router/router.go index 8891f68..4387305 100644 --- a/http/router/router.go +++ b/http/router/router.go @@ -1,21 +1,21 @@ package router import ( - "lishwist/http/rsvp" "net/http" - "github.com/Teajey/sqlstore" + "lishwist/http/response" + "lishwist/http/session" ) type VisibilityRouter struct { - Store *sqlstore.Store - Public *rsvp.ServeMux - Private *rsvp.ServeMux + store *session.Store + Public *response.ServeMux + Private *response.ServeMux } func (s *VisibilityRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { - session, _ := s.Store.Get(r, "lishwist_user") - authorized, _ := session.Values["authorized"].(bool) + session, _ := s.store.Get(r, "lishwist_user") + _, authorized := session.Values["sessionKey"] if authorized { s.Private.ServeHTTP(w, r) @@ -24,10 +24,15 @@ func (s *VisibilityRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -func New(store *sqlstore.Store) *VisibilityRouter { +func New(store *session.Store) *VisibilityRouter { return &VisibilityRouter{ - Store: store, - Public: rsvp.NewServeMux(store), - Private: rsvp.NewServeMux(store), + store: store, + Public: response.NewServeMux(store), + Private: response.NewServeMux(store), } } + +func (r *VisibilityRouter) HandleFunc(pattern string, handler response.HandlerFunc) { + r.Public.HandleFunc(pattern, handler) + r.Private.HandleFunc(pattern, handler) +} diff --git a/http/routing/context.go b/http/routing/context.go index a45b9f1..d9faf46 100644 --- a/http/routing/context.go +++ b/http/routing/context.go @@ -1,24 +1,34 @@ package routing import ( - lishwist "lishwist/core" - "lishwist/http/rsvp" + "log" "net/http" + + lishwist "lishwist/core" + + "lishwist/http/response" + + "github.com/Teajey/rsvp" ) -func ExpectAppSession(next func(*lishwist.Session, http.Header, *rsvp.Request) rsvp.Response) rsvp.HandlerFunc { - return func(w http.Header, r *rsvp.Request) rsvp.Response { - session := r.GetSession() - username, ok := session.GetValue("username").(string) +func ExpectAppSession(next func(*lishwist.Session, http.Header, *http.Request) rsvp.Response) response.HandlerFunc { + return func(session *response.Session, h http.Header, r *http.Request) rsvp.Response { + sessionKey, ok := session.GetValue("sessionKey").(string) if !ok { - return rsvp.Error(http.StatusInternalServerError, "Something went wrong.").Log("Failed to get username from session") + log.Printf("Failed to get key from session\n") + return response.Error(http.StatusInternalServerError, "Something went wrong.") } - appSession, err := lishwist.SessionFromUsername(username) + appSession, err := lishwist.SessionFromKey(sessionKey) if err != nil { - return rsvp.Error(http.StatusInternalServerError, "Something went wrong.").Log("Failed to get session by username %q: %s", username, err) + log.Printf("Failed to get session by key %v: %s\n", sessionKey, err) + return response.Error(http.StatusInternalServerError, "Something went wrong.") + } + if appSession == nil { + log.Printf("Session not found under key: %s\n", sessionKey) + return response.Error(http.StatusInternalServerError, "Something went wrong.") } - return next(appSession, w, r) + return next(appSession, h, r) } } diff --git a/http/routing/error.go b/http/routing/error.go deleted file mode 100644 index 91fa5d0..0000000 --- a/http/routing/error.go +++ /dev/null @@ -1,17 +0,0 @@ -package routing - -import ( - "fmt" - "log" - "net/http" - "strings" -) - -func writeGeneralErrorJson(w http.ResponseWriter, status int, format string, a ...any) { - msg := fmt.Sprintf(format, a...) - log.Printf("General error: %s\n", msg) - w.Header().Add("Content-Type", "application/json") - w.WriteHeader(status) - escapedMsg := strings.ReplaceAll(msg, `"`, `\"`) - _, _ = w.Write([]byte(fmt.Sprintf(`{"GeneralError":"%s"}`, escapedMsg))) -} diff --git a/http/routing/foreign_wishlist.go b/http/routing/foreign_wishlist.go index 9a91c88..a2c4136 100644 --- a/http/routing/foreign_wishlist.go +++ b/http/routing/foreign_wishlist.go @@ -2,8 +2,11 @@ package routing import ( lishwist "lishwist/core" - "lishwist/http/rsvp" + "lishwist/http/response" + "log" "net/http" + + "github.com/Teajey/rsvp" ) type foreignWishlistProps struct { @@ -13,24 +16,27 @@ type foreignWishlistProps struct { Gifts []lishwist.Wish } -func ForeignWishlist(app *lishwist.Session, h http.Header, r *rsvp.Request) rsvp.Response { +func ForeignWishlist(app *lishwist.Session, h http.Header, r *http.Request) rsvp.Response { userReference := r.PathValue("userReference") - if app.User.Reference == userReference { - return rsvp.SeeOther("/") + user := app.User() + if user.Reference == userReference { + return rsvp.Found("/", "You're not allowed to view your own wishlist!") } otherUser, err := lishwist.GetUserByReference(userReference) if err != nil { - return rsvp.Error(http.StatusInternalServerError, "An error occurred while fetching this user :(").Log("Couldn't get user by reference %q: %s", userReference, err) + log.Printf("Couldn't get user by reference %q: %s\n", userReference, err) + return response.Error(http.StatusInternalServerError, "An error occurred while fetching this user :(") } if otherUser == nil { - return rsvp.Error(http.StatusInternalServerError, "User not found") + return response.Error(http.StatusInternalServerError, "User not found") } wishes, err := app.GetOthersWishes(userReference) if err != nil { - return rsvp.Error(http.StatusInternalServerError, "An error occurred while fetching this user :(").Log("%q couldn't get wishes of other user %q: %s", app.User.Name, otherUser.Name, err) + log.Printf("%q couldn't get wishes of other user %q: %s\n", user.Name, otherUser.Name, err) + return response.Error(http.StatusInternalServerError, "An error occurred while fetching this user :(") } - p := foreignWishlistProps{CurrentUserId: app.User.Id, CurrentUserName: app.User.Name, Username: otherUser.Name, Gifts: wishes} - return rsvp.Data("foreign_wishlist.gotmpl", p) + p := foreignWishlistProps{CurrentUserId: user.Id, CurrentUserName: user.Name, Username: otherUser.Name, Gifts: wishes} + return response.Data("foreign_wishlist.gotmpl", p) } type publicWishlistProps struct { @@ -38,19 +44,21 @@ type publicWishlistProps struct { GiftCount int } -func PublicWishlist(h http.Header, r *rsvp.Request) rsvp.Response { +func PublicWishlist(s *response.Session, h http.Header, r *http.Request) rsvp.Response { userReference := r.PathValue("userReference") otherUser, err := lishwist.GetUserByReference(userReference) if err != nil { - return rsvp.Error(http.StatusInternalServerError, "An error occurred while fetching this user :(").Log("Couldn't get user by reference %q on public wishlist: %s", userReference, err) + log.Printf("Couldn't get user by reference %q on public wishlist: %s\n", userReference, err) + return response.Error(http.StatusInternalServerError, "An error occurred while fetching this user :(") } if otherUser == nil { - return rsvp.Error(http.StatusInternalServerError, "User not found") + return response.Error(http.StatusInternalServerError, "User not found") } giftCount, err := otherUser.WishCount() if err != nil { - return rsvp.Error(http.StatusInternalServerError, "An error occurred while fetching this user :(").Log("Couldn't get wishes of user %q on public wishlist: %s", otherUser.Name, err) + log.Printf("Couldn't get wishes of user %q on public wishlist: %s\n", otherUser.Name, err) + return response.Error(http.StatusInternalServerError, "An error occurred while fetching this user :(") } p := publicWishlistProps{Username: otherUser.Name, GiftCount: giftCount} - return rsvp.Data("public_foreign_wishlist.gotmpl", p) + return response.Data("public_foreign_wishlist.gotmpl", p) } diff --git a/http/routing/groups.go b/http/routing/groups.go index eefdac3..33f8749 100644 --- a/http/routing/groups.go +++ b/http/routing/groups.go @@ -1,11 +1,14 @@ package routing import ( + "log" "net/http" "slices" lishwist "lishwist/core" - "lishwist/http/rsvp" + "lishwist/http/response" + + "github.com/Teajey/rsvp" ) type GroupProps struct { @@ -13,97 +16,105 @@ type GroupProps struct { CurrentUsername string } -func AdminGroup(app *lishwist.Session, h http.Header, r *rsvp.Request) rsvp.Response { +func AdminGroup(app *lishwist.Session, h http.Header, r *http.Request) rsvp.Response { reference := r.PathValue("groupReference") group, err := app.GetGroupByReference(reference) if err != nil { - return rsvp.Error(http.StatusInternalServerError, "Couldn't get group: %s", err) + return response.Error(http.StatusInternalServerError, "Couldn't get group: %s", err) } if group == nil { - return rsvp.Error(http.StatusNotFound, "Group not found") + return response.Error(http.StatusNotFound, "Group not found") } - if !app.User.IsAdmin { - index := group.MemberIndex(app.User.Id) + user := app.User() + if !user.IsAdmin { + index := group.MemberIndex(user.Id) group.Members = slices.Delete(group.Members, index, index+1) } p := GroupProps{ Group: group, - CurrentUsername: app.User.Name, + CurrentUsername: user.Name, } - return rsvp.Data("group_page.gotmpl", p) + return response.Data("group_page.gotmpl", p) } -func Group(app *lishwist.Session, h http.Header, r *rsvp.Request) rsvp.Response { - if app.User.IsAdmin { +func Group(app *lishwist.Session, h http.Header, r *http.Request) rsvp.Response { + user := app.User() + if user.IsAdmin { return AdminGroup(app, h, r) } groupReference := r.PathValue("groupReference") group, err := app.GetGroupByReference(groupReference) if err != nil { - return rsvp.Error(http.StatusInternalServerError, "An error occurred while fetching this group :(").Log("Couldn't get group: %s", err) + log.Printf("Couldn't get group: %s\n", err) + return response.Error(http.StatusInternalServerError, "An error occurred while fetching this group :(") } if group == nil { - return rsvp.Error(http.StatusNotFound, "Group not found. (It might be because you're not a member)") + return response.Error(http.StatusNotFound, "Group not found. (It might be because you're not a member)") } - index := group.MemberIndex(app.User.Id) + index := group.MemberIndex(user.Id) group.Members = slices.Delete(group.Members, index, index+1) p := GroupProps{ Group: group, - CurrentUsername: app.User.Name, + CurrentUsername: user.Name, } - return rsvp.Data("group_page.gotmpl", p) + return response.Data("group_page.gotmpl", p) } -func PublicGroup(h http.Header, r *rsvp.Request) rsvp.Response { +func PublicGroup(s *response.Session, h http.Header, r *http.Request) rsvp.Response { groupReference := r.PathValue("groupReference") group, err := lishwist.GetGroupByReference(groupReference) if err != nil { - return rsvp.Error(http.StatusInternalServerError, "An error occurred while fetching this group :(").Log("Couldn't get group: %s", err) + log.Printf("Couldn't get group: %s\n", err) + return response.Error(http.StatusInternalServerError, "An error occurred while fetching this group :(") } p := GroupProps{ Group: group, } - return rsvp.Data("public_group_page.gotmpl", p) + return response.Data("public_group_page.gotmpl", p) } -func GroupPost(app *lishwist.Session, h http.Header, r *rsvp.Request) rsvp.Response { +func GroupPost(app *lishwist.Session, h http.Header, r *http.Request) rsvp.Response { admin := app.Admin() if admin == nil { - return NotFound(h, r) + return response.NotFound() + } + + err := r.ParseForm() + if err != nil { + return response.Error(http.StatusBadRequest, "Failed to parse form") } - form := r.ParseForm() var group *lishwist.Group reference := r.PathValue("groupReference") - name := form.Get("name") - addUsers := form["addUser"] - removeUsers := form["removeUser"] + name := r.Form.Get("name") + addUsers := r.Form["addUser"] + removeUsers := r.Form["removeUser"] if name != "" { createdGroup, err := admin.CreateGroup(name, reference) if err != nil { - return rsvp.Error(http.StatusInternalServerError, "Failed to create group: %s", err) + return response.Error(http.StatusInternalServerError, "Failed to create group: %s", err) } group = createdGroup } else { existingGroup, err := lishwist.GetGroupByReference(reference) if err != nil { - return rsvp.Error(http.StatusInternalServerError, "Failed to get group: %s", err) + return response.Error(http.StatusInternalServerError, "Failed to get group: %s", err) } if existingGroup == nil { - return rsvp.Error(http.StatusNotFound, "Group not found", err) + return response.Error(http.StatusNotFound, "Group not found: %s", err) } group = existingGroup for _, userId := range removeUsers { index := group.MemberIndex(userId) if index == -1 { - return rsvp.Error(http.StatusBadRequest, "Group %q does not contain a user with id %s", reference, userId) + return response.Error(http.StatusBadRequest, "Group %q does not contain a user with id %s", reference, userId) } err = admin.RemoveUserFromGroup(userId, group.Id) if err != nil { - return rsvp.Error(http.StatusInternalServerError, "On group %q failed to remove user with id %s: %s", reference, userId, err) + return response.Error(http.StatusInternalServerError, "On group %q failed to remove user with id %s: %s", reference, userId, err) } group.Members = slices.Delete(group.Members, index, index+1) } @@ -112,31 +123,31 @@ func GroupPost(app *lishwist.Session, h http.Header, r *rsvp.Request) rsvp.Respo for _, userId := range addUsers { user, err := admin.GetUser(userId) if err != nil { - return rsvp.Error(http.StatusInternalServerError, "Groups exists, but a user with id %s could not be fetched: %s", userId, err) + return response.Error(http.StatusInternalServerError, "Groups exists, but a user with id %s could not be fetched: %s", userId, err) } if user == nil { - return rsvp.Error(http.StatusInternalServerError, "Groups exists, but a user with id %s does not exist", userId) + return response.Error(http.StatusInternalServerError, "Groups exists, but a user with id %s does not exist", userId) } err = admin.AddUserToGroup(user.Id, group.Id) if err != nil { - return rsvp.Error(http.StatusInternalServerError, "Groups exists, but failed to add user with id %s: %s", userId, err) + return response.Error(http.StatusInternalServerError, "Groups exists, but failed to add user with id %s: %s", userId, err) } group.Members = append(group.Members, *user) } - return rsvp.Data("", group) + return response.Data("", group) } -func Groups(app *lishwist.Session, h http.Header, r *rsvp.Request) rsvp.Response { +func Groups(app *lishwist.Session, h http.Header, r *http.Request) rsvp.Response { admin := app.Admin() if admin == nil { - return NotFound(h, r) + return response.NotFound() } groups, err := admin.ListGroups() if err != nil { - return rsvp.Error(http.StatusInternalServerError, "Failed to get groups: %s", err) + return response.Error(http.StatusInternalServerError, "Failed to get groups: %s", err) } - return rsvp.Data("", groups) + return response.Data("", groups) } diff --git a/http/routing/home.go b/http/routing/home.go index 4e4229c..589b746 100644 --- a/http/routing/home.go +++ b/http/routing/home.go @@ -1,11 +1,14 @@ package routing import ( + "log" "net/http" lishwist "lishwist/core" "lishwist/http/env" - "lishwist/http/rsvp" + "lishwist/http/response" + + "github.com/Teajey/rsvp" ) type HomeProps struct { @@ -17,26 +20,34 @@ type HomeProps struct { Groups []lishwist.Group } -func Home(app *lishwist.Session, h http.Header, r *rsvp.Request) rsvp.Response { +func Home(app *lishwist.Session, h http.Header, r *http.Request) rsvp.Response { gifts, err := app.GetWishes() if err != nil { - return rsvp.Error(http.StatusInternalServerError, "An error occurred while fetching your wishlist :(").Log("Failed to get gifts: %s", err) + log.Printf("Failed to get gifts: %s\n", err) + return response.Error(http.StatusInternalServerError, "An error occurred while fetching your wishlist :(") } - todo, err := app.GetTodo() + user := app.User() + todo, err := user.GetTodo() if err != nil { - return rsvp.Error(http.StatusInternalServerError, "An error occurred while fetching your wishlist :(").Log("Failed to get todo: %s", err) + log.Printf("Failed to get todo: %s\n", err) + return response.Error(http.StatusInternalServerError, "An error occurred while fetching your wishlist :(") } groups, err := app.GetGroups() if err != nil { - return rsvp.Error(http.StatusInternalServerError, "An error occurred while fetching your wishlist :(").Log("Failed to get groups: %s", err) + log.Printf("Failed to get groups: %s\n", err) + return response.Error(http.StatusInternalServerError, "An error occurred while fetching your wishlist :(") } - p := HomeProps{Username: app.User.Name, Gifts: gifts, Todo: todo, Reference: app.User.Reference, HostUrl: env.HostUrl.String(), Groups: groups} - return rsvp.Data("home.gotmpl", p) + p := HomeProps{Username: user.Name, Gifts: gifts, Todo: todo, Reference: user.Reference, HostUrl: env.HostUrl.String(), Groups: groups} + return response.Data("home.gotmpl", p) } -func HomePost(app *lishwist.Session, h http.Header, r *rsvp.Request) rsvp.Response { - form := r.ParseForm() - switch form.Get("intent") { +func HomePost(app *lishwist.Session, h http.Header, r *http.Request) rsvp.Response { + err := r.ParseForm() + if err != nil { + return response.Error(http.StatusBadRequest, "Failed to parse form") + } + + switch r.Form.Get("intent") { case "add_idea": return WishlistAdd(app, h, r) case "delete_idea": diff --git a/http/routing/login.go b/http/routing/login.go index bb5cf65..e38895e 100644 --- a/http/routing/login.go +++ b/http/routing/login.go @@ -1,18 +1,22 @@ package routing import ( + "errors" + "log" + "net/http" + "time" + lishwist "lishwist/core" "lishwist/http/api" - "lishwist/http/rsvp" - "net/http" + "lishwist/http/response" + + "github.com/Teajey/rsvp" ) -func Login(h http.Header, r *rsvp.Request) rsvp.Response { - session := r.GetSession() - +func Login(s *response.Session, h http.Header, r *http.Request) rsvp.Response { props := api.NewLoginProps("", "") - flash := session.FlashGet() + flash := s.FlashGet() flashProps, ok := flash.(*api.LoginProps) if ok { props.Username.Value = flashProps.Username.Value @@ -22,48 +26,55 @@ func Login(h http.Header, r *rsvp.Request) rsvp.Response { props.Password.Error = flashProps.Password.Error } - flash = session.FlashGet() + flash = s.FlashGet() successfulReg, _ := flash.(bool) if successfulReg { props.SuccessfulRegistration = true } - return rsvp.Data("login.gotmpl", props).SaveSession(session) + return rsvp.Response{TemplateName: "login.gotmpl", Body: props} } -func LoginPost(h http.Header, r *rsvp.Request) rsvp.Response { - form := r.ParseForm() - session := r.GetSession() +func LoginPost(session *response.Session, h http.Header, r *http.Request) rsvp.Response { + err := r.ParseForm() + if err != nil { + return response.Error(http.StatusBadRequest, "Failed to parse form") + } - username := form.Get("username") - password := form.Get("password") + username := r.Form.Get("username") + password := r.Form.Get("password") props := api.NewLoginProps(username, password) + resp := rsvp.SeeOther(r.URL.Path, props) + valid := props.Validate() props.Password.Value = "" if !valid { session.FlashSet(&props) - return rsvp.SeeOther("/").SaveSession(session).Log("Invalid props: %#v\n", props) + log.Printf("Invalid props: %#v\n", props) + return resp } - app, err := lishwist.Login(username, password) + appSession, err := lishwist.Login(username, password, time.Hour*24) if err != nil { - switch err.(type) { - case lishwist.ErrorInvalidCredentials: + var targ lishwist.ErrorInvalidCredentials + switch { + case errors.As(err, &targ): props.GeneralError = "Username or password invalid" session.FlashSet(&props) - return rsvp.SeeOther("/").SaveSession(session).Log("Invalid credentials: %s: %#v\n", err, props) + log.Printf("Invalid credentials: %s: %#v\n", err, props) + return resp default: props.GeneralError = "Something went wrong." session.FlashSet(&props) - return rsvp.SeeOther("/").SaveSession(session).Log("Login error: %s\n", err) + log.Printf("Login error: %s\n", err) + return resp } } session.SetID("") - session.SetValue("authorized", true) - session.SetValue("username", app.User.Name) + session.SetValue("sessionKey", appSession.Key) - return rsvp.SeeOther(r.URL().Path).SaveSession(session) + return rsvp.SeeOther(r.URL.Path, "Login successful!") } diff --git a/http/routing/logout.go b/http/routing/logout.go index dc1bb66..f32ffd6 100644 --- a/http/routing/logout.go +++ b/http/routing/logout.go @@ -1,15 +1,15 @@ package routing import ( - "lishwist/http/rsvp" + "lishwist/http/response" "net/http" + + "github.com/Teajey/rsvp" ) -func LogoutPost(h http.Header, r *rsvp.Request) rsvp.Response { - session := r.GetSession() - +func LogoutPost(session *response.Session, h http.Header, r *http.Request) rsvp.Response { session.Options().MaxAge = 0 session.ClearValues() - return rsvp.SeeOther("/").SaveSession(session) + return rsvp.SeeOther("/", "Logout successful") } diff --git a/http/routing/not_found.go b/http/routing/not_found.go index 43a0efd..5ebc6c4 100644 --- a/http/routing/not_found.go +++ b/http/routing/not_found.go @@ -3,9 +3,11 @@ package routing import ( "net/http" - "lishwist/http/rsvp" + "lishwist/http/response" + + "github.com/Teajey/rsvp" ) -func NotFound(h http.Header, r *rsvp.Request) rsvp.Response { - return rsvp.Error(http.StatusNotFound, "Page not found") +func NotFound(s *response.Session, h http.Header, r *http.Request) rsvp.Response { + return response.Error(http.StatusNotFound, "Page not found") } diff --git a/http/routing/register.go b/http/routing/register.go index b7a8d4c..dc1b6c3 100644 --- a/http/routing/register.go +++ b/http/routing/register.go @@ -4,14 +4,16 @@ import ( "errors" lishwist "lishwist/core" "lishwist/http/api" - "lishwist/http/rsvp" + "lishwist/http/response" + "log" "net/http" + + "github.com/Teajey/rsvp" ) -func Register(h http.Header, r *rsvp.Request) rsvp.Response { +func Register(session *response.Session, h http.Header, r *http.Request) rsvp.Response { props := api.NewRegisterProps("", "", "") - session := r.GetSession() flash := session.FlashGet() flashProps, _ := flash.(*api.RegisterProps) @@ -23,16 +25,18 @@ func Register(h http.Header, r *rsvp.Request) rsvp.Response { props.ConfirmPassword.Error = flashProps.ConfirmPassword.Error } - return rsvp.Data("register.gotmpl", props).SaveSession(session) + return response.Data("register.gotmpl", props) } -func RegisterPost(h http.Header, r *rsvp.Request) rsvp.Response { - form := r.ParseForm() - s := r.GetSession() +func RegisterPost(s *response.Session, h http.Header, r *http.Request) rsvp.Response { + err := r.ParseForm() + if err != nil { + return response.Error(http.StatusBadRequest, "Failed to parse form") + } - username := form.Get("username") - newPassword := form.Get("newPassword") - confirmPassword := form.Get("confirmPassword") + username := r.Form.Get("username") + newPassword := r.Form.Get("newPassword") + confirmPassword := r.Form.Get("confirmPassword") props := api.NewRegisterProps(username, newPassword, confirmPassword) @@ -41,10 +45,11 @@ func RegisterPost(h http.Header, r *rsvp.Request) rsvp.Response { props.ConfirmPassword.Value = "" if !valid { s.FlashSet(&props) - return rsvp.SeeOther("/").SaveSession(s).Log("Invalid props: %#v\n", props) + log.Printf("Invalid register props: %#v\n", props) + return rsvp.SeeOther(r.URL.Path, props) } - _, err := lishwist.Register(username, newPassword) + _, err = lishwist.Register(username, newPassword) if err != nil { if errors.Is(err, lishwist.ErrorUsernameTaken) { props.Username.Error = "Username is taken" @@ -52,9 +57,10 @@ func RegisterPost(h http.Header, r *rsvp.Request) rsvp.Response { props.GeneralError = "Something went wrong." } s.FlashSet(&props) - return rsvp.SeeOther("/register").SaveSession(s).Log("Registration failed: %s\n", err) + log.Printf("Registration failed: %s\n", err) + return rsvp.SeeOther(r.URL.Path, props) } s.FlashSet(true) - return rsvp.SeeOther("/").SaveSession(s) + return rsvp.SeeOther("/", "Registration successful!") } diff --git a/http/routing/todo.go b/http/routing/todo.go index 4f6764a..91a7823 100644 --- a/http/routing/todo.go +++ b/http/routing/todo.go @@ -1,29 +1,38 @@ package routing import ( - lishwist "lishwist/core" - "lishwist/http/rsvp" + "log" "net/http" + + "github.com/Teajey/rsvp" + + lishwist "lishwist/core" + "lishwist/http/response" ) -func TodoUpdate(app *lishwist.Session, h http.Header, r *rsvp.Request) rsvp.Response { - form := r.ParseForm() +func TodoUpdate(app *lishwist.Session, h http.Header, r *http.Request) rsvp.Response { + err := r.ParseForm() + if err != nil { + return response.Error(http.StatusBadRequest, "Failed to parse form") + } - switch form.Get("intent") { + switch r.Form.Get("intent") { case "unclaim_todo": - unclaims := form["gift"] + unclaims := r.Form["gift"] err := app.ClaimWishes([]string{}, unclaims) if err != nil { - return rsvp.Error(http.StatusInternalServerError, "Failed to update claim...").LogError(err) + log.Printf("%s\n", err) + return response.Error(http.StatusInternalServerError, "Failed to update claim...") } case "complete_todo": - claims := form["gift"] + claims := r.Form["gift"] err := app.CompleteWishes(claims) if err != nil { - return rsvp.Error(http.StatusInternalServerError, "Failed to complete gifts...").LogError(err) + log.Printf("%s\n", err) + return response.Error(http.StatusInternalServerError, "Failed to complete gifts...") } default: - return rsvp.Error(http.StatusBadRequest, "Invalid intent") + return response.Error(http.StatusBadRequest, "Invalid intent") } - return rsvp.SeeOther("/") + return rsvp.SeeOther("/", "Update successful") } diff --git a/http/routing/users.go b/http/routing/users.go index becb2e3..0ab648f 100644 --- a/http/routing/users.go +++ b/http/routing/users.go @@ -2,72 +2,77 @@ package routing import ( lishwist "lishwist/core" - "lishwist/http/rsvp" + "lishwist/http/response" "net/http" + + "github.com/Teajey/rsvp" ) -func Users(app *lishwist.Session, h http.Header, r *rsvp.Request) rsvp.Response { +func Users(app *lishwist.Session, h http.Header, r *http.Request) rsvp.Response { admin := app.Admin() if admin == nil { - return NotFound(h, r) + return response.NotFound() } users, err := admin.ListUsers() if err != nil { - return rsvp.Error(http.StatusInternalServerError, "Failed to get users: %s", err) + return response.Error(http.StatusInternalServerError, "Failed to get users: %s", err) } - return rsvp.Data("", users) + return response.Data("", users) } -func User(app *lishwist.Session, h http.Header, r *rsvp.Request) rsvp.Response { +func User(app *lishwist.Session, h http.Header, r *http.Request) rsvp.Response { admin := app.Admin() if admin == nil { - return NotFound(h, r) + return response.NotFound() } reference := r.PathValue("userReference") user, err := lishwist.GetUserByReference(reference) if err != nil { - return rsvp.Error(http.StatusInternalServerError, "Failed to get user: %s", err) + return response.Error(http.StatusInternalServerError, "Failed to get user: %s", err) } if user == nil { - return rsvp.Error(http.StatusNotFound, "User not found") + return response.Error(http.StatusNotFound, "User not found") } - return rsvp.Data("", user) + return response.Data("", user) } -func UserPost(app *lishwist.Session, h http.Header, r *rsvp.Request) rsvp.Response { +func UserPost(app *lishwist.Session, h http.Header, r *http.Request) rsvp.Response { admin := app.Admin() if admin == nil { - return NotFound(h, r) + return response.NotFound() } - form := r.ParseForm() + err := r.ParseForm() + if err != nil { + return response.Error(http.StatusBadRequest, "Failed to parse form") + } reference := r.PathValue("userReference") - if reference == app.User.Reference { - return rsvp.Error(http.StatusForbidden, "You cannot delete yourself.") + if reference == app.User().Reference { + return response.Error(http.StatusForbidden, "You cannot delete yourself.") } user, err := lishwist.GetUserByReference(reference) if err != nil { - return rsvp.Error(http.StatusInternalServerError, "Failed to get user: %s", err) + return response.Error(http.StatusInternalServerError, "Failed to get user: %s", err) } if user == nil { - return rsvp.Error(http.StatusNotFound, "User not found") + return response.Error(http.StatusNotFound, "User not found") } - intent := form.Get("intent") + intent := r.Form.Get("intent") if intent != "" { err = admin.UserSetLive(reference, intent != "delete") if err != nil { - return rsvp.Error(http.StatusInternalServerError, "Failed to delete user: %s", err) + return response.Error(http.StatusInternalServerError, "Failed to delete user: %s", err) } } - return rsvp.Data("", user) + return response.Data("", user) } diff --git a/http/routing/wishlist.go b/http/routing/wishlist.go index dec0cf3..4b7c783 100644 --- a/http/routing/wishlist.go +++ b/http/routing/wishlist.go @@ -1,68 +1,95 @@ package routing import ( - lishwist "lishwist/core" - "lishwist/http/rsvp" + "log" "net/http" + + "github.com/Teajey/rsvp" + + lishwist "lishwist/core" + "lishwist/http/response" ) -func WishlistAdd(app *lishwist.Session, h http.Header, r *rsvp.Request) rsvp.Response { - form := r.ParseForm() - newGiftName := form.Get("gift_name") - err := app.MakeWish(newGiftName) +func WishlistAdd(app *lishwist.Session, h http.Header, r *http.Request) rsvp.Response { + err := r.ParseForm() if err != nil { - return rsvp.Error(http.StatusInternalServerError, "Failed to add gift.").LogError(err) + return response.Error(http.StatusBadRequest, "Failed to parse form") } - return rsvp.SeeOther("/") + + newGiftName := r.Form.Get("gift_name") + err = app.MakeWish(newGiftName) + if err != nil { + log.Printf("%s\n", err) + return response.Error(http.StatusInternalServerError, "Failed to add gift.") + } + return rsvp.SeeOther("/", "Wish added!") } -func WishlistDelete(app *lishwist.Session, h http.Header, r *rsvp.Request) rsvp.Response { - form := r.ParseForm() - targets := form["gift"] - err := app.RevokeWishes(targets...) +func WishlistDelete(app *lishwist.Session, h http.Header, r *http.Request) rsvp.Response { + err := r.ParseForm() if err != nil { - return rsvp.Error(http.StatusInternalServerError, "Failed to remove gifts.").LogError(err) + return response.Error(http.StatusBadRequest, "Failed to parse form") } - return rsvp.SeeOther("/") + + targets := r.Form["gift"] + err = app.RevokeWishes(targets...) + if err != nil { + log.Printf("%s\n", err) + return response.Error(http.StatusInternalServerError, "Failed to remove gifts.") + } + return rsvp.SeeOther("/", "Wish deleted") } -func ForeignWishlistPost(app *lishwist.Session, h http.Header, r *rsvp.Request) rsvp.Response { - form := r.ParseForm() +func ForeignWishlistPost(app *lishwist.Session, h http.Header, r *http.Request) rsvp.Response { + err := r.ParseForm() + if err != nil { + return response.Error(http.StatusBadRequest, "Failed to parse form") + } + userReference := r.PathValue("userReference") - intent := form.Get("intent") + resp := rsvp.SeeOther("/list/"+userReference, "Update successful") + intent := r.Form.Get("intent") switch intent { case "claim": - claims := form["unclaimed"] - unclaims := form["claimed"] + claims := r.Form["unclaimed"] + unclaims := r.Form["claimed"] err := app.ClaimWishes(claims, unclaims) if err != nil { - return rsvp.Error(http.StatusInternalServerError, "Failed to update claim...").LogError(err) + log.Printf("%s\n", err) + return response.Error(http.StatusInternalServerError, "Failed to update claim...") } + resp.Body = "Successfully claimed wishes" case "complete": - claims := form["claimed"] + claims := r.Form["claimed"] err := app.CompleteWishes(claims) if err != nil { - return rsvp.Error(http.StatusInternalServerError, "Failed to complete gifts...").LogError(err) + log.Printf("%s\n", err) + return response.Error(http.StatusInternalServerError, "Failed to complete gifts...") } + resp.Body = "Successfully completed wishes" case "add": - wishName := form.Get("gift_name") + wishName := r.Form.Get("gift_name") if wishName == "" { - return rsvp.Error(http.StatusBadRequest, "Gift name not provided") + return response.Error(http.StatusBadRequest, "Gift name not provided") } err := app.SuggestWishForUser(userReference, wishName) if err != nil { - return rsvp.Error(http.StatusInternalServerError, "Failed to add gift idea to other user...").LogError(err) + log.Printf("%s\n", err) + return response.Error(http.StatusInternalServerError, "Failed to add gift idea to other user...") } + resp.Body = "Successfully added wishes" case "delete": - claims := form["unclaimed"] - unclaims := form["claimed"] + claims := r.Form["unclaimed"] + unclaims := r.Form["claimed"] gifts := append(claims, unclaims...) err := app.RecindWishesForUser(gifts...) if err != nil { - return rsvp.Error(http.StatusInternalServerError, "Failed to remove gift idea for other user...").LogError(err) + log.Printf("%s\n", err) + return response.Error(http.StatusInternalServerError, "Failed to remove gift idea for other user...") } + resp.Body = "Successfully removed wishes" default: - return rsvp.Error(http.StatusBadRequest, "Invalid intent %q", intent) + return response.Error(http.StatusBadRequest, "Invalid intent %q", intent) } - return rsvp.SeeOther("/list/" + userReference) + return resp } diff --git a/http/rsvp/handler.go b/http/rsvp/handler.go deleted file mode 100644 index c172834..0000000 --- a/http/rsvp/handler.go +++ /dev/null @@ -1,56 +0,0 @@ -package rsvp - -import ( - "log" - "net/http" - - "github.com/Teajey/sqlstore" -) - -type ServeMux struct { - inner *http.ServeMux - store *sqlstore.Store -} - -func NewServeMux(store *sqlstore.Store) *ServeMux { - return &ServeMux{ - inner: http.NewServeMux(), - store: store, - } -} - -func (m *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { - m.inner.ServeHTTP(w, r) -} - -type Handler interface { - ServeHTTP(h http.Header, r *Request) Response -} - -type HandlerFunc func(h http.Header, r *Request) Response - -func (m *ServeMux) HandleFunc(pattern string, handler HandlerFunc) { - m.inner.HandleFunc(pattern, func(w http.ResponseWriter, stdReq *http.Request) { - r := wrapStdRequest(m.store, stdReq) - - response := handler(w.Header(), &r) - - err := response.Write(w, stdReq) - if err != nil { - response.Data = struct{ Message error }{err} - response.HtmlTemplateName = "error_page.gotmpl" - response.Status = http.StatusInternalServerError - } else { - return - } - err = response.Write(w, stdReq) - if err != nil { - log.Printf("Failed to write rsvp.Response to bytes: %s\n", err) - http.Error(w, "Failed to write response", http.StatusInternalServerError) - } - }) -} - -func (m *ServeMux) Handle(pattern string, handler Handler) { - m.HandleFunc(pattern, handler.ServeHTTP) -} diff --git a/http/rsvp/request.go b/http/rsvp/request.go deleted file mode 100644 index 7e387cd..0000000 --- a/http/rsvp/request.go +++ /dev/null @@ -1,42 +0,0 @@ -package rsvp - -import ( - "log" - "net/http" - "net/url" - - "github.com/Teajey/sqlstore" -) - -type Request struct { - inner *http.Request - store *sqlstore.Store -} - -func wrapStdRequest(store *sqlstore.Store, r *http.Request) Request { - return Request{ - inner: r, - store: store, - } -} - -func (r *Request) GetSession() Session { - session, _ := r.store.Get(r.inner, "lishwist_user") - return Session{session} -} - -func (r *Request) ParseForm() url.Values { - err := r.inner.ParseForm() - if err != nil { - log.Printf("Failed to parse form: %s\n", err) - } - return r.inner.Form -} - -func (r *Request) PathValue(name string) string { - return r.inner.PathValue(name) -} - -func (r *Request) URL() *url.URL { - return r.inner.URL -} diff --git a/http/rsvp/response.go b/http/rsvp/response.go deleted file mode 100644 index 2c380b5..0000000 --- a/http/rsvp/response.go +++ /dev/null @@ -1,119 +0,0 @@ -package rsvp - -import ( - "bytes" - "encoding/json" - "fmt" - "lishwist/http/templates" - "log" - "net/http" - "strings" -) - -type Response struct { - HtmlTemplateName string - Data any - SeeOther string - Session *Session - Status int - LogMessage string -} - -func (res *Response) Write(w http.ResponseWriter, r *http.Request) error { - if res.LogMessage != "" { - log.Printf("%s --- %s\n", res.Data, res.LogMessage) - } - - if res.Session != nil { - err := res.Session.inner.Save(r, w) - if err != nil { - return fmt.Errorf("Failed to write session: %w", err) - } - } - - if res.SeeOther != "" { - http.Redirect(w, r, res.SeeOther, http.StatusSeeOther) - if res.Session != nil { - flash := res.Session.FlashPeek() - if flash != nil { - err := json.NewEncoder(w).Encode(flash) - if err != nil { - return err - } - } - } - return nil - } - - bodyBytes := bytes.NewBuffer([]byte{}) - accept := r.Header.Get("Accept") - - if res.Status != 0 { - w.WriteHeader(res.Status) - } - - switch { - case strings.Contains(accept, "text/html"): - if res.HtmlTemplateName == "" { - err := json.NewEncoder(bodyBytes).Encode(res.Data) - if err != nil { - return err - } - } else { - err := templates.Execute(bodyBytes, res.HtmlTemplateName, res.Data) - if err != nil { - return err - } - } - case strings.Contains(accept, "application/json"): - err := json.NewEncoder(bodyBytes).Encode(res.Data) - if err != nil { - return err - } - default: - err := json.NewEncoder(bodyBytes).Encode(res.Data) - if err != nil { - return err - } - } - - _, err := w.Write(bodyBytes.Bytes()) - if err != nil { - log.Printf("Failed to write rsvp.Response to HTTP: %s\n", err) - } - return nil -} - -func Data(htmlTemplateName string, data any) Response { - return Response{ - HtmlTemplateName: htmlTemplateName, - Data: data, - } -} - -func (r Response) Log(format string, a ...any) Response { - r.LogMessage = fmt.Sprintf(format, a...) - return r -} - -func (r Response) LogError(err error) Response { - r.LogMessage = fmt.Sprintf("%s", err) - return r -} - -func (r Response) SaveSession(s Session) Response { - r.Session = &s - return r -} - -func SeeOther(url string) Response { - return Response{SeeOther: url} -} - -func Error(status int, format string, a ...any) Response { - return Response{ - Status: status, - HtmlTemplateName: "error_page.gotmpl", - Data: struct{ Message string }{fmt.Sprintf(format, a...)}, - } -} diff --git a/http/server/server.go b/http/server/server.go new file mode 100644 index 0000000..8b18eea --- /dev/null +++ b/http/server/server.go @@ -0,0 +1,61 @@ +package server + +import ( + "encoding/gob" + "net/http" + "strings" + + "lishwist/http/api" + "lishwist/http/env" + "lishwist/http/response" + "lishwist/http/router" + "lishwist/http/routing" + "lishwist/http/session" + + "github.com/Teajey/rsvp" +) + +func prefixMovedPermanently(before, after string) response.HandlerFunc { + return func(s *response.Session, h http.Header, r *http.Request) rsvp.Response { + suffix := strings.TrimPrefix(r.RequestURI, before) + return rsvp.MovedPermanently(after + suffix) + } +} + +func Create(useSecureCookies bool) *router.VisibilityRouter { + gob.Register(&api.RegisterProps{}) + gob.Register(&api.LoginProps{}) + + store := session.NewInMemoryStore([]byte(env.SessionSecret)) + store.Options.MaxAge = 86_400 // 24 hours in seconds + store.Options.Secure = useSecureCookies + store.Options.HttpOnly = true + + r := router.New(store) + + r.Public.HandleFunc("GET /", routing.Login) + r.Public.HandleFunc("GET /groups/{groupReference}", routing.PublicGroup) + r.Public.HandleFunc("GET /lists/{userReference}", routing.PublicWishlist) + r.Public.HandleFunc("GET /register", routing.Register) + r.Public.HandleFunc("POST /", routing.LoginPost) + r.Public.HandleFunc("POST /register", routing.RegisterPost) + + r.Private.HandleFunc("GET /", routing.NotFound) + r.Private.HandleFunc("GET /groups", routing.ExpectAppSession(routing.Groups)) + r.Private.HandleFunc("GET /groups/{groupReference}", routing.ExpectAppSession(routing.Group)) + r.Private.HandleFunc("GET /lists/{userReference}", routing.ExpectAppSession(routing.ForeignWishlist)) + r.Private.HandleFunc("GET /users", routing.ExpectAppSession(routing.Users)) + r.Private.HandleFunc("GET /users/{userReference}", routing.ExpectAppSession(routing.User)) + r.Private.HandleFunc("GET /{$}", routing.ExpectAppSession(routing.Home)) + r.Private.HandleFunc("POST /groups/{groupReference}", routing.ExpectAppSession(routing.GroupPost)) + r.Private.HandleFunc("POST /list/{userReference}", routing.ExpectAppSession(routing.ForeignWishlistPost)) + r.Private.HandleFunc("POST /logout", routing.LogoutPost) + r.Private.HandleFunc("POST /users/{userReference}", routing.ExpectAppSession(routing.UserPost)) + r.Private.HandleFunc("POST /{$}", routing.ExpectAppSession(routing.HomePost)) + + // Deprecated + r.HandleFunc("GET /group/{groupReference}", prefixMovedPermanently("/group/", "/groups/")) + r.HandleFunc("GET /list/{userReference}", prefixMovedPermanently("/list/", "/lists/")) + + return r +} diff --git a/http/session/inmemory.go b/http/session/inmemory.go new file mode 100644 index 0000000..1184f46 --- /dev/null +++ b/http/session/inmemory.go @@ -0,0 +1,42 @@ +package session + +import ( + "errors" + "lishwist/http/internal/id" + + "github.com/gorilla/securecookie" + "github.com/gorilla/sessions" +) + +var inMemStore = make(map[string]string) + +var errNotFound = errors.New("not found") + +func NewInMemoryStore(keyPairs ...[]byte) *Store { + return &Store{ + callbacks: Callbacks{ + Delete: func(key string) error { + delete(inMemStore, key) + return nil + }, + Insert: func(encodedValues string) (string, error) { + key := id.Generate() + inMemStore[key] = encodedValues + return key, nil + }, + Select: func(key string) (string, error) { + encodedValues, ok := inMemStore[key] + if !ok { + return "", errNotFound + } + return encodedValues, nil + }, + Update: func(key string, encodedValues string) error { + inMemStore[key] = encodedValues + return nil + }, + }, + Codecs: securecookie.CodecsFromPairs(keyPairs...), + Options: &sessions.Options{}, + } +} diff --git a/http/session/session.go b/http/session/session.go deleted file mode 100644 index 2d695f8..0000000 --- a/http/session/session.go +++ /dev/null @@ -1,25 +0,0 @@ -package sesh - -import ( - "log" - "net/http" - - "github.com/gorilla/sessions" -) - -func GetFirstFlash(w http.ResponseWriter, r *http.Request, session *sessions.Session, key ...string) (any, error) { - flashes := session.Flashes(key...) - - if len(flashes) < 1 { - return nil, nil - } - - flash := flashes[0] - - if err := session.Save(r, w); err != nil { - log.Println("Couldn't save session:", err) - return nil, err - } - - return flash, nil -} diff --git a/http/session/store.go b/http/session/store.go new file mode 100644 index 0000000..a68ceb4 --- /dev/null +++ b/http/session/store.go @@ -0,0 +1,115 @@ +package session + +import ( + "fmt" + "net/http" + + "github.com/gorilla/securecookie" + "github.com/gorilla/sessions" +) + +type Callbacks struct { + Delete func(id string) error + Insert func(encodedValues string) (string, error) + Select func(id string) (string, error) + Update func(id, encodedValues string) error +} + +type Store struct { + callbacks Callbacks + Codecs []securecookie.Codec + Options *sessions.Options +} + +func NewGenericStore(cb Callbacks, keyPairs ...[]byte) *Store { + return &Store{ + callbacks: cb, + Codecs: securecookie.CodecsFromPairs(keyPairs...), + Options: &sessions.Options{}, + } +} + +// Get should return a cached session. +func (m *Store) Get(r *http.Request, name string) (*sessions.Session, error) { + return sessions.GetRegistry(r).Get(m, name) +} + +// New should create and return a new session. +// +// Note that New should never return a nil session, even in the case of +// an error if using the Registry infrastructure to cache the session. +func (s *Store) New(r *http.Request, name string) (*sessions.Session, error) { + session := sessions.NewSession(s, name) + opts := *s.Options + session.Options = &opts + session.IsNew = true + + var err error + + c, errCookie := r.Cookie(name) + if errCookie != nil { + return session, nil + } + + err = securecookie.DecodeMulti(name, c.Value, &session.ID, s.Codecs...) + if err != nil { + return session, fmt.Errorf("failed to decode session id: %w", err) + } + + sessionValue, err := s.callbacks.Select(session.ID) + if err != nil { + return session, fmt.Errorf("failed to get session value: %w", err) + } + + err = securecookie.DecodeMulti(name, string(sessionValue), &session.Values, s.Codecs...) + if err == nil { + session.IsNew = false + } else { + err = fmt.Errorf("failed to decode session values: %w", err) + } + + return session, err +} + +// Save should persist session to the underlying store implementation. +func (s *Store) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error { + // Delete if max-age is <= 0 + if session.Options.MaxAge <= 0 { + err := s.callbacks.Delete(session.ID) + if err != nil { + return fmt.Errorf("failed to delete session: %w", err) + } + http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options)) + return nil + } + + encodedValues, err := securecookie.EncodeMulti(session.Name(), session.Values, + s.Codecs...) + if err != nil { + return fmt.Errorf("failed to encode cookie value: %w", err) + } + + if session.ID == "" { + i, err := s.callbacks.Insert(encodedValues) + if err != nil { + return fmt.Errorf("failed to insert session: %w", err) + } + + session.ID = i + } else { + err := s.callbacks.Update(session.ID, encodedValues) + if err != nil { + return fmt.Errorf("failed to update session: %w", err) + } + } + + encodedId, err := securecookie.EncodeMulti(session.Name(), session.ID, + s.Codecs...) + if err != nil { + return fmt.Errorf("failed to encode cookie value: %w", err) + } + + http.SetCookie(w, sessions.NewCookie(session.Name(), encodedId, session.Options)) + + return nil +} diff --git a/http/templates/base.gotmpl b/http/templates/base.gotmpl index 5407dc1..fcfa5f1 100644 --- a/http/templates/base.gotmpl +++ b/http/templates/base.gotmpl @@ -12,10 +12,7 @@ {{end}} {{end}} - - - -
+{{define "head"}}{{.Message}}
-{{.Message}}
+They don't have any gift ideas. Ask them to think of something, or add an idea yourself! 👇 (everyone - except them will be able to see it and claim it)
- {{end}} -There's nobody else in this group.
- {{end}} + +There's nobody else in this group.
+ {{end}} +Your list is empty. Think of some things to add!
- {{end}} -Registration successful. Now you can login.
-{{.}}
-Registration successful. Now you can login.
+{{.}}
+{{.Username}} hasn't written any gift ideas!
-{{template "login_prompt"}} to add some! :^)
- {{else}} - {{if eq .GiftCount 1}} -{{.Username}} has only written one gift idea.
-{{template "login_prompt"}} to claim it, or add more! :^)
- {{else}} -{{.Username}} has written {{.GiftCount}} gift ideas.
-{{template "login_prompt"}} to claim an idea, or add more! :^)
- {{end}} - {{end}} + +{{.Username}} hasn't written any gift ideas!
+{{template "login_prompt"}} to add some! :^)
+ {{else}} + {{if eq .GiftCount 1}} +{{.Username}} has only written one gift idea.
+{{template "login_prompt"}} to claim it, or add more! :^)
+ {{else}} +{{.Username}} has written {{.GiftCount}} gift ideas.
+{{template "login_prompt"}} to claim an idea, or add more! :^)
+ {{end}} + {{end}} +{{template "login_prompt"}} to see your groups
- {{with .Group.Members}} -There's nobody else in this group.
- {{end}} + +{{template "login_prompt"}} to see your groups
+ {{with .Group.Members}} +There's nobody else in this group.
+ {{end}} +Your password will be stored in a safe, responsible manner; but don't trust my programming skills!
-Maybe use a password here that you don't use for important things...
-{{.}}
-Your password will be stored in a safe, responsible manner; but don't trust my programming skills!
+Maybe use a password here that you don't use for important things...
+{{.}}
+