feat: internally managed session
This commit is contained in:
parent
57e18ae0ce
commit
abb9c54036
|
|
@ -45,12 +45,14 @@ func GetGroupByReference(reference string) (*Group, error)
|
|||
func (g *Group) MemberIndex(userId string) int
|
||||
|
||||
type Session struct {
|
||||
Key string
|
||||
Expiry time.Time
|
||||
// Has unexported fields.
|
||||
}
|
||||
|
||||
func Login(username, password string) (*Session, error)
|
||||
func Login(username, password string, sessionMaxAge time.Duration) (*Session, error)
|
||||
|
||||
func SessionFromUsername(username string) (*Session, error)
|
||||
func SessionFromKey(key string) (*Session, error)
|
||||
|
||||
func (s *Session) Admin() *Admin
|
||||
|
||||
|
|
|
|||
|
|
@ -15,12 +15,12 @@ var Connection *sql.DB
|
|||
func Init(dataSourceName string) error {
|
||||
db, err := sql.Open("sqlite3", dataSourceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to open db connection: %w", err)
|
||||
return fmt.Errorf("failed to open db connection: %w", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(initQuery)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to initialize db: %w", err)
|
||||
return fmt.Errorf("failed to initialize db: %w", err)
|
||||
}
|
||||
|
||||
Connection = db
|
||||
|
|
|
|||
|
|
@ -37,8 +37,11 @@ CREATE TABLE IF NOT EXISTS "group_member" (
|
|||
);
|
||||
CREATE TABLE IF NOT EXISTS "session" (
|
||||
"id" INTEGER NOT NULL UNIQUE,
|
||||
"value" TEXT NOT NULL,
|
||||
PRIMARY KEY("id" AUTOINCREMENT)
|
||||
"key" TEXT NOT NULL UNIQUE,
|
||||
"user_id" INTEGER NOT NULL,
|
||||
"expiry" TEXT NOT NULL,
|
||||
PRIMARY KEY("id" AUTOINCREMENT),
|
||||
FOREIGN KEY("user_id") REFERENCES "user"("id")
|
||||
);
|
||||
|
||||
DROP VIEW IF EXISTS "v_user";
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package fixtures
|
|||
import (
|
||||
"log"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
lishwist "lishwist/core"
|
||||
|
||||
|
|
@ -26,7 +27,7 @@ func Login(t *testing.T, username, password string) *lishwist.Session {
|
|||
log.Fatalf("Failed to register on login fixture: %s\n", err)
|
||||
}
|
||||
|
||||
session, err := lishwist.Login(username, password)
|
||||
session, err := lishwist.Login(username, password, time.Hour*24)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to login on fixture: %s\n", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,14 @@
|
|||
package id
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
func Generate() string {
|
||||
bytes := make([]byte, 16)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return hex.EncodeToString(bytes)
|
||||
}
|
||||
|
|
@ -2,13 +2,14 @@ package lishwist
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type ErrorInvalidCredentials error
|
||||
|
||||
func Login(username, password string) (*Session, error) {
|
||||
func Login(username, password string, sessionMaxAge time.Duration) (*Session, error) {
|
||||
user, err := getUserByName(username)
|
||||
if err != nil {
|
||||
return nil, ErrorInvalidCredentials(fmt.Errorf("Failed to fetch user: %w", err))
|
||||
|
|
@ -27,5 +28,10 @@ func Login(username, password string) (*Session, error) {
|
|||
return nil, ErrorInvalidCredentials(fmt.Errorf("Password compare failed: %w", err))
|
||||
}
|
||||
|
||||
return &Session{*user}, nil
|
||||
session, err := insertSession(*user, sessionMaxAge)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to insert session: %w", err)
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package lishwist_test
|
|||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
lishwist "lishwist/core"
|
||||
"lishwist/core/internal/fixtures"
|
||||
|
|
@ -18,7 +19,7 @@ func TestLogin(t *testing.T) {
|
|||
t.Fatalf("Failed to register: %s\n", err)
|
||||
}
|
||||
|
||||
_, err = lishwist.Login("thomas", "123")
|
||||
_, err = lishwist.Login("thomas", "123", time.Hour*24)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to login: %s\n", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,19 @@
|
|||
package lishwist
|
||||
|
||||
import "fmt"
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"lishwist/core/internal/db"
|
||||
"lishwist/core/internal/id"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
user User
|
||||
Key string
|
||||
Expiry time.Time
|
||||
}
|
||||
|
||||
// Returns a copy of the user associated with this session
|
||||
|
|
@ -11,10 +21,46 @@ func (s *Session) User() User {
|
|||
return s.user
|
||||
}
|
||||
|
||||
func SessionFromUsername(username string) (*Session, error) {
|
||||
user, err := getUserByName(username)
|
||||
func SessionFromKey(key string) (*Session, error) {
|
||||
s := Session{}
|
||||
query := "SELECT user.id, user.name, user.display_name, user.reference, user.is_admin, user.is_live, session.key, session.expiry FROM v_user as user JOIN session ON user.id = session.user_id WHERE session.key = ?"
|
||||
var expiry string
|
||||
err := db.Connection.QueryRow(query, key).Scan(
|
||||
&s.user.Id,
|
||||
&s.user.Name,
|
||||
&s.user.NormalName,
|
||||
&s.user.Reference,
|
||||
&s.user.IsAdmin,
|
||||
&s.user.IsLive,
|
||||
&s.Key,
|
||||
&expiry,
|
||||
)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to get user: %w", err)
|
||||
return nil, fmt.Errorf("failed to fetch session: %w", err)
|
||||
}
|
||||
return &Session{*user}, nil
|
||||
s.Expiry, err = time.Parse(time.RFC3339Nano, expiry)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse session expiry: %w", err)
|
||||
}
|
||||
if time.Now().After(s.Expiry) {
|
||||
return nil, nil
|
||||
}
|
||||
return &s, err
|
||||
}
|
||||
|
||||
func insertSession(user User, maxAge time.Duration) (*Session, error) {
|
||||
s := Session{
|
||||
user: user,
|
||||
Key: id.Generate(),
|
||||
Expiry: time.Now().Add(maxAge),
|
||||
}
|
||||
stmt := "INSERT INTO session (key, user_id, expiry) VALUES (?, ?, ?)"
|
||||
_, err := db.Connection.Exec(stmt, &s.Key, &user.Id, &s.Expiry)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute query: %w", err)
|
||||
}
|
||||
return &s, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,39 +0,0 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"lishwist/core/internal/db"
|
||||
|
||||
"github.com/Teajey/sqlstore"
|
||||
)
|
||||
|
||||
func NewStore(keyPairs ...[]byte) (*sqlstore.Store, error) {
|
||||
deleteStmt, err := db.Connection.Prepare("DELETE FROM session WHERE id = ?;")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to prepare delete statement: %w", err)
|
||||
}
|
||||
|
||||
insertStmt, err := db.Connection.Prepare("INSERT INTO session (value) VALUES (?);")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to prepare insert statement: %w", err)
|
||||
}
|
||||
|
||||
selectStmt, err := db.Connection.Prepare("SELECT value FROM session WHERE id = ?;")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to prepare select statement: %w", err)
|
||||
}
|
||||
|
||||
updateStmt, err := db.Connection.Prepare("UPDATE session SET value = ?2 WHERE id = ?1;")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to prepare update statement: %w", err)
|
||||
}
|
||||
|
||||
s := sqlstore.NewSqlStore(db.Connection, sqlstore.Statements{
|
||||
Delete: deleteStmt,
|
||||
Insert: insertStmt,
|
||||
Select: selectStmt,
|
||||
Update: updateStmt,
|
||||
}, keyPairs...)
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
|
@ -6,12 +6,12 @@ import (
|
|||
"os"
|
||||
)
|
||||
|
||||
func GuaranteeEnv(key string) (variable string) {
|
||||
func GuaranteeEnv(key string) string {
|
||||
variable, ok := os.LookupEnv(key)
|
||||
if !ok || variable == "" {
|
||||
log.Fatalln("Missing environment variable:", key)
|
||||
}
|
||||
return
|
||||
return variable
|
||||
}
|
||||
|
||||
var DatabaseFile = GuaranteeEnv("LISHWIST_DATABASE_FILE")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,14 @@
|
|||
package id
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
)
|
||||
|
||||
func Generate() string {
|
||||
bytes := make([]byte, 16)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return hex.EncodeToString(bytes)
|
||||
}
|
||||
|
|
@ -6,11 +6,11 @@ import (
|
|||
"net/http"
|
||||
|
||||
lishwist "lishwist/core"
|
||||
"lishwist/core/session"
|
||||
"lishwist/http/api"
|
||||
"lishwist/http/env"
|
||||
"lishwist/http/router"
|
||||
"lishwist/http/routing"
|
||||
"lishwist/http/session"
|
||||
)
|
||||
|
||||
func main() {
|
||||
|
|
@ -22,11 +22,8 @@ func main() {
|
|||
log.Fatalf("Failed to init Lishwist: %s\n", err)
|
||||
}
|
||||
|
||||
store, err := session.NewStore([]byte(env.SessionSecret))
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to initialize session store: %s\n", err)
|
||||
}
|
||||
store.Options.MaxAge = 86_400
|
||||
store := session.NewInMemoryStore([]byte(env.SessionSecret))
|
||||
store.Options.MaxAge = 86_400 // 24 hours in seconds
|
||||
store.Options.Secure = !env.InDev
|
||||
store.Options.HttpOnly = true
|
||||
|
||||
|
|
|
|||
|
|
@ -1,20 +1,21 @@
|
|||
package response
|
||||
|
||||
import (
|
||||
"lishwist/http/templates"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"lishwist/http/session"
|
||||
"lishwist/http/templates"
|
||||
|
||||
"github.com/Teajey/rsvp"
|
||||
"github.com/Teajey/sqlstore"
|
||||
)
|
||||
|
||||
type ServeMux struct {
|
||||
inner *rsvp.ServeMux
|
||||
store *sqlstore.Store
|
||||
store *session.Store
|
||||
}
|
||||
|
||||
func NewServeMux(store *sqlstore.Store) *ServeMux {
|
||||
func NewServeMux(store *session.Store) *ServeMux {
|
||||
mux := rsvp.NewServeMux()
|
||||
mux.Config.HtmlTemplate = templates.Template
|
||||
return &ServeMux{
|
||||
|
|
|
|||
|
|
@ -1,21 +1,22 @@
|
|||
package router
|
||||
|
||||
import (
|
||||
"lishwist/http/response"
|
||||
"net/http"
|
||||
|
||||
"github.com/Teajey/sqlstore"
|
||||
"lishwist/http/response"
|
||||
"lishwist/http/session"
|
||||
)
|
||||
|
||||
type VisibilityRouter struct {
|
||||
Store *sqlstore.Store
|
||||
store *session.Store
|
||||
Public *response.ServeMux
|
||||
Private *response.ServeMux
|
||||
}
|
||||
|
||||
func (s *VisibilityRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
session, _ := s.Store.Get(r, "lishwist_user")
|
||||
authorized, _ := session.Values["authorized"].(bool)
|
||||
session, _ := s.store.Get(r, "lishwist_user")
|
||||
_, authorized := session.Values["sessionKey"]
|
||||
|
||||
|
||||
if authorized {
|
||||
s.Private.ServeHTTP(w, r)
|
||||
|
|
@ -24,9 +25,9 @@ func (s *VisibilityRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
|
||||
func New(store *sqlstore.Store) *VisibilityRouter {
|
||||
func New(store *session.Store) *VisibilityRouter {
|
||||
return &VisibilityRouter{
|
||||
Store: store,
|
||||
store: store,
|
||||
Public: response.NewServeMux(store),
|
||||
Private: response.NewServeMux(store),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,15 +13,19 @@ import (
|
|||
|
||||
func ExpectAppSession(next func(*lishwist.Session, http.Header, *http.Request) rsvp.Response) response.HandlerFunc {
|
||||
return func(session *response.Session, h http.Header, r *http.Request) rsvp.Response {
|
||||
username, ok := session.GetValue("username").(string)
|
||||
sessionKey, ok := session.GetValue("sessionKey").(string)
|
||||
if !ok {
|
||||
log.Printf("Failed to get username from session\n")
|
||||
log.Printf("Failed to get key from session\n")
|
||||
return response.Error(http.StatusInternalServerError, "Something went wrong.")
|
||||
}
|
||||
|
||||
appSession, err := lishwist.SessionFromUsername(username)
|
||||
appSession, err := lishwist.SessionFromKey(sessionKey)
|
||||
if err != nil {
|
||||
log.Printf("Failed to get session by username %q: %s\n", username, err)
|
||||
log.Printf("Failed to get session by key %v: %s\n", sessionKey, err)
|
||||
return response.Error(http.StatusInternalServerError, "Something went wrong.")
|
||||
}
|
||||
if appSession == nil {
|
||||
log.Printf("Session not found under key: %s\n", sessionKey)
|
||||
return response.Error(http.StatusInternalServerError, "Something went wrong.")
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
package routing
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
lishwist "lishwist/core"
|
||||
"lishwist/http/api"
|
||||
|
|
@ -54,10 +56,11 @@ func LoginPost(session *response.Session, h http.Header, r *http.Request) rsvp.R
|
|||
return resp
|
||||
}
|
||||
|
||||
app, err := lishwist.Login(username, password)
|
||||
appSession, err := lishwist.Login(username, password, time.Hour*24)
|
||||
if err != nil {
|
||||
switch err.(type) {
|
||||
case lishwist.ErrorInvalidCredentials:
|
||||
var targ lishwist.ErrorInvalidCredentials
|
||||
switch {
|
||||
case errors.As(err, &targ):
|
||||
props.GeneralError = "Username or password invalid"
|
||||
session.FlashSet(&props)
|
||||
log.Printf("Invalid credentials: %s: %#v\n", err, props)
|
||||
|
|
@ -70,10 +73,8 @@ func LoginPost(session *response.Session, h http.Header, r *http.Request) rsvp.R
|
|||
}
|
||||
}
|
||||
|
||||
user := app.User()
|
||||
session.SetID("")
|
||||
session.SetValue("authorized", true)
|
||||
session.SetValue("username", user.Name)
|
||||
session.SetValue("sessionKey", appSession.Key)
|
||||
|
||||
return rsvp.SeeOther(r.URL.Path, "Login successful!")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,42 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"lishwist/http/internal/id"
|
||||
|
||||
"github.com/gorilla/securecookie"
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
var inMemStore = make(map[string]string)
|
||||
|
||||
var errNotFound = errors.New("not found")
|
||||
|
||||
func NewInMemoryStore(keyPairs ...[]byte) *Store {
|
||||
return &Store{
|
||||
callbacks: Callbacks{
|
||||
Delete: func(key string) error {
|
||||
delete(inMemStore, key)
|
||||
return nil
|
||||
},
|
||||
Insert: func(encodedValues string) (string, error) {
|
||||
key := id.Generate()
|
||||
inMemStore[key] = encodedValues
|
||||
return key, nil
|
||||
},
|
||||
Select: func(key string) (string, error) {
|
||||
encodedValues, ok := inMemStore[key]
|
||||
if !ok {
|
||||
return "", errNotFound
|
||||
}
|
||||
return encodedValues, nil
|
||||
},
|
||||
Update: func(key string, encodedValues string) error {
|
||||
inMemStore[key] = encodedValues
|
||||
return nil
|
||||
},
|
||||
},
|
||||
Codecs: securecookie.CodecsFromPairs(keyPairs...),
|
||||
Options: &sessions.Options{},
|
||||
}
|
||||
}
|
||||
|
|
@ -1,25 +0,0 @@
|
|||
package sesh
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
func GetFirstFlash(w http.ResponseWriter, r *http.Request, session *sessions.Session, key ...string) (any, error) {
|
||||
flashes := session.Flashes(key...)
|
||||
|
||||
if len(flashes) < 1 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
flash := flashes[0]
|
||||
|
||||
if err := session.Save(r, w); err != nil {
|
||||
log.Println("Couldn't save session:", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return flash, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,115 @@
|
|||
package session
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/securecookie"
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
type Callbacks struct {
|
||||
Delete func(id string) error
|
||||
Insert func(encodedValues string) (string, error)
|
||||
Select func(id string) (string, error)
|
||||
Update func(id, encodedValues string) error
|
||||
}
|
||||
|
||||
type Store struct {
|
||||
callbacks Callbacks
|
||||
Codecs []securecookie.Codec
|
||||
Options *sessions.Options
|
||||
}
|
||||
|
||||
func NewGenericStore(cb Callbacks, keyPairs ...[]byte) *Store {
|
||||
return &Store{
|
||||
callbacks: cb,
|
||||
Codecs: securecookie.CodecsFromPairs(keyPairs...),
|
||||
Options: &sessions.Options{},
|
||||
}
|
||||
}
|
||||
|
||||
// Get should return a cached session.
|
||||
func (m *Store) Get(r *http.Request, name string) (*sessions.Session, error) {
|
||||
return sessions.GetRegistry(r).Get(m, name)
|
||||
}
|
||||
|
||||
// New should create and return a new session.
|
||||
//
|
||||
// Note that New should never return a nil session, even in the case of
|
||||
// an error if using the Registry infrastructure to cache the session.
|
||||
func (s *Store) New(r *http.Request, name string) (*sessions.Session, error) {
|
||||
session := sessions.NewSession(s, name)
|
||||
opts := *s.Options
|
||||
session.Options = &opts
|
||||
session.IsNew = true
|
||||
|
||||
var err error
|
||||
|
||||
c, errCookie := r.Cookie(name)
|
||||
if errCookie != nil {
|
||||
return session, nil
|
||||
}
|
||||
|
||||
err = securecookie.DecodeMulti(name, c.Value, &session.ID, s.Codecs...)
|
||||
if err != nil {
|
||||
return session, fmt.Errorf("failed to decode session id: %w", err)
|
||||
}
|
||||
|
||||
sessionValue, err := s.callbacks.Select(session.ID)
|
||||
if err != nil {
|
||||
return session, fmt.Errorf("failed to get session value: %w", err)
|
||||
}
|
||||
|
||||
err = securecookie.DecodeMulti(name, string(sessionValue), &session.Values, s.Codecs...)
|
||||
if err == nil {
|
||||
session.IsNew = false
|
||||
} else {
|
||||
err = fmt.Errorf("failed to decode session values: %w", err)
|
||||
}
|
||||
|
||||
return session, err
|
||||
}
|
||||
|
||||
// Save should persist session to the underlying store implementation.
|
||||
func (s *Store) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
|
||||
// Delete if max-age is <= 0
|
||||
if session.Options.MaxAge <= 0 {
|
||||
err := s.callbacks.Delete(session.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete session: %w", err)
|
||||
}
|
||||
http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options))
|
||||
return nil
|
||||
}
|
||||
|
||||
encodedValues, err := securecookie.EncodeMulti(session.Name(), session.Values,
|
||||
s.Codecs...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode cookie value: %w", err)
|
||||
}
|
||||
|
||||
if session.ID == "" {
|
||||
i, err := s.callbacks.Insert(encodedValues)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to insert session: %w", err)
|
||||
}
|
||||
|
||||
session.ID = i
|
||||
} else {
|
||||
err := s.callbacks.Update(session.ID, encodedValues)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update session: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
encodedId, err := securecookie.EncodeMulti(session.Name(), session.ID,
|
||||
s.Codecs...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encode cookie value: %w", err)
|
||||
}
|
||||
|
||||
http.SetCookie(w, sessions.NewCookie(session.Name(), encodedId, session.Options))
|
||||
|
||||
return nil
|
||||
}
|
||||
Loading…
Reference in New Issue