Compare commits
6 Commits
b471a2e084
...
fbc6c9ca20
| Author | SHA1 | Date |
|---|---|---|
|
|
fbc6c9ca20 | |
|
|
fefc4ac3db | |
|
|
c8d179e297 | |
|
|
67abba1f67 | |
|
|
f2e67761ff | |
|
|
7f284d5003 |
|
|
@ -1,6 +1,6 @@
|
|||
.DS_Store
|
||||
gin-bin
|
||||
lishwist.db
|
||||
*lishwist.db
|
||||
.env*.local
|
||||
server/db/init_sql.go
|
||||
server/api/db/init_sql.go
|
||||
.ignored/
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ import (
|
|||
var database *sql.DB
|
||||
|
||||
func Open() error {
|
||||
db, err := sql.Open("sqlite", "./lishwist.db")
|
||||
db, err := sql.Open("sqlite", env.DatabaseFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -23,7 +23,7 @@ func Open() error {
|
|||
}
|
||||
|
||||
func Init() error {
|
||||
_, err := database.Exec(InitQuery)
|
||||
_, err := database.Exec(initQuery)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -11,7 +11,7 @@ import (
|
|||
var initTemplate = template.Must(template.New("").Parse("// Code generated DO NOT EDIT.\n" +
|
||||
"package db\n" +
|
||||
"\n" +
|
||||
"const InitQuery = `{{.}}`\n",
|
||||
"const initQuery = `{{.}}`\n",
|
||||
))
|
||||
|
||||
func main() {
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"lishwist/normalize"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
|
|
@ -22,52 +22,47 @@ func (g *Group) MemberIndex(userId string) int {
|
|||
return -1
|
||||
}
|
||||
|
||||
func queryForGroup(query string, args ...any) (*Group, error) {
|
||||
var group Group
|
||||
err := database.QueryRow(query, args...).Scan(&group.Id, &group.Name, &group.Reference)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
members, err := queryForGroupMembers(group.Id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
group.Members = members
|
||||
return &group, nil
|
||||
}
|
||||
|
||||
func queryForGroups(query string, args ...any) ([]Group, error) {
|
||||
func queryManyGroups(query string, args ...any) ([]Group, error) {
|
||||
groups := []Group{}
|
||||
rows, err := database.Query(query, args...)
|
||||
if err != nil {
|
||||
return groups, fmt.Errorf("Query failed: %w", err)
|
||||
return nil, fmt.Errorf("Query failed: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var group Group
|
||||
err := rows.Scan(&group.Id, &group.Name, &group.Reference)
|
||||
if err != nil {
|
||||
return groups, fmt.Errorf("Failed to scan row: %w", err)
|
||||
return nil, fmt.Errorf("Failed to scan row: %w", err)
|
||||
}
|
||||
members, err := queryForGroupMembers(group.Id)
|
||||
members, err := queryManyGroupMembers(group.Id)
|
||||
if err != nil {
|
||||
return groups, fmt.Errorf("Failed to query for group members: %w", err)
|
||||
return nil, fmt.Errorf("Failed to query for group members: %w", err)
|
||||
}
|
||||
group.Members = members
|
||||
groups = append(groups, group)
|
||||
}
|
||||
err = rows.Err()
|
||||
if err != nil {
|
||||
return groups, fmt.Errorf("Rows error: %w", err)
|
||||
return nil, fmt.Errorf("Rows error: %w", err)
|
||||
}
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func queryForGroupMembers(groupId string) ([]User, error) {
|
||||
query := "SELECT user.id, user.name, user.reference, user.is_admin, user.is_live FROM v_user AS user JOIN group_member ON group_member.user_id = user.id JOIN [group] ON [group].id = group_member.group_id WHERE [group].id = ? ORDER BY group_member.user_id"
|
||||
members, err := queryForUsers(query, groupId)
|
||||
func queryOneGroup(query string, args ...any) (*Group, error) {
|
||||
groups, err := queryManyGroups(query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(groups) < 1 {
|
||||
return nil, nil
|
||||
}
|
||||
return &groups[0], nil
|
||||
}
|
||||
|
||||
func queryManyGroupMembers(groupId string) ([]User, error) {
|
||||
query := "SELECT user.id, user.name, user.display_name, user.reference, user.is_admin, user.is_live FROM v_user AS user JOIN group_member ON group_member.user_id = user.id JOIN [group] ON [group].id = group_member.group_id WHERE [group].id = ? ORDER BY group_member.user_id"
|
||||
members, err := queryManyUsers(query, groupId)
|
||||
if err != nil {
|
||||
return members, fmt.Errorf("Failed to get members: %w", err)
|
||||
}
|
||||
|
|
@ -76,15 +71,16 @@ func queryForGroupMembers(groupId string) ([]User, error) {
|
|||
|
||||
func GetGroupByReference(reference string) (*Group, error) {
|
||||
query := "SELECT [group].id, [group].name, [group].reference FROM [group] WHERE [group].reference = ?"
|
||||
return queryForGroup(query, reference)
|
||||
return queryOneGroup(query, reference)
|
||||
}
|
||||
|
||||
func GetAllGroups() ([]Group, error) {
|
||||
query := "SELECT id, name, reference FROM [group];"
|
||||
return queryForGroups(query)
|
||||
return queryManyGroups(query)
|
||||
}
|
||||
|
||||
func CreateGroup(name string, reference string) (*Group, error) {
|
||||
name = normalize.Name(name)
|
||||
stmt := "INSERT INTO [group] (name, reference) VALUES (?, ?)"
|
||||
result, err := database.Exec(stmt, name, reference)
|
||||
if err != nil {
|
||||
|
|
@ -2,6 +2,7 @@ BEGIN TRANSACTION;
|
|||
CREATE TABLE IF NOT EXISTS "user" (
|
||||
"id" INTEGER NOT NULL UNIQUE,
|
||||
"name" TEXT NOT NULL UNIQUE,
|
||||
"display_name" TEXT NOT NULL UNIQUE,
|
||||
"reference" TEXT NOT NULL UNIQUE,
|
||||
"motto" TEXT NOT NULL DEFAULT "",
|
||||
"password_hash" TEXT NOT NULL,
|
||||
|
|
@ -3,16 +3,18 @@ package db
|
|||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"lishwist/normalize"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
Id string
|
||||
Name string
|
||||
Reference string
|
||||
IsAdmin bool
|
||||
IsLive bool
|
||||
Id string
|
||||
NormalName string
|
||||
Name string
|
||||
Reference string
|
||||
IsAdmin bool
|
||||
IsLive bool
|
||||
}
|
||||
|
||||
type Gift struct {
|
||||
|
|
@ -28,18 +30,7 @@ type Gift struct {
|
|||
CreatorName string `json:",omitempty"`
|
||||
}
|
||||
|
||||
func queryForUser(query string, args ...any) (*User, error) {
|
||||
var u User
|
||||
err := database.QueryRow(query, args...).Scan(&u.Id, &u.Name, &u.Reference, &u.IsAdmin, &u.IsLive)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
func queryForUsers(query string, args ...any) ([]User, error) {
|
||||
func queryManyUsers(query string, args ...any) ([]User, error) {
|
||||
rows, err := database.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -48,7 +39,7 @@ func queryForUsers(query string, args ...any) ([]User, error) {
|
|||
users := []User{}
|
||||
for rows.Next() {
|
||||
var u User
|
||||
err = rows.Scan(&u.Id, &u.Name, &u.Reference, &u.IsAdmin, &u.IsLive)
|
||||
err = rows.Scan(&u.Id, &u.NormalName, &u.Name, &u.Reference, &u.IsAdmin, &u.IsLive)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -61,29 +52,41 @@ func queryForUsers(query string, args ...any) ([]User, error) {
|
|||
return users, nil
|
||||
}
|
||||
|
||||
func queryOneUser(query string, args ...any) (*User, error) {
|
||||
users, err := queryManyUsers(query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(users) < 1 {
|
||||
return nil, nil
|
||||
}
|
||||
return &users[0], nil
|
||||
}
|
||||
|
||||
func GetAllUsers() ([]User, error) {
|
||||
stmt := "SELECT id, name, reference, is_admin, is_live FROM user"
|
||||
return queryForUsers(stmt)
|
||||
stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM user"
|
||||
return queryManyUsers(stmt)
|
||||
}
|
||||
|
||||
func GetUser(id string) (*User, error) {
|
||||
stmt := "SELECT id, name, reference, is_admin, is_live FROM v_user WHERE id = ?"
|
||||
return queryForUser(stmt, id)
|
||||
stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM v_user WHERE id = ?"
|
||||
return queryOneUser(stmt, id)
|
||||
}
|
||||
|
||||
func GetUserByName(username string) (*User, error) {
|
||||
stmt := "SELECT id, name, reference, is_admin, is_live FROM v_user WHERE name = ?"
|
||||
return queryForUser(stmt, username)
|
||||
username = normalize.Name(username)
|
||||
stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM v_user WHERE name = ?"
|
||||
return queryOneUser(stmt, username)
|
||||
}
|
||||
|
||||
func GetUserByReference(reference string) (*User, error) {
|
||||
stmt := "SELECT id, name, reference, is_admin, is_live FROM v_user WHERE reference = ?"
|
||||
return queryForUser(stmt, reference)
|
||||
stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM v_user WHERE reference = ?"
|
||||
return queryOneUser(stmt, reference)
|
||||
}
|
||||
|
||||
func GetAnyUserByReference(reference string) (*User, error) {
|
||||
stmt := "SELECT id, name, reference, is_admin, is_live FROM user WHERE reference = ?"
|
||||
return queryForUser(stmt, reference)
|
||||
stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM user WHERE reference = ?"
|
||||
return queryOneUser(stmt, reference)
|
||||
}
|
||||
|
||||
func (u *User) SetLive(setting bool) error {
|
||||
|
|
@ -96,13 +99,14 @@ func (u *User) SetLive(setting bool) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func CreateUser(username string, passHash []byte) (*User, error) {
|
||||
stmt := "INSERT INTO user (name, reference, password_hash) VALUES (?, ?, ?)"
|
||||
func CreateUser(name string, passHash []byte) (*User, error) {
|
||||
username := normalize.Name(name)
|
||||
stmt := "INSERT INTO user (name, display_name, reference, password_hash) VALUES (?, ?, ?, ?)"
|
||||
reference, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result, err := database.Exec(stmt, username, reference, passHash)
|
||||
result, err := database.Exec(stmt, username, name, reference, passHash)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -112,7 +116,7 @@ func CreateUser(username string, passHash []byte) (*User, error) {
|
|||
}
|
||||
user := User{
|
||||
Id: fmt.Sprintf("%d", id),
|
||||
Name: username,
|
||||
Name: name,
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
|
@ -406,10 +410,10 @@ func (u *User) AddGiftToUser(otherUserReference string, giftName string) error {
|
|||
|
||||
func (u *User) GetGroups() ([]Group, error) {
|
||||
stmt := "SELECT [group].id, [group].name, [group].reference FROM [group] JOIN group_member ON group_member.group_id = [group].id JOIN v_user AS user ON user.id = group_member.user_id WHERE user.id = ?"
|
||||
return queryForGroups(stmt, u.Id)
|
||||
return queryManyGroups(stmt, u.Id)
|
||||
}
|
||||
|
||||
func (u *User) GetGroupByReference(reference string) (*Group, error) {
|
||||
stmt := "SELECT [group].id, [group].name, [group].reference FROM [group] JOIN group_member ON [group].id == group_member.group_id WHERE [group].reference = ? AND group_member.user_id = ?"
|
||||
return queryForGroup(stmt, reference, u.Id)
|
||||
return queryOneGroup(stmt, reference, u.Id)
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"lishwist/db"
|
||||
"lishwist/api/db"
|
||||
"lishwist/templates"
|
||||
"log"
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ package api
|
|||
import (
|
||||
"log"
|
||||
|
||||
"lishwist/db"
|
||||
"lishwist/api/db"
|
||||
"lishwist/templates"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
|
@ -77,7 +77,7 @@ func Register(username, newPassword, confirmPassword string) *RegisterProps {
|
|||
|
||||
existingUser, _ := db.GetUserByName(username)
|
||||
if existingUser != nil {
|
||||
log.Printf("Username is taken: %q\n", username)
|
||||
log.Printf("Username is taken: %q\n", existingUser.NormalName)
|
||||
props.Username.Error = "Username is taken"
|
||||
return props
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ func GuaranteeEnv(key string) (variable string) {
|
|||
return
|
||||
}
|
||||
|
||||
var DatabaseFile = GuaranteeEnv("LISHWIST_DATABASE_FILE")
|
||||
var SessionSecret = GuaranteeEnv("LISHWIST_SESSION_SECRET")
|
||||
var HostRootUrl = GuaranteeEnv("LISHWIST_HOST_ROOT_URL")
|
||||
var HostPort = os.Getenv("LISHWIST_HOST_PORT")
|
||||
|
|
|
|||
|
|
@ -6,7 +6,8 @@ import (
|
|||
"net/http"
|
||||
|
||||
"lishwist/api"
|
||||
"lishwist/db"
|
||||
// TODO: lishwist/api/db ought not to be used outside lishwist/api
|
||||
"lishwist/api/db"
|
||||
"lishwist/env"
|
||||
"lishwist/router"
|
||||
"lishwist/routing"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,14 @@
|
|||
package normalize
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func Trim(s string) string {
|
||||
return strings.Trim(s, " \t")
|
||||
}
|
||||
|
||||
func Name(name string) string {
|
||||
name = Trim(name)
|
||||
return strings.ToLower(name)
|
||||
}
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
package routing
|
||||
|
||||
import (
|
||||
"lishwist/db"
|
||||
"lishwist/api/db"
|
||||
"lishwist/rsvp"
|
||||
"net/http"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
package routing
|
||||
|
||||
import (
|
||||
"lishwist/db"
|
||||
"lishwist/api/db"
|
||||
"lishwist/rsvp"
|
||||
"net/http"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import (
|
|||
"net/http"
|
||||
"slices"
|
||||
|
||||
"lishwist/db"
|
||||
"lishwist/api/db"
|
||||
"lishwist/rsvp"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ package routing
|
|||
import (
|
||||
"net/http"
|
||||
|
||||
"lishwist/db"
|
||||
"lishwist/api/db"
|
||||
"lishwist/env"
|
||||
"lishwist/rsvp"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ func Login(h http.Header, r *rsvp.Request) rsvp.Response {
|
|||
|
||||
props := api.NewLoginProps("", "")
|
||||
|
||||
flash := session.FlashGet("login_props")
|
||||
flash := session.FlashGet()
|
||||
flashProps, ok := flash.(*api.LoginProps)
|
||||
if ok {
|
||||
props.Username.Value = flashProps.Username.Value
|
||||
|
|
@ -21,7 +21,7 @@ func Login(h http.Header, r *rsvp.Request) rsvp.Response {
|
|||
props.Password.Error = flashProps.Password.Error
|
||||
}
|
||||
|
||||
flash = session.FlashGet("successful_registration")
|
||||
flash = session.FlashGet()
|
||||
successfulReg, _ := flash.(bool)
|
||||
if successfulReg {
|
||||
props.SuccessfulRegistration = true
|
||||
|
|
@ -39,7 +39,7 @@ func LoginPost(h http.Header, r *rsvp.Request) rsvp.Response {
|
|||
|
||||
props := api.Login(username, password)
|
||||
if props != nil {
|
||||
session.FlashSet(&props, "login_props")
|
||||
session.FlashSet(&props)
|
||||
return rsvp.SeeOther("/").SaveSession(session)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ func Register(h http.Header, r *rsvp.Request) rsvp.Response {
|
|||
props := api.NewRegisterProps("", "", "")
|
||||
|
||||
session := r.GetSession()
|
||||
flash := session.FlashGet("register_props")
|
||||
flash := session.FlashGet()
|
||||
|
||||
flashProps, _ := flash.(*api.RegisterProps)
|
||||
if flashProps != nil {
|
||||
|
|
@ -36,10 +36,10 @@ func RegisterPost(h http.Header, r *rsvp.Request) rsvp.Response {
|
|||
s := r.GetSession()
|
||||
|
||||
if props != nil {
|
||||
s.FlashSet(&props, "register_props")
|
||||
s.FlashSet(&props)
|
||||
return rsvp.SeeOther("/register").SaveSession(s)
|
||||
}
|
||||
|
||||
s.FlashSet(true, "successful_registration")
|
||||
s.FlashSet(true)
|
||||
return rsvp.SeeOther("/").SaveSession(s)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
package routing
|
||||
|
||||
import (
|
||||
"lishwist/db"
|
||||
"lishwist/api/db"
|
||||
"lishwist/rsvp"
|
||||
"net/http"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
package routing
|
||||
|
||||
import (
|
||||
"lishwist/db"
|
||||
"lishwist/api/db"
|
||||
"lishwist/rsvp"
|
||||
"net/http"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
package routing
|
||||
|
||||
import (
|
||||
"lishwist/db"
|
||||
"lishwist/api/db"
|
||||
"lishwist/rsvp"
|
||||
"net/http"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -29,6 +29,13 @@ func (res *Response) Write(w http.ResponseWriter, r *http.Request) error {
|
|||
|
||||
if res.SeeOther != "" {
|
||||
http.Redirect(w, r, res.SeeOther, http.StatusSeeOther)
|
||||
flash := res.Session.FlashPeek()
|
||||
if flash != nil {
|
||||
err := json.NewEncoder(w).Encode(flash)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ type Session struct {
|
|||
inner *sessions.Session
|
||||
}
|
||||
|
||||
func (s *Session) FlashGet(key ...string) any {
|
||||
list := s.inner.Flashes(key...)
|
||||
func (s *Session) FlashGet() any {
|
||||
list := s.inner.Flashes()
|
||||
if len(list) < 1 {
|
||||
return nil
|
||||
} else {
|
||||
|
|
@ -17,8 +17,17 @@ func (s *Session) FlashGet(key ...string) any {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Session) FlashSet(value any, key ...string) {
|
||||
s.inner.AddFlash(value, key...)
|
||||
func (s *Session) FlashPeek() any {
|
||||
list, _ := s.inner.Values["_flash"].([]any)
|
||||
if len(list) < 1 {
|
||||
return nil
|
||||
} else {
|
||||
return list[0]
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Session) FlashSet(value any) {
|
||||
s.inner.AddFlash(value)
|
||||
}
|
||||
|
||||
func (s *Session) SetID(value string) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue