feat: register and login
This commit is contained in:
parent
d89b855299
commit
bffa68c9f7
|
|
@ -2,5 +2,5 @@
|
|||
gin-bin
|
||||
*lishwist.db
|
||||
.env*.local
|
||||
http/api/db/init_sql.go
|
||||
init_sql.go
|
||||
.ignored/
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
package lishwist
|
||||
|
||||
type Admin struct {
|
||||
user *User
|
||||
}
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
module lishwist/core
|
||||
|
||||
go 1.23
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
package lishwist
|
||||
|
||||
import (
|
||||
"lishwist/core/internal/db"
|
||||
)
|
||||
|
||||
func Init(dataSourceName string) error {
|
||||
return db.Init(dataSourceName)
|
||||
}
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
//go:generate go run gen_init_sql.go
|
||||
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/glebarez/go-sqlite"
|
||||
)
|
||||
|
||||
var Connection *sql.DB
|
||||
|
||||
func Init(dataSourceName string) error {
|
||||
db, err := sql.Open("sqlite", dataSourceName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to open db connection: %w", err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(initQuery)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to initialize db: %w", err)
|
||||
}
|
||||
|
||||
Connection = db
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
//go:build ignore
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
var initTemplate = template.Must(template.New("").Parse("// Code generated DO NOT EDIT.\n" +
|
||||
"package db\n" +
|
||||
"\n" +
|
||||
"const initQuery = `{{.}}`\n",
|
||||
))
|
||||
|
||||
func main() {
|
||||
initStmt, err := os.ReadFile("./init.sql")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
f, err := os.Create("./init_sql.go")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
initTemplate.Execute(f, string(initStmt))
|
||||
}
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
BEGIN TRANSACTION;
|
||||
CREATE TABLE IF NOT EXISTS "user" (
|
||||
"id" INTEGER NOT NULL UNIQUE,
|
||||
"name" TEXT NOT NULL UNIQUE,
|
||||
"display_name" TEXT NOT NULL UNIQUE,
|
||||
"reference" TEXT NOT NULL UNIQUE,
|
||||
"motto" TEXT NOT NULL DEFAULT "",
|
||||
"password_hash" TEXT NOT NULL,
|
||||
"is_admin" INTEGER NOT NULL DEFAULT 0,
|
||||
"is_live" INTEGER NOT NULL DEFAULT 1,
|
||||
PRIMARY KEY("id" AUTOINCREMENT)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS "gift" (
|
||||
"id" INTEGER NOT NULL UNIQUE,
|
||||
"name" TEXT NOT NULL,
|
||||
"recipient_id" INTEGER NOT NULL,
|
||||
"claimant_id" INTEGER,
|
||||
"creator_id" INTEGER NOT NULL,
|
||||
"sent" INTEGER NOT NULL DEFAULT 0,
|
||||
PRIMARY KEY("id" AUTOINCREMENT),
|
||||
FOREIGN KEY("recipient_id") REFERENCES "user"("id"),
|
||||
FOREIGN KEY("creator_id") REFERENCES "user"("id"),
|
||||
FOREIGN KEY("claimant_id") REFERENCES "user"("id")
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS "group" (
|
||||
"id" INTEGER NOT NULL UNIQUE,
|
||||
"name" TEXT NOT NULL UNIQUE,
|
||||
"reference" TEXT NOT NULL UNIQUE,
|
||||
PRIMARY KEY("id" AUTOINCREMENT)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS "group_member" (
|
||||
"group_id" INTEGER NOT NULL,
|
||||
"user_id" INTEGER NOT NULL,
|
||||
UNIQUE("user_id","group_id"),
|
||||
FOREIGN KEY("group_id") REFERENCES "group"("id"),
|
||||
FOREIGN KEY("user_id") REFERENCES "user"("id")
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS "session" (
|
||||
"id" INTEGER NOT NULL UNIQUE,
|
||||
"user_id" INTEGER NOT NULL,
|
||||
PRIMARY KEY("id" AUTOINCREMENT),
|
||||
FOREIGN KEY("user_id") REFERENCES "user"("id")
|
||||
);
|
||||
|
||||
DROP VIEW IF EXISTS "v_user";
|
||||
CREATE VIEW "v_user"
|
||||
AS
|
||||
SELECT * FROM user WHERE user.is_live = 1;
|
||||
|
||||
-- DROP VIEW IF EXISTS "v_wish";
|
||||
-- CREATE VIEW "v_wish"
|
||||
-- AS
|
||||
-- SELECT gift.id, gift.name, gift.sent FROM gift JOIN user AS recipient;
|
||||
|
||||
COMMIT;
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
package normalize
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func Trim(s string) string {
|
||||
return strings.Trim(s, " \t")
|
||||
}
|
||||
|
||||
func Name(name string) string {
|
||||
name = Trim(name)
|
||||
return strings.ToLower(name)
|
||||
}
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
package lishwist
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func (sm *SessionManager) Login(username, password string) (*Session, error) {
|
||||
user, err := getUserByName(username)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to fetch user: %w", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, fmt.Errorf("User not found by name: %s", username)
|
||||
}
|
||||
|
||||
passHash, err := user.getPassHash()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to get password hash: %w", err)
|
||||
}
|
||||
|
||||
err = bcrypt.CompareHashAndPassword(passHash, []byte(password))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session, err := sm.createSession(user)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Couldn't create session: %w", err)
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
package lishwist_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
lishwist "lishwist/core"
|
||||
)
|
||||
|
||||
func TestLogin(t *testing.T) {
|
||||
err := lishwist.Init(":memory:")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to init db: %s\n", err)
|
||||
}
|
||||
|
||||
lw := lishwist.NewSessionManager(time.Second*10, 32)
|
||||
|
||||
err = lishwist.Register("thomas", "123")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to register: %s\n", err)
|
||||
}
|
||||
|
||||
_, err = lw.Login("thomas", "123")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to login: %s\n", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
package lishwist
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func Register(username, newPassword string) error {
|
||||
if username == "" {
|
||||
return fmt.Errorf("Username required")
|
||||
}
|
||||
if newPassword == "" {
|
||||
return fmt.Errorf("newPassword required")
|
||||
}
|
||||
|
||||
existingUser, _ := getUserByName(username)
|
||||
if existingUser != nil {
|
||||
return fmt.Errorf("Username is taken")
|
||||
}
|
||||
|
||||
hashedPasswordBytes, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.MinCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
_, err = createUser(username, hashedPasswordBytes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to create user: %w\n", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
package lishwist
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"lishwist/core/internal/db"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
Id string
|
||||
Token string
|
||||
User *User
|
||||
ExpiresAt time.Time
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
type SessionManager struct {
|
||||
sessionDuration time.Duration
|
||||
sessionTokenLength uint
|
||||
}
|
||||
|
||||
func NewSessionManager(sessionDuration time.Duration, sessionTokenLength uint) SessionManager {
|
||||
return SessionManager{
|
||||
sessionDuration,
|
||||
sessionTokenLength,
|
||||
}
|
||||
}
|
||||
|
||||
func generateSecureToken(size uint) (string, error) {
|
||||
bytes := make([]byte, size)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
func (sm *SessionManager) createSession(user *User) (*Session, error) {
|
||||
stmt := "INSERT INTO session (user_id) VALUES (?);"
|
||||
result, err := db.Connection.Exec(stmt, user.Id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
token, err := generateSecureToken(sm.sessionTokenLength)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Failed to generate secure token: %w", err)
|
||||
}
|
||||
|
||||
session := Session{
|
||||
Id: fmt.Sprintf("%s", id),
|
||||
Token: token,
|
||||
User: user,
|
||||
ExpiresAt: time.Now().Add(sm.sessionDuration),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
|
@ -0,0 +1,98 @@
|
|||
package lishwist
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"lishwist/core/internal/db"
|
||||
"lishwist/core/internal/normalize"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
Id string
|
||||
NormalName string
|
||||
Name string
|
||||
Reference string
|
||||
IsAdmin bool
|
||||
IsLive bool
|
||||
}
|
||||
|
||||
func queryManyUsers(query string, args ...any) ([]User, error) {
|
||||
rows, err := db.Connection.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
users := []User{}
|
||||
for rows.Next() {
|
||||
var u User
|
||||
err = rows.Scan(&u.Id, &u.NormalName, &u.Name, &u.Reference, &u.IsAdmin, &u.IsLive)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users = append(users, u)
|
||||
}
|
||||
err = rows.Err()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func queryOneUser(query string, args ...any) (*User, error) {
|
||||
users, err := queryManyUsers(query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(users) < 1 {
|
||||
return nil, nil
|
||||
}
|
||||
return &users[0], nil
|
||||
}
|
||||
|
||||
func (u *User) GetAdmin() *Admin {
|
||||
if u.IsAdmin {
|
||||
return &Admin{u}
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func getUserByName(username string) (*User, error) {
|
||||
username = normalize.Name(username)
|
||||
stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM v_user WHERE name = ?"
|
||||
return queryOneUser(stmt, username)
|
||||
}
|
||||
|
||||
func createUser(name string, passHash []byte) (*User, error) {
|
||||
username := normalize.Name(name)
|
||||
stmt := "INSERT INTO user (name, display_name, reference, password_hash) VALUES (?, ?, ?, ?)"
|
||||
reference, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result, err := db.Connection.Exec(stmt, username, name, reference, passHash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
user := User{
|
||||
Id: fmt.Sprintf("%d", id),
|
||||
Name: name,
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (u *User) getPassHash() ([]byte, error) {
|
||||
stmt := "SELECT password_hash FROM v_user WHERE id = ?"
|
||||
var passHash string
|
||||
err := db.Connection.QueryRow(stmt, u.Id).Scan(&passHash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return []byte(passHash), nil
|
||||
}
|
||||
Loading…
Reference in New Issue