diff --git a/core/admin.go b/core/admin.go index 133cde3..51ab1a1 100644 --- a/core/admin.go +++ b/core/admin.go @@ -3,3 +3,11 @@ package lishwist type Admin struct { user *User } + +func (s *Session) Admin() *Admin { + if s.User.IsAdmin { + return &Admin{s.User} + } else { + return nil + } +} diff --git a/core/internal/fixtures/assert.go b/core/internal/fixtures/assert.go index 3d654d8..35d288f 100644 --- a/core/internal/fixtures/assert.go +++ b/core/internal/fixtures/assert.go @@ -7,3 +7,15 @@ func AssertEq[C comparable](t *testing.T, context string, expected, actual C) { t.Errorf("%s: %#v != %#v", context, expected, actual) } } + +func Assert(t *testing.T, context string, condition bool) { + if !condition { + t.Errorf("%s", context) + } +} + +func FailIfErr(t *testing.T, err error, context string) { + if err != nil { + t.Fatalf("%s: %s\n", context, err) + } +} diff --git a/core/internal/fixtures/login.go b/core/internal/fixtures/login.go index 99583de..abff84c 100644 --- a/core/internal/fixtures/login.go +++ b/core/internal/fixtures/login.go @@ -15,14 +15,14 @@ func Login(username, password string) *lishwist.Session { lw := lishwist.NewSessionManager(time.Second*10, 32) - err = lishwist.Register("thomas", "123") + err = lishwist.Register(username, password) if err != nil { - log.Fatalf("Failed to register: %s\n", err) + log.Fatalf("Failed to register on login fixture: %s\n", err) } - session, err := lw.Login("thomas", "123") + session, err := lw.Login(username, password) if err != nil { - log.Fatalf("Failed to login: %s\n", err) + log.Fatalf("Failed to login on fixture: %s\n", err) } return session diff --git a/core/register.go b/core/register.go index ff7fa6b..5ec8432 100644 --- a/core/register.go +++ b/core/register.go @@ -1,6 +1,7 @@ package lishwist import ( + "errors" "fmt" "golang.org/x/crypto/bcrypt" @@ -8,15 +9,15 @@ import ( func Register(username, newPassword string) error { if username == "" { - return fmt.Errorf("Username required") + return errors.New("Username required") } if newPassword == "" { - return fmt.Errorf("newPassword required") + return errors.New("newPassword required") } existingUser, _ := getUserByName(username) if existingUser != nil { - return fmt.Errorf("Username is taken") + return errors.New("Username is taken") } hashedPasswordBytes, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.MinCost) @@ -24,7 +25,12 @@ func Register(username, newPassword string) error { return fmt.Errorf("Failed to hash password: %w", err) } - _, err = createUser(username, hashedPasswordBytes) + usersExist, err := hasUsers() + if err != nil { + return fmt.Errorf("Failed to count users: %w", err) + } + + _, err = createUser(username, hashedPasswordBytes, !usersExist) if err != nil { return fmt.Errorf("Failed to create user: %w\n", err) } diff --git a/core/user.go b/core/user.go index 6b8b5f6..33edcd5 100644 --- a/core/user.go +++ b/core/user.go @@ -51,34 +51,26 @@ func queryOneUser(query string, args ...any) (*User, error) { return &users[0], nil } -func (u *User) GetAdmin() *Admin { - if u.IsAdmin { - return &Admin{u} - } else { - return nil - } -} - func getUserByName(username string) (*User, error) { username = normalize.Name(username) stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM v_user WHERE name = ?" return queryOneUser(stmt, username) } -func createUser(name string, passHash []byte) (*User, error) { +func createUser(name string, passHash []byte, isAdmin bool) (*User, error) { username := normalize.Name(name) - stmt := "INSERT INTO user (name, display_name, reference, password_hash) VALUES (?, ?, ?, ?)" + stmt := "INSERT INTO user (name, display_name, reference, password_hash, is_admin) VALUES (?, ?, ?, ?, ?)" reference, err := uuid.NewRandom() if err != nil { - return nil, err + return nil, fmt.Errorf("Failed to generate reference: %w") } - result, err := db.Connection.Exec(stmt, username, name, reference, passHash) + result, err := db.Connection.Exec(stmt, username, name, reference, passHash, isAdmin) if err != nil { - return nil, err + return nil, fmt.Errorf("Failed to execute query: %w") } id, err := result.LastInsertId() if err != nil { - return nil, err + return nil, fmt.Errorf("Failed to get last insert id: %w") } user := User{ Id: fmt.Sprintf("%d", id), @@ -101,3 +93,18 @@ func getUserByReference(reference string) (*User, error) { stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM v_user WHERE reference = ?" return queryOneUser(stmt, reference) } + +func hasUsers() (bool, error) { + stmt := "SELECT COUNT(id) FROM v_user LIMIT 1" + var userCount uint + err := db.Connection.QueryRow(stmt).Scan(&userCount) + if err != nil { + return false, err + } + return userCount > 0, nil +} + +func (*Admin) ListUsers() ([]User, error) { + stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM user" + return queryManyUsers(stmt) +} diff --git a/core/user_test.go b/core/user_test.go new file mode 100644 index 0000000..6d53d5b --- /dev/null +++ b/core/user_test.go @@ -0,0 +1,22 @@ +package lishwist_test + +import ( + "testing" + + lishwist "lishwist/core" + "lishwist/core/internal/fixtures" +) + +func TestFirstUserIsAdmin(t *testing.T) { + s := fixtures.Login("thomas", "123") + + err := lishwist.Register("caleb", "123") + fixtures.FailIfErr(t, err, "Failed to register caleb") + + users, err := s.Admin().ListUsers() + fixtures.FailIfErr(t, err, "Failed to list users") + + fixtures.AssertEq(t, "Number of users", 2, len(users)) + fixtures.Assert(t, "User 1 is admin", users[0].IsAdmin) + fixtures.Assert(t, "User 2 is not admin", !users[1].IsAdmin) +}