diff --git a/server/api/register.go b/server/api/register.go index 77a4df1..88b5312 100644 --- a/server/api/register.go +++ b/server/api/register.go @@ -77,7 +77,7 @@ func Register(username, newPassword, confirmPassword string) *RegisterProps { existingUser, _ := db.GetUserByName(username) if existingUser != nil { - log.Printf("Username is taken: %q\n", username) + log.Printf("Username is taken: %q\n", existingUser.NormalName) props.Username.Error = "Username is taken" return props } diff --git a/server/db/group.go b/server/db/group.go index 5b1efdc..9ba69d3 100644 --- a/server/db/group.go +++ b/server/db/group.go @@ -3,6 +3,7 @@ package db import ( "database/sql" "fmt" + "lishwist/normalize" "strconv" ) @@ -66,7 +67,7 @@ func queryForGroups(query string, args ...any) ([]Group, error) { } func queryForGroupMembers(groupId string) ([]User, error) { - query := "SELECT user.id, user.name, user.reference, user.is_admin, user.is_live FROM v_user AS user JOIN group_member ON group_member.user_id = user.id JOIN [group] ON [group].id = group_member.group_id WHERE [group].id = ? ORDER BY group_member.user_id" + query := "SELECT user.id, user.name, user.display_name, user.reference, user.is_admin, user.is_live FROM v_user AS user JOIN group_member ON group_member.user_id = user.id JOIN [group] ON [group].id = group_member.group_id WHERE [group].id = ? ORDER BY group_member.user_id" members, err := queryForUsers(query, groupId) if err != nil { return members, fmt.Errorf("Failed to get members: %w", err) @@ -85,6 +86,7 @@ func GetAllGroups() ([]Group, error) { } func CreateGroup(name string, reference string) (*Group, error) { + name = normalize.Name(name) stmt := "INSERT INTO [group] (name, reference) VALUES (?, ?)" result, err := database.Exec(stmt, name, reference) if err != nil { diff --git a/server/db/init.sql b/server/db/init.sql index 5f1aeff..2780036 100644 --- a/server/db/init.sql +++ b/server/db/init.sql @@ -2,6 +2,7 @@ BEGIN TRANSACTION; CREATE TABLE IF NOT EXISTS "user" ( "id" INTEGER NOT NULL UNIQUE, "name" TEXT NOT NULL UNIQUE, + "display_name" TEXT NOT NULL UNIQUE, "reference" TEXT NOT NULL UNIQUE, "motto" TEXT NOT NULL DEFAULT "", "password_hash" TEXT NOT NULL, diff --git a/server/db/user.go b/server/db/user.go index 79975f9..ea40c88 100644 --- a/server/db/user.go +++ b/server/db/user.go @@ -3,16 +3,18 @@ package db import ( "database/sql" "fmt" + "lishwist/normalize" "github.com/google/uuid" ) type User struct { - Id string - Name string - Reference string - IsAdmin bool - IsLive bool + Id string + NormalName string + Name string + Reference string + IsAdmin bool + IsLive bool } type Gift struct { @@ -30,7 +32,7 @@ type Gift struct { func queryForUser(query string, args ...any) (*User, error) { var u User - err := database.QueryRow(query, args...).Scan(&u.Id, &u.Name, &u.Reference, &u.IsAdmin, &u.IsLive) + err := database.QueryRow(query, args...).Scan(&u.Id, &u.NormalName, &u.Name, &u.Reference, &u.IsAdmin, &u.IsLive) if err == sql.ErrNoRows { return nil, nil } else if err != nil { @@ -48,7 +50,7 @@ func queryForUsers(query string, args ...any) ([]User, error) { users := []User{} for rows.Next() { var u User - err = rows.Scan(&u.Id, &u.Name, &u.Reference, &u.IsAdmin, &u.IsLive) + err = rows.Scan(&u.Id, &u.NormalName, &u.Name, &u.Reference, &u.IsAdmin, &u.IsLive) if err != nil { return nil, err } @@ -62,27 +64,28 @@ func queryForUsers(query string, args ...any) ([]User, error) { } func GetAllUsers() ([]User, error) { - stmt := "SELECT id, name, reference, is_admin, is_live FROM user" + stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM user" return queryForUsers(stmt) } func GetUser(id string) (*User, error) { - stmt := "SELECT id, name, reference, is_admin, is_live FROM v_user WHERE id = ?" + stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM v_user WHERE id = ?" return queryForUser(stmt, id) } func GetUserByName(username string) (*User, error) { - stmt := "SELECT id, name, reference, is_admin, is_live FROM v_user WHERE name = ?" + username = normalize.Name(username) + stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM v_user WHERE name = ?" return queryForUser(stmt, username) } func GetUserByReference(reference string) (*User, error) { - stmt := "SELECT id, name, reference, is_admin, is_live FROM v_user WHERE reference = ?" + stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM v_user WHERE reference = ?" return queryForUser(stmt, reference) } func GetAnyUserByReference(reference string) (*User, error) { - stmt := "SELECT id, name, reference, is_admin, is_live FROM user WHERE reference = ?" + stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM user WHERE reference = ?" return queryForUser(stmt, reference) } @@ -96,13 +99,14 @@ func (u *User) SetLive(setting bool) error { return err } -func CreateUser(username string, passHash []byte) (*User, error) { - stmt := "INSERT INTO user (name, reference, password_hash) VALUES (?, ?, ?)" +func CreateUser(name string, passHash []byte) (*User, error) { + username := normalize.Name(name) + stmt := "INSERT INTO user (name, display_name, reference, password_hash) VALUES (?, ?, ?, ?)" reference, err := uuid.NewRandom() if err != nil { return nil, err } - result, err := database.Exec(stmt, username, reference, passHash) + result, err := database.Exec(stmt, username, name, reference, passHash) if err != nil { return nil, err } @@ -112,7 +116,7 @@ func CreateUser(username string, passHash []byte) (*User, error) { } user := User{ Id: fmt.Sprintf("%d", id), - Name: username, + Name: name, } return &user, nil } diff --git a/server/normalize/name.go b/server/normalize/name.go new file mode 100644 index 0000000..ccd8574 --- /dev/null +++ b/server/normalize/name.go @@ -0,0 +1,14 @@ +package normalize + +import ( + "strings" +) + +func Trim(s string) string { + return strings.Trim(s, " \t") +} + +func Name(name string) string { + name = Trim(name) + return strings.ToLower(name) +}