239 lines
5.8 KiB
Go
239 lines
5.8 KiB
Go
package jwt
|
|
|
|
import (
|
|
"fmt"
|
|
|
|
"github.com/lestrrat-go/jwx/internal/json"
|
|
"github.com/lestrrat-go/jwx/jwa"
|
|
"github.com/lestrrat-go/jwx/jwe"
|
|
"github.com/lestrrat-go/jwx/jws"
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
type SerializeCtx interface {
|
|
Step() int
|
|
Nested() bool
|
|
}
|
|
|
|
type serializeCtx struct {
|
|
step int
|
|
nested bool
|
|
}
|
|
|
|
func (ctx *serializeCtx) Step() int {
|
|
return ctx.step
|
|
}
|
|
|
|
func (ctx *serializeCtx) Nested() bool {
|
|
return ctx.nested
|
|
}
|
|
|
|
type SerializeStep interface {
|
|
Serialize(SerializeCtx, interface{}) (interface{}, error)
|
|
}
|
|
|
|
// Serializer is a generic serializer for JWTs. Whereas other conveinience
|
|
// functions can only do one thing (such as generate a JWS signed JWT),
|
|
// Using this construct you can serialize the token however you want.
|
|
//
|
|
// By default the serializer only marshals the token into a JSON payload.
|
|
// You must set up the rest of the steps that should be taken by the
|
|
// serializer.
|
|
//
|
|
// For example, to marshal the token into JSON, then apply JWS and JWE
|
|
// in that order, you would do:
|
|
//
|
|
// serialized, err := jwt.NewSerialer().
|
|
// Sign(jwa.RS256, key).
|
|
// Encrypt(jwa.RSA_OAEP, key.PublicKey).
|
|
// Serialize(token)
|
|
//
|
|
// The `jwt.Sign()` function is equivalent to
|
|
//
|
|
// serialized, err := jwt.NewSerializer().
|
|
// Sign(...args...).
|
|
// Serialize(token)
|
|
type Serializer struct {
|
|
steps []SerializeStep
|
|
}
|
|
|
|
// NewSerializer creates a new empty serializer.
|
|
func NewSerializer() *Serializer {
|
|
return &Serializer{}
|
|
}
|
|
|
|
// Reset clears all of the registered steps.
|
|
func (s *Serializer) Reset() *Serializer {
|
|
s.steps = nil
|
|
return s
|
|
}
|
|
|
|
// Step adds a new Step to the serialization process
|
|
func (s *Serializer) Step(step SerializeStep) *Serializer {
|
|
s.steps = append(s.steps, step)
|
|
return s
|
|
}
|
|
|
|
type jsonSerializer struct{}
|
|
|
|
func (jsonSerializer) Serialize(_ SerializeCtx, v interface{}) (interface{}, error) {
|
|
token, ok := v.(Token)
|
|
if !ok {
|
|
return nil, errors.Errorf(`invalid input: expected jwt.Token`)
|
|
}
|
|
|
|
buf, err := json.Marshal(token)
|
|
if err != nil {
|
|
return nil, errors.Errorf(`failed to serialize as JSON`)
|
|
}
|
|
return buf, nil
|
|
}
|
|
|
|
type genericHeader interface {
|
|
Get(string) (interface{}, bool)
|
|
Set(string, interface{}) error
|
|
}
|
|
|
|
func setTypeOrCty(ctx SerializeCtx, hdrs genericHeader) error {
|
|
// cty and typ are common between JWE/JWS, so we don't use
|
|
// the constants in jws/jwe package here
|
|
const typKey = `typ`
|
|
const ctyKey = `cty`
|
|
|
|
if ctx.Step() == 1 {
|
|
// We are executed immediately after json marshaling
|
|
if _, ok := hdrs.Get(typKey); !ok {
|
|
if err := hdrs.Set(typKey, `JWT`); err != nil {
|
|
return errors.Wrapf(err, `failed to set %s key to "JWT"`, typKey)
|
|
}
|
|
}
|
|
} else {
|
|
if ctx.Nested() {
|
|
// If this is part of a nested sequence, we should set cty = 'JWT'
|
|
// https://datatracker.ietf.org/doc/html/rfc7519#section-5.2
|
|
if err := hdrs.Set(ctyKey, `JWT`); err != nil {
|
|
return errors.Wrapf(err, `failed to set %s key to "JWT"`, ctyKey)
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
type jwsSerializer struct {
|
|
alg jwa.SignatureAlgorithm
|
|
key interface{}
|
|
options []SignOption
|
|
}
|
|
|
|
func (s *jwsSerializer) Serialize(ctx SerializeCtx, v interface{}) (interface{}, error) {
|
|
payload, ok := v.([]byte)
|
|
if !ok {
|
|
return nil, errors.New(`expected []byte as input`)
|
|
}
|
|
|
|
var hdrs jws.Headers
|
|
//nolint:forcetypeassert
|
|
for _, option := range s.options {
|
|
switch option.Ident() {
|
|
case identJwsHeaders{}:
|
|
hdrs = option.Value().(jws.Headers)
|
|
}
|
|
}
|
|
|
|
if hdrs == nil {
|
|
hdrs = jws.NewHeaders()
|
|
}
|
|
|
|
if err := setTypeOrCty(ctx, hdrs); err != nil {
|
|
return nil, err // this is already wrapped
|
|
}
|
|
|
|
// JWTs MUST NOT use b64 = false
|
|
// https://datatracker.ietf.org/doc/html/rfc7797#section-7
|
|
if v, ok := hdrs.Get("b64"); ok {
|
|
if bval, bok := v.(bool); bok {
|
|
if !bval { // b64 = false
|
|
return nil, errors.New(`b64 cannot be false for JWTs`)
|
|
}
|
|
}
|
|
}
|
|
return jws.Sign(payload, s.alg, s.key, jws.WithHeaders(hdrs))
|
|
}
|
|
|
|
func (s *Serializer) Sign(alg jwa.SignatureAlgorithm, key interface{}, options ...SignOption) *Serializer {
|
|
return s.Step(&jwsSerializer{
|
|
alg: alg,
|
|
key: key,
|
|
options: options,
|
|
})
|
|
}
|
|
|
|
type jweSerializer struct {
|
|
keyalg jwa.KeyEncryptionAlgorithm
|
|
key interface{}
|
|
contentalg jwa.ContentEncryptionAlgorithm
|
|
compressalg jwa.CompressionAlgorithm
|
|
options []EncryptOption
|
|
}
|
|
|
|
func (s *jweSerializer) Serialize(ctx SerializeCtx, v interface{}) (interface{}, error) {
|
|
payload, ok := v.([]byte)
|
|
if !ok {
|
|
return nil, fmt.Errorf(`expected []byte as input`)
|
|
}
|
|
|
|
var hdrs jwe.Headers
|
|
//nolint:forcetypeassert
|
|
for _, option := range s.options {
|
|
switch option.Ident() {
|
|
case identJweHeaders{}:
|
|
hdrs = option.Value().(jwe.Headers)
|
|
}
|
|
}
|
|
|
|
if hdrs == nil {
|
|
hdrs = jwe.NewHeaders()
|
|
}
|
|
|
|
if err := setTypeOrCty(ctx, hdrs); err != nil {
|
|
return nil, err // this is already wrapped
|
|
}
|
|
return jwe.Encrypt(payload, s.keyalg, s.key, s.contentalg, s.compressalg, jwe.WithProtectedHeaders(hdrs))
|
|
}
|
|
|
|
func (s *Serializer) Encrypt(keyalg jwa.KeyEncryptionAlgorithm, key interface{}, contentalg jwa.ContentEncryptionAlgorithm, compressalg jwa.CompressionAlgorithm, options ...EncryptOption) *Serializer {
|
|
return s.Step(&jweSerializer{
|
|
keyalg: keyalg,
|
|
key: key,
|
|
contentalg: contentalg,
|
|
compressalg: compressalg,
|
|
options: options,
|
|
})
|
|
}
|
|
|
|
func (s *Serializer) Serialize(t Token) ([]byte, error) {
|
|
steps := make([]SerializeStep, len(s.steps)+1)
|
|
steps[0] = jsonSerializer{}
|
|
for i, step := range s.steps {
|
|
steps[i+1] = step
|
|
}
|
|
|
|
var ctx serializeCtx
|
|
ctx.nested = len(s.steps) > 1
|
|
var payload interface{} = t
|
|
for i, step := range steps {
|
|
ctx.step = i
|
|
v, err := step.Serialize(&ctx, payload)
|
|
if err != nil {
|
|
return nil, errors.Wrapf(err, `failed to serialize token at step #%d`, i+1)
|
|
}
|
|
payload = v
|
|
}
|
|
|
|
res, ok := payload.([]byte)
|
|
if !ok {
|
|
return nil, errors.New(`invalid serialization produced`)
|
|
}
|
|
|
|
return res, nil
|
|
}
|