package session import ( "fmt" "net/http" "github.com/gorilla/securecookie" "github.com/gorilla/sessions" ) type Callbacks struct { Delete func(id string) error Insert func(encodedValues string) (string, error) Select func(id string) (string, error) Update func(id, encodedValues string) error } type Store struct { callbacks Callbacks Codecs []securecookie.Codec Options *sessions.Options } func NewGenericStore(cb Callbacks, keyPairs ...[]byte) *Store { return &Store{ callbacks: cb, Codecs: securecookie.CodecsFromPairs(keyPairs...), Options: &sessions.Options{}, } } // Get should return a cached session. func (m *Store) Get(r *http.Request, name string) (*sessions.Session, error) { return sessions.GetRegistry(r).Get(m, name) } // New should create and return a new session. // // Note that New should never return a nil session, even in the case of // an error if using the Registry infrastructure to cache the session. func (s *Store) New(r *http.Request, name string) (*sessions.Session, error) { session := sessions.NewSession(s, name) opts := *s.Options session.Options = &opts session.IsNew = true var err error c, errCookie := r.Cookie(name) if errCookie != nil { return session, nil } err = securecookie.DecodeMulti(name, c.Value, &session.ID, s.Codecs...) if err != nil { return session, fmt.Errorf("failed to decode session id: %w", err) } sessionValue, err := s.callbacks.Select(session.ID) if err != nil { return session, fmt.Errorf("failed to get session value: %w", err) } err = securecookie.DecodeMulti(name, string(sessionValue), &session.Values, s.Codecs...) if err == nil { session.IsNew = false } else { err = fmt.Errorf("failed to decode session values: %w", err) } return session, err } // Save should persist session to the underlying store implementation. func (s *Store) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error { // Delete if max-age is <= 0 if session.Options.MaxAge <= 0 { err := s.callbacks.Delete(session.ID) if err != nil { return fmt.Errorf("failed to delete session: %w", err) } http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options)) return nil } encodedValues, err := securecookie.EncodeMulti(session.Name(), session.Values, s.Codecs...) if err != nil { return fmt.Errorf("failed to encode cookie value: %w", err) } if session.ID == "" { i, err := s.callbacks.Insert(encodedValues) if err != nil { return fmt.Errorf("failed to insert session: %w", err) } session.ID = i } else { err := s.callbacks.Update(session.ID, encodedValues) if err != nil { return fmt.Errorf("failed to update session: %w", err) } } encodedId, err := securecookie.EncodeMulti(session.Name(), session.ID, s.Codecs...) if err != nil { return fmt.Errorf("failed to encode cookie value: %w", err) } http.SetCookie(w, sessions.NewCookie(session.Name(), encodedId, session.Options)) return nil }