go-sample-webpage/vendor/github.com/lestrrat-go/jwx/jwt/validate.go

189 lines
4.4 KiB
Go
Raw Normal View History

2021-11-04 01:14:51 +00:00
package jwt
import (
"fmt"
"strconv"
"time"
"github.com/pkg/errors"
)
type Clock interface {
Now() time.Time
}
type ClockFunc func() time.Time
func (f ClockFunc) Now() time.Time {
return f()
}
func isSupportedTimeClaim(c string) error {
switch c {
case ExpirationKey, IssuedAtKey, NotBeforeKey:
return nil
}
return errors.Errorf(`unsupported time claim %s`, strconv.Quote(c))
}
func timeClaim(t Token, clock Clock, c string) time.Time {
switch c {
case ExpirationKey:
return t.Expiration()
case IssuedAtKey:
return t.IssuedAt()
case NotBeforeKey:
return t.NotBefore()
case "":
return clock.Now()
}
return time.Time{} // should *NEVER* reach here, but...
}
// Validate makes sure that the essential claims stand.
//
// See the various `WithXXX` functions for optional parameters
// that can control the behavior of this method.
func Validate(t Token, options ...ValidateOption) error {
var issuer string
var subject string
var audience string
var jwtid string
var clock Clock = ClockFunc(time.Now)
var skew time.Duration
var deltas []delta
requiredMap := make(map[string]struct{})
claimValues := make(map[string]interface{})
for _, o := range options {
//nolint:forcetypeassert
switch o.Ident() {
case identClock{}:
clock = o.Value().(Clock)
case identAcceptableSkew{}:
skew = o.Value().(time.Duration)
case identIssuer{}:
issuer = o.Value().(string)
case identSubject{}:
subject = o.Value().(string)
case identAudience{}:
audience = o.Value().(string)
case identJwtid{}:
jwtid = o.Value().(string)
case identRequiredClaim{}:
requiredMap[o.Value().(string)] = struct{}{}
case identTimeDelta{}:
d := o.Value().(delta)
deltas = append(deltas, d)
if d.c1 != "" {
if err := isSupportedTimeClaim(d.c1); err != nil {
return err
}
requiredMap[d.c1] = struct{}{}
}
if d.c2 != "" {
if err := isSupportedTimeClaim(d.c2); err != nil {
return err
}
requiredMap[d.c2] = struct{}{}
}
case identClaim{}:
claim := o.Value().(claimValue)
claimValues[claim.name] = claim.value
}
}
for c := range requiredMap {
if _, ok := t.Get(c); !ok {
return errors.Errorf(`required claim %s was not found`, c)
}
}
for _, delta := range deltas {
// We don't check if the claims already exist, because we already did that
// by piggybacking on `required` check.
t1 := timeClaim(t, clock, delta.c1).Truncate(time.Second)
t2 := timeClaim(t, clock, delta.c2).Truncate(time.Second)
if delta.less { // t1 - t2 <= delta.dur
// t1 - t2 < delta.dur + skew
if t1.Sub(t2) > delta.dur+skew {
return errors.Errorf(`delta between %s and %s exceeds %s (skew %s)`, delta.c1, delta.c2, delta.dur, skew)
}
} else {
if t1.Sub(t2) < delta.dur-skew {
return errors.Errorf(`delta between %s and %s is less than %s (skew %s)`, delta.c1, delta.c2, delta.dur, skew)
}
}
}
// check for iss
if len(issuer) > 0 {
if v := t.Issuer(); v != issuer {
return errors.New(`iss not satisfied`)
}
}
// check for jti
if len(jwtid) > 0 {
if v := t.JwtID(); v != jwtid {
return errors.New(`jti not satisfied`)
}
}
// check for sub
if len(subject) > 0 {
if v := t.Subject(); v != subject {
return errors.New(`sub not satisfied`)
}
}
// check for aud
if len(audience) > 0 {
var found bool
for _, v := range t.Audience() {
if v == audience {
found = true
break
}
}
if !found {
return errors.New(`aud not satisfied`)
}
}
// check for exp
if tv := t.Expiration(); !tv.IsZero() && tv.Unix() != 0 {
now := clock.Now().Truncate(time.Second)
ttv := tv.Truncate(time.Second)
if !now.Before(ttv.Add(skew)) {
return errors.New(`exp not satisfied`)
}
}
// check for iat
if tv := t.IssuedAt(); !tv.IsZero() && tv.Unix() != 0 {
now := clock.Now().Truncate(time.Second)
ttv := tv.Truncate(time.Second)
if now.Before(ttv.Add(-1 * skew)) {
return errors.New(`iat not satisfied`)
}
}
// check for nbf
if tv := t.NotBefore(); !tv.IsZero() && tv.Unix() != 0 {
now := clock.Now().Truncate(time.Second)
ttv := tv.Truncate(time.Second)
// now cannot be before t, so we check for now > t - skew
if !now.Equal(ttv) && !now.After(ttv.Add(-1*skew)) {
return errors.New(`nbf not satisfied`)
}
}
for name, expectedValue := range claimValues {
if v, ok := t.Get(name); !ok || v != expectedValue {
return fmt.Errorf(`%v not satisfied`, name)
}
}
return nil
}