130 lines
2.8 KiB
Go
130 lines
2.8 KiB
Go
package jws
|
|
|
|
import (
|
|
"crypto"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
|
|
"github.com/lestrrat-go/jwx/internal/keyconv"
|
|
"github.com/lestrrat-go/jwx/jwa"
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
var rsaSigners map[jwa.SignatureAlgorithm]*rsaSigner
|
|
var rsaVerifiers map[jwa.SignatureAlgorithm]*rsaVerifier
|
|
|
|
func init() {
|
|
algs := map[jwa.SignatureAlgorithm]struct {
|
|
Hash crypto.Hash
|
|
PSS bool
|
|
}{
|
|
jwa.RS256: {
|
|
Hash: crypto.SHA256,
|
|
},
|
|
jwa.RS384: {
|
|
Hash: crypto.SHA384,
|
|
},
|
|
jwa.RS512: {
|
|
Hash: crypto.SHA512,
|
|
},
|
|
jwa.PS256: {
|
|
Hash: crypto.SHA256,
|
|
PSS: true,
|
|
},
|
|
jwa.PS384: {
|
|
Hash: crypto.SHA384,
|
|
PSS: true,
|
|
},
|
|
jwa.PS512: {
|
|
Hash: crypto.SHA512,
|
|
PSS: true,
|
|
},
|
|
}
|
|
|
|
rsaSigners = make(map[jwa.SignatureAlgorithm]*rsaSigner)
|
|
rsaVerifiers = make(map[jwa.SignatureAlgorithm]*rsaVerifier)
|
|
for alg, item := range algs {
|
|
rsaSigners[alg] = &rsaSigner{
|
|
alg: alg,
|
|
hash: item.Hash,
|
|
pss: item.PSS,
|
|
}
|
|
rsaVerifiers[alg] = &rsaVerifier{
|
|
alg: alg,
|
|
hash: item.Hash,
|
|
pss: item.PSS,
|
|
}
|
|
}
|
|
}
|
|
|
|
type rsaSigner struct {
|
|
alg jwa.SignatureAlgorithm
|
|
hash crypto.Hash
|
|
pss bool
|
|
}
|
|
|
|
func newRSASigner(alg jwa.SignatureAlgorithm) Signer {
|
|
return rsaSigners[alg]
|
|
}
|
|
|
|
func (rs *rsaSigner) Algorithm() jwa.SignatureAlgorithm {
|
|
return rs.alg
|
|
}
|
|
|
|
func (rs *rsaSigner) Sign(payload []byte, key interface{}) ([]byte, error) {
|
|
if key == nil {
|
|
return nil, errors.New(`missing private key while signing payload`)
|
|
}
|
|
|
|
signer, ok := key.(crypto.Signer)
|
|
if !ok {
|
|
var privkey rsa.PrivateKey
|
|
if err := keyconv.RSAPrivateKey(&privkey, key); err != nil {
|
|
return nil, errors.Wrapf(err, `failed to retrieve rsa.PrivateKey out of %T`, key)
|
|
}
|
|
signer = &privkey
|
|
}
|
|
|
|
h := rs.hash.New()
|
|
if _, err := h.Write(payload); err != nil {
|
|
return nil, errors.Wrap(err, "failed to write payload to hash")
|
|
}
|
|
if rs.pss {
|
|
return signer.Sign(rand.Reader, h.Sum(nil), &rsa.PSSOptions{
|
|
Hash: rs.hash,
|
|
SaltLength: rsa.PSSSaltLengthEqualsHash,
|
|
})
|
|
}
|
|
return signer.Sign(rand.Reader, h.Sum(nil), rs.hash)
|
|
}
|
|
|
|
type rsaVerifier struct {
|
|
alg jwa.SignatureAlgorithm
|
|
hash crypto.Hash
|
|
pss bool
|
|
}
|
|
|
|
func newRSAVerifier(alg jwa.SignatureAlgorithm) Verifier {
|
|
return rsaVerifiers[alg]
|
|
}
|
|
|
|
func (rv *rsaVerifier) Verify(payload, signature []byte, key interface{}) error {
|
|
if key == nil {
|
|
return errors.New(`missing public key while verifying payload`)
|
|
}
|
|
|
|
var pubkey rsa.PublicKey
|
|
if err := keyconv.RSAPublicKey(&pubkey, key); err != nil {
|
|
return errors.Wrapf(err, `failed to retrieve rsa.PublicKey out of %T`, key)
|
|
}
|
|
|
|
h := rv.hash.New()
|
|
if _, err := h.Write(payload); err != nil {
|
|
return errors.Wrap(err, "failed to write payload to hash")
|
|
}
|
|
|
|
if rv.pss {
|
|
return rsa.VerifyPSS(&pubkey, rv.hash, h.Sum(nil), signature, nil)
|
|
}
|
|
return rsa.VerifyPKCS1v15(&pubkey, rv.hash, h.Sum(nil), signature)
|
|
}
|