141 lines
3.7 KiB
Go
141 lines
3.7 KiB
Go
|
// Copyright 2021 Flamego. All rights reserved.
|
||
|
// Use of this source code is governed by a MIT-style
|
||
|
// license that can be found in the LICENSE file.
|
||
|
|
||
|
package session
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"crypto/rand"
|
||
|
"math/big"
|
||
|
"net/http"
|
||
|
"time"
|
||
|
|
||
|
"github.com/pkg/errors"
|
||
|
)
|
||
|
|
||
|
// Store is a session store with capabilities of checking, reading, destroying
|
||
|
// and GC sessions.
|
||
|
type Store interface {
|
||
|
// Exist returns true of the session with given ID exists.
|
||
|
Exist(ctx context.Context, sid string) bool
|
||
|
// Read returns the session with given ID. If a session with the ID does not
|
||
|
// exist, a new session with the same ID is created and returned.
|
||
|
Read(ctx context.Context, sid string) (Session, error)
|
||
|
// Destroy deletes session with given ID from the session store completely.
|
||
|
Destroy(ctx context.Context, sid string) error
|
||
|
// Save persists session data to the session store.
|
||
|
Save(ctx context.Context, session Session) error
|
||
|
// GC performs a GC operation on the session store.
|
||
|
GC(ctx context.Context) error
|
||
|
}
|
||
|
|
||
|
// Initer takes arbitrary number of arguments needed for initialization and
|
||
|
// returns an initialized session store.
|
||
|
type Initer func(ctx context.Context, args ...interface{}) (Store, error)
|
||
|
|
||
|
// manager is wrapper for wiring HTTP request and session stores.
|
||
|
type manager struct {
|
||
|
store Store // The session store that is being managed.
|
||
|
}
|
||
|
|
||
|
// newManager returns a new manager with given session store.
|
||
|
func newManager(store Store) *manager {
|
||
|
return &manager{
|
||
|
store: store,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// startGC starts a background goroutine to trigger GC of the session store in
|
||
|
// given time interval. Errors are printed using the `errFunc`. It returns a
|
||
|
// send-only channel for stopping the background goroutine.
|
||
|
func (m *manager) startGC(ctx context.Context, interval time.Duration, errFunc func(error)) chan<- struct{} {
|
||
|
stop := make(chan struct{})
|
||
|
go func() {
|
||
|
ticker := time.NewTicker(interval)
|
||
|
for {
|
||
|
err := m.store.GC(ctx)
|
||
|
if err != nil {
|
||
|
errFunc(err)
|
||
|
}
|
||
|
|
||
|
select {
|
||
|
case <-stop:
|
||
|
ticker.Stop()
|
||
|
return
|
||
|
case <-ticker.C:
|
||
|
}
|
||
|
}
|
||
|
}()
|
||
|
return stop
|
||
|
}
|
||
|
|
||
|
// randomChars returns a generated string in given number of random characters.
|
||
|
func randomChars(n int) (string, error) {
|
||
|
const alphanum = "0123456789abcdefghijklmnopqrstuvwxyz"
|
||
|
|
||
|
randomInt := func(max *big.Int) (int, error) {
|
||
|
r, err := rand.Int(rand.Reader, max)
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
|
||
|
return int(r.Int64()), nil
|
||
|
}
|
||
|
|
||
|
buffer := make([]byte, n)
|
||
|
max := big.NewInt(int64(len(alphanum)))
|
||
|
for i := 0; i < n; i++ {
|
||
|
index, err := randomInt(max)
|
||
|
if err != nil {
|
||
|
return "", err
|
||
|
}
|
||
|
|
||
|
buffer[i] = alphanum[index]
|
||
|
}
|
||
|
|
||
|
return string(buffer), nil
|
||
|
}
|
||
|
|
||
|
// isValidSessionID returns true if given session ID looks like a valid ID.
|
||
|
func isValidSessionID(sid string, idLength int) bool {
|
||
|
if len(sid) != idLength {
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
for i := range sid {
|
||
|
switch {
|
||
|
case '0' <= sid[i] && sid[i] <= '9':
|
||
|
case 'a' <= sid[i] && sid[i] <= 'z':
|
||
|
default:
|
||
|
return false
|
||
|
}
|
||
|
}
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
// load loads the session from the session store with session ID provided in the
|
||
|
// named cookie. If a session with the ID does not exist, it creates a new
|
||
|
// session with the same ID. A boolean value is returned to indicate whether a
|
||
|
// new session is created.
|
||
|
func (m *manager) load(r *http.Request, sid string, idLength int) (_ Session, created bool, err error) {
|
||
|
if isValidSessionID(sid, idLength) && m.store.Exist(r.Context(), sid) {
|
||
|
sess, err := m.store.Read(r.Context(), sid)
|
||
|
if err != nil {
|
||
|
return nil, false, errors.Wrap(err, "read")
|
||
|
}
|
||
|
return sess, false, nil
|
||
|
}
|
||
|
|
||
|
sid, err = randomChars(idLength)
|
||
|
if err != nil {
|
||
|
return nil, false, errors.Wrap(err, "new ID")
|
||
|
}
|
||
|
|
||
|
sess, err := m.store.Read(r.Context(), sid)
|
||
|
if err != nil {
|
||
|
return nil, false, errors.Wrap(err, "read")
|
||
|
}
|
||
|
return sess, true, nil
|
||
|
}
|