diff --git a/core/api.snap.txt b/core/api.snap.txt index 04aafcd..bb64f85 100644 --- a/core/api.snap.txt +++ b/core/api.snap.txt @@ -45,12 +45,14 @@ func GetGroupByReference(reference string) (*Group, error) func (g *Group) MemberIndex(userId string) int type Session struct { + Key string + Expiry time.Time // Has unexported fields. } -func Login(username, password string) (*Session, error) +func Login(username, password string, sessionMaxAge time.Duration) (*Session, error) -func SessionFromUsername(username string) (*Session, error) +func SessionFromKey(key string) (*Session, error) func (s *Session) Admin() *Admin diff --git a/core/internal/db/db.go b/core/internal/db/db.go index 938dc12..7c3bf16 100644 --- a/core/internal/db/db.go +++ b/core/internal/db/db.go @@ -15,12 +15,12 @@ var Connection *sql.DB func Init(dataSourceName string) error { db, err := sql.Open("sqlite3", dataSourceName) if err != nil { - return fmt.Errorf("Failed to open db connection: %w", err) + return fmt.Errorf("failed to open db connection: %w", err) } _, err = db.Exec(initQuery) if err != nil { - return fmt.Errorf("Failed to initialize db: %w", err) + return fmt.Errorf("failed to initialize db: %w", err) } Connection = db diff --git a/core/internal/db/init.sql b/core/internal/db/init.sql index e6d94b6..4a687e2 100644 --- a/core/internal/db/init.sql +++ b/core/internal/db/init.sql @@ -37,8 +37,11 @@ CREATE TABLE IF NOT EXISTS "group_member" ( ); CREATE TABLE IF NOT EXISTS "session" ( "id" INTEGER NOT NULL UNIQUE, - "value" TEXT NOT NULL, - PRIMARY KEY("id" AUTOINCREMENT) + "key" TEXT NOT NULL UNIQUE, + "user_id" INTEGER NOT NULL, + "expiry" TEXT NOT NULL, + PRIMARY KEY("id" AUTOINCREMENT), + FOREIGN KEY("user_id") REFERENCES "user"("id") ); DROP VIEW IF EXISTS "v_user"; diff --git a/core/internal/fixtures/login.go b/core/internal/fixtures/login.go index c730027..5b94868 100644 --- a/core/internal/fixtures/login.go +++ b/core/internal/fixtures/login.go @@ -3,6 +3,7 @@ package fixtures import ( "log" "testing" + "time" lishwist "lishwist/core" @@ -26,7 +27,7 @@ func Login(t *testing.T, username, password string) *lishwist.Session { log.Fatalf("Failed to register on login fixture: %s\n", err) } - session, err := lishwist.Login(username, password) + session, err := lishwist.Login(username, password, time.Hour*24) if err != nil { log.Fatalf("Failed to login on fixture: %s\n", err) } diff --git a/core/internal/id/generate.go b/core/internal/id/generate.go new file mode 100644 index 0000000..21efd68 --- /dev/null +++ b/core/internal/id/generate.go @@ -0,0 +1,14 @@ +package id + +import ( + "crypto/rand" + "encoding/hex" +) + +func Generate() string { + bytes := make([]byte, 16) + if _, err := rand.Read(bytes); err != nil { + panic(err) + } + return hex.EncodeToString(bytes) +} diff --git a/core/login.go b/core/login.go index 47659a8..48bca3f 100644 --- a/core/login.go +++ b/core/login.go @@ -2,13 +2,14 @@ package lishwist import ( "fmt" + "time" "golang.org/x/crypto/bcrypt" ) type ErrorInvalidCredentials error -func Login(username, password string) (*Session, error) { +func Login(username, password string, sessionMaxAge time.Duration) (*Session, error) { user, err := getUserByName(username) if err != nil { return nil, ErrorInvalidCredentials(fmt.Errorf("Failed to fetch user: %w", err)) @@ -27,5 +28,10 @@ func Login(username, password string) (*Session, error) { return nil, ErrorInvalidCredentials(fmt.Errorf("Password compare failed: %w", err)) } - return &Session{*user}, nil + session, err := insertSession(*user, sessionMaxAge) + if err != nil { + return nil, fmt.Errorf("failed to insert session: %w", err) + } + + return session, nil } diff --git a/core/login_test.go b/core/login_test.go index 9603a09..1d3ad00 100644 --- a/core/login_test.go +++ b/core/login_test.go @@ -2,6 +2,7 @@ package lishwist_test import ( "testing" + "time" lishwist "lishwist/core" "lishwist/core/internal/fixtures" @@ -18,7 +19,7 @@ func TestLogin(t *testing.T) { t.Fatalf("Failed to register: %s\n", err) } - _, err = lishwist.Login("thomas", "123") + _, err = lishwist.Login("thomas", "123", time.Hour*24) if err != nil { t.Fatalf("Failed to login: %s\n", err) } diff --git a/core/session.go b/core/session.go index 2e6eab6..188d328 100644 --- a/core/session.go +++ b/core/session.go @@ -1,9 +1,19 @@ package lishwist -import "fmt" +import ( + "database/sql" + "errors" + "fmt" + "time" + + "lishwist/core/internal/db" + "lishwist/core/internal/id" +) type Session struct { - user User + user User + Key string + Expiry time.Time } // Returns a copy of the user associated with this session @@ -11,10 +21,46 @@ func (s *Session) User() User { return s.user } -func SessionFromUsername(username string) (*Session, error) { - user, err := getUserByName(username) - if err != nil { - return nil, fmt.Errorf("Failed to get user: %w", err) +func SessionFromKey(key string) (*Session, error) { + s := Session{} + query := "SELECT user.id, user.name, user.display_name, user.reference, user.is_admin, user.is_live, session.key, session.expiry FROM v_user as user JOIN session ON user.id = session.user_id WHERE session.key = ?" + var expiry string + err := db.Connection.QueryRow(query, key).Scan( + &s.user.Id, + &s.user.Name, + &s.user.NormalName, + &s.user.Reference, + &s.user.IsAdmin, + &s.user.IsLive, + &s.Key, + &expiry, + ) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil } - return &Session{*user}, nil + if err != nil { + return nil, fmt.Errorf("failed to fetch session: %w", err) + } + s.Expiry, err = time.Parse(time.RFC3339Nano, expiry) + if err != nil { + return nil, fmt.Errorf("failed to parse session expiry: %w", err) + } + if time.Now().After(s.Expiry) { + return nil, nil + } + return &s, err +} + +func insertSession(user User, maxAge time.Duration) (*Session, error) { + s := Session{ + user: user, + Key: id.Generate(), + Expiry: time.Now().Add(maxAge), + } + stmt := "INSERT INTO session (key, user_id, expiry) VALUES (?, ?, ?)" + _, err := db.Connection.Exec(stmt, &s.Key, &user.Id, &s.Expiry) + if err != nil { + return nil, fmt.Errorf("failed to execute query: %w", err) + } + return &s, nil } diff --git a/core/session/store.go b/core/session/store.go deleted file mode 100644 index 3045fc3..0000000 --- a/core/session/store.go +++ /dev/null @@ -1,39 +0,0 @@ -package session - -import ( - "fmt" - "lishwist/core/internal/db" - - "github.com/Teajey/sqlstore" -) - -func NewStore(keyPairs ...[]byte) (*sqlstore.Store, error) { - deleteStmt, err := db.Connection.Prepare("DELETE FROM session WHERE id = ?;") - if err != nil { - return nil, fmt.Errorf("Failed to prepare delete statement: %w", err) - } - - insertStmt, err := db.Connection.Prepare("INSERT INTO session (value) VALUES (?);") - if err != nil { - return nil, fmt.Errorf("Failed to prepare insert statement: %w", err) - } - - selectStmt, err := db.Connection.Prepare("SELECT value FROM session WHERE id = ?;") - if err != nil { - return nil, fmt.Errorf("Failed to prepare select statement: %w", err) - } - - updateStmt, err := db.Connection.Prepare("UPDATE session SET value = ?2 WHERE id = ?1;") - if err != nil { - return nil, fmt.Errorf("Failed to prepare update statement: %w", err) - } - - s := sqlstore.NewSqlStore(db.Connection, sqlstore.Statements{ - Delete: deleteStmt, - Insert: insertStmt, - Select: selectStmt, - Update: updateStmt, - }, keyPairs...) - - return s, nil -} diff --git a/http/env/env.go b/http/env/env.go index ee9a036..10a998b 100644 --- a/http/env/env.go +++ b/http/env/env.go @@ -6,12 +6,12 @@ import ( "os" ) -func GuaranteeEnv(key string) (variable string) { +func GuaranteeEnv(key string) string { variable, ok := os.LookupEnv(key) if !ok || variable == "" { log.Fatalln("Missing environment variable:", key) } - return + return variable } var DatabaseFile = GuaranteeEnv("LISHWIST_DATABASE_FILE") diff --git a/http/internal/id/generate.go b/http/internal/id/generate.go new file mode 100644 index 0000000..21efd68 --- /dev/null +++ b/http/internal/id/generate.go @@ -0,0 +1,14 @@ +package id + +import ( + "crypto/rand" + "encoding/hex" +) + +func Generate() string { + bytes := make([]byte, 16) + if _, err := rand.Read(bytes); err != nil { + panic(err) + } + return hex.EncodeToString(bytes) +} diff --git a/http/main.go b/http/main.go index ad53076..912ae87 100644 --- a/http/main.go +++ b/http/main.go @@ -6,11 +6,11 @@ import ( "net/http" lishwist "lishwist/core" - "lishwist/core/session" "lishwist/http/api" "lishwist/http/env" "lishwist/http/router" "lishwist/http/routing" + "lishwist/http/session" ) func main() { @@ -22,11 +22,8 @@ func main() { log.Fatalf("Failed to init Lishwist: %s\n", err) } - store, err := session.NewStore([]byte(env.SessionSecret)) - if err != nil { - log.Fatalf("Failed to initialize session store: %s\n", err) - } - store.Options.MaxAge = 86_400 + store := session.NewInMemoryStore([]byte(env.SessionSecret)) + store.Options.MaxAge = 86_400 // 24 hours in seconds store.Options.Secure = !env.InDev store.Options.HttpOnly = true diff --git a/http/response/handler.go b/http/response/handler.go index 8749c13..3962387 100644 --- a/http/response/handler.go +++ b/http/response/handler.go @@ -1,20 +1,21 @@ package response import ( - "lishwist/http/templates" "log" "net/http" + "lishwist/http/session" + "lishwist/http/templates" + "github.com/Teajey/rsvp" - "github.com/Teajey/sqlstore" ) type ServeMux struct { inner *rsvp.ServeMux - store *sqlstore.Store + store *session.Store } -func NewServeMux(store *sqlstore.Store) *ServeMux { +func NewServeMux(store *session.Store) *ServeMux { mux := rsvp.NewServeMux() mux.Config.HtmlTemplate = templates.Template return &ServeMux{ diff --git a/http/router/router.go b/http/router/router.go index 4a0c75d..5f4386e 100644 --- a/http/router/router.go +++ b/http/router/router.go @@ -1,21 +1,22 @@ package router import ( - "lishwist/http/response" "net/http" - "github.com/Teajey/sqlstore" + "lishwist/http/response" + "lishwist/http/session" ) type VisibilityRouter struct { - Store *sqlstore.Store + store *session.Store Public *response.ServeMux Private *response.ServeMux } func (s *VisibilityRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { - session, _ := s.Store.Get(r, "lishwist_user") - authorized, _ := session.Values["authorized"].(bool) + session, _ := s.store.Get(r, "lishwist_user") + _, authorized := session.Values["sessionKey"] + if authorized { s.Private.ServeHTTP(w, r) @@ -24,9 +25,9 @@ func (s *VisibilityRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } -func New(store *sqlstore.Store) *VisibilityRouter { +func New(store *session.Store) *VisibilityRouter { return &VisibilityRouter{ - Store: store, + store: store, Public: response.NewServeMux(store), Private: response.NewServeMux(store), } diff --git a/http/routing/context.go b/http/routing/context.go index 646bcd2..d9faf46 100644 --- a/http/routing/context.go +++ b/http/routing/context.go @@ -13,15 +13,19 @@ import ( func ExpectAppSession(next func(*lishwist.Session, http.Header, *http.Request) rsvp.Response) response.HandlerFunc { return func(session *response.Session, h http.Header, r *http.Request) rsvp.Response { - username, ok := session.GetValue("username").(string) + sessionKey, ok := session.GetValue("sessionKey").(string) if !ok { - log.Printf("Failed to get username from session\n") + log.Printf("Failed to get key from session\n") return response.Error(http.StatusInternalServerError, "Something went wrong.") } - appSession, err := lishwist.SessionFromUsername(username) + appSession, err := lishwist.SessionFromKey(sessionKey) if err != nil { - log.Printf("Failed to get session by username %q: %s\n", username, err) + log.Printf("Failed to get session by key %v: %s\n", sessionKey, err) + return response.Error(http.StatusInternalServerError, "Something went wrong.") + } + if appSession == nil { + log.Printf("Session not found under key: %s\n", sessionKey) return response.Error(http.StatusInternalServerError, "Something went wrong.") } diff --git a/http/routing/login.go b/http/routing/login.go index 5021e67..e38895e 100644 --- a/http/routing/login.go +++ b/http/routing/login.go @@ -1,8 +1,10 @@ package routing import ( + "errors" "log" "net/http" + "time" lishwist "lishwist/core" "lishwist/http/api" @@ -54,10 +56,11 @@ func LoginPost(session *response.Session, h http.Header, r *http.Request) rsvp.R return resp } - app, err := lishwist.Login(username, password) + appSession, err := lishwist.Login(username, password, time.Hour*24) if err != nil { - switch err.(type) { - case lishwist.ErrorInvalidCredentials: + var targ lishwist.ErrorInvalidCredentials + switch { + case errors.As(err, &targ): props.GeneralError = "Username or password invalid" session.FlashSet(&props) log.Printf("Invalid credentials: %s: %#v\n", err, props) @@ -70,10 +73,8 @@ func LoginPost(session *response.Session, h http.Header, r *http.Request) rsvp.R } } - user := app.User() session.SetID("") - session.SetValue("authorized", true) - session.SetValue("username", user.Name) + session.SetValue("sessionKey", appSession.Key) return rsvp.SeeOther(r.URL.Path, "Login successful!") } diff --git a/http/session/inmemory.go b/http/session/inmemory.go new file mode 100644 index 0000000..1184f46 --- /dev/null +++ b/http/session/inmemory.go @@ -0,0 +1,42 @@ +package session + +import ( + "errors" + "lishwist/http/internal/id" + + "github.com/gorilla/securecookie" + "github.com/gorilla/sessions" +) + +var inMemStore = make(map[string]string) + +var errNotFound = errors.New("not found") + +func NewInMemoryStore(keyPairs ...[]byte) *Store { + return &Store{ + callbacks: Callbacks{ + Delete: func(key string) error { + delete(inMemStore, key) + return nil + }, + Insert: func(encodedValues string) (string, error) { + key := id.Generate() + inMemStore[key] = encodedValues + return key, nil + }, + Select: func(key string) (string, error) { + encodedValues, ok := inMemStore[key] + if !ok { + return "", errNotFound + } + return encodedValues, nil + }, + Update: func(key string, encodedValues string) error { + inMemStore[key] = encodedValues + return nil + }, + }, + Codecs: securecookie.CodecsFromPairs(keyPairs...), + Options: &sessions.Options{}, + } +} diff --git a/http/session/session.go b/http/session/session.go deleted file mode 100644 index 2d695f8..0000000 --- a/http/session/session.go +++ /dev/null @@ -1,25 +0,0 @@ -package sesh - -import ( - "log" - "net/http" - - "github.com/gorilla/sessions" -) - -func GetFirstFlash(w http.ResponseWriter, r *http.Request, session *sessions.Session, key ...string) (any, error) { - flashes := session.Flashes(key...) - - if len(flashes) < 1 { - return nil, nil - } - - flash := flashes[0] - - if err := session.Save(r, w); err != nil { - log.Println("Couldn't save session:", err) - return nil, err - } - - return flash, nil -} diff --git a/http/session/store.go b/http/session/store.go new file mode 100644 index 0000000..a68ceb4 --- /dev/null +++ b/http/session/store.go @@ -0,0 +1,115 @@ +package session + +import ( + "fmt" + "net/http" + + "github.com/gorilla/securecookie" + "github.com/gorilla/sessions" +) + +type Callbacks struct { + Delete func(id string) error + Insert func(encodedValues string) (string, error) + Select func(id string) (string, error) + Update func(id, encodedValues string) error +} + +type Store struct { + callbacks Callbacks + Codecs []securecookie.Codec + Options *sessions.Options +} + +func NewGenericStore(cb Callbacks, keyPairs ...[]byte) *Store { + return &Store{ + callbacks: cb, + Codecs: securecookie.CodecsFromPairs(keyPairs...), + Options: &sessions.Options{}, + } +} + +// Get should return a cached session. +func (m *Store) Get(r *http.Request, name string) (*sessions.Session, error) { + return sessions.GetRegistry(r).Get(m, name) +} + +// New should create and return a new session. +// +// Note that New should never return a nil session, even in the case of +// an error if using the Registry infrastructure to cache the session. +func (s *Store) New(r *http.Request, name string) (*sessions.Session, error) { + session := sessions.NewSession(s, name) + opts := *s.Options + session.Options = &opts + session.IsNew = true + + var err error + + c, errCookie := r.Cookie(name) + if errCookie != nil { + return session, nil + } + + err = securecookie.DecodeMulti(name, c.Value, &session.ID, s.Codecs...) + if err != nil { + return session, fmt.Errorf("failed to decode session id: %w", err) + } + + sessionValue, err := s.callbacks.Select(session.ID) + if err != nil { + return session, fmt.Errorf("failed to get session value: %w", err) + } + + err = securecookie.DecodeMulti(name, string(sessionValue), &session.Values, s.Codecs...) + if err == nil { + session.IsNew = false + } else { + err = fmt.Errorf("failed to decode session values: %w", err) + } + + return session, err +} + +// Save should persist session to the underlying store implementation. +func (s *Store) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error { + // Delete if max-age is <= 0 + if session.Options.MaxAge <= 0 { + err := s.callbacks.Delete(session.ID) + if err != nil { + return fmt.Errorf("failed to delete session: %w", err) + } + http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options)) + return nil + } + + encodedValues, err := securecookie.EncodeMulti(session.Name(), session.Values, + s.Codecs...) + if err != nil { + return fmt.Errorf("failed to encode cookie value: %w", err) + } + + if session.ID == "" { + i, err := s.callbacks.Insert(encodedValues) + if err != nil { + return fmt.Errorf("failed to insert session: %w", err) + } + + session.ID = i + } else { + err := s.callbacks.Update(session.ID, encodedValues) + if err != nil { + return fmt.Errorf("failed to update session: %w", err) + } + } + + encodedId, err := securecookie.EncodeMulti(session.Name(), session.ID, + s.Codecs...) + if err != nil { + return fmt.Errorf("failed to encode cookie value: %w", err) + } + + http.SetCookie(w, sessions.NewCookie(session.Name(), encodedId, session.Options)) + + return nil +}