feat: first user is an admin
This commit is contained in:
parent
5c13893f23
commit
5769d44576
|
|
@ -3,3 +3,11 @@ package lishwist
|
|||
type Admin struct {
|
||||
user *User
|
||||
}
|
||||
|
||||
func (s *Session) Admin() *Admin {
|
||||
if s.User.IsAdmin {
|
||||
return &Admin{s.User}
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,3 +7,15 @@ func AssertEq[C comparable](t *testing.T, context string, expected, actual C) {
|
|||
t.Errorf("%s: %#v != %#v", context, expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func Assert(t *testing.T, context string, condition bool) {
|
||||
if !condition {
|
||||
t.Errorf("%s", context)
|
||||
}
|
||||
}
|
||||
|
||||
func FailIfErr(t *testing.T, err error, context string) {
|
||||
if err != nil {
|
||||
t.Fatalf("%s: %s\n", context, err)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,14 +15,14 @@ func Login(username, password string) *lishwist.Session {
|
|||
|
||||
lw := lishwist.NewSessionManager(time.Second*10, 32)
|
||||
|
||||
err = lishwist.Register("thomas", "123")
|
||||
err = lishwist.Register(username, password)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to register: %s\n", err)
|
||||
log.Fatalf("Failed to register on login fixture: %s\n", err)
|
||||
}
|
||||
|
||||
session, err := lw.Login("thomas", "123")
|
||||
session, err := lw.Login(username, password)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to login: %s\n", err)
|
||||
log.Fatalf("Failed to login on fixture: %s\n", err)
|
||||
}
|
||||
|
||||
return session
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package lishwist
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
|
@ -8,15 +9,15 @@ import (
|
|||
|
||||
func Register(username, newPassword string) error {
|
||||
if username == "" {
|
||||
return fmt.Errorf("Username required")
|
||||
return errors.New("Username required")
|
||||
}
|
||||
if newPassword == "" {
|
||||
return fmt.Errorf("newPassword required")
|
||||
return errors.New("newPassword required")
|
||||
}
|
||||
|
||||
existingUser, _ := getUserByName(username)
|
||||
if existingUser != nil {
|
||||
return fmt.Errorf("Username is taken")
|
||||
return errors.New("Username is taken")
|
||||
}
|
||||
|
||||
hashedPasswordBytes, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.MinCost)
|
||||
|
|
@ -24,7 +25,12 @@ func Register(username, newPassword string) error {
|
|||
return fmt.Errorf("Failed to hash password: %w", err)
|
||||
}
|
||||
|
||||
_, err = createUser(username, hashedPasswordBytes)
|
||||
usersExist, err := hasUsers()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to count users: %w", err)
|
||||
}
|
||||
|
||||
_, err = createUser(username, hashedPasswordBytes, !usersExist)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to create user: %w\n", err)
|
||||
}
|
||||
|
|
|
|||
35
core/user.go
35
core/user.go
|
|
@ -51,34 +51,26 @@ func queryOneUser(query string, args ...any) (*User, error) {
|
|||
return &users[0], nil
|
||||
}
|
||||
|
||||
func (u *User) GetAdmin() *Admin {
|
||||
if u.IsAdmin {
|
||||
return &Admin{u}
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func getUserByName(username string) (*User, error) {
|
||||
username = normalize.Name(username)
|
||||
stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM v_user WHERE name = ?"
|
||||
return queryOneUser(stmt, username)
|
||||
}
|
||||
|
||||
func createUser(name string, passHash []byte) (*User, error) {
|
||||
func createUser(name string, passHash []byte, isAdmin bool) (*User, error) {
|
||||
username := normalize.Name(name)
|
||||
stmt := "INSERT INTO user (name, display_name, reference, password_hash) VALUES (?, ?, ?, ?)"
|
||||
stmt := "INSERT INTO user (name, display_name, reference, password_hash, is_admin) VALUES (?, ?, ?, ?, ?)"
|
||||
reference, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("Failed to generate reference: %w")
|
||||
}
|
||||
result, err := db.Connection.Exec(stmt, username, name, reference, passHash)
|
||||
result, err := db.Connection.Exec(stmt, username, name, reference, passHash, isAdmin)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("Failed to execute query: %w")
|
||||
}
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("Failed to get last insert id: %w")
|
||||
}
|
||||
user := User{
|
||||
Id: fmt.Sprintf("%d", id),
|
||||
|
|
@ -101,3 +93,18 @@ func getUserByReference(reference string) (*User, error) {
|
|||
stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM v_user WHERE reference = ?"
|
||||
return queryOneUser(stmt, reference)
|
||||
}
|
||||
|
||||
func hasUsers() (bool, error) {
|
||||
stmt := "SELECT COUNT(id) FROM v_user LIMIT 1"
|
||||
var userCount uint
|
||||
err := db.Connection.QueryRow(stmt).Scan(&userCount)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return userCount > 0, nil
|
||||
}
|
||||
|
||||
func (*Admin) ListUsers() ([]User, error) {
|
||||
stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM user"
|
||||
return queryManyUsers(stmt)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,22 @@
|
|||
package lishwist_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
lishwist "lishwist/core"
|
||||
"lishwist/core/internal/fixtures"
|
||||
)
|
||||
|
||||
func TestFirstUserIsAdmin(t *testing.T) {
|
||||
s := fixtures.Login("thomas", "123")
|
||||
|
||||
err := lishwist.Register("caleb", "123")
|
||||
fixtures.FailIfErr(t, err, "Failed to register caleb")
|
||||
|
||||
users, err := s.Admin().ListUsers()
|
||||
fixtures.FailIfErr(t, err, "Failed to list users")
|
||||
|
||||
fixtures.AssertEq(t, "Number of users", 2, len(users))
|
||||
fixtures.Assert(t, "User 1 is admin", users[0].IsAdmin)
|
||||
fixtures.Assert(t, "User 2 is not admin", !users[1].IsAdmin)
|
||||
}
|
||||
Loading…
Reference in New Issue