209 lines
3.8 KiB
Go
209 lines
3.8 KiB
Go
|
package jwk
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
|
||
|
"github.com/lestrrat-go/iter/arrayiter"
|
||
|
"github.com/lestrrat-go/jwx/internal/json"
|
||
|
"github.com/lestrrat-go/jwx/internal/pool"
|
||
|
"github.com/pkg/errors"
|
||
|
)
|
||
|
|
||
|
// NewSet creates and empty `jwk.Set` object
|
||
|
func NewSet() Set {
|
||
|
return &set{}
|
||
|
}
|
||
|
|
||
|
func (s *set) Get(idx int) (Key, bool) {
|
||
|
s.mu.RLock()
|
||
|
defer s.mu.RUnlock()
|
||
|
|
||
|
if idx >= 0 && idx < len(s.keys) {
|
||
|
return s.keys[idx], true
|
||
|
}
|
||
|
return nil, false
|
||
|
}
|
||
|
|
||
|
func (s *set) Len() int {
|
||
|
s.mu.RLock()
|
||
|
defer s.mu.RUnlock()
|
||
|
|
||
|
return len(s.keys)
|
||
|
}
|
||
|
|
||
|
// indexNL is Index(), but without the locking
|
||
|
func (s *set) indexNL(key Key) int {
|
||
|
for i, k := range s.keys {
|
||
|
if k == key {
|
||
|
return i
|
||
|
}
|
||
|
}
|
||
|
return -1
|
||
|
}
|
||
|
|
||
|
func (s *set) Index(key Key) int {
|
||
|
s.mu.RLock()
|
||
|
defer s.mu.RUnlock()
|
||
|
|
||
|
return s.indexNL(key)
|
||
|
}
|
||
|
|
||
|
func (s *set) Add(key Key) bool {
|
||
|
s.mu.Lock()
|
||
|
defer s.mu.Unlock()
|
||
|
|
||
|
if i := s.indexNL(key); i > -1 {
|
||
|
return false
|
||
|
}
|
||
|
s.keys = append(s.keys, key)
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
func (s *set) Remove(key Key) bool {
|
||
|
s.mu.Lock()
|
||
|
defer s.mu.Unlock()
|
||
|
|
||
|
for i, k := range s.keys {
|
||
|
if k == key {
|
||
|
switch i {
|
||
|
case 0:
|
||
|
s.keys = s.keys[1:]
|
||
|
case len(s.keys) - 1:
|
||
|
s.keys = s.keys[:i]
|
||
|
default:
|
||
|
s.keys = append(s.keys[:i], s.keys[i+1:]...)
|
||
|
}
|
||
|
return true
|
||
|
}
|
||
|
}
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
func (s *set) Clear() {
|
||
|
s.mu.Lock()
|
||
|
defer s.mu.Unlock()
|
||
|
|
||
|
s.keys = nil
|
||
|
}
|
||
|
|
||
|
func (s *set) Iterate(ctx context.Context) KeyIterator {
|
||
|
ch := make(chan *KeyPair, s.Len())
|
||
|
go iterate(ctx, s.keys, ch)
|
||
|
return arrayiter.New(ch)
|
||
|
}
|
||
|
|
||
|
func iterate(ctx context.Context, keys []Key, ch chan *KeyPair) {
|
||
|
defer close(ch)
|
||
|
|
||
|
for i, key := range keys {
|
||
|
pair := &KeyPair{Index: i, Value: key}
|
||
|
select {
|
||
|
case <-ctx.Done():
|
||
|
return
|
||
|
case ch <- pair:
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type keySetMarshalProxy struct {
|
||
|
Keys []json.RawMessage `json:"keys"`
|
||
|
}
|
||
|
|
||
|
func (s *set) MarshalJSON() ([]byte, error) {
|
||
|
s.mu.RLock()
|
||
|
defer s.mu.RUnlock()
|
||
|
|
||
|
buf := pool.GetBytesBuffer()
|
||
|
defer pool.ReleaseBytesBuffer(buf)
|
||
|
enc := json.NewEncoder(buf)
|
||
|
|
||
|
buf.WriteString(`{"keys":[`)
|
||
|
for i, k := range s.keys {
|
||
|
if i > 0 {
|
||
|
buf.WriteByte(',')
|
||
|
}
|
||
|
if err := enc.Encode(k); err != nil {
|
||
|
return nil, errors.Wrapf(err, `failed to marshal key #%d`, i)
|
||
|
}
|
||
|
}
|
||
|
buf.WriteString("]}")
|
||
|
|
||
|
ret := make([]byte, buf.Len())
|
||
|
copy(ret, buf.Bytes())
|
||
|
return ret, nil
|
||
|
}
|
||
|
|
||
|
func (s *set) UnmarshalJSON(data []byte) error {
|
||
|
s.mu.Lock()
|
||
|
defer s.mu.Unlock()
|
||
|
|
||
|
var proxy keySetMarshalProxy
|
||
|
if err := json.Unmarshal(data, &proxy); err != nil {
|
||
|
return errors.Wrap(err, `failed to unmarshal into Key (proxy)`)
|
||
|
}
|
||
|
|
||
|
var options []ParseOption
|
||
|
if dc := s.dc; dc != nil {
|
||
|
if localReg := dc.Registry(); localReg != nil {
|
||
|
options = append(options, withLocalRegistry(localReg))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if len(proxy.Keys) == 0 {
|
||
|
k, err := ParseKey(data, options...)
|
||
|
if err != nil {
|
||
|
return errors.Wrap(err, `failed to unmarshal key from JSON headers`)
|
||
|
}
|
||
|
s.keys = append(s.keys, k)
|
||
|
} else {
|
||
|
for i, buf := range proxy.Keys {
|
||
|
k, err := ParseKey([]byte(buf), options...)
|
||
|
if err != nil {
|
||
|
return errors.Wrapf(err, `failed to unmarshal key #%d (total %d) from multi-key JWK set`, i+1, len(proxy.Keys))
|
||
|
}
|
||
|
s.keys = append(s.keys, k)
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (s *set) LookupKeyID(kid string) (Key, bool) {
|
||
|
s.mu.RLock()
|
||
|
defer s.mu.RUnlock()
|
||
|
|
||
|
for iter := s.Iterate(context.TODO()); iter.Next(context.TODO()); {
|
||
|
pair := iter.Pair()
|
||
|
key := pair.Value.(Key) //nolint:forcetypeassert
|
||
|
if key.KeyID() == kid {
|
||
|
return key, true
|
||
|
}
|
||
|
}
|
||
|
return nil, false
|
||
|
}
|
||
|
|
||
|
func (s *set) DecodeCtx() DecodeCtx {
|
||
|
s.mu.RLock()
|
||
|
defer s.mu.RUnlock()
|
||
|
return s.dc
|
||
|
}
|
||
|
|
||
|
func (s *set) SetDecodeCtx(dc DecodeCtx) {
|
||
|
s.mu.Lock()
|
||
|
defer s.mu.Unlock()
|
||
|
s.dc = dc
|
||
|
}
|
||
|
|
||
|
func (s *set) Clone() (Set, error) {
|
||
|
s2 := &set{}
|
||
|
|
||
|
s.mu.RLock()
|
||
|
defer s.mu.RUnlock()
|
||
|
|
||
|
s2.keys = make([]Key, len(s.keys))
|
||
|
|
||
|
for i := 0; i < len(s.keys); i++ {
|
||
|
s2.keys[i] = s.keys[i]
|
||
|
}
|
||
|
return s2, nil
|
||
|
}
|