289 lines
8 KiB
Go
289 lines
8 KiB
Go
|
package jwtauth
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"errors"
|
||
|
"net/http"
|
||
|
"strings"
|
||
|
"time"
|
||
|
|
||
|
"github.com/lestrrat-go/jwx/jwa"
|
||
|
"github.com/lestrrat-go/jwx/jwt"
|
||
|
)
|
||
|
|
||
|
type JWTAuth struct {
|
||
|
alg jwa.SignatureAlgorithm
|
||
|
signKey interface{} // private-key
|
||
|
verifyKey interface{} // public-key, only used by RSA and ECDSA algorithms
|
||
|
verifier jwt.ParseOption
|
||
|
}
|
||
|
|
||
|
var (
|
||
|
TokenCtxKey = &contextKey{"Token"}
|
||
|
ErrorCtxKey = &contextKey{"Error"}
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
ErrUnauthorized = errors.New("token is unauthorized")
|
||
|
ErrExpired = errors.New("token is expired")
|
||
|
ErrNBFInvalid = errors.New("token nbf validation failed")
|
||
|
ErrIATInvalid = errors.New("token iat validation failed")
|
||
|
ErrNoTokenFound = errors.New("no token found")
|
||
|
ErrAlgoInvalid = errors.New("algorithm mismatch")
|
||
|
)
|
||
|
|
||
|
func New(alg string, signKey interface{}, verifyKey interface{}) *JWTAuth {
|
||
|
ja := &JWTAuth{alg: jwa.SignatureAlgorithm(alg), signKey: signKey, verifyKey: verifyKey}
|
||
|
|
||
|
if ja.verifyKey != nil {
|
||
|
ja.verifier = jwt.WithVerify(ja.alg, ja.verifyKey)
|
||
|
} else {
|
||
|
ja.verifier = jwt.WithVerify(ja.alg, ja.signKey)
|
||
|
}
|
||
|
|
||
|
return ja
|
||
|
}
|
||
|
|
||
|
// Verifier http middleware handler will verify a JWT string from a http request.
|
||
|
//
|
||
|
// Verifier will search for a JWT token in a http request, in the order:
|
||
|
// 1. 'jwt' URI query parameter
|
||
|
// 2. 'Authorization: BEARER T' request header
|
||
|
// 3. Cookie 'jwt' value
|
||
|
//
|
||
|
// The first JWT string that is found as a query parameter, authorization header
|
||
|
// or cookie header is then decoded by the `jwt-go` library and a *jwt.Token
|
||
|
// object is set on the request context. In the case of a signature decoding error
|
||
|
// the Verifier will also set the error on the request context.
|
||
|
//
|
||
|
// The Verifier always calls the next http handler in sequence, which can either
|
||
|
// be the generic `jwtauth.Authenticator` middleware or your own custom handler
|
||
|
// which checks the request context jwt token and error to prepare a custom
|
||
|
// http response.
|
||
|
func Verifier(ja *JWTAuth) func(http.Handler) http.Handler {
|
||
|
return func(next http.Handler) http.Handler {
|
||
|
return Verify(ja, TokenFromHeader, TokenFromCookie)(next)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func Verify(ja *JWTAuth, findTokenFns ...func(r *http.Request) string) func(http.Handler) http.Handler {
|
||
|
return func(next http.Handler) http.Handler {
|
||
|
hfn := func(w http.ResponseWriter, r *http.Request) {
|
||
|
ctx := r.Context()
|
||
|
token, err := VerifyRequest(ja, r, findTokenFns...)
|
||
|
ctx = NewContext(ctx, token, err)
|
||
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||
|
}
|
||
|
return http.HandlerFunc(hfn)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func VerifyRequest(ja *JWTAuth, r *http.Request, findTokenFns ...func(r *http.Request) string) (jwt.Token, error) {
|
||
|
var tokenString string
|
||
|
|
||
|
// Extract token string from the request by calling token find functions in
|
||
|
// the order they where provided. Further extraction stops if a function
|
||
|
// returns a non-empty string.
|
||
|
for _, fn := range findTokenFns {
|
||
|
tokenString = fn(r)
|
||
|
if tokenString != "" {
|
||
|
break
|
||
|
}
|
||
|
}
|
||
|
if tokenString == "" {
|
||
|
return nil, ErrNoTokenFound
|
||
|
}
|
||
|
|
||
|
return VerifyToken(ja, tokenString)
|
||
|
}
|
||
|
|
||
|
func VerifyToken(ja *JWTAuth, tokenString string) (jwt.Token, error) {
|
||
|
// Decode & verify the token
|
||
|
token, err := ja.Decode(tokenString)
|
||
|
if err != nil {
|
||
|
return token, ErrorReason(err)
|
||
|
}
|
||
|
|
||
|
if token == nil {
|
||
|
return nil, ErrUnauthorized
|
||
|
}
|
||
|
|
||
|
if err := jwt.Validate(token); err != nil {
|
||
|
return token, ErrorReason(err)
|
||
|
}
|
||
|
|
||
|
// Valid!
|
||
|
return token, nil
|
||
|
}
|
||
|
|
||
|
func (ja *JWTAuth) Encode(claims map[string]interface{}) (t jwt.Token, tokenString string, err error) {
|
||
|
t = jwt.New()
|
||
|
for k, v := range claims {
|
||
|
t.Set(k, v)
|
||
|
}
|
||
|
payload, err := ja.sign(t)
|
||
|
if err != nil {
|
||
|
return nil, "", err
|
||
|
}
|
||
|
tokenString = string(payload)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
func (ja *JWTAuth) Decode(tokenString string) (jwt.Token, error) {
|
||
|
return ja.parse([]byte(tokenString))
|
||
|
}
|
||
|
|
||
|
func (ja *JWTAuth) sign(token jwt.Token) ([]byte, error) {
|
||
|
return jwt.Sign(token, ja.alg, ja.signKey)
|
||
|
}
|
||
|
|
||
|
func (ja *JWTAuth) parse(payload []byte) (jwt.Token, error) {
|
||
|
return jwt.Parse(payload, ja.verifier)
|
||
|
}
|
||
|
|
||
|
// ErrorReason will normalize the error message from the underlining
|
||
|
// jwt library
|
||
|
func ErrorReason(err error) error {
|
||
|
switch err.Error() {
|
||
|
case "exp not satisfied", ErrExpired.Error():
|
||
|
return ErrExpired
|
||
|
case "iat not satisfied", ErrIATInvalid.Error():
|
||
|
return ErrIATInvalid
|
||
|
case "nbf not satisfied", ErrNBFInvalid.Error():
|
||
|
return ErrNBFInvalid
|
||
|
default:
|
||
|
return ErrUnauthorized
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Authenticator is a default authentication middleware to enforce access from the
|
||
|
// Verifier middleware request context values. The Authenticator sends a 401 Unauthorized
|
||
|
// response for any unverified tokens and passes the good ones through. It's just fine
|
||
|
// until you decide to write something similar and customize your client response.
|
||
|
func Authenticator(next http.Handler) http.Handler {
|
||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||
|
token, _, err := FromContext(r.Context())
|
||
|
|
||
|
if err != nil {
|
||
|
http.Error(w, err.Error(), 401)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
if token == nil || jwt.Validate(token) != nil {
|
||
|
http.Error(w, http.StatusText(401), 401)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// Token is authenticated, pass it through
|
||
|
next.ServeHTTP(w, r)
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func NewContext(ctx context.Context, t jwt.Token, err error) context.Context {
|
||
|
ctx = context.WithValue(ctx, TokenCtxKey, t)
|
||
|
ctx = context.WithValue(ctx, ErrorCtxKey, err)
|
||
|
return ctx
|
||
|
}
|
||
|
|
||
|
func FromContext(ctx context.Context) (jwt.Token, map[string]interface{}, error) {
|
||
|
token, _ := ctx.Value(TokenCtxKey).(jwt.Token)
|
||
|
|
||
|
var err error
|
||
|
var claims map[string]interface{}
|
||
|
|
||
|
if token != nil {
|
||
|
claims, err = token.AsMap(context.Background())
|
||
|
if err != nil {
|
||
|
return token, nil, err
|
||
|
}
|
||
|
} else {
|
||
|
claims = map[string]interface{}{}
|
||
|
}
|
||
|
|
||
|
err, _ = ctx.Value(ErrorCtxKey).(error)
|
||
|
|
||
|
return token, claims, err
|
||
|
}
|
||
|
|
||
|
// UnixTime returns the given time in UTC milliseconds
|
||
|
func UnixTime(tm time.Time) int64 {
|
||
|
return tm.UTC().Unix()
|
||
|
}
|
||
|
|
||
|
// EpochNow is a helper function that returns the NumericDate time value used by the spec
|
||
|
func EpochNow() int64 {
|
||
|
return time.Now().UTC().Unix()
|
||
|
}
|
||
|
|
||
|
// ExpireIn is a helper function to return calculated time in the future for "exp" claim
|
||
|
func ExpireIn(tm time.Duration) int64 {
|
||
|
return EpochNow() + int64(tm.Seconds())
|
||
|
}
|
||
|
|
||
|
// Set issued at ("iat") to specified time in the claims
|
||
|
func SetIssuedAt(claims map[string]interface{}, tm time.Time) {
|
||
|
claims["iat"] = tm.UTC().Unix()
|
||
|
}
|
||
|
|
||
|
// Set issued at ("iat") to present time in the claims
|
||
|
func SetIssuedNow(claims map[string]interface{}) {
|
||
|
claims["iat"] = EpochNow()
|
||
|
}
|
||
|
|
||
|
// Set expiry ("exp") in the claims
|
||
|
func SetExpiry(claims map[string]interface{}, tm time.Time) {
|
||
|
claims["exp"] = tm.UTC().Unix()
|
||
|
}
|
||
|
|
||
|
// Set expiry ("exp") in the claims to some duration from the present time
|
||
|
func SetExpiryIn(claims map[string]interface{}, tm time.Duration) {
|
||
|
claims["exp"] = ExpireIn(tm)
|
||
|
}
|
||
|
|
||
|
// TokenFromCookie tries to retreive the token string from a cookie named
|
||
|
// "jwt".
|
||
|
func TokenFromCookie(r *http.Request) string {
|
||
|
cookie, err := r.Cookie("jwt")
|
||
|
if err != nil {
|
||
|
return ""
|
||
|
}
|
||
|
return cookie.Value
|
||
|
}
|
||
|
|
||
|
// TokenFromHeader tries to retreive the token string from the
|
||
|
// "Authorization" reqeust header: "Authorization: BEARER T".
|
||
|
func TokenFromHeader(r *http.Request) string {
|
||
|
// Get token from authorization header.
|
||
|
bearer := r.Header.Get("Authorization")
|
||
|
if len(bearer) > 7 && strings.ToUpper(bearer[0:6]) == "BEARER" {
|
||
|
return bearer[7:]
|
||
|
}
|
||
|
return ""
|
||
|
}
|
||
|
|
||
|
// TokenFromQuery tries to retreive the token string from the "jwt" URI
|
||
|
// query parameter.
|
||
|
//
|
||
|
// To use it, build our own middleware handler, such as:
|
||
|
//
|
||
|
// func Verifier(ja *JWTAuth) func(http.Handler) http.Handler {
|
||
|
// return func(next http.Handler) http.Handler {
|
||
|
// return Verify(ja, TokenFromQuery, TokenFromHeader, TokenFromCookie)(next)
|
||
|
// }
|
||
|
// }
|
||
|
func TokenFromQuery(r *http.Request) string {
|
||
|
// Get token from query param named "jwt".
|
||
|
return r.URL.Query().Get("jwt")
|
||
|
}
|
||
|
|
||
|
// contextKey is a value for use with context.WithValue. It's used as
|
||
|
// a pointer so it fits in an interface{} without allocation. This technique
|
||
|
// for defining context keys was copied from Go 1.7's new use of context in net/http.
|
||
|
type contextKey struct {
|
||
|
name string
|
||
|
}
|
||
|
|
||
|
func (k *contextKey) String() string {
|
||
|
return "jwtauth context value " + k.name
|
||
|
}
|