Core separation #13
|
|
@ -2,5 +2,5 @@
|
||||||
gin-bin
|
gin-bin
|
||||||
*lishwist.db
|
*lishwist.db
|
||||||
.env*.local
|
.env*.local
|
||||||
http/api/db/init_sql.go
|
init_sql.go
|
||||||
.ignored/
|
.ignored/
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
package lishwist
|
||||||
|
|
||||||
|
type Admin struct {
|
||||||
|
user *User
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
module lishwist/core
|
||||||
|
|
||||||
|
go 1.23
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
package lishwist
|
||||||
|
|
||||||
|
import (
|
||||||
|
"lishwist/core/internal/db"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Init(dataSourceName string) error {
|
||||||
|
return db.Init(dataSourceName)
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,28 @@
|
||||||
|
//go:generate go run gen_init_sql.go
|
||||||
|
|
||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
_ "github.com/glebarez/go-sqlite"
|
||||||
|
)
|
||||||
|
|
||||||
|
var Connection *sql.DB
|
||||||
|
|
||||||
|
func Init(dataSourceName string) error {
|
||||||
|
db, err := sql.Open("sqlite", dataSourceName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Failed to open db connection: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.Exec(initQuery)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Failed to initialize db: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
Connection = db
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,30 @@
|
||||||
|
//go:build ignore
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"text/template"
|
||||||
|
)
|
||||||
|
|
||||||
|
var initTemplate = template.Must(template.New("").Parse("// Code generated DO NOT EDIT.\n" +
|
||||||
|
"package db\n" +
|
||||||
|
"\n" +
|
||||||
|
"const initQuery = `{{.}}`\n",
|
||||||
|
))
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
initStmt, err := os.ReadFile("./init.sql")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := os.Create("./init_sql.go")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
initTemplate.Execute(f, string(initStmt))
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,55 @@
|
||||||
|
BEGIN TRANSACTION;
|
||||||
|
CREATE TABLE IF NOT EXISTS "user" (
|
||||||
|
"id" INTEGER NOT NULL UNIQUE,
|
||||||
|
"name" TEXT NOT NULL UNIQUE,
|
||||||
|
"display_name" TEXT NOT NULL UNIQUE,
|
||||||
|
"reference" TEXT NOT NULL UNIQUE,
|
||||||
|
"motto" TEXT NOT NULL DEFAULT "",
|
||||||
|
"password_hash" TEXT NOT NULL,
|
||||||
|
"is_admin" INTEGER NOT NULL DEFAULT 0,
|
||||||
|
"is_live" INTEGER NOT NULL DEFAULT 1,
|
||||||
|
PRIMARY KEY("id" AUTOINCREMENT)
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS "gift" (
|
||||||
|
"id" INTEGER NOT NULL UNIQUE,
|
||||||
|
"name" TEXT NOT NULL,
|
||||||
|
"recipient_id" INTEGER NOT NULL,
|
||||||
|
"claimant_id" INTEGER,
|
||||||
|
"creator_id" INTEGER NOT NULL,
|
||||||
|
"sent" INTEGER NOT NULL DEFAULT 0,
|
||||||
|
PRIMARY KEY("id" AUTOINCREMENT),
|
||||||
|
FOREIGN KEY("recipient_id") REFERENCES "user"("id"),
|
||||||
|
FOREIGN KEY("creator_id") REFERENCES "user"("id"),
|
||||||
|
FOREIGN KEY("claimant_id") REFERENCES "user"("id")
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS "group" (
|
||||||
|
"id" INTEGER NOT NULL UNIQUE,
|
||||||
|
"name" TEXT NOT NULL UNIQUE,
|
||||||
|
"reference" TEXT NOT NULL UNIQUE,
|
||||||
|
PRIMARY KEY("id" AUTOINCREMENT)
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS "group_member" (
|
||||||
|
"group_id" INTEGER NOT NULL,
|
||||||
|
"user_id" INTEGER NOT NULL,
|
||||||
|
UNIQUE("user_id","group_id"),
|
||||||
|
FOREIGN KEY("group_id") REFERENCES "group"("id"),
|
||||||
|
FOREIGN KEY("user_id") REFERENCES "user"("id")
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS "session" (
|
||||||
|
"id" INTEGER NOT NULL UNIQUE,
|
||||||
|
"user_id" INTEGER NOT NULL,
|
||||||
|
PRIMARY KEY("id" AUTOINCREMENT),
|
||||||
|
FOREIGN KEY("user_id") REFERENCES "user"("id")
|
||||||
|
);
|
||||||
|
|
||||||
|
DROP VIEW IF EXISTS "v_user";
|
||||||
|
CREATE VIEW "v_user"
|
||||||
|
AS
|
||||||
|
SELECT * FROM user WHERE user.is_live = 1;
|
||||||
|
|
||||||
|
-- DROP VIEW IF EXISTS "v_wish";
|
||||||
|
-- CREATE VIEW "v_wish"
|
||||||
|
-- AS
|
||||||
|
-- SELECT gift.id, gift.name, gift.sent FROM gift JOIN user AS recipient;
|
||||||
|
|
||||||
|
COMMIT;
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
package normalize
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Trim(s string) string {
|
||||||
|
return strings.Trim(s, " \t")
|
||||||
|
}
|
||||||
|
|
||||||
|
func Name(name string) string {
|
||||||
|
name = Trim(name)
|
||||||
|
return strings.ToLower(name)
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,34 @@
|
||||||
|
package lishwist
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (sm *SessionManager) Login(username, password string) (*Session, error) {
|
||||||
|
user, err := getUserByName(username)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Failed to fetch user: %w", err)
|
||||||
|
}
|
||||||
|
if user == nil {
|
||||||
|
return nil, fmt.Errorf("User not found by name: %s", username)
|
||||||
|
}
|
||||||
|
|
||||||
|
passHash, err := user.getPassHash()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Failed to get password hash: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = bcrypt.CompareHashAndPassword(passHash, []byte(password))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
session, err := sm.createSession(user)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Couldn't create session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return session, nil
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,27 @@
|
||||||
|
package lishwist_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
lishwist "lishwist/core"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLogin(t *testing.T) {
|
||||||
|
err := lishwist.Init(":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to init db: %s\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
lw := lishwist.NewSessionManager(time.Second*10, 32)
|
||||||
|
|
||||||
|
err = lishwist.Register("thomas", "123")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to register: %s\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = lw.Login("thomas", "123")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to login: %s\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,33 @@
|
||||||
|
package lishwist
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Register(username, newPassword string) error {
|
||||||
|
if username == "" {
|
||||||
|
return fmt.Errorf("Username required")
|
||||||
|
}
|
||||||
|
if newPassword == "" {
|
||||||
|
return fmt.Errorf("newPassword required")
|
||||||
|
}
|
||||||
|
|
||||||
|
existingUser, _ := getUserByName(username)
|
||||||
|
if existingUser != nil {
|
||||||
|
return fmt.Errorf("Username is taken")
|
||||||
|
}
|
||||||
|
|
||||||
|
hashedPasswordBytes, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.MinCost)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Failed to hash password: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = createUser(username, hashedPasswordBytes)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Failed to create user: %w\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,64 @@
|
||||||
|
package lishwist
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"lishwist/core/internal/db"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Session struct {
|
||||||
|
Id string
|
||||||
|
Token string
|
||||||
|
User *User
|
||||||
|
ExpiresAt time.Time
|
||||||
|
CreatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
type SessionManager struct {
|
||||||
|
sessionDuration time.Duration
|
||||||
|
sessionTokenLength uint
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSessionManager(sessionDuration time.Duration, sessionTokenLength uint) SessionManager {
|
||||||
|
return SessionManager{
|
||||||
|
sessionDuration,
|
||||||
|
sessionTokenLength,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateSecureToken(size uint) (string, error) {
|
||||||
|
bytes := make([]byte, size)
|
||||||
|
if _, err := rand.Read(bytes); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return base64.URLEncoding.EncodeToString(bytes), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sm *SessionManager) createSession(user *User) (*Session, error) {
|
||||||
|
stmt := "INSERT INTO session (user_id) VALUES (?);"
|
||||||
|
result, err := db.Connection.Exec(stmt, user.Id)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
id, err := result.LastInsertId()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := generateSecureToken(sm.sessionTokenLength)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("Failed to generate secure token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
session := Session{
|
||||||
|
Id: fmt.Sprintf("%s", id),
|
||||||
|
Token: token,
|
||||||
|
User: user,
|
||||||
|
ExpiresAt: time.Now().Add(sm.sessionDuration),
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return &session, nil
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,98 @@
|
||||||
|
package lishwist
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
|
"lishwist/core/internal/db"
|
||||||
|
"lishwist/core/internal/normalize"
|
||||||
|
)
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
Id string
|
||||||
|
NormalName string
|
||||||
|
Name string
|
||||||
|
Reference string
|
||||||
|
IsAdmin bool
|
||||||
|
IsLive bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func queryManyUsers(query string, args ...any) ([]User, error) {
|
||||||
|
rows, err := db.Connection.Query(query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
users := []User{}
|
||||||
|
for rows.Next() {
|
||||||
|
var u User
|
||||||
|
err = rows.Scan(&u.Id, &u.NormalName, &u.Name, &u.Reference, &u.IsAdmin, &u.IsLive)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
users = append(users, u)
|
||||||
|
}
|
||||||
|
err = rows.Err()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return users, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func queryOneUser(query string, args ...any) (*User, error) {
|
||||||
|
users, err := queryManyUsers(query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(users) < 1 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return &users[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) GetAdmin() *Admin {
|
||||||
|
if u.IsAdmin {
|
||||||
|
return &Admin{u}
|
||||||
|
} else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getUserByName(username string) (*User, error) {
|
||||||
|
username = normalize.Name(username)
|
||||||
|
stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM v_user WHERE name = ?"
|
||||||
|
return queryOneUser(stmt, username)
|
||||||
|
}
|
||||||
|
|
||||||
|
func createUser(name string, passHash []byte) (*User, error) {
|
||||||
|
username := normalize.Name(name)
|
||||||
|
stmt := "INSERT INTO user (name, display_name, reference, password_hash) VALUES (?, ?, ?, ?)"
|
||||||
|
reference, err := uuid.NewRandom()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result, err := db.Connection.Exec(stmt, username, name, reference, passHash)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
id, err := result.LastInsertId()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
user := User{
|
||||||
|
Id: fmt.Sprintf("%d", id),
|
||||||
|
Name: name,
|
||||||
|
}
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) getPassHash() ([]byte, error) {
|
||||||
|
stmt := "SELECT password_hash FROM v_user WHERE id = ?"
|
||||||
|
var passHash string
|
||||||
|
err := db.Connection.QueryRow(stmt, u.Id).Scan(&passHash)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return []byte(passHash), nil
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue