193 lines
5.4 KiB
Go
193 lines
5.4 KiB
Go
|
// Copyright 2021 Flamego. All rights reserved.
|
||
|
// Use of this source code is governed by a MIT-style
|
||
|
// license that can be found in the LICENSE file.
|
||
|
|
||
|
package session
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"net/http"
|
||
|
"reflect"
|
||
|
"time"
|
||
|
|
||
|
"github.com/pkg/errors"
|
||
|
|
||
|
"github.com/flamego/flamego"
|
||
|
)
|
||
|
|
||
|
// Session is a session for the current request.
|
||
|
type Session interface {
|
||
|
// ID returns the session ID.
|
||
|
ID() string
|
||
|
// Get returns the value of given key in the session. It returns nil if no such
|
||
|
// key exists.
|
||
|
Get(key interface{}) interface{}
|
||
|
// Set sets the value of given key in the session.
|
||
|
Set(key, val interface{})
|
||
|
// SetFlash sets the flash to be the given value in the session.
|
||
|
SetFlash(val interface{})
|
||
|
// Delete deletes a key from the session.
|
||
|
Delete(key interface{})
|
||
|
// Flush wipes out all existing data in the session.
|
||
|
Flush()
|
||
|
// Encode encodes session data to binary.
|
||
|
Encode() ([]byte, error)
|
||
|
}
|
||
|
|
||
|
// CookieOptions contains options for setting HTTP cookies.
|
||
|
type CookieOptions struct {
|
||
|
// Name is the name of the cookie. Default is "flamego_session".
|
||
|
Name string
|
||
|
// Path is the Path attribute of the cookie. Default is "/".
|
||
|
Path string
|
||
|
// Domain is the Domain attribute of the cookie. Default is not set.
|
||
|
Domain string
|
||
|
// MaxAge is the MaxAge attribute of the cookie. Default is not set.
|
||
|
MaxAge int
|
||
|
// Secure specifies whether to set Secure for the cookie.
|
||
|
Secure bool
|
||
|
// HTTPOnly specifies whether to set HTTPOnly for the cookie.
|
||
|
HTTPOnly bool
|
||
|
// SameSite is the SameSite attribute of the cookie. Default is
|
||
|
// http.SameSiteLaxMode.
|
||
|
SameSite http.SameSite
|
||
|
}
|
||
|
|
||
|
// Options contains options for the session.Sessioner middleware.
|
||
|
type Options struct {
|
||
|
// Initer is the initialization function of the session store. Default is
|
||
|
// session.MemoryIniter.
|
||
|
Initer Initer
|
||
|
// Config is the configuration object to be passed to the Initer for the session
|
||
|
// store.
|
||
|
Config interface{}
|
||
|
// Cookie is a set of options for setting HTTP cookies.
|
||
|
Cookie CookieOptions
|
||
|
// IDLength specifies the length of session IDs. Default is 16.
|
||
|
IDLength int
|
||
|
// GCInterval is the time interval for GC operations. Default is 5 minutes.
|
||
|
GCInterval time.Duration
|
||
|
// ErrorFunc is the function used to print errors when something went wrong on
|
||
|
// the background. Default is to drop errors silently.
|
||
|
ErrorFunc func(err error)
|
||
|
// ReadIDFunc is the function to read session ID from the request. Default is
|
||
|
// reading from cookie.
|
||
|
ReadIDFunc func(r *http.Request) string
|
||
|
// WriteIDFunc is the function to write session ID to the response. Default is
|
||
|
// writing to cookie. The `created` argument indicates whether a new session was
|
||
|
// created in the session store.
|
||
|
WriteIDFunc func(w http.ResponseWriter, r *http.Request, sid string, created bool)
|
||
|
}
|
||
|
|
||
|
const minimumSIDLength = 3
|
||
|
|
||
|
var ErrMinimumSIDLength = errors.Errorf("the SID does not have the minimum required length %d", minimumSIDLength)
|
||
|
|
||
|
// Sessioner returns a middleware handler that injects session.Session and
|
||
|
// session.Store into the request context, which are used for manipulating
|
||
|
// session data.
|
||
|
func Sessioner(opts ...Options) flamego.Handler {
|
||
|
var opt Options
|
||
|
if len(opts) > 0 {
|
||
|
opt = opts[0]
|
||
|
}
|
||
|
|
||
|
parseOptions := func(opts Options) Options {
|
||
|
if opts.Initer == nil {
|
||
|
opts.Initer = MemoryIniter()
|
||
|
}
|
||
|
|
||
|
if reflect.DeepEqual(opts.Cookie, CookieOptions{}) {
|
||
|
opts.Cookie = CookieOptions{
|
||
|
HTTPOnly: true,
|
||
|
}
|
||
|
}
|
||
|
if opts.Cookie.Name == "" {
|
||
|
opts.Cookie.Name = "flamego_session"
|
||
|
}
|
||
|
if opts.Cookie.SameSite < http.SameSiteDefaultMode || opts.Cookie.SameSite > http.SameSiteNoneMode {
|
||
|
opts.Cookie.SameSite = http.SameSiteLaxMode
|
||
|
}
|
||
|
if opts.Cookie.Path == "" {
|
||
|
opts.Cookie.Path = "/"
|
||
|
}
|
||
|
|
||
|
// NOTE: The file store requires at least 3 characters for the filename.
|
||
|
if opts.IDLength < minimumSIDLength {
|
||
|
opts.IDLength = 16
|
||
|
}
|
||
|
|
||
|
if opts.GCInterval.Seconds() < 1 {
|
||
|
opts.GCInterval = 5 * time.Minute
|
||
|
}
|
||
|
|
||
|
if opts.ErrorFunc == nil {
|
||
|
opts.ErrorFunc = func(error) {}
|
||
|
}
|
||
|
|
||
|
if opts.ReadIDFunc == nil {
|
||
|
opts.ReadIDFunc = func(r *http.Request) string {
|
||
|
cookie, err := r.Cookie(opts.Cookie.Name)
|
||
|
if err != nil {
|
||
|
return ""
|
||
|
}
|
||
|
return cookie.Value
|
||
|
}
|
||
|
}
|
||
|
if opts.WriteIDFunc == nil {
|
||
|
opts.WriteIDFunc = func(w http.ResponseWriter, r *http.Request, sid string, created bool) {
|
||
|
if !created {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
cookie := &http.Cookie{
|
||
|
Name: opts.Cookie.Name,
|
||
|
Value: sid,
|
||
|
Path: opts.Cookie.Path,
|
||
|
Domain: opts.Cookie.Domain,
|
||
|
MaxAge: opts.Cookie.MaxAge,
|
||
|
Secure: opts.Cookie.Secure,
|
||
|
HttpOnly: opts.Cookie.HTTPOnly,
|
||
|
SameSite: opts.Cookie.SameSite,
|
||
|
}
|
||
|
http.SetCookie(w, cookie)
|
||
|
r.AddCookie(cookie)
|
||
|
}
|
||
|
}
|
||
|
return opts
|
||
|
}
|
||
|
|
||
|
opt = parseOptions(opt)
|
||
|
ctx := context.Background()
|
||
|
|
||
|
store, err := opt.Initer(ctx, opt.Config)
|
||
|
if err != nil {
|
||
|
panic("session: " + err.Error())
|
||
|
}
|
||
|
|
||
|
mgr := newManager(store)
|
||
|
mgr.startGC(ctx, opt.GCInterval, opt.ErrorFunc)
|
||
|
|
||
|
return flamego.ContextInvoker(func(c flamego.Context) {
|
||
|
sid := opt.ReadIDFunc(c.Request().Request)
|
||
|
sess, created, err := mgr.load(c.Request().Request, sid, opt.IDLength)
|
||
|
if err != nil {
|
||
|
panic("session: load: " + err.Error())
|
||
|
}
|
||
|
|
||
|
opt.WriteIDFunc(c.ResponseWriter(), c.Request().Request, sess.ID(), created)
|
||
|
|
||
|
flash := sess.Get(flashKey)
|
||
|
sess.Delete(flashKey)
|
||
|
|
||
|
c.Map(store, sess)
|
||
|
c.MapTo(flash, (*Flash)(nil))
|
||
|
c.Next()
|
||
|
|
||
|
err = store.Save(c.Request().Context(), sess)
|
||
|
if err != nil {
|
||
|
panic("session: save: " + err.Error())
|
||
|
}
|
||
|
})
|
||
|
}
|