Merge pull request 'Changes made while flying Tokyo -> Auckland' (#11) from tokyo-wellington-changes into main
Reviewed-on: #11
This commit is contained in:
commit
fbc6c9ca20
|
|
@ -1,6 +1,6 @@
|
||||||
.DS_Store
|
.DS_Store
|
||||||
gin-bin
|
gin-bin
|
||||||
lishwist.db
|
*lishwist.db
|
||||||
.env*.local
|
.env*.local
|
||||||
server/db/init_sql.go
|
server/api/db/init_sql.go
|
||||||
.ignored/
|
.ignored/
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ import (
|
||||||
var database *sql.DB
|
var database *sql.DB
|
||||||
|
|
||||||
func Open() error {
|
func Open() error {
|
||||||
db, err := sql.Open("sqlite", "./lishwist.db")
|
db, err := sql.Open("sqlite", env.DatabaseFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -23,7 +23,7 @@ func Open() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Init() error {
|
func Init() error {
|
||||||
_, err := database.Exec(InitQuery)
|
_, err := database.Exec(initQuery)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
@ -11,7 +11,7 @@ import (
|
||||||
var initTemplate = template.Must(template.New("").Parse("// Code generated DO NOT EDIT.\n" +
|
var initTemplate = template.Must(template.New("").Parse("// Code generated DO NOT EDIT.\n" +
|
||||||
"package db\n" +
|
"package db\n" +
|
||||||
"\n" +
|
"\n" +
|
||||||
"const InitQuery = `{{.}}`\n",
|
"const initQuery = `{{.}}`\n",
|
||||||
))
|
))
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"lishwist/normalize"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -22,52 +22,47 @@ func (g *Group) MemberIndex(userId string) int {
|
||||||
return -1
|
return -1
|
||||||
}
|
}
|
||||||
|
|
||||||
func queryForGroup(query string, args ...any) (*Group, error) {
|
func queryManyGroups(query string, args ...any) ([]Group, error) {
|
||||||
var group Group
|
|
||||||
err := database.QueryRow(query, args...).Scan(&group.Id, &group.Name, &group.Reference)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
return nil, nil
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
members, err := queryForGroupMembers(group.Id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
group.Members = members
|
|
||||||
return &group, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func queryForGroups(query string, args ...any) ([]Group, error) {
|
|
||||||
groups := []Group{}
|
groups := []Group{}
|
||||||
rows, err := database.Query(query, args...)
|
rows, err := database.Query(query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return groups, fmt.Errorf("Query failed: %w", err)
|
return nil, fmt.Errorf("Query failed: %w", err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var group Group
|
var group Group
|
||||||
err := rows.Scan(&group.Id, &group.Name, &group.Reference)
|
err := rows.Scan(&group.Id, &group.Name, &group.Reference)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return groups, fmt.Errorf("Failed to scan row: %w", err)
|
return nil, fmt.Errorf("Failed to scan row: %w", err)
|
||||||
}
|
}
|
||||||
members, err := queryForGroupMembers(group.Id)
|
members, err := queryManyGroupMembers(group.Id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return groups, fmt.Errorf("Failed to query for group members: %w", err)
|
return nil, fmt.Errorf("Failed to query for group members: %w", err)
|
||||||
}
|
}
|
||||||
group.Members = members
|
group.Members = members
|
||||||
groups = append(groups, group)
|
groups = append(groups, group)
|
||||||
}
|
}
|
||||||
err = rows.Err()
|
err = rows.Err()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return groups, fmt.Errorf("Rows error: %w", err)
|
return nil, fmt.Errorf("Rows error: %w", err)
|
||||||
}
|
}
|
||||||
return groups, nil
|
return groups, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func queryForGroupMembers(groupId string) ([]User, error) {
|
func queryOneGroup(query string, args ...any) (*Group, error) {
|
||||||
query := "SELECT user.id, user.name, user.reference, user.is_admin, user.is_live FROM v_user AS user JOIN group_member ON group_member.user_id = user.id JOIN [group] ON [group].id = group_member.group_id WHERE [group].id = ? ORDER BY group_member.user_id"
|
groups, err := queryManyGroups(query, args...)
|
||||||
members, err := queryForUsers(query, groupId)
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(groups) < 1 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return &groups[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func queryManyGroupMembers(groupId string) ([]User, error) {
|
||||||
|
query := "SELECT user.id, user.name, user.display_name, user.reference, user.is_admin, user.is_live FROM v_user AS user JOIN group_member ON group_member.user_id = user.id JOIN [group] ON [group].id = group_member.group_id WHERE [group].id = ? ORDER BY group_member.user_id"
|
||||||
|
members, err := queryManyUsers(query, groupId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return members, fmt.Errorf("Failed to get members: %w", err)
|
return members, fmt.Errorf("Failed to get members: %w", err)
|
||||||
}
|
}
|
||||||
|
|
@ -76,15 +71,16 @@ func queryForGroupMembers(groupId string) ([]User, error) {
|
||||||
|
|
||||||
func GetGroupByReference(reference string) (*Group, error) {
|
func GetGroupByReference(reference string) (*Group, error) {
|
||||||
query := "SELECT [group].id, [group].name, [group].reference FROM [group] WHERE [group].reference = ?"
|
query := "SELECT [group].id, [group].name, [group].reference FROM [group] WHERE [group].reference = ?"
|
||||||
return queryForGroup(query, reference)
|
return queryOneGroup(query, reference)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAllGroups() ([]Group, error) {
|
func GetAllGroups() ([]Group, error) {
|
||||||
query := "SELECT id, name, reference FROM [group];"
|
query := "SELECT id, name, reference FROM [group];"
|
||||||
return queryForGroups(query)
|
return queryManyGroups(query)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateGroup(name string, reference string) (*Group, error) {
|
func CreateGroup(name string, reference string) (*Group, error) {
|
||||||
|
name = normalize.Name(name)
|
||||||
stmt := "INSERT INTO [group] (name, reference) VALUES (?, ?)"
|
stmt := "INSERT INTO [group] (name, reference) VALUES (?, ?)"
|
||||||
result, err := database.Exec(stmt, name, reference)
|
result, err := database.Exec(stmt, name, reference)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -2,6 +2,7 @@ BEGIN TRANSACTION;
|
||||||
CREATE TABLE IF NOT EXISTS "user" (
|
CREATE TABLE IF NOT EXISTS "user" (
|
||||||
"id" INTEGER NOT NULL UNIQUE,
|
"id" INTEGER NOT NULL UNIQUE,
|
||||||
"name" TEXT NOT NULL UNIQUE,
|
"name" TEXT NOT NULL UNIQUE,
|
||||||
|
"display_name" TEXT NOT NULL UNIQUE,
|
||||||
"reference" TEXT NOT NULL UNIQUE,
|
"reference" TEXT NOT NULL UNIQUE,
|
||||||
"motto" TEXT NOT NULL DEFAULT "",
|
"motto" TEXT NOT NULL DEFAULT "",
|
||||||
"password_hash" TEXT NOT NULL,
|
"password_hash" TEXT NOT NULL,
|
||||||
|
|
@ -3,16 +3,18 @@ package db
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"lishwist/normalize"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
Id string
|
Id string
|
||||||
Name string
|
NormalName string
|
||||||
Reference string
|
Name string
|
||||||
IsAdmin bool
|
Reference string
|
||||||
IsLive bool
|
IsAdmin bool
|
||||||
|
IsLive bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type Gift struct {
|
type Gift struct {
|
||||||
|
|
@ -28,18 +30,7 @@ type Gift struct {
|
||||||
CreatorName string `json:",omitempty"`
|
CreatorName string `json:",omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func queryForUser(query string, args ...any) (*User, error) {
|
func queryManyUsers(query string, args ...any) ([]User, error) {
|
||||||
var u User
|
|
||||||
err := database.QueryRow(query, args...).Scan(&u.Id, &u.Name, &u.Reference, &u.IsAdmin, &u.IsLive)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
return nil, nil
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &u, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func queryForUsers(query string, args ...any) ([]User, error) {
|
|
||||||
rows, err := database.Query(query, args...)
|
rows, err := database.Query(query, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
@ -48,7 +39,7 @@ func queryForUsers(query string, args ...any) ([]User, error) {
|
||||||
users := []User{}
|
users := []User{}
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var u User
|
var u User
|
||||||
err = rows.Scan(&u.Id, &u.Name, &u.Reference, &u.IsAdmin, &u.IsLive)
|
err = rows.Scan(&u.Id, &u.NormalName, &u.Name, &u.Reference, &u.IsAdmin, &u.IsLive)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -61,29 +52,41 @@ func queryForUsers(query string, args ...any) ([]User, error) {
|
||||||
return users, nil
|
return users, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func queryOneUser(query string, args ...any) (*User, error) {
|
||||||
|
users, err := queryManyUsers(query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(users) < 1 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return &users[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
func GetAllUsers() ([]User, error) {
|
func GetAllUsers() ([]User, error) {
|
||||||
stmt := "SELECT id, name, reference, is_admin, is_live FROM user"
|
stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM user"
|
||||||
return queryForUsers(stmt)
|
return queryManyUsers(stmt)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUser(id string) (*User, error) {
|
func GetUser(id string) (*User, error) {
|
||||||
stmt := "SELECT id, name, reference, is_admin, is_live FROM v_user WHERE id = ?"
|
stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM v_user WHERE id = ?"
|
||||||
return queryForUser(stmt, id)
|
return queryOneUser(stmt, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserByName(username string) (*User, error) {
|
func GetUserByName(username string) (*User, error) {
|
||||||
stmt := "SELECT id, name, reference, is_admin, is_live FROM v_user WHERE name = ?"
|
username = normalize.Name(username)
|
||||||
return queryForUser(stmt, username)
|
stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM v_user WHERE name = ?"
|
||||||
|
return queryOneUser(stmt, username)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetUserByReference(reference string) (*User, error) {
|
func GetUserByReference(reference string) (*User, error) {
|
||||||
stmt := "SELECT id, name, reference, is_admin, is_live FROM v_user WHERE reference = ?"
|
stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM v_user WHERE reference = ?"
|
||||||
return queryForUser(stmt, reference)
|
return queryOneUser(stmt, reference)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAnyUserByReference(reference string) (*User, error) {
|
func GetAnyUserByReference(reference string) (*User, error) {
|
||||||
stmt := "SELECT id, name, reference, is_admin, is_live FROM user WHERE reference = ?"
|
stmt := "SELECT id, name, display_name, reference, is_admin, is_live FROM user WHERE reference = ?"
|
||||||
return queryForUser(stmt, reference)
|
return queryOneUser(stmt, reference)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) SetLive(setting bool) error {
|
func (u *User) SetLive(setting bool) error {
|
||||||
|
|
@ -96,13 +99,14 @@ func (u *User) SetLive(setting bool) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateUser(username string, passHash []byte) (*User, error) {
|
func CreateUser(name string, passHash []byte) (*User, error) {
|
||||||
stmt := "INSERT INTO user (name, reference, password_hash) VALUES (?, ?, ?)"
|
username := normalize.Name(name)
|
||||||
|
stmt := "INSERT INTO user (name, display_name, reference, password_hash) VALUES (?, ?, ?, ?)"
|
||||||
reference, err := uuid.NewRandom()
|
reference, err := uuid.NewRandom()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
result, err := database.Exec(stmt, username, reference, passHash)
|
result, err := database.Exec(stmt, username, name, reference, passHash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -112,7 +116,7 @@ func CreateUser(username string, passHash []byte) (*User, error) {
|
||||||
}
|
}
|
||||||
user := User{
|
user := User{
|
||||||
Id: fmt.Sprintf("%d", id),
|
Id: fmt.Sprintf("%d", id),
|
||||||
Name: username,
|
Name: name,
|
||||||
}
|
}
|
||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
@ -406,10 +410,10 @@ func (u *User) AddGiftToUser(otherUserReference string, giftName string) error {
|
||||||
|
|
||||||
func (u *User) GetGroups() ([]Group, error) {
|
func (u *User) GetGroups() ([]Group, error) {
|
||||||
stmt := "SELECT [group].id, [group].name, [group].reference FROM [group] JOIN group_member ON group_member.group_id = [group].id JOIN v_user AS user ON user.id = group_member.user_id WHERE user.id = ?"
|
stmt := "SELECT [group].id, [group].name, [group].reference FROM [group] JOIN group_member ON group_member.group_id = [group].id JOIN v_user AS user ON user.id = group_member.user_id WHERE user.id = ?"
|
||||||
return queryForGroups(stmt, u.Id)
|
return queryManyGroups(stmt, u.Id)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *User) GetGroupByReference(reference string) (*Group, error) {
|
func (u *User) GetGroupByReference(reference string) (*Group, error) {
|
||||||
stmt := "SELECT [group].id, [group].name, [group].reference FROM [group] JOIN group_member ON [group].id == group_member.group_id WHERE [group].reference = ? AND group_member.user_id = ?"
|
stmt := "SELECT [group].id, [group].name, [group].reference FROM [group] JOIN group_member ON [group].id == group_member.group_id WHERE [group].reference = ? AND group_member.user_id = ?"
|
||||||
return queryForGroup(stmt, reference, u.Id)
|
return queryOneGroup(stmt, reference, u.Id)
|
||||||
}
|
}
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"lishwist/db"
|
"lishwist/api/db"
|
||||||
"lishwist/templates"
|
"lishwist/templates"
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ package api
|
||||||
import (
|
import (
|
||||||
"log"
|
"log"
|
||||||
|
|
||||||
"lishwist/db"
|
"lishwist/api/db"
|
||||||
"lishwist/templates"
|
"lishwist/templates"
|
||||||
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
|
@ -77,7 +77,7 @@ func Register(username, newPassword, confirmPassword string) *RegisterProps {
|
||||||
|
|
||||||
existingUser, _ := db.GetUserByName(username)
|
existingUser, _ := db.GetUserByName(username)
|
||||||
if existingUser != nil {
|
if existingUser != nil {
|
||||||
log.Printf("Username is taken: %q\n", username)
|
log.Printf("Username is taken: %q\n", existingUser.NormalName)
|
||||||
props.Username.Error = "Username is taken"
|
props.Username.Error = "Username is taken"
|
||||||
return props
|
return props
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ func GuaranteeEnv(key string) (variable string) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var DatabaseFile = GuaranteeEnv("LISHWIST_DATABASE_FILE")
|
||||||
var SessionSecret = GuaranteeEnv("LISHWIST_SESSION_SECRET")
|
var SessionSecret = GuaranteeEnv("LISHWIST_SESSION_SECRET")
|
||||||
var HostRootUrl = GuaranteeEnv("LISHWIST_HOST_ROOT_URL")
|
var HostRootUrl = GuaranteeEnv("LISHWIST_HOST_ROOT_URL")
|
||||||
var HostPort = os.Getenv("LISHWIST_HOST_PORT")
|
var HostPort = os.Getenv("LISHWIST_HOST_PORT")
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,8 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"lishwist/api"
|
"lishwist/api"
|
||||||
"lishwist/db"
|
// TODO: lishwist/api/db ought not to be used outside lishwist/api
|
||||||
|
"lishwist/api/db"
|
||||||
"lishwist/env"
|
"lishwist/env"
|
||||||
"lishwist/router"
|
"lishwist/router"
|
||||||
"lishwist/routing"
|
"lishwist/routing"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
package normalize
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Trim(s string) string {
|
||||||
|
return strings.Trim(s, " \t")
|
||||||
|
}
|
||||||
|
|
||||||
|
func Name(name string) string {
|
||||||
|
name = Trim(name)
|
||||||
|
return strings.ToLower(name)
|
||||||
|
}
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
package routing
|
package routing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"lishwist/db"
|
"lishwist/api/db"
|
||||||
"lishwist/rsvp"
|
"lishwist/rsvp"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
package routing
|
package routing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"lishwist/db"
|
"lishwist/api/db"
|
||||||
"lishwist/rsvp"
|
"lishwist/rsvp"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"slices"
|
"slices"
|
||||||
|
|
||||||
"lishwist/db"
|
"lishwist/api/db"
|
||||||
"lishwist/rsvp"
|
"lishwist/rsvp"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ package routing
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"lishwist/db"
|
"lishwist/api/db"
|
||||||
"lishwist/env"
|
"lishwist/env"
|
||||||
"lishwist/rsvp"
|
"lishwist/rsvp"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ func Login(h http.Header, r *rsvp.Request) rsvp.Response {
|
||||||
|
|
||||||
props := api.NewLoginProps("", "")
|
props := api.NewLoginProps("", "")
|
||||||
|
|
||||||
flash := session.FlashGet("login_props")
|
flash := session.FlashGet()
|
||||||
flashProps, ok := flash.(*api.LoginProps)
|
flashProps, ok := flash.(*api.LoginProps)
|
||||||
if ok {
|
if ok {
|
||||||
props.Username.Value = flashProps.Username.Value
|
props.Username.Value = flashProps.Username.Value
|
||||||
|
|
@ -21,7 +21,7 @@ func Login(h http.Header, r *rsvp.Request) rsvp.Response {
|
||||||
props.Password.Error = flashProps.Password.Error
|
props.Password.Error = flashProps.Password.Error
|
||||||
}
|
}
|
||||||
|
|
||||||
flash = session.FlashGet("successful_registration")
|
flash = session.FlashGet()
|
||||||
successfulReg, _ := flash.(bool)
|
successfulReg, _ := flash.(bool)
|
||||||
if successfulReg {
|
if successfulReg {
|
||||||
props.SuccessfulRegistration = true
|
props.SuccessfulRegistration = true
|
||||||
|
|
@ -39,7 +39,7 @@ func LoginPost(h http.Header, r *rsvp.Request) rsvp.Response {
|
||||||
|
|
||||||
props := api.Login(username, password)
|
props := api.Login(username, password)
|
||||||
if props != nil {
|
if props != nil {
|
||||||
session.FlashSet(&props, "login_props")
|
session.FlashSet(&props)
|
||||||
return rsvp.SeeOther("/").SaveSession(session)
|
return rsvp.SeeOther("/").SaveSession(session)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ func Register(h http.Header, r *rsvp.Request) rsvp.Response {
|
||||||
props := api.NewRegisterProps("", "", "")
|
props := api.NewRegisterProps("", "", "")
|
||||||
|
|
||||||
session := r.GetSession()
|
session := r.GetSession()
|
||||||
flash := session.FlashGet("register_props")
|
flash := session.FlashGet()
|
||||||
|
|
||||||
flashProps, _ := flash.(*api.RegisterProps)
|
flashProps, _ := flash.(*api.RegisterProps)
|
||||||
if flashProps != nil {
|
if flashProps != nil {
|
||||||
|
|
@ -36,10 +36,10 @@ func RegisterPost(h http.Header, r *rsvp.Request) rsvp.Response {
|
||||||
s := r.GetSession()
|
s := r.GetSession()
|
||||||
|
|
||||||
if props != nil {
|
if props != nil {
|
||||||
s.FlashSet(&props, "register_props")
|
s.FlashSet(&props)
|
||||||
return rsvp.SeeOther("/register").SaveSession(s)
|
return rsvp.SeeOther("/register").SaveSession(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.FlashSet(true, "successful_registration")
|
s.FlashSet(true)
|
||||||
return rsvp.SeeOther("/").SaveSession(s)
|
return rsvp.SeeOther("/").SaveSession(s)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
package routing
|
package routing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"lishwist/db"
|
"lishwist/api/db"
|
||||||
"lishwist/rsvp"
|
"lishwist/rsvp"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
package routing
|
package routing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"lishwist/db"
|
"lishwist/api/db"
|
||||||
"lishwist/rsvp"
|
"lishwist/rsvp"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
package routing
|
package routing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"lishwist/db"
|
"lishwist/api/db"
|
||||||
"lishwist/rsvp"
|
"lishwist/rsvp"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,13 @@ func (res *Response) Write(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
|
||||||
if res.SeeOther != "" {
|
if res.SeeOther != "" {
|
||||||
http.Redirect(w, r, res.SeeOther, http.StatusSeeOther)
|
http.Redirect(w, r, res.SeeOther, http.StatusSeeOther)
|
||||||
|
flash := res.Session.FlashPeek()
|
||||||
|
if flash != nil {
|
||||||
|
err := json.NewEncoder(w).Encode(flash)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,8 @@ type Session struct {
|
||||||
inner *sessions.Session
|
inner *sessions.Session
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) FlashGet(key ...string) any {
|
func (s *Session) FlashGet() any {
|
||||||
list := s.inner.Flashes(key...)
|
list := s.inner.Flashes()
|
||||||
if len(list) < 1 {
|
if len(list) < 1 {
|
||||||
return nil
|
return nil
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -17,8 +17,17 @@ func (s *Session) FlashGet(key ...string) any {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) FlashSet(value any, key ...string) {
|
func (s *Session) FlashPeek() any {
|
||||||
s.inner.AddFlash(value, key...)
|
list, _ := s.inner.Values["_flash"].([]any)
|
||||||
|
if len(list) < 1 {
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
return list[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Session) FlashSet(value any) {
|
||||||
|
s.inner.AddFlash(value)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Session) SetID(value string) {
|
func (s *Session) SetID(value string) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue