go-sample-webpage/vendor/github.com/lestrrat-go/jwx/jws/message.go

364 lines
8.9 KiB
Go

package jws
import (
"bytes"
"context"
"github.com/lestrrat-go/jwx/internal/base64"
"github.com/lestrrat-go/jwx/internal/json"
"github.com/lestrrat-go/jwx/internal/pool"
"github.com/lestrrat-go/jwx/jwk"
"github.com/pkg/errors"
)
func NewSignature() *Signature {
return &Signature{}
}
func (s Signature) PublicHeaders() Headers {
return s.headers
}
func (s *Signature) SetPublicHeaders(v Headers) *Signature {
s.headers = v
return s
}
func (s Signature) ProtectedHeaders() Headers {
return s.protected
}
func (s *Signature) SetProtectedHeaders(v Headers) *Signature {
s.protected = v
return s
}
func (s Signature) Signature() []byte {
return s.signature
}
func (s *Signature) SetSignature(v []byte) *Signature {
s.signature = v
return s
}
// Sign populates the signature field, with a signature generated by
// given the signer object and payload.
//
// The first return value is the raw signature in binary format.
// The second return value s the full three-segment signature
// (e.g. "eyXXXX.XXXXX.XXXX")
func (s *Signature) Sign(payload []byte, signer Signer, key interface{}) ([]byte, []byte, error) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
hdrs, err := mergeHeaders(ctx, s.headers, s.protected)
if err != nil {
return nil, nil, errors.Wrap(err, `failed to merge headers`)
}
if err := hdrs.Set(AlgorithmKey, signer.Algorithm()); err != nil {
return nil, nil, errors.Wrap(err, `failed to set "alg"`)
}
// If the key is a jwk.Key instance, obtain the raw key
if jwkKey, ok := key.(jwk.Key); ok {
// If we have a key ID specified by this jwk.Key, use that in the header
if kid := jwkKey.KeyID(); kid != "" {
if err := hdrs.Set(jwk.KeyIDKey, kid); err != nil {
return nil, nil, errors.Wrap(err, `set key ID from jwk.Key`)
}
}
}
hdrbuf, err := json.Marshal(hdrs)
if err != nil {
return nil, nil, errors.Wrap(err, `failed to marshal headers`)
}
buf := pool.GetBytesBuffer()
defer pool.ReleaseBytesBuffer(buf)
buf.WriteString(base64.EncodeToString(hdrbuf))
buf.WriteByte('.')
if getB64Value(hdrs) {
buf.WriteString(base64.EncodeToString(payload))
} else {
if bytes.ContainsRune(payload, '.') {
return nil, nil, errors.New(`payload must not contain a "." when b64 = false`)
}
buf.Write(payload)
}
signature, err := signer.Sign(buf.Bytes(), key)
if err != nil {
return nil, nil, errors.Wrap(err, `failed to sign payload`)
}
s.signature = signature
buf.WriteByte('.')
buf.WriteString(base64.EncodeToString(signature))
ret := make([]byte, buf.Len())
copy(ret, buf.Bytes())
return signature, ret, nil
}
func NewMessage() *Message {
return &Message{}
}
// Payload returns the decoded payload
func (m Message) Payload() []byte {
return m.payload
}
func (m *Message) SetPayload(v []byte) *Message {
m.payload = v
return m
}
func (m Message) Signatures() []*Signature {
return m.signatures
}
func (m *Message) AppendSignature(v *Signature) *Message {
m.signatures = append(m.signatures, v)
return m
}
func (m *Message) ClearSignatures() *Message {
m.signatures = nil
return m
}
// LookupSignature looks up a particular signature entry using
// the `kid` value
func (m Message) LookupSignature(kid string) []*Signature {
var sigs []*Signature
for _, sig := range m.signatures {
if hdr := sig.PublicHeaders(); hdr != nil {
hdrKeyID := hdr.KeyID()
if hdrKeyID == kid {
sigs = append(sigs, sig)
continue
}
}
if hdr := sig.ProtectedHeaders(); hdr != nil {
hdrKeyID := hdr.KeyID()
if hdrKeyID == kid {
sigs = append(sigs, sig)
continue
}
}
}
return sigs
}
type messageProxy struct {
Payload string `json:"payload"` // base64 URL encoded
Signatures []*signatureProxy `json:"signatures,omitempty"`
// These are only available when we're using flattened JSON
// (normally I would embed *signatureProxy, but because
// signatureProxy is not exported, we can't use that)
Header *json.RawMessage `json:"header,omitempty"`
Protected *string `json:"protected,omitempty"`
Signature *string `json:"signature,omitempty"`
}
type signatureProxy struct {
Header json.RawMessage `json:"header"`
Protected string `json:"protected"`
Signature string `json:"signature"`
}
func (m *Message) UnmarshalJSON(buf []byte) error {
var proxy messageProxy
if err := json.Unmarshal(buf, &proxy); err != nil {
return errors.Wrap(err, `failed to unmarshal into temporary structure`)
}
if proxy.Signature != nil {
if len(proxy.Signatures) > 0 {
return errors.New(`invalid format ("signatures" and "signature" keys cannot both be present)`)
}
var sigproxy signatureProxy
if hdr := proxy.Header; hdr != nil {
sigproxy.Header = *hdr
}
if hdr := proxy.Protected; hdr != nil {
sigproxy.Protected = *hdr
}
sigproxy.Signature = *proxy.Signature
proxy.Signatures = append(proxy.Signatures, &sigproxy)
}
b64 := true
for i, sigproxy := range proxy.Signatures {
var sig Signature
if len(sigproxy.Header) > 0 {
sig.headers = NewHeaders()
if err := json.Unmarshal(sigproxy.Header, sig.headers); err != nil {
return errors.Wrapf(err, `failed to unmarshal "header" for signature #%d`, i+1)
}
}
if len(sigproxy.Protected) > 0 {
buf, err := base64.DecodeString(sigproxy.Protected)
if err != nil {
return errors.Wrapf(err, `failed to decode "protected" for signature #%d`, i+1)
}
sig.protected = NewHeaders()
if err := json.Unmarshal(buf, sig.protected); err != nil {
return errors.Wrapf(err, `failed to unmarshal "protected" for signature #%d`, i+1)
}
if i == 0 {
b64 = getB64Value(sig.protected)
} else {
if b64 != getB64Value(sig.protected) {
return errors.Errorf(`b64 value must be the same for all signatures`)
}
}
}
if len(sigproxy.Signature) == 0 {
return errors.Errorf(`"signature" must be non-empty for signature #%d`, i+1)
}
buf, err := base64.DecodeString(sigproxy.Signature)
if err != nil {
return errors.Wrapf(err, `failed to decode "signature" for signature #%d`, i+1)
}
sig.signature = buf
m.signatures = append(m.signatures, &sig)
}
if !b64 {
m.payload = []byte(proxy.Payload)
} else {
// Everything in the proxy is base64 encoded, except for signatures.header
if len(proxy.Payload) == 0 {
return errors.New(`"payload" must be non-empty`)
}
buf, err := base64.DecodeString(proxy.Payload)
if err != nil {
return errors.Wrap(err, `failed to decode payload`)
}
m.payload = buf
}
m.b64 = b64
return nil
}
func (m Message) MarshalJSON() ([]byte, error) {
if len(m.signatures) == 1 {
return m.marshalFlattened()
}
return m.marshalFull()
}
func (m Message) marshalFlattened() ([]byte, error) {
buf := pool.GetBytesBuffer()
defer pool.ReleaseBytesBuffer(buf)
sig := m.signatures[0]
buf.WriteRune('{')
var wrote bool
if hdr := sig.headers; hdr != nil {
hdrjs, err := hdr.MarshalJSON()
if err != nil {
return nil, errors.Wrap(err, `failed to marshal "header" (flattened format)`)
}
buf.WriteString(`"header":`)
buf.Write(hdrjs)
wrote = true
}
if wrote {
buf.WriteRune(',')
}
buf.WriteString(`"payload":"`)
buf.WriteString(base64.EncodeToString(m.payload))
buf.WriteRune('"')
if protected := sig.protected; protected != nil {
protectedbuf, err := protected.MarshalJSON()
if err != nil {
return nil, errors.Wrap(err, `failed to marshal "protected" (flattened format)`)
}
buf.WriteString(`,"protected":"`)
buf.WriteString(base64.EncodeToString(protectedbuf))
buf.WriteRune('"')
}
buf.WriteString(`,"signature":"`)
buf.WriteString(base64.EncodeToString(sig.signature))
buf.WriteRune('"')
buf.WriteRune('}')
ret := make([]byte, buf.Len())
copy(ret, buf.Bytes())
return ret, nil
}
func (m Message) marshalFull() ([]byte, error) {
buf := pool.GetBytesBuffer()
defer pool.ReleaseBytesBuffer(buf)
buf.WriteString(`{"payload":"`)
buf.WriteString(base64.EncodeToString(m.payload))
buf.WriteString(`","signatures":[`)
for i, sig := range m.signatures {
if i > 0 {
buf.WriteRune(',')
}
buf.WriteRune('{')
var wrote bool
if hdr := sig.headers; hdr != nil {
hdrbuf, err := hdr.MarshalJSON()
if err != nil {
return nil, errors.Wrapf(err, `failed to marshal "header" for signature #%d`, i+1)
}
buf.WriteString(`"header":`)
buf.Write(hdrbuf)
wrote = true
}
if protected := sig.protected; protected != nil {
protectedbuf, err := protected.MarshalJSON()
if err != nil {
return nil, errors.Wrapf(err, `failed to marshal "protected" for signature #%d`, i+1)
}
if wrote {
buf.WriteRune(',')
}
buf.WriteString(`"protected":"`)
buf.WriteString(base64.EncodeToString(protectedbuf))
buf.WriteRune('"')
wrote = true
}
if wrote {
buf.WriteRune(',')
}
buf.WriteString(`"signature":"`)
buf.WriteString(base64.EncodeToString(sig.signature))
buf.WriteString(`"}`)
}
buf.WriteString(`]}`)
ret := make([]byte, buf.Len())
copy(ret, buf.Bytes())
return ret, nil
}