feat: first user is an admin

This commit is contained in:
Teajey 2025-06-19 20:32:52 +09:00
parent 5c13893f23
commit 5769d44576
Signed by: Teajey
GPG Key ID: 970E790FE834A713
6 changed files with 77 additions and 22 deletions

View File

@ -3,3 +3,11 @@ package lishwist
type Admin struct { type Admin struct {
user *User user *User
} }
func (s *Session) Admin() *Admin {
if s.User.IsAdmin {
return &Admin{s.User}
} else {
return nil
}
}

View File

@ -7,3 +7,15 @@ func AssertEq[C comparable](t *testing.T, context string, expected, actual C) {
t.Errorf("%s: %#v != %#v", context, expected, actual) t.Errorf("%s: %#v != %#v", context, expected, actual)
} }
} }
func Assert(t *testing.T, context string, condition bool) {
if !condition {
t.Errorf("%s", context)
}
}
func FailIfErr(t *testing.T, err error, context string) {
if err != nil {
t.Fatalf("%s: %s\n", context, err)
}
}

View File

@ -15,14 +15,14 @@ func Login(username, password string) *lishwist.Session {
lw := lishwist.NewSessionManager(time.Second*10, 32) lw := lishwist.NewSessionManager(time.Second*10, 32)
err = lishwist.Register("thomas", "123") err = lishwist.Register(username, password)
if err != nil { if err != nil {
log.Fatalf("Failed to register: %s\n", err) log.Fatalf("Failed to register on login fixture: %s\n", err)
} }
session, err := lw.Login("thomas", "123") session, err := lw.Login(username, password)
if err != nil { if err != nil {
log.Fatalf("Failed to login: %s\n", err) log.Fatalf("Failed to login on fixture: %s\n", err)
} }
return session return session

View File

@ -1,6 +1,7 @@
package lishwist package lishwist
import ( import (
"errors"
"fmt" "fmt"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
@ -8,15 +9,15 @@ import (
func Register(username, newPassword string) error { func Register(username, newPassword string) error {
if username == "" { if username == "" {
return fmt.Errorf("Username required") return errors.New("Username required")
} }
if newPassword == "" { if newPassword == "" {
return fmt.Errorf("newPassword required") return errors.New("newPassword required")
} }
existingUser, _ := getUserByName(username) existingUser, _ := getUserByName(username)
if existingUser != nil { if existingUser != nil {
return fmt.Errorf("Username is taken") return errors.New("Username is taken")
} }
hashedPasswordBytes, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.MinCost) hashedPasswordBytes, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.MinCost)
@ -24,7 +25,12 @@ func Register(username, newPassword string) error {
return fmt.Errorf("Failed to hash password: %w", err) return fmt.Errorf("Failed to hash password: %w", err)
} }
_, err = createUser(username, hashedPasswordBytes) usersExist, err := hasUsers()
if err != nil {
return fmt.Errorf("Failed to count users: %w", err)
}
_, err = createUser(username, hashedPasswordBytes, !usersExist)
if err != nil { if err != nil {
return fmt.Errorf("Failed to create user: %w\n", err) return fmt.Errorf("Failed to create user: %w\n", err)
} }

View File

@ -51,34 +51,26 @@ func queryOneUser(query string, args ...any) (*User, error) {
return &users[0], nil return &users[0], nil
} }
func (u *User) GetAdmin() *Admin {
if u.IsAdmin {
return &Admin{u}
} else {
return nil
}
}
func getUserByName(username string) (*User, error) { func getUserByName(username string) (*User, error) {
username = normalize.Name(username) username = normalize.Name(username)
stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM v_user WHERE name = ?" stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM v_user WHERE name = ?"
return queryOneUser(stmt, username) return queryOneUser(stmt, username)
} }
func createUser(name string, passHash []byte) (*User, error) { func createUser(name string, passHash []byte, isAdmin bool) (*User, error) {
username := normalize.Name(name) username := normalize.Name(name)
stmt := "INSERT INTO user (name, display_name, reference, password_hash) VALUES (?, ?, ?, ?)" stmt := "INSERT INTO user (name, display_name, reference, password_hash, is_admin) VALUES (?, ?, ?, ?, ?)"
reference, err := uuid.NewRandom() reference, err := uuid.NewRandom()
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("Failed to generate reference: %w")
} }
result, err := db.Connection.Exec(stmt, username, name, reference, passHash) result, err := db.Connection.Exec(stmt, username, name, reference, passHash, isAdmin)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("Failed to execute query: %w")
} }
id, err := result.LastInsertId() id, err := result.LastInsertId()
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("Failed to get last insert id: %w")
} }
user := User{ user := User{
Id: fmt.Sprintf("%d", id), Id: fmt.Sprintf("%d", id),
@ -101,3 +93,18 @@ func getUserByReference(reference string) (*User, error) {
stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM v_user WHERE reference = ?" stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM v_user WHERE reference = ?"
return queryOneUser(stmt, reference) return queryOneUser(stmt, reference)
} }
func hasUsers() (bool, error) {
stmt := "SELECT COUNT(id) FROM v_user LIMIT 1"
var userCount uint
err := db.Connection.QueryRow(stmt).Scan(&userCount)
if err != nil {
return false, err
}
return userCount > 0, nil
}
func (*Admin) ListUsers() ([]User, error) {
stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM user"
return queryManyUsers(stmt)
}

22
core/user_test.go Normal file
View File

@ -0,0 +1,22 @@
package lishwist_test
import (
"testing"
lishwist "lishwist/core"
"lishwist/core/internal/fixtures"
)
func TestFirstUserIsAdmin(t *testing.T) {
s := fixtures.Login("thomas", "123")
err := lishwist.Register("caleb", "123")
fixtures.FailIfErr(t, err, "Failed to register caleb")
users, err := s.Admin().ListUsers()
fixtures.FailIfErr(t, err, "Failed to list users")
fixtures.AssertEq(t, "Number of users", 2, len(users))
fixtures.Assert(t, "User 1 is admin", users[0].IsAdmin)
fixtures.Assert(t, "User 2 is not admin", !users[1].IsAdmin)
}