diff --git a/.gitignore b/.gitignore index 691da12..191c9ff 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,4 @@ gin-bin lishwist.db .env*.local -db/init_sql.go +server/db/init_sql.go diff --git a/go.work b/go.work index d6397df..bd31234 100644 --- a/go.work +++ b/go.work @@ -1,3 +1,5 @@ -go 1.22.0 +go 1.23 + +toolchain go1.23.3 use ./server diff --git a/server/auth/auth.go b/server/auth/auth.go index 7df46f6..d2ee3eb 100644 --- a/server/auth/auth.go +++ b/server/auth/auth.go @@ -8,11 +8,11 @@ import ( "lishwist/db" "lishwist/env" - "github.com/gorilla/sessions" + "github.com/Teajey/sqlstore" ) type AuthMiddleware struct { - Store *sessions.CookieStore + Store *sqlstore.Store protectedHandler http.Handler publicHandler http.Handler } @@ -44,7 +44,12 @@ func (auth *AuthMiddleware) ExpectUser(r *http.Request) *db.User { func NewAuthMiddleware(protectedHandler http.Handler, publicHandler http.Handler) *AuthMiddleware { gob.Register(&RegisterProps{}) gob.Register(&LoginProps{}) - store := sessions.NewCookieStore([]byte(env.JwtSecret)) + store, err := db.NewSessionStore() + if err != nil { + log.Fatalln("Failed to create store:", err) + } store.Options.MaxAge = 86_400 + store.Options.Secure = !env.InDev + store.Options.HttpOnly = true return &AuthMiddleware{store, protectedHandler, publicHandler} } diff --git a/server/auth/login_post.go b/server/auth/login_post.go index 920b539..9d19a67 100644 --- a/server/auth/login_post.go +++ b/server/auth/login_post.go @@ -43,13 +43,9 @@ func (auth *AuthMiddleware) LoginPost(w http.ResponseWriter, r *http.Request) { return } - session, err := auth.Store.Get(r, "lishwist_user") - if err != nil { - log.Println("Couldn't get jwt:", err) - props.GeneralError = "Something went wrong. Error code: Sokka" - auth.RedirectWithFlash(w, r, "/", "login_props", &props) - return - } + // NOTE: Overwriting any existing cookie or session here. So we don't care if there's an error + session, _ := auth.Store.Get(r, "lishwist_user") + session.ID = "" session.Values["authorized"] = true session.Values["username"] = username if err := session.Save(r, w); err != nil { diff --git a/server/auth/logout.go b/server/auth/logout.go index ad0ef17..c4fcb0e 100644 --- a/server/auth/logout.go +++ b/server/auth/logout.go @@ -10,6 +10,8 @@ func (auth *AuthMiddleware) LogoutPost(w http.ResponseWriter, r *http.Request) { http.Error(w, "Something went wrong. Error code: Iroh", http.StatusInternalServerError) return } + + session.Options.MaxAge = 0 session.Values = nil if err := session.Save(r, w); err != nil { http.Error(w, "Something went wrong. Error code: Azula", http.StatusInternalServerError) diff --git a/server/db/db.go b/server/db/db.go index 7e1c893..109bfb5 100644 --- a/server/db/db.go +++ b/server/db/db.go @@ -4,7 +4,10 @@ package db import ( "database/sql" + "fmt" + "lishwist/env" + "github.com/Teajey/sqlstore" _ "github.com/glebarez/go-sqlite" ) @@ -26,3 +29,34 @@ func Init() error { } return nil } + +func NewSessionStore() (*sqlstore.Store, error) { + deleteStmt, err := database.Prepare("DELETE FROM session WHERE id = ?;") + if err != nil { + return nil, fmt.Errorf("Failed to prepare delete statement: %w", err) + } + + insertStmt, err := database.Prepare("INSERT INTO session (value) VALUES (?);") + if err != nil { + return nil, fmt.Errorf("Failed to prepare insert statement: %w", err) + } + + selectStmt, err := database.Prepare("SELECT value FROM session WHERE id = ?;") + if err != nil { + return nil, fmt.Errorf("Failed to prepare select statement: %w", err) + } + + updateStmt, err := database.Prepare("UPDATE session SET value = ?2 WHERE id = ?1;") + if err != nil { + return nil, fmt.Errorf("Failed to prepare update statement: %w", err) + } + + s := sqlstore.NewSqlStore(database, sqlstore.Statements{ + Delete: deleteStmt, + Insert: insertStmt, + Select: selectStmt, + Update: updateStmt, + }, []byte(env.JwtSecret)) + + return s, nil +} diff --git a/server/db/init.sql b/server/db/init.sql index d3b70d7..b9f5172 100644 --- a/server/db/init.sql +++ b/server/db/init.sql @@ -32,4 +32,9 @@ CREATE TABLE IF NOT EXISTS "group_member" ( FOREIGN KEY("group_id") REFERENCES "group"("id"), FOREIGN KEY("user_id") REFERENCES "user"("id") ); +CREATE TABLE IF NOT EXISTS "session" ( + "id" INTEGER NOT NULL UNIQUE, + "value" TEXT NOT NULL, + PRIMARY KEY("id" AUTOINCREMENT) +); COMMIT; diff --git a/server/db/user.go b/server/db/user.go index 97bf4cd..067c64b 100644 --- a/server/db/user.go +++ b/server/db/user.go @@ -190,7 +190,7 @@ func (u *User) GetTodo() ([]Gift, error) { var sent bool var recipientName string var recipientRef string - rows.Scan(&id, &name, &sent, &recipientName, &recipientRef) + _ = rows.Scan(&id, &name, &sent, &recipientName, &recipientRef) gift := Gift{ Id: id, Name: name, diff --git a/server/env/env.go b/server/env/env.go index e56dc12..3f9f00b 100644 --- a/server/env/env.go +++ b/server/env/env.go @@ -14,10 +14,11 @@ func GuaranteeEnv(key string) (variable string) { return } -var JwtSecret = GuaranteeEnv("LISHWIST_JWT_SECRET") +var JwtSecret = GuaranteeEnv("LISHWIST_SESSION_SECRET") var HostRootUrl = GuaranteeEnv("LISHWIST_HOST_ROOT_URL") var HostPort = os.Getenv("LISHWIST_HOST_PORT") var ServePort = GuaranteeEnv("LISHWIST_SERVE_PORT") +var InDev = os.Getenv("LISHWIST_IN_DEV") != "" var HostUrl = func() *url.URL { rawUrl := HostRootUrl if HostPort != "" { diff --git a/server/go.mod b/server/go.mod index 28992fd..cebcb93 100644 --- a/server/go.mod +++ b/server/go.mod @@ -1,11 +1,14 @@ module lishwist -go 1.22.0 +go 1.23 + +toolchain go1.23.3 require ( + github.com/Teajey/sqlstore v0.0.6 github.com/glebarez/go-sqlite v1.22.0 github.com/google/uuid v1.6.0 - github.com/gorilla/sessions v1.2.2 + github.com/gorilla/sessions v1.4.0 golang.org/x/crypto v0.22.0 ) diff --git a/server/go.sum b/server/go.sum index 5b45ff9..8b650a2 100644 --- a/server/go.sum +++ b/server/go.sum @@ -1,3 +1,11 @@ +github.com/Teajey/sqlstore v0.0.3 h1:6Y1jz9/yw1cj/Z/jrii0s87RAomKWr/07B1auDgw8pg= +github.com/Teajey/sqlstore v0.0.3/go.mod h1:hjk0S593/2Q4QxkEXCgpThj9w5KWGTQi9JtgfziHXXk= +github.com/Teajey/sqlstore v0.0.4 h1:ATe25BD8cV0FUw4w2qlccx5m0c5kQI0K4ksl/LnSHsc= +github.com/Teajey/sqlstore v0.0.4/go.mod h1:hjk0S593/2Q4QxkEXCgpThj9w5KWGTQi9JtgfziHXXk= +github.com/Teajey/sqlstore v0.0.5 h1:WZvu54baa8+9n1sKQe9GuxBVwSISw+xCkw4VFSwwIs8= +github.com/Teajey/sqlstore v0.0.5/go.mod h1:hjk0S593/2Q4QxkEXCgpThj9w5KWGTQi9JtgfziHXXk= +github.com/Teajey/sqlstore v0.0.6 h1:kUEpA+3BKFHZl128MuMeYY6zVcmq1QmOlNyofcFEJOA= +github.com/Teajey/sqlstore v0.0.6/go.mod h1:hjk0S593/2Q4QxkEXCgpThj9w5KWGTQi9JtgfziHXXk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/glebarez/go-sqlite v1.22.0 h1:uAcMJhaA6r3LHMTFgP0SifzgXg46yJkgxqyuyec+ruQ= @@ -10,8 +18,8 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= -github.com/gorilla/sessions v1.2.2 h1:lqzMYz6bOfvn2WriPUjNByzeXIlVzURcPmgMczkmTjY= -github.com/gorilla/sessions v1.2.2/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ= +github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ= +github.com/gorilla/sessions v1.4.0/go.mod h1:FLWm50oby91+hl7p/wRxDth9bWSuk0qVL2emc7lT5ik= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= diff --git a/server/main.go b/server/main.go index c5cb1c4..8ab54f4 100644 --- a/server/main.go +++ b/server/main.go @@ -45,5 +45,8 @@ func main() { http.Handle("/", authMiddleware) - http.ListenAndServe(":"+env.ServePort, nil) + err = http.ListenAndServe(":"+env.ServePort, nil) + if err != nil { + log.Fatalln("Failed to listen and server:", err) + } }