feat: first user is an admin
This commit is contained in:
parent
5c13893f23
commit
5769d44576
|
|
@ -3,3 +3,11 @@ package lishwist
|
||||||
type Admin struct {
|
type Admin struct {
|
||||||
user *User
|
user *User
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Session) Admin() *Admin {
|
||||||
|
if s.User.IsAdmin {
|
||||||
|
return &Admin{s.User}
|
||||||
|
} else {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -7,3 +7,15 @@ func AssertEq[C comparable](t *testing.T, context string, expected, actual C) {
|
||||||
t.Errorf("%s: %#v != %#v", context, expected, actual)
|
t.Errorf("%s: %#v != %#v", context, expected, actual)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Assert(t *testing.T, context string, condition bool) {
|
||||||
|
if !condition {
|
||||||
|
t.Errorf("%s", context)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func FailIfErr(t *testing.T, err error, context string) {
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("%s: %s\n", context, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,14 +15,14 @@ func Login(username, password string) *lishwist.Session {
|
||||||
|
|
||||||
lw := lishwist.NewSessionManager(time.Second*10, 32)
|
lw := lishwist.NewSessionManager(time.Second*10, 32)
|
||||||
|
|
||||||
err = lishwist.Register("thomas", "123")
|
err = lishwist.Register(username, password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to register: %s\n", err)
|
log.Fatalf("Failed to register on login fixture: %s\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
session, err := lw.Login("thomas", "123")
|
session, err := lw.Login(username, password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Failed to login: %s\n", err)
|
log.Fatalf("Failed to login on fixture: %s\n", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return session
|
return session
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
package lishwist
|
package lishwist
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
|
@ -8,15 +9,15 @@ import (
|
||||||
|
|
||||||
func Register(username, newPassword string) error {
|
func Register(username, newPassword string) error {
|
||||||
if username == "" {
|
if username == "" {
|
||||||
return fmt.Errorf("Username required")
|
return errors.New("Username required")
|
||||||
}
|
}
|
||||||
if newPassword == "" {
|
if newPassword == "" {
|
||||||
return fmt.Errorf("newPassword required")
|
return errors.New("newPassword required")
|
||||||
}
|
}
|
||||||
|
|
||||||
existingUser, _ := getUserByName(username)
|
existingUser, _ := getUserByName(username)
|
||||||
if existingUser != nil {
|
if existingUser != nil {
|
||||||
return fmt.Errorf("Username is taken")
|
return errors.New("Username is taken")
|
||||||
}
|
}
|
||||||
|
|
||||||
hashedPasswordBytes, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.MinCost)
|
hashedPasswordBytes, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.MinCost)
|
||||||
|
|
@ -24,7 +25,12 @@ func Register(username, newPassword string) error {
|
||||||
return fmt.Errorf("Failed to hash password: %w", err)
|
return fmt.Errorf("Failed to hash password: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = createUser(username, hashedPasswordBytes)
|
usersExist, err := hasUsers()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Failed to count users: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = createUser(username, hashedPasswordBytes, !usersExist)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("Failed to create user: %w\n", err)
|
return fmt.Errorf("Failed to create user: %w\n", err)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
35
core/user.go
35
core/user.go
|
|
@ -51,34 +51,26 @@ func queryOneUser(query string, args ...any) (*User, error) {
|
||||||
return &users[0], nil
|
return &users[0], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) GetAdmin() *Admin {
|
|
||||||
if u.IsAdmin {
|
|
||||||
return &Admin{u}
|
|
||||||
} else {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getUserByName(username string) (*User, error) {
|
func getUserByName(username string) (*User, error) {
|
||||||
username = normalize.Name(username)
|
username = normalize.Name(username)
|
||||||
stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM v_user WHERE name = ?"
|
stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM v_user WHERE name = ?"
|
||||||
return queryOneUser(stmt, username)
|
return queryOneUser(stmt, username)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createUser(name string, passHash []byte) (*User, error) {
|
func createUser(name string, passHash []byte, isAdmin bool) (*User, error) {
|
||||||
username := normalize.Name(name)
|
username := normalize.Name(name)
|
||||||
stmt := "INSERT INTO user (name, display_name, reference, password_hash) VALUES (?, ?, ?, ?)"
|
stmt := "INSERT INTO user (name, display_name, reference, password_hash, is_admin) VALUES (?, ?, ?, ?, ?)"
|
||||||
reference, err := uuid.NewRandom()
|
reference, err := uuid.NewRandom()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("Failed to generate reference: %w")
|
||||||
}
|
}
|
||||||
result, err := db.Connection.Exec(stmt, username, name, reference, passHash)
|
result, err := db.Connection.Exec(stmt, username, name, reference, passHash, isAdmin)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("Failed to execute query: %w")
|
||||||
}
|
}
|
||||||
id, err := result.LastInsertId()
|
id, err := result.LastInsertId()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("Failed to get last insert id: %w")
|
||||||
}
|
}
|
||||||
user := User{
|
user := User{
|
||||||
Id: fmt.Sprintf("%d", id),
|
Id: fmt.Sprintf("%d", id),
|
||||||
|
|
@ -101,3 +93,18 @@ func getUserByReference(reference string) (*User, error) {
|
||||||
stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM v_user WHERE reference = ?"
|
stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM v_user WHERE reference = ?"
|
||||||
return queryOneUser(stmt, reference)
|
return queryOneUser(stmt, reference)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hasUsers() (bool, error) {
|
||||||
|
stmt := "SELECT COUNT(id) FROM v_user LIMIT 1"
|
||||||
|
var userCount uint
|
||||||
|
err := db.Connection.QueryRow(stmt).Scan(&userCount)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return userCount > 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*Admin) ListUsers() ([]User, error) {
|
||||||
|
stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM user"
|
||||||
|
return queryManyUsers(stmt)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,22 @@
|
||||||
|
package lishwist_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
lishwist "lishwist/core"
|
||||||
|
"lishwist/core/internal/fixtures"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFirstUserIsAdmin(t *testing.T) {
|
||||||
|
s := fixtures.Login("thomas", "123")
|
||||||
|
|
||||||
|
err := lishwist.Register("caleb", "123")
|
||||||
|
fixtures.FailIfErr(t, err, "Failed to register caleb")
|
||||||
|
|
||||||
|
users, err := s.Admin().ListUsers()
|
||||||
|
fixtures.FailIfErr(t, err, "Failed to list users")
|
||||||
|
|
||||||
|
fixtures.AssertEq(t, "Number of users", 2, len(users))
|
||||||
|
fixtures.Assert(t, "User 1 is admin", users[0].IsAdmin)
|
||||||
|
fixtures.Assert(t, "User 2 is not admin", !users[1].IsAdmin)
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue