diff --git a/server/routing/login.go b/server/routing/login.go index e296ddb..dffb9b1 100644 --- a/server/routing/login.go +++ b/server/routing/login.go @@ -11,7 +11,7 @@ func Login(h http.Header, r *rsvp.Request) rsvp.Response { props := api.NewLoginProps("", "") - flash := session.FlashGet("login_props") + flash := session.FlashGet() flashProps, ok := flash.(*api.LoginProps) if ok { props.Username.Value = flashProps.Username.Value @@ -21,7 +21,7 @@ func Login(h http.Header, r *rsvp.Request) rsvp.Response { props.Password.Error = flashProps.Password.Error } - flash = session.FlashGet("successful_registration") + flash = session.FlashGet() successfulReg, _ := flash.(bool) if successfulReg { props.SuccessfulRegistration = true @@ -39,7 +39,7 @@ func LoginPost(h http.Header, r *rsvp.Request) rsvp.Response { props := api.Login(username, password) if props != nil { - session.FlashSet(&props, "login_props") + session.FlashSet(&props) return rsvp.SeeOther("/").SaveSession(session) } diff --git a/server/routing/register.go b/server/routing/register.go index 33aded9..c9118b3 100644 --- a/server/routing/register.go +++ b/server/routing/register.go @@ -10,7 +10,7 @@ func Register(h http.Header, r *rsvp.Request) rsvp.Response { props := api.NewRegisterProps("", "", "") session := r.GetSession() - flash := session.FlashGet("register_props") + flash := session.FlashGet() flashProps, _ := flash.(*api.RegisterProps) if flashProps != nil { @@ -36,10 +36,10 @@ func RegisterPost(h http.Header, r *rsvp.Request) rsvp.Response { s := r.GetSession() if props != nil { - s.FlashSet(&props, "register_props") + s.FlashSet(&props) return rsvp.SeeOther("/register").SaveSession(s) } - s.FlashSet(true, "successful_registration") + s.FlashSet(true) return rsvp.SeeOther("/").SaveSession(s) } diff --git a/server/rsvp/response.go b/server/rsvp/response.go index 921e79e..d5653db 100644 --- a/server/rsvp/response.go +++ b/server/rsvp/response.go @@ -29,6 +29,13 @@ func (res *Response) Write(w http.ResponseWriter, r *http.Request) error { if res.SeeOther != "" { http.Redirect(w, r, res.SeeOther, http.StatusSeeOther) + flash := res.Session.FlashPeek() + if flash != nil { + err := json.NewEncoder(w).Encode(flash) + if err != nil { + return err + } + } return nil } diff --git a/server/rsvp/session.go b/server/rsvp/session.go index 9761766..4389d5a 100644 --- a/server/rsvp/session.go +++ b/server/rsvp/session.go @@ -8,8 +8,8 @@ type Session struct { inner *sessions.Session } -func (s *Session) FlashGet(key ...string) any { - list := s.inner.Flashes(key...) +func (s *Session) FlashGet() any { + list := s.inner.Flashes() if len(list) < 1 { return nil } else { @@ -17,8 +17,17 @@ func (s *Session) FlashGet(key ...string) any { } } -func (s *Session) FlashSet(value any, key ...string) { - s.inner.AddFlash(value, key...) +func (s *Session) FlashPeek() any { + list, _ := s.inner.Values["_flash"].([]any) + if len(list) < 1 { + return nil + } else { + return list[0] + } +} + +func (s *Session) FlashSet(value any) { + s.inner.AddFlash(value) } func (s *Session) SetID(value string) {