364 lines
8.9 KiB
Go
364 lines
8.9 KiB
Go
|
package jws
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"context"
|
||
|
|
||
|
"github.com/lestrrat-go/jwx/internal/base64"
|
||
|
"github.com/lestrrat-go/jwx/internal/json"
|
||
|
"github.com/lestrrat-go/jwx/internal/pool"
|
||
|
"github.com/lestrrat-go/jwx/jwk"
|
||
|
"github.com/pkg/errors"
|
||
|
)
|
||
|
|
||
|
func NewSignature() *Signature {
|
||
|
return &Signature{}
|
||
|
}
|
||
|
|
||
|
func (s Signature) PublicHeaders() Headers {
|
||
|
return s.headers
|
||
|
}
|
||
|
|
||
|
func (s *Signature) SetPublicHeaders(v Headers) *Signature {
|
||
|
s.headers = v
|
||
|
return s
|
||
|
}
|
||
|
|
||
|
func (s Signature) ProtectedHeaders() Headers {
|
||
|
return s.protected
|
||
|
}
|
||
|
|
||
|
func (s *Signature) SetProtectedHeaders(v Headers) *Signature {
|
||
|
s.protected = v
|
||
|
return s
|
||
|
}
|
||
|
|
||
|
func (s Signature) Signature() []byte {
|
||
|
return s.signature
|
||
|
}
|
||
|
|
||
|
func (s *Signature) SetSignature(v []byte) *Signature {
|
||
|
s.signature = v
|
||
|
return s
|
||
|
}
|
||
|
|
||
|
// Sign populates the signature field, with a signature generated by
|
||
|
// given the signer object and payload.
|
||
|
//
|
||
|
// The first return value is the raw signature in binary format.
|
||
|
// The second return value s the full three-segment signature
|
||
|
// (e.g. "eyXXXX.XXXXX.XXXX")
|
||
|
func (s *Signature) Sign(payload []byte, signer Signer, key interface{}) ([]byte, []byte, error) {
|
||
|
ctx, cancel := context.WithCancel(context.Background())
|
||
|
defer cancel()
|
||
|
|
||
|
hdrs, err := mergeHeaders(ctx, s.headers, s.protected)
|
||
|
if err != nil {
|
||
|
return nil, nil, errors.Wrap(err, `failed to merge headers`)
|
||
|
}
|
||
|
|
||
|
if err := hdrs.Set(AlgorithmKey, signer.Algorithm()); err != nil {
|
||
|
return nil, nil, errors.Wrap(err, `failed to set "alg"`)
|
||
|
}
|
||
|
|
||
|
// If the key is a jwk.Key instance, obtain the raw key
|
||
|
if jwkKey, ok := key.(jwk.Key); ok {
|
||
|
// If we have a key ID specified by this jwk.Key, use that in the header
|
||
|
if kid := jwkKey.KeyID(); kid != "" {
|
||
|
if err := hdrs.Set(jwk.KeyIDKey, kid); err != nil {
|
||
|
return nil, nil, errors.Wrap(err, `set key ID from jwk.Key`)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
hdrbuf, err := json.Marshal(hdrs)
|
||
|
if err != nil {
|
||
|
return nil, nil, errors.Wrap(err, `failed to marshal headers`)
|
||
|
}
|
||
|
|
||
|
buf := pool.GetBytesBuffer()
|
||
|
defer pool.ReleaseBytesBuffer(buf)
|
||
|
|
||
|
buf.WriteString(base64.EncodeToString(hdrbuf))
|
||
|
buf.WriteByte('.')
|
||
|
if getB64Value(hdrs) {
|
||
|
buf.WriteString(base64.EncodeToString(payload))
|
||
|
} else {
|
||
|
if bytes.ContainsRune(payload, '.') {
|
||
|
return nil, nil, errors.New(`payload must not contain a "." when b64 = false`)
|
||
|
}
|
||
|
buf.Write(payload)
|
||
|
}
|
||
|
|
||
|
signature, err := signer.Sign(buf.Bytes(), key)
|
||
|
if err != nil {
|
||
|
return nil, nil, errors.Wrap(err, `failed to sign payload`)
|
||
|
}
|
||
|
s.signature = signature
|
||
|
|
||
|
buf.WriteByte('.')
|
||
|
buf.WriteString(base64.EncodeToString(signature))
|
||
|
ret := make([]byte, buf.Len())
|
||
|
copy(ret, buf.Bytes())
|
||
|
|
||
|
return signature, ret, nil
|
||
|
}
|
||
|
|
||
|
func NewMessage() *Message {
|
||
|
return &Message{}
|
||
|
}
|
||
|
|
||
|
// Payload returns the decoded payload
|
||
|
func (m Message) Payload() []byte {
|
||
|
return m.payload
|
||
|
}
|
||
|
|
||
|
func (m *Message) SetPayload(v []byte) *Message {
|
||
|
m.payload = v
|
||
|
return m
|
||
|
}
|
||
|
|
||
|
func (m Message) Signatures() []*Signature {
|
||
|
return m.signatures
|
||
|
}
|
||
|
|
||
|
func (m *Message) AppendSignature(v *Signature) *Message {
|
||
|
m.signatures = append(m.signatures, v)
|
||
|
return m
|
||
|
}
|
||
|
|
||
|
func (m *Message) ClearSignatures() *Message {
|
||
|
m.signatures = nil
|
||
|
return m
|
||
|
}
|
||
|
|
||
|
// LookupSignature looks up a particular signature entry using
|
||
|
// the `kid` value
|
||
|
func (m Message) LookupSignature(kid string) []*Signature {
|
||
|
var sigs []*Signature
|
||
|
for _, sig := range m.signatures {
|
||
|
if hdr := sig.PublicHeaders(); hdr != nil {
|
||
|
hdrKeyID := hdr.KeyID()
|
||
|
if hdrKeyID == kid {
|
||
|
sigs = append(sigs, sig)
|
||
|
continue
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if hdr := sig.ProtectedHeaders(); hdr != nil {
|
||
|
hdrKeyID := hdr.KeyID()
|
||
|
if hdrKeyID == kid {
|
||
|
sigs = append(sigs, sig)
|
||
|
continue
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return sigs
|
||
|
}
|
||
|
|
||
|
type messageProxy struct {
|
||
|
Payload string `json:"payload"` // base64 URL encoded
|
||
|
Signatures []*signatureProxy `json:"signatures,omitempty"`
|
||
|
|
||
|
// These are only available when we're using flattened JSON
|
||
|
// (normally I would embed *signatureProxy, but because
|
||
|
// signatureProxy is not exported, we can't use that)
|
||
|
Header *json.RawMessage `json:"header,omitempty"`
|
||
|
Protected *string `json:"protected,omitempty"`
|
||
|
Signature *string `json:"signature,omitempty"`
|
||
|
}
|
||
|
|
||
|
type signatureProxy struct {
|
||
|
Header json.RawMessage `json:"header"`
|
||
|
Protected string `json:"protected"`
|
||
|
Signature string `json:"signature"`
|
||
|
}
|
||
|
|
||
|
func (m *Message) UnmarshalJSON(buf []byte) error {
|
||
|
var proxy messageProxy
|
||
|
if err := json.Unmarshal(buf, &proxy); err != nil {
|
||
|
return errors.Wrap(err, `failed to unmarshal into temporary structure`)
|
||
|
}
|
||
|
|
||
|
if proxy.Signature != nil {
|
||
|
if len(proxy.Signatures) > 0 {
|
||
|
return errors.New(`invalid format ("signatures" and "signature" keys cannot both be present)`)
|
||
|
}
|
||
|
|
||
|
var sigproxy signatureProxy
|
||
|
if hdr := proxy.Header; hdr != nil {
|
||
|
sigproxy.Header = *hdr
|
||
|
}
|
||
|
if hdr := proxy.Protected; hdr != nil {
|
||
|
sigproxy.Protected = *hdr
|
||
|
}
|
||
|
sigproxy.Signature = *proxy.Signature
|
||
|
|
||
|
proxy.Signatures = append(proxy.Signatures, &sigproxy)
|
||
|
}
|
||
|
|
||
|
b64 := true
|
||
|
for i, sigproxy := range proxy.Signatures {
|
||
|
var sig Signature
|
||
|
|
||
|
if len(sigproxy.Header) > 0 {
|
||
|
sig.headers = NewHeaders()
|
||
|
if err := json.Unmarshal(sigproxy.Header, sig.headers); err != nil {
|
||
|
return errors.Wrapf(err, `failed to unmarshal "header" for signature #%d`, i+1)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if len(sigproxy.Protected) > 0 {
|
||
|
buf, err := base64.DecodeString(sigproxy.Protected)
|
||
|
if err != nil {
|
||
|
return errors.Wrapf(err, `failed to decode "protected" for signature #%d`, i+1)
|
||
|
}
|
||
|
sig.protected = NewHeaders()
|
||
|
if err := json.Unmarshal(buf, sig.protected); err != nil {
|
||
|
return errors.Wrapf(err, `failed to unmarshal "protected" for signature #%d`, i+1)
|
||
|
}
|
||
|
|
||
|
if i == 0 {
|
||
|
b64 = getB64Value(sig.protected)
|
||
|
} else {
|
||
|
if b64 != getB64Value(sig.protected) {
|
||
|
return errors.Errorf(`b64 value must be the same for all signatures`)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if len(sigproxy.Signature) == 0 {
|
||
|
return errors.Errorf(`"signature" must be non-empty for signature #%d`, i+1)
|
||
|
}
|
||
|
|
||
|
buf, err := base64.DecodeString(sigproxy.Signature)
|
||
|
if err != nil {
|
||
|
return errors.Wrapf(err, `failed to decode "signature" for signature #%d`, i+1)
|
||
|
}
|
||
|
sig.signature = buf
|
||
|
m.signatures = append(m.signatures, &sig)
|
||
|
}
|
||
|
|
||
|
if !b64 {
|
||
|
m.payload = []byte(proxy.Payload)
|
||
|
} else {
|
||
|
// Everything in the proxy is base64 encoded, except for signatures.header
|
||
|
if len(proxy.Payload) == 0 {
|
||
|
return errors.New(`"payload" must be non-empty`)
|
||
|
}
|
||
|
|
||
|
buf, err := base64.DecodeString(proxy.Payload)
|
||
|
if err != nil {
|
||
|
return errors.Wrap(err, `failed to decode payload`)
|
||
|
}
|
||
|
m.payload = buf
|
||
|
}
|
||
|
m.b64 = b64
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (m Message) MarshalJSON() ([]byte, error) {
|
||
|
if len(m.signatures) == 1 {
|
||
|
return m.marshalFlattened()
|
||
|
}
|
||
|
return m.marshalFull()
|
||
|
}
|
||
|
|
||
|
func (m Message) marshalFlattened() ([]byte, error) {
|
||
|
buf := pool.GetBytesBuffer()
|
||
|
defer pool.ReleaseBytesBuffer(buf)
|
||
|
|
||
|
sig := m.signatures[0]
|
||
|
|
||
|
buf.WriteRune('{')
|
||
|
var wrote bool
|
||
|
|
||
|
if hdr := sig.headers; hdr != nil {
|
||
|
hdrjs, err := hdr.MarshalJSON()
|
||
|
if err != nil {
|
||
|
return nil, errors.Wrap(err, `failed to marshal "header" (flattened format)`)
|
||
|
}
|
||
|
buf.WriteString(`"header":`)
|
||
|
buf.Write(hdrjs)
|
||
|
wrote = true
|
||
|
}
|
||
|
|
||
|
if wrote {
|
||
|
buf.WriteRune(',')
|
||
|
}
|
||
|
buf.WriteString(`"payload":"`)
|
||
|
buf.WriteString(base64.EncodeToString(m.payload))
|
||
|
buf.WriteRune('"')
|
||
|
|
||
|
if protected := sig.protected; protected != nil {
|
||
|
protectedbuf, err := protected.MarshalJSON()
|
||
|
if err != nil {
|
||
|
return nil, errors.Wrap(err, `failed to marshal "protected" (flattened format)`)
|
||
|
}
|
||
|
buf.WriteString(`,"protected":"`)
|
||
|
buf.WriteString(base64.EncodeToString(protectedbuf))
|
||
|
buf.WriteRune('"')
|
||
|
}
|
||
|
|
||
|
buf.WriteString(`,"signature":"`)
|
||
|
buf.WriteString(base64.EncodeToString(sig.signature))
|
||
|
buf.WriteRune('"')
|
||
|
buf.WriteRune('}')
|
||
|
|
||
|
ret := make([]byte, buf.Len())
|
||
|
copy(ret, buf.Bytes())
|
||
|
return ret, nil
|
||
|
}
|
||
|
|
||
|
func (m Message) marshalFull() ([]byte, error) {
|
||
|
buf := pool.GetBytesBuffer()
|
||
|
defer pool.ReleaseBytesBuffer(buf)
|
||
|
|
||
|
buf.WriteString(`{"payload":"`)
|
||
|
buf.WriteString(base64.EncodeToString(m.payload))
|
||
|
buf.WriteString(`","signatures":[`)
|
||
|
for i, sig := range m.signatures {
|
||
|
if i > 0 {
|
||
|
buf.WriteRune(',')
|
||
|
}
|
||
|
|
||
|
buf.WriteRune('{')
|
||
|
var wrote bool
|
||
|
if hdr := sig.headers; hdr != nil {
|
||
|
hdrbuf, err := hdr.MarshalJSON()
|
||
|
if err != nil {
|
||
|
return nil, errors.Wrapf(err, `failed to marshal "header" for signature #%d`, i+1)
|
||
|
}
|
||
|
buf.WriteString(`"header":`)
|
||
|
buf.Write(hdrbuf)
|
||
|
wrote = true
|
||
|
}
|
||
|
|
||
|
if protected := sig.protected; protected != nil {
|
||
|
protectedbuf, err := protected.MarshalJSON()
|
||
|
if err != nil {
|
||
|
return nil, errors.Wrapf(err, `failed to marshal "protected" for signature #%d`, i+1)
|
||
|
}
|
||
|
if wrote {
|
||
|
buf.WriteRune(',')
|
||
|
}
|
||
|
buf.WriteString(`"protected":"`)
|
||
|
buf.WriteString(base64.EncodeToString(protectedbuf))
|
||
|
buf.WriteRune('"')
|
||
|
wrote = true
|
||
|
}
|
||
|
|
||
|
if wrote {
|
||
|
buf.WriteRune(',')
|
||
|
}
|
||
|
buf.WriteString(`"signature":"`)
|
||
|
buf.WriteString(base64.EncodeToString(sig.signature))
|
||
|
buf.WriteString(`"}`)
|
||
|
}
|
||
|
buf.WriteString(`]}`)
|
||
|
|
||
|
ret := make([]byte, buf.Len())
|
||
|
copy(ret, buf.Bytes())
|
||
|
return ret, nil
|
||
|
}
|