Use RSVP #15

Merged
Teajey merged 14 commits from rsvp-lib into main 2025-09-13 03:06:22 +12:00
19 changed files with 294 additions and 110 deletions
Showing only changes of commit abb9c54036 - Show all commits

View File

@ -45,12 +45,14 @@ func GetGroupByReference(reference string) (*Group, error)
func (g *Group) MemberIndex(userId string) int
type Session struct {
Key string
Expiry time.Time
// Has unexported fields.
}
func Login(username, password string) (*Session, error)
func Login(username, password string, sessionMaxAge time.Duration) (*Session, error)
func SessionFromUsername(username string) (*Session, error)
func SessionFromKey(key string) (*Session, error)
func (s *Session) Admin() *Admin

View File

@ -15,12 +15,12 @@ var Connection *sql.DB
func Init(dataSourceName string) error {
db, err := sql.Open("sqlite3", dataSourceName)
if err != nil {
return fmt.Errorf("Failed to open db connection: %w", err)
return fmt.Errorf("failed to open db connection: %w", err)
}
_, err = db.Exec(initQuery)
if err != nil {
return fmt.Errorf("Failed to initialize db: %w", err)
return fmt.Errorf("failed to initialize db: %w", err)
}
Connection = db

View File

@ -37,8 +37,11 @@ CREATE TABLE IF NOT EXISTS "group_member" (
);
CREATE TABLE IF NOT EXISTS "session" (
"id" INTEGER NOT NULL UNIQUE,
"value" TEXT NOT NULL,
PRIMARY KEY("id" AUTOINCREMENT)
"key" TEXT NOT NULL UNIQUE,
"user_id" INTEGER NOT NULL,
"expiry" TEXT NOT NULL,
PRIMARY KEY("id" AUTOINCREMENT),
FOREIGN KEY("user_id") REFERENCES "user"("id")
);
DROP VIEW IF EXISTS "v_user";

View File

@ -3,6 +3,7 @@ package fixtures
import (
"log"
"testing"
"time"
lishwist "lishwist/core"
@ -26,7 +27,7 @@ func Login(t *testing.T, username, password string) *lishwist.Session {
log.Fatalf("Failed to register on login fixture: %s\n", err)
}
session, err := lishwist.Login(username, password)
session, err := lishwist.Login(username, password, time.Hour*24)
if err != nil {
log.Fatalf("Failed to login on fixture: %s\n", err)
}

View File

@ -0,0 +1,14 @@
package id
import (
"crypto/rand"
"encoding/hex"
)
func Generate() string {
bytes := make([]byte, 16)
if _, err := rand.Read(bytes); err != nil {
panic(err)
}
return hex.EncodeToString(bytes)
}

View File

@ -2,13 +2,14 @@ package lishwist
import (
"fmt"
"time"
"golang.org/x/crypto/bcrypt"
)
type ErrorInvalidCredentials error
func Login(username, password string) (*Session, error) {
func Login(username, password string, sessionMaxAge time.Duration) (*Session, error) {
user, err := getUserByName(username)
if err != nil {
return nil, ErrorInvalidCredentials(fmt.Errorf("Failed to fetch user: %w", err))
@ -27,5 +28,10 @@ func Login(username, password string) (*Session, error) {
return nil, ErrorInvalidCredentials(fmt.Errorf("Password compare failed: %w", err))
}
return &Session{*user}, nil
session, err := insertSession(*user, sessionMaxAge)
if err != nil {
return nil, fmt.Errorf("failed to insert session: %w", err)
}
return session, nil
}

View File

@ -2,6 +2,7 @@ package lishwist_test
import (
"testing"
"time"
lishwist "lishwist/core"
"lishwist/core/internal/fixtures"
@ -18,7 +19,7 @@ func TestLogin(t *testing.T) {
t.Fatalf("Failed to register: %s\n", err)
}
_, err = lishwist.Login("thomas", "123")
_, err = lishwist.Login("thomas", "123", time.Hour*24)
if err != nil {
t.Fatalf("Failed to login: %s\n", err)
}

View File

@ -1,9 +1,19 @@
package lishwist
import "fmt"
import (
"database/sql"
"errors"
"fmt"
"time"
"lishwist/core/internal/db"
"lishwist/core/internal/id"
)
type Session struct {
user User
user User
Key string
Expiry time.Time
}
// Returns a copy of the user associated with this session
@ -11,10 +21,46 @@ func (s *Session) User() User {
return s.user
}
func SessionFromUsername(username string) (*Session, error) {
user, err := getUserByName(username)
if err != nil {
return nil, fmt.Errorf("Failed to get user: %w", err)
func SessionFromKey(key string) (*Session, error) {
s := Session{}
query := "SELECT user.id, user.name, user.display_name, user.reference, user.is_admin, user.is_live, session.key, session.expiry FROM v_user as user JOIN session ON user.id = session.user_id WHERE session.key = ?"
var expiry string
err := db.Connection.QueryRow(query, key).Scan(
&s.user.Id,
&s.user.Name,
&s.user.NormalName,
&s.user.Reference,
&s.user.IsAdmin,
&s.user.IsLive,
&s.Key,
&expiry,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return &Session{*user}, nil
if err != nil {
return nil, fmt.Errorf("failed to fetch session: %w", err)
}
s.Expiry, err = time.Parse(time.RFC3339Nano, expiry)
if err != nil {
return nil, fmt.Errorf("failed to parse session expiry: %w", err)
}
if time.Now().After(s.Expiry) {
return nil, nil
}
return &s, err
}
func insertSession(user User, maxAge time.Duration) (*Session, error) {
s := Session{
user: user,
Key: id.Generate(),
Expiry: time.Now().Add(maxAge),
}
stmt := "INSERT INTO session (key, user_id, expiry) VALUES (?, ?, ?)"
_, err := db.Connection.Exec(stmt, &s.Key, &user.Id, &s.Expiry)
if err != nil {
return nil, fmt.Errorf("failed to execute query: %w", err)
}
return &s, nil
}

View File

@ -1,39 +0,0 @@
package session
import (
"fmt"
"lishwist/core/internal/db"
"github.com/Teajey/sqlstore"
)
func NewStore(keyPairs ...[]byte) (*sqlstore.Store, error) {
deleteStmt, err := db.Connection.Prepare("DELETE FROM session WHERE id = ?;")
if err != nil {
return nil, fmt.Errorf("Failed to prepare delete statement: %w", err)
}
insertStmt, err := db.Connection.Prepare("INSERT INTO session (value) VALUES (?);")
if err != nil {
return nil, fmt.Errorf("Failed to prepare insert statement: %w", err)
}
selectStmt, err := db.Connection.Prepare("SELECT value FROM session WHERE id = ?;")
if err != nil {
return nil, fmt.Errorf("Failed to prepare select statement: %w", err)
}
updateStmt, err := db.Connection.Prepare("UPDATE session SET value = ?2 WHERE id = ?1;")
if err != nil {
return nil, fmt.Errorf("Failed to prepare update statement: %w", err)
}
s := sqlstore.NewSqlStore(db.Connection, sqlstore.Statements{
Delete: deleteStmt,
Insert: insertStmt,
Select: selectStmt,
Update: updateStmt,
}, keyPairs...)
return s, nil
}

4
http/env/env.go vendored
View File

@ -6,12 +6,12 @@ import (
"os"
)
func GuaranteeEnv(key string) (variable string) {
func GuaranteeEnv(key string) string {
variable, ok := os.LookupEnv(key)
if !ok || variable == "" {
log.Fatalln("Missing environment variable:", key)
}
return
return variable
}
var DatabaseFile = GuaranteeEnv("LISHWIST_DATABASE_FILE")

View File

@ -0,0 +1,14 @@
package id
import (
"crypto/rand"
"encoding/hex"
)
func Generate() string {
bytes := make([]byte, 16)
if _, err := rand.Read(bytes); err != nil {
panic(err)
}
return hex.EncodeToString(bytes)
}

View File

@ -6,11 +6,11 @@ import (
"net/http"
lishwist "lishwist/core"
"lishwist/core/session"
"lishwist/http/api"
"lishwist/http/env"
"lishwist/http/router"
"lishwist/http/routing"
"lishwist/http/session"
)
func main() {
@ -22,11 +22,8 @@ func main() {
log.Fatalf("Failed to init Lishwist: %s\n", err)
}
store, err := session.NewStore([]byte(env.SessionSecret))
if err != nil {
log.Fatalf("Failed to initialize session store: %s\n", err)
}
store.Options.MaxAge = 86_400
store := session.NewInMemoryStore([]byte(env.SessionSecret))
store.Options.MaxAge = 86_400 // 24 hours in seconds
store.Options.Secure = !env.InDev
store.Options.HttpOnly = true

View File

@ -1,20 +1,21 @@
package response
import (
"lishwist/http/templates"
"log"
"net/http"
"lishwist/http/session"
"lishwist/http/templates"
"github.com/Teajey/rsvp"
"github.com/Teajey/sqlstore"
)
type ServeMux struct {
inner *rsvp.ServeMux
store *sqlstore.Store
store *session.Store
}
func NewServeMux(store *sqlstore.Store) *ServeMux {
func NewServeMux(store *session.Store) *ServeMux {
mux := rsvp.NewServeMux()
mux.Config.HtmlTemplate = templates.Template
return &ServeMux{

View File

@ -1,21 +1,22 @@
package router
import (
"lishwist/http/response"
"net/http"
"github.com/Teajey/sqlstore"
"lishwist/http/response"
"lishwist/http/session"
)
type VisibilityRouter struct {
Store *sqlstore.Store
store *session.Store
Public *response.ServeMux
Private *response.ServeMux
}
func (s *VisibilityRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
session, _ := s.Store.Get(r, "lishwist_user")
authorized, _ := session.Values["authorized"].(bool)
session, _ := s.store.Get(r, "lishwist_user")
_, authorized := session.Values["sessionKey"]
if authorized {
s.Private.ServeHTTP(w, r)
@ -24,9 +25,9 @@ func (s *VisibilityRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
func New(store *sqlstore.Store) *VisibilityRouter {
func New(store *session.Store) *VisibilityRouter {
return &VisibilityRouter{
Store: store,
store: store,
Public: response.NewServeMux(store),
Private: response.NewServeMux(store),
}

View File

@ -13,15 +13,19 @@ import (
func ExpectAppSession(next func(*lishwist.Session, http.Header, *http.Request) rsvp.Response) response.HandlerFunc {
return func(session *response.Session, h http.Header, r *http.Request) rsvp.Response {
username, ok := session.GetValue("username").(string)
sessionKey, ok := session.GetValue("sessionKey").(string)
if !ok {
log.Printf("Failed to get username from session\n")
log.Printf("Failed to get key from session\n")
return response.Error(http.StatusInternalServerError, "Something went wrong.")
}
appSession, err := lishwist.SessionFromUsername(username)
appSession, err := lishwist.SessionFromKey(sessionKey)
if err != nil {
log.Printf("Failed to get session by username %q: %s\n", username, err)
log.Printf("Failed to get session by key %v: %s\n", sessionKey, err)
return response.Error(http.StatusInternalServerError, "Something went wrong.")
}
if appSession == nil {
log.Printf("Session not found under key: %s\n", sessionKey)
return response.Error(http.StatusInternalServerError, "Something went wrong.")
}

View File

@ -1,8 +1,10 @@
package routing
import (
"errors"
"log"
"net/http"
"time"
lishwist "lishwist/core"
"lishwist/http/api"
@ -54,10 +56,11 @@ func LoginPost(session *response.Session, h http.Header, r *http.Request) rsvp.R
return resp
}
app, err := lishwist.Login(username, password)
appSession, err := lishwist.Login(username, password, time.Hour*24)
if err != nil {
switch err.(type) {
case lishwist.ErrorInvalidCredentials:
var targ lishwist.ErrorInvalidCredentials
switch {
case errors.As(err, &targ):
props.GeneralError = "Username or password invalid"
session.FlashSet(&props)
log.Printf("Invalid credentials: %s: %#v\n", err, props)
@ -70,10 +73,8 @@ func LoginPost(session *response.Session, h http.Header, r *http.Request) rsvp.R
}
}
user := app.User()
session.SetID("")
session.SetValue("authorized", true)
session.SetValue("username", user.Name)
session.SetValue("sessionKey", appSession.Key)
return rsvp.SeeOther(r.URL.Path, "Login successful!")
}

42
http/session/inmemory.go Normal file
View File

@ -0,0 +1,42 @@
package session
import (
"errors"
"lishwist/http/internal/id"
"github.com/gorilla/securecookie"
"github.com/gorilla/sessions"
)
var inMemStore = make(map[string]string)
var errNotFound = errors.New("not found")
func NewInMemoryStore(keyPairs ...[]byte) *Store {
return &Store{
callbacks: Callbacks{
Delete: func(key string) error {
delete(inMemStore, key)
return nil
},
Insert: func(encodedValues string) (string, error) {
key := id.Generate()
inMemStore[key] = encodedValues
return key, nil
},
Select: func(key string) (string, error) {
encodedValues, ok := inMemStore[key]
if !ok {
return "", errNotFound
}
return encodedValues, nil
},
Update: func(key string, encodedValues string) error {
inMemStore[key] = encodedValues
return nil
},
},
Codecs: securecookie.CodecsFromPairs(keyPairs...),
Options: &sessions.Options{},
}
}

View File

@ -1,25 +0,0 @@
package sesh
import (
"log"
"net/http"
"github.com/gorilla/sessions"
)
func GetFirstFlash(w http.ResponseWriter, r *http.Request, session *sessions.Session, key ...string) (any, error) {
flashes := session.Flashes(key...)
if len(flashes) < 1 {
return nil, nil
}
flash := flashes[0]
if err := session.Save(r, w); err != nil {
log.Println("Couldn't save session:", err)
return nil, err
}
return flash, nil
}

115
http/session/store.go Normal file
View File

@ -0,0 +1,115 @@
package session
import (
"fmt"
"net/http"
"github.com/gorilla/securecookie"
"github.com/gorilla/sessions"
)
type Callbacks struct {
Delete func(id string) error
Insert func(encodedValues string) (string, error)
Select func(id string) (string, error)
Update func(id, encodedValues string) error
}
type Store struct {
callbacks Callbacks
Codecs []securecookie.Codec
Options *sessions.Options
}
func NewGenericStore(cb Callbacks, keyPairs ...[]byte) *Store {
return &Store{
callbacks: cb,
Codecs: securecookie.CodecsFromPairs(keyPairs...),
Options: &sessions.Options{},
}
}
// Get should return a cached session.
func (m *Store) Get(r *http.Request, name string) (*sessions.Session, error) {
return sessions.GetRegistry(r).Get(m, name)
}
// New should create and return a new session.
//
// Note that New should never return a nil session, even in the case of
// an error if using the Registry infrastructure to cache the session.
func (s *Store) New(r *http.Request, name string) (*sessions.Session, error) {
session := sessions.NewSession(s, name)
opts := *s.Options
session.Options = &opts
session.IsNew = true
var err error
c, errCookie := r.Cookie(name)
if errCookie != nil {
return session, nil
}
err = securecookie.DecodeMulti(name, c.Value, &session.ID, s.Codecs...)
if err != nil {
return session, fmt.Errorf("failed to decode session id: %w", err)
}
sessionValue, err := s.callbacks.Select(session.ID)
if err != nil {
return session, fmt.Errorf("failed to get session value: %w", err)
}
err = securecookie.DecodeMulti(name, string(sessionValue), &session.Values, s.Codecs...)
if err == nil {
session.IsNew = false
} else {
err = fmt.Errorf("failed to decode session values: %w", err)
}
return session, err
}
// Save should persist session to the underlying store implementation.
func (s *Store) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
// Delete if max-age is <= 0
if session.Options.MaxAge <= 0 {
err := s.callbacks.Delete(session.ID)
if err != nil {
return fmt.Errorf("failed to delete session: %w", err)
}
http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options))
return nil
}
encodedValues, err := securecookie.EncodeMulti(session.Name(), session.Values,
s.Codecs...)
if err != nil {
return fmt.Errorf("failed to encode cookie value: %w", err)
}
if session.ID == "" {
i, err := s.callbacks.Insert(encodedValues)
if err != nil {
return fmt.Errorf("failed to insert session: %w", err)
}
session.ID = i
} else {
err := s.callbacks.Update(session.ID, encodedValues)
if err != nil {
return fmt.Errorf("failed to update session: %w", err)
}
}
encodedId, err := securecookie.EncodeMulti(session.Name(), session.ID,
s.Codecs...)
if err != nil {
return fmt.Errorf("failed to encode cookie value: %w", err)
}
http.SetCookie(w, sessions.NewCookie(session.Name(), encodedId, session.Options))
return nil
}