Compare commits

...

4 commits

Author SHA1 Message Date
554cf0d3eb chore(deps): update module github.com/stretchr/testify to v1.9.0
Some checks failed
Dev Version / Release (push) Has been cancelled
2024-09-14 01:25:45 +00:00
1a7b182350 chore(deps): update module github.com/rs/zerolog to v1.33.0 (#3)
Some checks are pending
Dev Version / Release (push) Waiting to run
This PR contains the following updates:

| Package | Type | Update | Change |
|---|---|---|---|
| [github.com/rs/zerolog](https://github.com/rs/zerolog) | require | minor | `v1.31.0` -> `v1.33.0` |

>  **Important**
>
> Release Notes retrieval for this PR were skipped because no github.com credentials were available.
> If you are self-hosted, please see [this instruction](https://github.com/renovatebot/renovate/blob/master/docs/usage/examples/self-hosting.md#githubcom-token-for-release-notes).

---

### Configuration

📅 **Schedule**: Branch creation - At any time (no schedule defined), Automerge - At any time (no schedule defined).

🚦 **Automerge**: Disabled by config. Please merge this manually once you are satisfied.

♻ **Rebasing**: Whenever PR becomes conflicted, or you tick the rebase/retry checkbox.

🔕 **Ignore**: Close this PR and you won't be reminded about this update again.

---

 - [ ] <!-- rebase-check -->If you want to rebase/retry this PR, check this box

---

This PR has been generated by [Renovate Bot](https://github.com/renovatebot/renovate).
<!--renovate-debug:eyJjcmVhdGVkSW5WZXIiOiIzOC4yMS4yIiwidXBkYXRlZEluVmVyIjoiMzguMjEuMiIsInRhcmdldEJyYW5jaCI6Im1haW4iLCJsYWJlbHMiOltdfQ==-->

Reviewed-on: #3
Co-authored-by: Renovate Bot <renovate@keks.cloud>
Co-committed-by: Renovate Bot <renovate@keks.cloud>
2024-09-14 01:23:35 +00:00
b792a8f374 chore(deps): update module github.com/urfave/cli/v2 to v2.27.4 (#5)
Some checks are pending
Dev Version / Release (push) Waiting to run
This PR contains the following updates:

| Package | Type | Update | Change |
|---|---|---|---|
| [github.com/urfave/cli/v2](https://github.com/urfave/cli) | require | minor | `v2.25.7` -> `v2.27.4` |

>  **Important**
>
> Release Notes retrieval for this PR were skipped because no github.com credentials were available.
> If you are self-hosted, please see [this instruction](https://github.com/renovatebot/renovate/blob/master/docs/usage/examples/self-hosting.md#githubcom-token-for-release-notes).

---

### Configuration

📅 **Schedule**: Branch creation - At any time (no schedule defined), Automerge - At any time (no schedule defined).

🚦 **Automerge**: Disabled by config. Please merge this manually once you are satisfied.

♻ **Rebasing**: Whenever PR becomes conflicted, or you tick the rebase/retry checkbox.

🔕 **Ignore**: Close this PR and you won't be reminded about this update again.

---

 - [ ] <!-- rebase-check -->If you want to rebase/retry this PR, check this box

---

This PR has been generated by [Renovate Bot](https://github.com/renovatebot/renovate).
<!--renovate-debug:eyJjcmVhdGVkSW5WZXIiOiIzOC4yMS4yIiwidXBkYXRlZEluVmVyIjoiMzguMjEuMiIsInRhcmdldEJyYW5jaCI6Im1haW4iLCJsYWJlbHMiOltdfQ==-->

Reviewed-on: #5
Co-authored-by: Renovate Bot <renovate@keks.cloud>
Co-committed-by: Renovate Bot <renovate@keks.cloud>
2024-09-14 01:23:23 +00:00
df8ce1f61a chore(deps): update module github.com/aws/aws-sdk-go to v1.55.5 (#2)
Some checks are pending
Dev Version / Release (push) Waiting to run
This PR contains the following updates:

| Package | Type | Update | Change |
|---|---|---|---|
| [github.com/aws/aws-sdk-go](https://github.com/aws/aws-sdk-go) | require | minor | `v1.45.25` -> `v1.55.5` |

>  **Important**
>
> Release Notes retrieval for this PR were skipped because no github.com credentials were available.
> If you are self-hosted, please see [this instruction](https://github.com/renovatebot/renovate/blob/master/docs/usage/examples/self-hosting.md#githubcom-token-for-release-notes).

---

### Configuration

📅 **Schedule**: Branch creation - At any time (no schedule defined), Automerge - At any time (no schedule defined).

🚦 **Automerge**: Disabled by config. Please merge this manually once you are satisfied.

♻ **Rebasing**: Whenever PR becomes conflicted, or you tick the rebase/retry checkbox.

🔕 **Ignore**: Close this PR and you won't be reminded about this update again.

---

 - [ ] <!-- rebase-check -->If you want to rebase/retry this PR, check this box

---

This PR has been generated by [Renovate Bot](https://github.com/renovatebot/renovate).
<!--renovate-debug:eyJjcmVhdGVkSW5WZXIiOiIzOC4yMS4yIiwidXBkYXRlZEluVmVyIjoiMzguMjEuMiIsInRhcmdldEJyYW5jaCI6Im1haW4iLCJsYWJlbHMiOltdfQ==-->

Reviewed-on: #2
Co-authored-by: Renovate Bot <renovate@keks.cloud>
Co-committed-by: Renovate Bot <renovate@keks.cloud>
2024-09-14 01:23:02 +00:00
77 changed files with 14148 additions and 3099 deletions

12
go.mod
View file

@ -3,17 +3,17 @@ module idun
go 1.21 go 1.21
require ( require (
github.com/aws/aws-sdk-go v1.45.25 github.com/aws/aws-sdk-go v1.55.5
github.com/pkg/sftp v1.13.6 github.com/pkg/sftp v1.13.6
github.com/rs/zerolog v1.31.0 github.com/rs/zerolog v1.33.0
github.com/stretchr/testify v1.8.4 github.com/stretchr/testify v1.9.0
github.com/urfave/cli/v2 v2.25.7 github.com/urfave/cli/v2 v2.27.4
golang.org/x/crypto v0.14.0 golang.org/x/crypto v0.14.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
) )
require ( require (
github.com/cpuguy83/go-md2man/v2 v2.0.2 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.4 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/kr/fs v0.1.0 // indirect github.com/kr/fs v0.1.0 // indirect
@ -21,6 +21,6 @@ require (
github.com/mattn/go-isatty v0.0.19 // indirect github.com/mattn/go-isatty v0.0.19 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect
golang.org/x/sys v0.13.0 // indirect golang.org/x/sys v0.13.0 // indirect
) )

12
go.sum
View file

@ -1,8 +1,12 @@
github.com/aws/aws-sdk-go v1.45.25 h1:c4fLlh5sLdK2DCRTY1z0hyuJZU4ygxX8m1FswL6/nF4= github.com/aws/aws-sdk-go v1.45.25 h1:c4fLlh5sLdK2DCRTY1z0hyuJZU4ygxX8m1FswL6/nF4=
github.com/aws/aws-sdk-go v1.45.25/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI= github.com/aws/aws-sdk-go v1.45.25/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI=
github.com/aws/aws-sdk-go v1.55.5 h1:KKUZBfBoyqy5d3swXyiC7Q76ic40rYcbqH7qjh59kzU=
github.com/aws/aws-sdk-go v1.55.5/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w= github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w=
github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/cpuguy83/go-md2man/v2 v2.0.4 h1:wfIWP927BUkWJb2NmU/kNDYIBTh/ziUX91+lVfRxZq4=
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@ -26,6 +30,8 @@ github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZN
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.31.0 h1:FcTR3NnLWW+NnTwwhFWiJSZr4ECLpqCm6QsEnyvbV4A= github.com/rs/zerolog v1.31.0 h1:FcTR3NnLWW+NnTwwhFWiJSZr4ECLpqCm6QsEnyvbV4A=
github.com/rs/zerolog v1.31.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/rs/zerolog v1.31.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@ -34,10 +40,16 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/urfave/cli/v2 v2.25.7 h1:VAzn5oq403l5pHjc4OhD54+XGO9cdKVL/7lDjF+iKUs= github.com/urfave/cli/v2 v2.25.7 h1:VAzn5oq403l5pHjc4OhD54+XGO9cdKVL/7lDjF+iKUs=
github.com/urfave/cli/v2 v2.25.7/go.mod h1:8qnjx1vcq5s2/wpsqoZFndg2CE5tNFyrTvS6SinrnYQ= github.com/urfave/cli/v2 v2.25.7/go.mod h1:8qnjx1vcq5s2/wpsqoZFndg2CE5tNFyrTvS6SinrnYQ=
github.com/urfave/cli/v2 v2.27.4 h1:o1owoI+02Eb+K107p27wEX9Bb8eqIoZCfLXloLUSWJ8=
github.com/urfave/cli/v2 v2.27.4/go.mod h1:m4QzxcD2qpra4z7WhzEGn74WZLViBnMpb1ToCAKdGRQ=
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU= github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU=
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8= github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8=
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4=
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=

View file

@ -442,6 +442,17 @@ func (c *Config) WithUseDualStack(enable bool) *Config {
return c return c
} }
// WithUseFIPSEndpoint sets a config UseFIPSEndpoint value returning a Config
// pointer for chaining.
func (c *Config) WithUseFIPSEndpoint(enable bool) *Config {
if enable {
c.UseFIPSEndpoint = endpoints.FIPSEndpointStateEnabled
} else {
c.UseFIPSEndpoint = endpoints.FIPSEndpointStateDisabled
}
return c
}
// WithEC2MetadataDisableTimeoutOverride sets a config EC2MetadataDisableTimeoutOverride value // WithEC2MetadataDisableTimeoutOverride sets a config EC2MetadataDisableTimeoutOverride value
// returning a Config pointer for chaining. // returning a Config pointer for chaining.
func (c *Config) WithEC2MetadataDisableTimeoutOverride(enable bool) *Config { func (c *Config) WithEC2MetadataDisableTimeoutOverride(enable bool) *Config {

View file

@ -31,6 +31,8 @@ package endpointcreds
import ( import (
"encoding/json" "encoding/json"
"fmt"
"strings"
"time" "time"
"github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws"
@ -69,7 +71,37 @@ type Provider struct {
// Optional authorization token value if set will be used as the value of // Optional authorization token value if set will be used as the value of
// the Authorization header of the endpoint credential request. // the Authorization header of the endpoint credential request.
//
// When constructed from environment, the provider will use the value of
// AWS_CONTAINER_AUTHORIZATION_TOKEN environment variable as the token
//
// Will be overridden if AuthorizationTokenProvider is configured
AuthorizationToken string AuthorizationToken string
// Optional auth provider func to dynamically load the auth token from a file
// everytime a credential is retrieved
//
// When constructed from environment, the provider will read and use the content
// of the file pointed to by AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE environment variable
// as the auth token everytime credentials are retrieved
//
// Will override AuthorizationToken if configured
AuthorizationTokenProvider AuthTokenProvider
}
// AuthTokenProvider defines an interface to dynamically load a value to be passed
// for the Authorization header of a credentials request.
type AuthTokenProvider interface {
GetToken() (string, error)
}
// TokenProviderFunc is a func type implementing AuthTokenProvider interface
// and enables customizing token provider behavior
type TokenProviderFunc func() (string, error)
// GetToken func retrieves auth token according to TokenProviderFunc implementation
func (p TokenProviderFunc) GetToken() (string, error) {
return p()
} }
// NewProviderClient returns a credentials Provider for retrieving AWS credentials // NewProviderClient returns a credentials Provider for retrieving AWS credentials
@ -164,7 +196,20 @@ func (p *Provider) getCredentials(ctx aws.Context) (*getCredentialsOutput, error
req := p.Client.NewRequest(op, nil, out) req := p.Client.NewRequest(op, nil, out)
req.SetContext(ctx) req.SetContext(ctx)
req.HTTPRequest.Header.Set("Accept", "application/json") req.HTTPRequest.Header.Set("Accept", "application/json")
if authToken := p.AuthorizationToken; len(authToken) != 0 {
authToken := p.AuthorizationToken
var err error
if p.AuthorizationTokenProvider != nil {
authToken, err = p.AuthorizationTokenProvider.GetToken()
if err != nil {
return nil, fmt.Errorf("get authorization token: %v", err)
}
}
if strings.ContainsAny(authToken, "\r\n") {
return nil, fmt.Errorf("authorization token contains invalid newline sequence")
}
if len(authToken) != 0 {
req.HTTPRequest.Header.Set("Authorization", authToken) req.HTTPRequest.Header.Set("Authorization", authToken)
} }

View file

@ -9,6 +9,7 @@ package defaults
import ( import (
"fmt" "fmt"
"io/ioutil"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
@ -115,9 +116,31 @@ func CredProviders(cfg *aws.Config, handlers request.Handlers) []credentials.Pro
const ( const (
httpProviderAuthorizationEnvVar = "AWS_CONTAINER_AUTHORIZATION_TOKEN" httpProviderAuthorizationEnvVar = "AWS_CONTAINER_AUTHORIZATION_TOKEN"
httpProviderAuthFileEnvVar = "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE"
httpProviderEnvVar = "AWS_CONTAINER_CREDENTIALS_FULL_URI" httpProviderEnvVar = "AWS_CONTAINER_CREDENTIALS_FULL_URI"
) )
// direct representation of the IPv4 address for the ECS container
// "169.254.170.2"
var ecsContainerIPv4 net.IP = []byte{
169, 254, 170, 2,
}
// direct representation of the IPv4 address for the EKS container
// "169.254.170.23"
var eksContainerIPv4 net.IP = []byte{
169, 254, 170, 23,
}
// direct representation of the IPv6 address for the EKS container
// "fd00:ec2::23"
var eksContainerIPv6 net.IP = []byte{
0xFD, 0, 0xE, 0xC2,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0x23,
}
// RemoteCredProvider returns a credentials provider for the default remote // RemoteCredProvider returns a credentials provider for the default remote
// endpoints such as EC2 or ECS Roles. // endpoints such as EC2 or ECS Roles.
func RemoteCredProvider(cfg aws.Config, handlers request.Handlers) credentials.Provider { func RemoteCredProvider(cfg aws.Config, handlers request.Handlers) credentials.Provider {
@ -135,19 +158,22 @@ func RemoteCredProvider(cfg aws.Config, handlers request.Handlers) credentials.P
var lookupHostFn = net.LookupHost var lookupHostFn = net.LookupHost
func isLoopbackHost(host string) (bool, error) { // isAllowedHost allows host to be loopback or known ECS/EKS container IPs
ip := net.ParseIP(host) //
if ip != nil { // host can either be an IP address OR an unresolved hostname - resolution will
return ip.IsLoopback(), nil // be automatically performed in the latter case
func isAllowedHost(host string) (bool, error) {
if ip := net.ParseIP(host); ip != nil {
return isIPAllowed(ip), nil
} }
// Host is not an ip, perform lookup
addrs, err := lookupHostFn(host) addrs, err := lookupHostFn(host)
if err != nil { if err != nil {
return false, err return false, err
} }
for _, addr := range addrs { for _, addr := range addrs {
if !net.ParseIP(addr).IsLoopback() { if ip := net.ParseIP(addr); ip == nil || !isIPAllowed(ip) {
return false, nil return false, nil
} }
} }
@ -155,6 +181,13 @@ func isLoopbackHost(host string) (bool, error) {
return true, nil return true, nil
} }
func isIPAllowed(ip net.IP) bool {
return ip.IsLoopback() ||
ip.Equal(ecsContainerIPv4) ||
ip.Equal(eksContainerIPv4) ||
ip.Equal(eksContainerIPv6)
}
func localHTTPCredProvider(cfg aws.Config, handlers request.Handlers, u string) credentials.Provider { func localHTTPCredProvider(cfg aws.Config, handlers request.Handlers, u string) credentials.Provider {
var errMsg string var errMsg string
@ -165,10 +198,12 @@ func localHTTPCredProvider(cfg aws.Config, handlers request.Handlers, u string)
host := aws.URLHostname(parsed) host := aws.URLHostname(parsed)
if len(host) == 0 { if len(host) == 0 {
errMsg = "unable to parse host from local HTTP cred provider URL" errMsg = "unable to parse host from local HTTP cred provider URL"
} else if isLoopback, loopbackErr := isLoopbackHost(host); loopbackErr != nil { } else if parsed.Scheme == "http" {
errMsg = fmt.Sprintf("failed to resolve host %q, %v", host, loopbackErr) if isAllowedHost, allowHostErr := isAllowedHost(host); allowHostErr != nil {
} else if !isLoopback { errMsg = fmt.Sprintf("failed to resolve host %q, %v", host, allowHostErr)
errMsg = fmt.Sprintf("invalid endpoint host, %q, only loopback hosts are allowed.", host) } else if !isAllowedHost {
errMsg = fmt.Sprintf("invalid endpoint host, %q, only loopback/ecs/eks hosts are allowed.", host)
}
} }
} }
@ -190,6 +225,15 @@ func httpCredProvider(cfg aws.Config, handlers request.Handlers, u string) crede
func(p *endpointcreds.Provider) { func(p *endpointcreds.Provider) {
p.ExpiryWindow = 5 * time.Minute p.ExpiryWindow = 5 * time.Minute
p.AuthorizationToken = os.Getenv(httpProviderAuthorizationEnvVar) p.AuthorizationToken = os.Getenv(httpProviderAuthorizationEnvVar)
if authFilePath := os.Getenv(httpProviderAuthFileEnvVar); authFilePath != "" {
p.AuthorizationTokenProvider = endpointcreds.TokenProviderFunc(func() (string, error) {
if contents, err := ioutil.ReadFile(authFilePath); err != nil {
return "", fmt.Errorf("failed to read authorization token from %v: %v", authFilePath, err)
} else {
return string(contents), nil
}
})
}
}, },
) )
} }

View file

@ -2,6 +2,7 @@ package ec2metadata
import ( import (
"fmt" "fmt"
"github.com/aws/aws-sdk-go/aws"
"net/http" "net/http"
"sync/atomic" "sync/atomic"
"time" "time"
@ -65,7 +66,9 @@ func (t *tokenProvider) fetchTokenHandler(r *request.Request) {
switch requestFailureError.StatusCode() { switch requestFailureError.StatusCode() {
case http.StatusForbidden, http.StatusNotFound, http.StatusMethodNotAllowed: case http.StatusForbidden, http.StatusNotFound, http.StatusMethodNotAllowed:
atomic.StoreUint32(&t.disabled, 1) atomic.StoreUint32(&t.disabled, 1)
if t.client.Config.LogLevel.Matches(aws.LogDebugWithDeprecated) {
t.client.Config.Logger.Log(fmt.Sprintf("WARN: failed to get session token, falling back to IMDSv1: %v", requestFailureError)) t.client.Config.Logger.Log(fmt.Sprintf("WARN: failed to get session token, falling back to IMDSv1: %v", requestFailureError))
}
case http.StatusBadRequest: case http.StatusBadRequest:
r.Error = requestFailureError r.Error = requestFailureError
} }

File diff suppressed because it is too large Load diff

View file

@ -256,8 +256,17 @@ func (a *WaiterAcceptor) match(name string, l aws.Logger, req *Request, err erro
s := a.Expected.(int) s := a.Expected.(int)
result = s == req.HTTPResponse.StatusCode result = s == req.HTTPResponse.StatusCode
case ErrorWaiterMatch: case ErrorWaiterMatch:
switch ex := a.Expected.(type) {
case string:
if aerr, ok := err.(awserr.Error); ok { if aerr, ok := err.(awserr.Error); ok {
result = aerr.Code() == a.Expected.(string) result = aerr.Code() == ex
}
case bool:
if ex {
result = err != nil
} else {
result = err == nil
}
} }
default: default:
waiterLogf(l, "WARNING: Waiter %s encountered unexpected matcher: %s", waiterLogf(l, "WARNING: Waiter %s encountered unexpected matcher: %s",

View file

@ -171,6 +171,12 @@ type envConfig struct {
// AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE=IPv6 // AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE=IPv6
EC2IMDSEndpointMode endpoints.EC2IMDSEndpointModeState EC2IMDSEndpointMode endpoints.EC2IMDSEndpointModeState
// Specifies that IMDS clients should not fallback to IMDSv1 if token
// requests fail.
//
// AWS_EC2_METADATA_V1_DISABLED=true
EC2IMDSv1Disabled *bool
// Specifies that SDK clients must resolve a dual-stack endpoint for // Specifies that SDK clients must resolve a dual-stack endpoint for
// services. // services.
// //
@ -251,6 +257,9 @@ var (
ec2IMDSEndpointModeEnvKey = []string{ ec2IMDSEndpointModeEnvKey = []string{
"AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE", "AWS_EC2_METADATA_SERVICE_ENDPOINT_MODE",
} }
ec2MetadataV1DisabledEnvKey = []string{
"AWS_EC2_METADATA_V1_DISABLED",
}
useCABundleKey = []string{ useCABundleKey = []string{
"AWS_CA_BUNDLE", "AWS_CA_BUNDLE",
} }
@ -393,6 +402,7 @@ func envConfigLoad(enableSharedConfig bool) (envConfig, error) {
if err := setEC2IMDSEndpointMode(&cfg.EC2IMDSEndpointMode, ec2IMDSEndpointModeEnvKey); err != nil { if err := setEC2IMDSEndpointMode(&cfg.EC2IMDSEndpointMode, ec2IMDSEndpointModeEnvKey); err != nil {
return envConfig{}, err return envConfig{}, err
} }
setBoolPtrFromEnvVal(&cfg.EC2IMDSv1Disabled, ec2MetadataV1DisabledEnvKey)
if err := setUseDualStackEndpointFromEnvVal(&cfg.UseDualStackEndpoint, awsUseDualStackEndpoint); err != nil { if err := setUseDualStackEndpointFromEnvVal(&cfg.UseDualStackEndpoint, awsUseDualStackEndpoint); err != nil {
return cfg, err return cfg, err
@ -414,6 +424,24 @@ func setFromEnvVal(dst *string, keys []string) {
} }
} }
func setBoolPtrFromEnvVal(dst **bool, keys []string) {
for _, k := range keys {
value := os.Getenv(k)
if len(value) == 0 {
continue
}
switch {
case strings.EqualFold(value, "false"):
*dst = new(bool)
**dst = false
case strings.EqualFold(value, "true"):
*dst = new(bool)
**dst = true
}
}
}
func setEC2IMDSEndpointMode(mode *endpoints.EC2IMDSEndpointModeState, keys []string) error { func setEC2IMDSEndpointMode(mode *endpoints.EC2IMDSEndpointModeState, keys []string) error {
for _, k := range keys { for _, k := range keys {
value := os.Getenv(k) value := os.Getenv(k)

View file

@ -779,6 +779,14 @@ func mergeConfigSrcs(cfg, userCfg *aws.Config,
cfg.EndpointResolver = wrapEC2IMDSEndpoint(cfg.EndpointResolver, ec2IMDSEndpoint, endpointMode) cfg.EndpointResolver = wrapEC2IMDSEndpoint(cfg.EndpointResolver, ec2IMDSEndpoint, endpointMode)
} }
cfg.EC2MetadataEnableFallback = userCfg.EC2MetadataEnableFallback
if cfg.EC2MetadataEnableFallback == nil && envCfg.EC2IMDSv1Disabled != nil {
cfg.EC2MetadataEnableFallback = aws.Bool(!*envCfg.EC2IMDSv1Disabled)
}
if cfg.EC2MetadataEnableFallback == nil && sharedCfg.EC2IMDSv1Disabled != nil {
cfg.EC2MetadataEnableFallback = aws.Bool(!*sharedCfg.EC2IMDSv1Disabled)
}
cfg.S3UseARNRegion = userCfg.S3UseARNRegion cfg.S3UseARNRegion = userCfg.S3UseARNRegion
if cfg.S3UseARNRegion == nil { if cfg.S3UseARNRegion == nil {
cfg.S3UseARNRegion = &envCfg.S3UseARNRegion cfg.S3UseARNRegion = &envCfg.S3UseARNRegion

View file

@ -80,6 +80,9 @@ const (
// EC2 IMDS Endpoint // EC2 IMDS Endpoint
ec2MetadataServiceEndpointKey = "ec2_metadata_service_endpoint" ec2MetadataServiceEndpointKey = "ec2_metadata_service_endpoint"
// ECS IMDSv1 disable fallback
ec2MetadataV1DisabledKey = "ec2_metadata_v1_disabled"
// Use DualStack Endpoint Resolution // Use DualStack Endpoint Resolution
useDualStackEndpoint = "use_dualstack_endpoint" useDualStackEndpoint = "use_dualstack_endpoint"
@ -179,6 +182,12 @@ type sharedConfig struct {
// ec2_metadata_service_endpoint=http://fd00:ec2::254 // ec2_metadata_service_endpoint=http://fd00:ec2::254
EC2IMDSEndpoint string EC2IMDSEndpoint string
// Specifies that IMDS clients should not fallback to IMDSv1 if token
// requests fail.
//
// ec2_metadata_v1_disabled=true
EC2IMDSv1Disabled *bool
// Specifies that SDK clients must resolve a dual-stack endpoint for // Specifies that SDK clients must resolve a dual-stack endpoint for
// services. // services.
// //
@ -434,6 +443,7 @@ func (cfg *sharedConfig) setFromIniFile(profile string, file sharedConfigFile, e
ec2MetadataServiceEndpointModeKey, file.Filename, err) ec2MetadataServiceEndpointModeKey, file.Filename, err)
} }
updateString(&cfg.EC2IMDSEndpoint, section, ec2MetadataServiceEndpointKey) updateString(&cfg.EC2IMDSEndpoint, section, ec2MetadataServiceEndpointKey)
updateBoolPtr(&cfg.EC2IMDSv1Disabled, section, ec2MetadataV1DisabledKey)
updateUseDualStackEndpoint(&cfg.UseDualStackEndpoint, section, useDualStackEndpoint) updateUseDualStackEndpoint(&cfg.UseDualStackEndpoint, section, useDualStackEndpoint)

View file

@ -125,6 +125,7 @@ var requiredSignedHeaders = rules{
"X-Amz-Copy-Source-Server-Side-Encryption-Customer-Algorithm": struct{}{}, "X-Amz-Copy-Source-Server-Side-Encryption-Customer-Algorithm": struct{}{},
"X-Amz-Copy-Source-Server-Side-Encryption-Customer-Key": struct{}{}, "X-Amz-Copy-Source-Server-Side-Encryption-Customer-Key": struct{}{},
"X-Amz-Copy-Source-Server-Side-Encryption-Customer-Key-Md5": struct{}{}, "X-Amz-Copy-Source-Server-Side-Encryption-Customer-Key-Md5": struct{}{},
"X-Amz-Expected-Bucket-Owner": struct{}{},
"X-Amz-Grant-Full-control": struct{}{}, "X-Amz-Grant-Full-control": struct{}{},
"X-Amz-Grant-Read": struct{}{}, "X-Amz-Grant-Read": struct{}{},
"X-Amz-Grant-Read-Acp": struct{}{}, "X-Amz-Grant-Read-Acp": struct{}{},

View file

@ -5,4 +5,4 @@ package aws
const SDKName = "aws-sdk-go" const SDKName = "aws-sdk-go"
// SDKVersion is the version of this SDK // SDKVersion is the version of this SDK
const SDKVersion = "1.45.25" const SDKVersion = "1.55.5"

View file

@ -122,8 +122,8 @@ func (q *queryParser) parseStruct(v url.Values, value reflect.Value, prefix stri
} }
func (q *queryParser) parseList(v url.Values, value reflect.Value, prefix string, tag reflect.StructTag) error { func (q *queryParser) parseList(v url.Values, value reflect.Value, prefix string, tag reflect.StructTag) error {
// If it's empty, generate an empty value // If it's empty, and not ec2, generate an empty value
if !value.IsNil() && value.Len() == 0 { if !value.IsNil() && value.Len() == 0 && !q.isEC2 {
v.Set(prefix, "") v.Set(prefix, "")
return nil return nil
} }

File diff suppressed because it is too large Load diff

View file

@ -25,6 +25,15 @@ const (
// "InvalidObjectState". // "InvalidObjectState".
// //
// Object is archived and inaccessible until restored. // Object is archived and inaccessible until restored.
//
// If the object you are retrieving is stored in the S3 Glacier Flexible Retrieval
// storage class, the S3 Glacier Deep Archive storage class, the S3 Intelligent-Tiering
// Archive Access tier, or the S3 Intelligent-Tiering Deep Archive Access tier,
// before you can retrieve the object you must first restore a copy using RestoreObject
// (https://docs.aws.amazon.com/AmazonS3/latest/API/API_RestoreObject.html).
// Otherwise, this operation returns an InvalidObjectState error. For information
// about restoring archived objects, see Restoring Archived Objects (https://docs.aws.amazon.com/AmazonS3/latest/dev/restoring-objects.html)
// in the Amazon S3 User Guide.
ErrCodeInvalidObjectState = "InvalidObjectState" ErrCodeInvalidObjectState = "InvalidObjectState"
// ErrCodeNoSuchBucket for service response error code // ErrCodeNoSuchBucket for service response error code

File diff suppressed because it is too large Load diff

View file

@ -3,15 +3,13 @@
// Package ssooidc provides the client and types for making API // Package ssooidc provides the client and types for making API
// requests to AWS SSO OIDC. // requests to AWS SSO OIDC.
// //
// AWS IAM Identity Center (successor to AWS Single Sign-On) OpenID Connect // IAM Identity Center OpenID Connect (OIDC) is a web service that enables a
// (OIDC) is a web service that enables a client (such as AWS CLI or a native // client (such as CLI or a native application) to register with IAM Identity
// application) to register with IAM Identity Center. The service also enables // Center. The service also enables the client to fetch the users access
// the client to fetch the users access token upon successful authentication // token upon successful authentication and authorization with IAM Identity
// and authorization with IAM Identity Center. // Center.
// //
// Although AWS Single Sign-On was renamed, the sso and identitystore API namespaces // IAM Identity Center uses the sso and identitystore API namespaces.
// will continue to retain their original name for backward compatibility purposes.
// For more information, see IAM Identity Center rename (https://docs.aws.amazon.com/singlesignon/latest/userguide/what-is.html#renamed).
// //
// # Considerations for Using This Guide // # Considerations for Using This Guide
// //
@ -22,21 +20,24 @@
// - The IAM Identity Center OIDC service currently implements only the portions // - The IAM Identity Center OIDC service currently implements only the portions
// of the OAuth 2.0 Device Authorization Grant standard (https://tools.ietf.org/html/rfc8628 // of the OAuth 2.0 Device Authorization Grant standard (https://tools.ietf.org/html/rfc8628
// (https://tools.ietf.org/html/rfc8628)) that are necessary to enable single // (https://tools.ietf.org/html/rfc8628)) that are necessary to enable single
// sign-on authentication with the AWS CLI. Support for other OIDC flows // sign-on authentication with the CLI.
// frequently needed for native applications, such as Authorization Code
// Flow (+ PKCE), will be addressed in future releases.
// //
// - The service emits only OIDC access tokens, such that obtaining a new // - With older versions of the CLI, the service only emits OIDC access tokens,
// token (For example, token refresh) requires explicit user re-authentication. // so to obtain a new token, users must explicitly re-authenticate. To access
// the OIDC flow that supports token refresh and doesnt require re-authentication,
// update to the latest CLI version (1.27.10 for CLI V1 and 2.9.0 for CLI
// V2) with support for OIDC token refresh and configurable IAM Identity
// Center session durations. For more information, see Configure Amazon Web
// Services access portal session duration (https://docs.aws.amazon.com/singlesignon/latest/userguide/configure-user-session.html).
// //
// - The access tokens provided by this service grant access to all AWS account // - The access tokens provided by this service grant access to all Amazon
// entitlements assigned to an IAM Identity Center user, not just a particular // Web Services account entitlements assigned to an IAM Identity Center user,
// application. // not just a particular application.
// //
// - The documentation in this guide does not describe the mechanism to convert // - The documentation in this guide does not describe the mechanism to convert
// the access token into AWS Auth (“sigv4”) credentials for use with // the access token into Amazon Web Services Auth (“sigv4”) credentials
// IAM-protected AWS service endpoints. For more information, see GetRoleCredentials // for use with IAM-protected Amazon Web Services service endpoints. For
// (https://docs.aws.amazon.com/singlesignon/latest/PortalAPIReference/API_GetRoleCredentials.html) // more information, see GetRoleCredentials (https://docs.aws.amazon.com/singlesignon/latest/PortalAPIReference/API_GetRoleCredentials.html)
// in the IAM Identity Center Portal API Reference Guide. // in the IAM Identity Center Portal API Reference Guide.
// //
// For general information about IAM Identity Center, see What is IAM Identity // For general information about IAM Identity Center, see What is IAM Identity

View file

@ -57,6 +57,13 @@ const (
// makes a CreateToken request with an invalid grant type. // makes a CreateToken request with an invalid grant type.
ErrCodeInvalidGrantException = "InvalidGrantException" ErrCodeInvalidGrantException = "InvalidGrantException"
// ErrCodeInvalidRedirectUriException for service response error code
// "InvalidRedirectUriException".
//
// Indicates that one or more redirect URI in the request is not supported for
// this operation.
ErrCodeInvalidRedirectUriException = "InvalidRedirectUriException"
// ErrCodeInvalidRequestException for service response error code // ErrCodeInvalidRequestException for service response error code
// "InvalidRequestException". // "InvalidRequestException".
// //
@ -64,6 +71,13 @@ const (
// a required parameter might be missing or out of range. // a required parameter might be missing or out of range.
ErrCodeInvalidRequestException = "InvalidRequestException" ErrCodeInvalidRequestException = "InvalidRequestException"
// ErrCodeInvalidRequestRegionException for service response error code
// "InvalidRequestRegionException".
//
// Indicates that a token provided as input to the request was issued by and
// is only usable by calling IAM Identity Center endpoints in another region.
ErrCodeInvalidRequestRegionException = "InvalidRequestRegionException"
// ErrCodeInvalidScopeException for service response error code // ErrCodeInvalidScopeException for service response error code
// "InvalidScopeException". // "InvalidScopeException".
// //
@ -99,7 +113,9 @@ var exceptionFromCode = map[string]func(protocol.ResponseMetadata) error{
"InvalidClientException": newErrorInvalidClientException, "InvalidClientException": newErrorInvalidClientException,
"InvalidClientMetadataException": newErrorInvalidClientMetadataException, "InvalidClientMetadataException": newErrorInvalidClientMetadataException,
"InvalidGrantException": newErrorInvalidGrantException, "InvalidGrantException": newErrorInvalidGrantException,
"InvalidRedirectUriException": newErrorInvalidRedirectUriException,
"InvalidRequestException": newErrorInvalidRequestException, "InvalidRequestException": newErrorInvalidRequestException,
"InvalidRequestRegionException": newErrorInvalidRequestRegionException,
"InvalidScopeException": newErrorInvalidScopeException, "InvalidScopeException": newErrorInvalidScopeException,
"SlowDownException": newErrorSlowDownException, "SlowDownException": newErrorSlowDownException,
"UnauthorizedClientException": newErrorUnauthorizedClientException, "UnauthorizedClientException": newErrorUnauthorizedClientException,

View file

@ -51,7 +51,7 @@ const (
func New(p client.ConfigProvider, cfgs ...*aws.Config) *SSOOIDC { func New(p client.ConfigProvider, cfgs ...*aws.Config) *SSOOIDC {
c := p.ClientConfig(EndpointsID, cfgs...) c := p.ClientConfig(EndpointsID, cfgs...)
if c.SigningNameDerived || len(c.SigningName) == 0 { if c.SigningNameDerived || len(c.SigningName) == 0 {
c.SigningName = "awsssooidc" c.SigningName = "sso-oauth"
} }
return newClient(*c.Config, c.Handlers, c.PartitionID, c.Endpoint, c.SigningRegion, c.SigningName, c.ResolvedRegion) return newClient(*c.Config, c.Handlers, c.PartitionID, c.Endpoint, c.SigningRegion, c.SigningName, c.ResolvedRegion)
} }

View file

@ -1460,7 +1460,15 @@ type AssumeRoleInput struct {
// in the IAM User Guide. // in the IAM User Guide.
PolicyArns []*PolicyDescriptorType `type:"list"` PolicyArns []*PolicyDescriptorType `type:"list"`
// Reserved for future use. // A list of previously acquired trusted context assertions in the format of
// a JSON array. The trusted context assertion is signed and encrypted by Amazon
// Web Services STS.
//
// The following is an example of a ProvidedContext value that includes a single
// trusted context assertion and the ARN of the context provider from which
// the trusted context assertion was generated.
//
// [{"ProviderArn":"arn:aws:iam::aws:contextProvider/IdentityCenter","ContextAssertion":"trusted-context-assertion"}]
ProvidedContexts []*ProvidedContext `type:"list"` ProvidedContexts []*ProvidedContext `type:"list"`
// The Amazon Resource Name (ARN) of the role to assume. // The Amazon Resource Name (ARN) of the role to assume.
@ -3405,14 +3413,18 @@ func (s *PolicyDescriptorType) SetArn(v string) *PolicyDescriptorType {
return s return s
} }
// Reserved for future use. // Contains information about the provided context. This includes the signed
// and encrypted trusted context assertion and the context provider ARN from
// which the trusted context assertion was generated.
type ProvidedContext struct { type ProvidedContext struct {
_ struct{} `type:"structure"` _ struct{} `type:"structure"`
// Reserved for future use. // The signed and encrypted trusted context assertion generated by the context
// provider. The trusted context assertion is signed and encrypted by Amazon
// Web Services STS.
ContextAssertion *string `min:"4" type:"string"` ContextAssertion *string `min:"4" type:"string"`
// Reserved for future use. // The context provider ARN from which the trusted context assertion was generated.
ProviderArn *string `min:"20" type:"string"` ProviderArn *string `min:"20" type:"string"`
} }

View file

@ -9,6 +9,8 @@ func Render(doc []byte) []byte {
renderer := NewRoffRenderer() renderer := NewRoffRenderer()
return blackfriday.Run(doc, return blackfriday.Run(doc,
[]blackfriday.Option{blackfriday.WithRenderer(renderer), []blackfriday.Option{
blackfriday.WithExtensions(renderer.GetExtensions())}...) blackfriday.WithRenderer(renderer),
blackfriday.WithExtensions(renderer.GetExtensions()),
}...)
} }

View file

@ -1,6 +1,8 @@
package md2man package md2man
import ( import (
"bufio"
"bytes"
"fmt" "fmt"
"io" "io"
"os" "os"
@ -34,10 +36,10 @@ const (
hruleTag = "\n.ti 0\n\\l'\\n(.lu'\n" hruleTag = "\n.ti 0\n\\l'\\n(.lu'\n"
linkTag = "\n\\[la]" linkTag = "\n\\[la]"
linkCloseTag = "\\[ra]" linkCloseTag = "\\[ra]"
codespanTag = "\\fB\\fC" codespanTag = "\\fB"
codespanCloseTag = "\\fR" codespanCloseTag = "\\fR"
codeTag = "\n.PP\n.RS\n\n.nf\n" codeTag = "\n.EX\n"
codeCloseTag = "\n.fi\n.RE\n" codeCloseTag = ".EE\n" // Do not prepend a newline character since code blocks, by definition, include a newline already (or at least as how blackfriday gives us on).
quoteTag = "\n.PP\n.RS\n" quoteTag = "\n.PP\n.RS\n"
quoteCloseTag = "\n.RE\n" quoteCloseTag = "\n.RE\n"
listTag = "\n.RS\n" listTag = "\n.RS\n"
@ -48,6 +50,7 @@ const (
tableEnd = ".TE\n" tableEnd = ".TE\n"
tableCellStart = "T{\n" tableCellStart = "T{\n"
tableCellEnd = "\nT}\n" tableCellEnd = "\nT}\n"
tablePreprocessor = `'\" t`
) )
// NewRoffRenderer creates a new blackfriday Renderer for generating roff documents // NewRoffRenderer creates a new blackfriday Renderer for generating roff documents
@ -74,6 +77,16 @@ func (r *roffRenderer) GetExtensions() blackfriday.Extensions {
// RenderHeader handles outputting the header at document start // RenderHeader handles outputting the header at document start
func (r *roffRenderer) RenderHeader(w io.Writer, ast *blackfriday.Node) { func (r *roffRenderer) RenderHeader(w io.Writer, ast *blackfriday.Node) {
// We need to walk the tree to check if there are any tables.
// If there are, we need to enable the roff table preprocessor.
ast.Walk(func(node *blackfriday.Node, entering bool) blackfriday.WalkStatus {
if node.Type == blackfriday.Table {
out(w, tablePreprocessor+"\n")
return blackfriday.Terminate
}
return blackfriday.GoToNext
})
// disable hyphenation // disable hyphenation
out(w, ".nh\n") out(w, ".nh\n")
} }
@ -86,8 +99,7 @@ func (r *roffRenderer) RenderFooter(w io.Writer, ast *blackfriday.Node) {
// RenderNode is called for each node in a markdown document; based on the node // RenderNode is called for each node in a markdown document; based on the node
// type the equivalent roff output is sent to the writer // type the equivalent roff output is sent to the writer
func (r *roffRenderer) RenderNode(w io.Writer, node *blackfriday.Node, entering bool) blackfriday.WalkStatus { func (r *roffRenderer) RenderNode(w io.Writer, node *blackfriday.Node, entering bool) blackfriday.WalkStatus {
walkAction := blackfriday.GoToNext
var walkAction = blackfriday.GoToNext
switch node.Type { switch node.Type {
case blackfriday.Text: case blackfriday.Text:
@ -109,9 +121,16 @@ func (r *roffRenderer) RenderNode(w io.Writer, node *blackfriday.Node, entering
out(w, strongCloseTag) out(w, strongCloseTag)
} }
case blackfriday.Link: case blackfriday.Link:
if !entering { // Don't render the link text for automatic links, because this
out(w, linkTag+string(node.LinkData.Destination)+linkCloseTag) // will only duplicate the URL in the roff output.
// See https://daringfireball.net/projects/markdown/syntax#autolink
if !bytes.Equal(node.LinkData.Destination, node.FirstChild.Literal) {
out(w, string(node.FirstChild.Literal))
} }
// Hyphens in a link must be escaped to avoid word-wrap in the rendered man page.
escapedLink := strings.ReplaceAll(string(node.LinkData.Destination), "-", "\\-")
out(w, linkTag+escapedLink+linkCloseTag)
walkAction = blackfriday.SkipChildren
case blackfriday.Image: case blackfriday.Image:
// ignore images // ignore images
walkAction = blackfriday.SkipChildren walkAction = blackfriday.SkipChildren
@ -160,6 +179,11 @@ func (r *roffRenderer) RenderNode(w io.Writer, node *blackfriday.Node, entering
r.handleTableCell(w, node, entering) r.handleTableCell(w, node, entering)
case blackfriday.HTMLSpan: case blackfriday.HTMLSpan:
// ignore other HTML tags // ignore other HTML tags
case blackfriday.HTMLBlock:
if bytes.HasPrefix(node.Literal, []byte("<!--")) {
break // ignore comments, no warning
}
fmt.Fprintln(os.Stderr, "WARNING: go-md2man does not handle node type "+node.Type.String())
default: default:
fmt.Fprintln(os.Stderr, "WARNING: go-md2man does not handle node type "+node.Type.String()) fmt.Fprintln(os.Stderr, "WARNING: go-md2man does not handle node type "+node.Type.String())
} }
@ -254,7 +278,7 @@ func (r *roffRenderer) handleTableCell(w io.Writer, node *blackfriday.Node, ente
start = "\t" start = "\t"
} }
if node.IsHeader { if node.IsHeader {
start += codespanTag start += strongTag
} else if nodeLiteralSize(node) > 30 { } else if nodeLiteralSize(node) > 30 {
start += tableCellStart start += tableCellStart
} }
@ -262,7 +286,7 @@ func (r *roffRenderer) handleTableCell(w io.Writer, node *blackfriday.Node, ente
} else { } else {
var end string var end string
if node.IsHeader { if node.IsHeader {
end = codespanCloseTag end = strongCloseTag
} else if nodeLiteralSize(node) > 30 { } else if nodeLiteralSize(node) > 30 {
end = tableCellEnd end = tableCellEnd
} }
@ -310,6 +334,28 @@ func out(w io.Writer, output string) {
} }
func escapeSpecialChars(w io.Writer, text []byte) { func escapeSpecialChars(w io.Writer, text []byte) {
scanner := bufio.NewScanner(bytes.NewReader(text))
// count the number of lines in the text
// we need to know this to avoid adding a newline after the last line
n := bytes.Count(text, []byte{'\n'})
idx := 0
for scanner.Scan() {
dt := scanner.Bytes()
if idx < n {
idx++
dt = append(dt, '\n')
}
escapeSpecialCharsLine(w, dt)
}
if err := scanner.Err(); err != nil {
panic(err)
}
}
func escapeSpecialCharsLine(w io.Writer, text []byte) {
for i := 0; i < len(text); i++ { for i := 0; i < len(text); i++ {
// escape initial apostrophe or period // escape initial apostrophe or period
if len(text) >= 1 && (text[0] == '\'' || text[0] == '.') { if len(text) >= 1 && (text[0] == '\'' || text[0] == '.') {

View file

@ -60,7 +60,7 @@ func main() {
// Output: {"time":1516134303,"level":"debug","message":"hello world"} // Output: {"time":1516134303,"level":"debug","message":"hello world"}
``` ```
> Note: By default log writes to `os.Stderr` > Note: By default log writes to `os.Stderr`
> Note: The default log level for `log.Print` is *debug* > Note: The default log level for `log.Print` is *trace*
### Contextual Logging ### Contextual Logging
@ -412,15 +412,7 @@ Equivalent of `Lshortfile`:
```go ```go
zerolog.CallerMarshalFunc = func(pc uintptr, file string, line int) string { zerolog.CallerMarshalFunc = func(pc uintptr, file string, line int) string {
short := file return filepath.Base(file) + ":" + strconv.Itoa(line)
for i := len(file) - 1; i > 0; i-- {
if file[i] == '/' {
short = file[i+1:]
break
}
}
file = short
return file + ":" + strconv.Itoa(line)
} }
log.Logger = log.With().Caller().Logger() log.Logger = log.With().Caller().Logger()
log.Info().Msg("hello world") log.Info().Msg("hello world")
@ -547,7 +539,7 @@ and facilitates the unification of logging and tracing in some systems:
type TracingHook struct{} type TracingHook struct{}
func (h TracingHook) Run(e *zerolog.Event, level zerolog.Level, msg string) { func (h TracingHook) Run(e *zerolog.Event, level zerolog.Level, msg string) {
ctx := e.Ctx() ctx := e.GetCtx()
spanId := getSpanIdFromContext(ctx) // as per your tracing framework spanId := getSpanIdFromContext(ctx) // as per your tracing framework
e.Str("span-id", spanId) e.Str("span-id", spanId)
} }
@ -646,10 +638,14 @@ Some settings can be changed and will be applied to all loggers:
* `zerolog.LevelFieldName`: Can be set to customize level field name. * `zerolog.LevelFieldName`: Can be set to customize level field name.
* `zerolog.MessageFieldName`: Can be set to customize message field name. * `zerolog.MessageFieldName`: Can be set to customize message field name.
* `zerolog.ErrorFieldName`: Can be set to customize `Err` field name. * `zerolog.ErrorFieldName`: Can be set to customize `Err` field name.
* `zerolog.TimeFieldFormat`: Can be set to customize `Time` field value formatting. If set with `zerolog.TimeFormatUnix`, `zerolog.TimeFormatUnixMs` or `zerolog.TimeFormatUnixMicro`, times are formated as UNIX timestamp. * `zerolog.TimeFieldFormat`: Can be set to customize `Time` field value formatting. If set with `zerolog.TimeFormatUnix`, `zerolog.TimeFormatUnixMs` or `zerolog.TimeFormatUnixMicro`, times are formatted as UNIX timestamp.
* `zerolog.DurationFieldUnit`: Can be set to customize the unit for time.Duration type fields added by `Dur` (default: `time.Millisecond`). * `zerolog.DurationFieldUnit`: Can be set to customize the unit for time.Duration type fields added by `Dur` (default: `time.Millisecond`).
* `zerolog.DurationFieldInteger`: If set to `true`, `Dur` fields are formatted as integers instead of floats (default: `false`). * `zerolog.DurationFieldInteger`: If set to `true`, `Dur` fields are formatted as integers instead of floats (default: `false`).
* `zerolog.ErrorHandler`: Called whenever zerolog fails to write an event on its output. If not set, an error is printed on the stderr. This handler must be thread safe and non-blocking. * `zerolog.ErrorHandler`: Called whenever zerolog fails to write an event on its output. If not set, an error is printed on the stderr. This handler must be thread safe and non-blocking.
* `zerolog.FloatingPointPrecision`: If set to a value other than -1, controls the number
of digits when formatting float numbers in JSON. See
[strconv.FormatFloat](https://pkg.go.dev/strconv#FormatFloat)
for more details.
## Field Types ## Field Types

View file

@ -183,13 +183,13 @@ func (a *Array) Uint64(i uint64) *Array {
// Float32 appends f as a float32 to the array. // Float32 appends f as a float32 to the array.
func (a *Array) Float32(f float32) *Array { func (a *Array) Float32(f float32) *Array {
a.buf = enc.AppendFloat32(enc.AppendArrayDelim(a.buf), f) a.buf = enc.AppendFloat32(enc.AppendArrayDelim(a.buf), f, FloatingPointPrecision)
return a return a
} }
// Float64 appends f as a float64 to the array. // Float64 appends f as a float64 to the array.
func (a *Array) Float64(f float64) *Array { func (a *Array) Float64(f float64) *Array {
a.buf = enc.AppendFloat64(enc.AppendArrayDelim(a.buf), f) a.buf = enc.AppendFloat64(enc.AppendArrayDelim(a.buf), f, FloatingPointPrecision)
return a return a
} }
@ -201,7 +201,7 @@ func (a *Array) Time(t time.Time) *Array {
// Dur appends d to the array. // Dur appends d to the array.
func (a *Array) Dur(d time.Duration) *Array { func (a *Array) Dur(d time.Duration) *Array {
a.buf = enc.AppendDuration(enc.AppendArrayDelim(a.buf), d, DurationFieldUnit, DurationFieldInteger) a.buf = enc.AppendDuration(enc.AppendArrayDelim(a.buf), d, DurationFieldUnit, DurationFieldInteger, FloatingPointPrecision)
return a return a
} }

View file

@ -28,6 +28,8 @@ const (
colorBold = 1 colorBold = 1
colorDarkGray = 90 colorDarkGray = 90
unknownLevel = "???"
) )
var ( var (
@ -57,12 +59,21 @@ type ConsoleWriter struct {
// TimeFormat specifies the format for timestamp in output. // TimeFormat specifies the format for timestamp in output.
TimeFormat string TimeFormat string
// TimeLocation tells ConsoleWriters default FormatTimestamp
// how to localize the time.
TimeLocation *time.Location
// PartsOrder defines the order of parts in output. // PartsOrder defines the order of parts in output.
PartsOrder []string PartsOrder []string
// PartsExclude defines parts to not display in output. // PartsExclude defines parts to not display in output.
PartsExclude []string PartsExclude []string
// FieldsOrder defines the order of contextual fields in output.
FieldsOrder []string
fieldIsOrdered map[string]int
// FieldsExclude defines contextual fields to not display in output. // FieldsExclude defines contextual fields to not display in output.
FieldsExclude []string FieldsExclude []string
@ -76,6 +87,8 @@ type ConsoleWriter struct {
FormatErrFieldValue Formatter FormatErrFieldValue Formatter
FormatExtra func(map[string]interface{}, *bytes.Buffer) error FormatExtra func(map[string]interface{}, *bytes.Buffer) error
FormatPrepare func(map[string]interface{}) error
} }
// NewConsoleWriter creates and initializes a new ConsoleWriter. // NewConsoleWriter creates and initializes a new ConsoleWriter.
@ -124,6 +137,13 @@ func (w ConsoleWriter) Write(p []byte) (n int, err error) {
return n, fmt.Errorf("cannot decode event: %s", err) return n, fmt.Errorf("cannot decode event: %s", err)
} }
if w.FormatPrepare != nil {
err = w.FormatPrepare(evt)
if err != nil {
return n, err
}
}
for _, p := range w.PartsOrder { for _, p := range w.PartsOrder {
w.writePart(buf, evt, p) w.writePart(buf, evt, p)
} }
@ -146,6 +166,15 @@ func (w ConsoleWriter) Write(p []byte) (n int, err error) {
return len(p), err return len(p), err
} }
// Call the underlying writer's Close method if it is an io.Closer. Otherwise
// does nothing.
func (w ConsoleWriter) Close() error {
if closer, ok := w.Out.(io.Closer); ok {
return closer.Close()
}
return nil
}
// writeFields appends formatted key-value pairs to buf. // writeFields appends formatted key-value pairs to buf.
func (w ConsoleWriter) writeFields(evt map[string]interface{}, buf *bytes.Buffer) { func (w ConsoleWriter) writeFields(evt map[string]interface{}, buf *bytes.Buffer) {
var fields = make([]string, 0, len(evt)) var fields = make([]string, 0, len(evt))
@ -167,7 +196,12 @@ func (w ConsoleWriter) writeFields(evt map[string]interface{}, buf *bytes.Buffer
} }
fields = append(fields, field) fields = append(fields, field)
} }
if len(w.FieldsOrder) > 0 {
w.orderFields(fields)
} else {
sort.Strings(fields) sort.Strings(fields)
}
// Write space only if something has already been written to the buffer, and if there are fields. // Write space only if something has already been written to the buffer, and if there are fields.
if buf.Len() > 0 && len(fields) > 0 { if buf.Len() > 0 && len(fields) > 0 {
@ -266,13 +300,13 @@ func (w ConsoleWriter) writePart(buf *bytes.Buffer, evt map[string]interface{},
} }
case TimestampFieldName: case TimestampFieldName:
if w.FormatTimestamp == nil { if w.FormatTimestamp == nil {
f = consoleDefaultFormatTimestamp(w.TimeFormat, w.NoColor) f = consoleDefaultFormatTimestamp(w.TimeFormat, w.TimeLocation, w.NoColor)
} else { } else {
f = w.FormatTimestamp f = w.FormatTimestamp
} }
case MessageFieldName: case MessageFieldName:
if w.FormatMessage == nil { if w.FormatMessage == nil {
f = consoleDefaultFormatMessage f = consoleDefaultFormatMessage(w.NoColor, evt[LevelFieldName])
} else { } else {
f = w.FormatMessage f = w.FormatMessage
} }
@ -300,6 +334,32 @@ func (w ConsoleWriter) writePart(buf *bytes.Buffer, evt map[string]interface{},
} }
} }
// orderFields takes an array of field names and an array representing field order
// and returns an array with any ordered fields at the beginning, in order,
// and the remaining fields after in their original order.
func (w ConsoleWriter) orderFields(fields []string) {
if w.fieldIsOrdered == nil {
w.fieldIsOrdered = make(map[string]int)
for i, fieldName := range w.FieldsOrder {
w.fieldIsOrdered[fieldName] = i
}
}
sort.Slice(fields, func(i, j int) bool {
ii, iOrdered := w.fieldIsOrdered[fields[i]]
jj, jOrdered := w.fieldIsOrdered[fields[j]]
if iOrdered && jOrdered {
return ii < jj
}
if iOrdered {
return true
}
if jOrdered {
return false
}
return fields[i] < fields[j]
})
}
// needsQuote returns true when the string s should be quoted in output. // needsQuote returns true when the string s should be quoted in output.
func needsQuote(s string) bool { func needsQuote(s string) bool {
for i := range s { for i := range s {
@ -310,10 +370,10 @@ func needsQuote(s string) bool {
return false return false
} }
// colorize returns the string s wrapped in ANSI code c, unless disabled is true. // colorize returns the string s wrapped in ANSI code c, unless disabled is true or c is 0.
func colorize(s interface{}, c int, disabled bool) string { func colorize(s interface{}, c int, disabled bool) string {
e := os.Getenv("NO_COLOR") e := os.Getenv("NO_COLOR")
if e != "" { if e != "" || c == 0 {
disabled = true disabled = true
} }
@ -334,19 +394,23 @@ func consoleDefaultPartsOrder() []string {
} }
} }
func consoleDefaultFormatTimestamp(timeFormat string, noColor bool) Formatter { func consoleDefaultFormatTimestamp(timeFormat string, location *time.Location, noColor bool) Formatter {
if timeFormat == "" { if timeFormat == "" {
timeFormat = consoleDefaultTimeFormat timeFormat = consoleDefaultTimeFormat
} }
if location == nil {
location = time.Local
}
return func(i interface{}) string { return func(i interface{}) string {
t := "<nil>" t := "<nil>"
switch tt := i.(type) { switch tt := i.(type) {
case string: case string:
ts, err := time.ParseInLocation(TimeFieldFormat, tt, time.Local) ts, err := time.ParseInLocation(TimeFieldFormat, tt, location)
if err != nil { if err != nil {
t = tt t = tt
} else { } else {
t = ts.Local().Format(timeFormat) t = ts.In(location).Format(timeFormat)
} }
case json.Number: case json.Number:
i, err := tt.Int64() i, err := tt.Int64()
@ -367,43 +431,37 @@ func consoleDefaultFormatTimestamp(timeFormat string, noColor bool) Formatter {
} }
ts := time.Unix(sec, nsec) ts := time.Unix(sec, nsec)
t = ts.Format(timeFormat) t = ts.In(location).Format(timeFormat)
} }
} }
return colorize(t, colorDarkGray, noColor) return colorize(t, colorDarkGray, noColor)
} }
} }
func stripLevel(ll string) string {
if len(ll) == 0 {
return unknownLevel
}
if len(ll) > 3 {
ll = ll[:3]
}
return strings.ToUpper(ll)
}
func consoleDefaultFormatLevel(noColor bool) Formatter { func consoleDefaultFormatLevel(noColor bool) Formatter {
return func(i interface{}) string { return func(i interface{}) string {
var l string
if ll, ok := i.(string); ok { if ll, ok := i.(string); ok {
switch ll { level, _ := ParseLevel(ll)
case LevelTraceValue: fl, ok := FormattedLevels[level]
l = colorize("TRC", colorMagenta, noColor) if ok {
case LevelDebugValue: return colorize(fl, LevelColors[level], noColor)
l = colorize("DBG", colorYellow, noColor) }
case LevelInfoValue: return stripLevel(ll)
l = colorize("INF", colorGreen, noColor)
case LevelWarnValue:
l = colorize("WRN", colorRed, noColor)
case LevelErrorValue:
l = colorize(colorize("ERR", colorRed, noColor), colorBold, noColor)
case LevelFatalValue:
l = colorize(colorize("FTL", colorRed, noColor), colorBold, noColor)
case LevelPanicValue:
l = colorize(colorize("PNC", colorRed, noColor), colorBold, noColor)
default:
l = colorize(ll, colorBold, noColor)
} }
} else {
if i == nil { if i == nil {
l = colorize("???", colorBold, noColor) return unknownLevel
} else {
l = strings.ToUpper(fmt.Sprintf("%s", i))[0:3]
} }
} return stripLevel(fmt.Sprintf("%s", i))
return l
} }
} }
@ -425,12 +483,19 @@ func consoleDefaultFormatCaller(noColor bool) Formatter {
} }
} }
func consoleDefaultFormatMessage(i interface{}) string { func consoleDefaultFormatMessage(noColor bool, level interface{}) Formatter {
if i == nil { return func(i interface{}) string {
if i == nil || i == "" {
return "" return ""
} }
switch level {
case LevelInfoValue, LevelWarnValue, LevelErrorValue, LevelFatalValue, LevelPanicValue:
return colorize(fmt.Sprintf("%s", i), colorBold, noColor)
default:
return fmt.Sprintf("%s", i) return fmt.Sprintf("%s", i)
} }
}
}
func consoleDefaultFormatFieldName(noColor bool) Formatter { func consoleDefaultFormatFieldName(noColor bool) Formatter {
return func(i interface{}) string { return func(i interface{}) string {
@ -450,6 +515,6 @@ func consoleDefaultFormatErrFieldName(noColor bool) Formatter {
func consoleDefaultFormatErrFieldValue(noColor bool) Formatter { func consoleDefaultFormatErrFieldValue(noColor bool) Formatter {
return func(i interface{}) string { return func(i interface{}) string {
return colorize(fmt.Sprintf("%s", i), colorRed, noColor) return colorize(colorize(fmt.Sprintf("%s", i), colorBold, noColor), colorRed, noColor)
} }
} }

View file

@ -3,7 +3,7 @@ package zerolog
import ( import (
"context" "context"
"fmt" "fmt"
"io/ioutil" "io"
"math" "math"
"net" "net"
"time" "time"
@ -23,7 +23,7 @@ func (c Context) Logger() Logger {
// Only map[string]interface{} and []interface{} are accepted. []interface{} must // Only map[string]interface{} and []interface{} are accepted. []interface{} must
// alternate string keys and arbitrary values, and extraneous ones are ignored. // alternate string keys and arbitrary values, and extraneous ones are ignored.
func (c Context) Fields(fields interface{}) Context { func (c Context) Fields(fields interface{}) Context {
c.l.context = appendFields(c.l.context, fields) c.l.context = appendFields(c.l.context, fields, c.l.stack)
return c return c
} }
@ -57,7 +57,7 @@ func (c Context) Array(key string, arr LogArrayMarshaler) Context {
// Object marshals an object that implement the LogObjectMarshaler interface. // Object marshals an object that implement the LogObjectMarshaler interface.
func (c Context) Object(key string, obj LogObjectMarshaler) Context { func (c Context) Object(key string, obj LogObjectMarshaler) Context {
e := newEvent(LevelWriterAdapter{ioutil.Discard}, 0) e := newEvent(LevelWriterAdapter{io.Discard}, 0)
e.Object(key, obj) e.Object(key, obj)
c.l.context = enc.AppendObjectData(c.l.context, e.buf) c.l.context = enc.AppendObjectData(c.l.context, e.buf)
putEvent(e) putEvent(e)
@ -66,7 +66,7 @@ func (c Context) Object(key string, obj LogObjectMarshaler) Context {
// EmbedObject marshals and Embeds an object that implement the LogObjectMarshaler interface. // EmbedObject marshals and Embeds an object that implement the LogObjectMarshaler interface.
func (c Context) EmbedObject(obj LogObjectMarshaler) Context { func (c Context) EmbedObject(obj LogObjectMarshaler) Context {
e := newEvent(LevelWriterAdapter{ioutil.Discard}, 0) e := newEvent(LevelWriterAdapter{io.Discard}, 0)
e.EmbedObject(obj) e.EmbedObject(obj)
c.l.context = enc.AppendObjectData(c.l.context, e.buf) c.l.context = enc.AppendObjectData(c.l.context, e.buf)
putEvent(e) putEvent(e)
@ -163,6 +163,22 @@ func (c Context) Errs(key string, errs []error) Context {
// Err adds the field "error" with serialized err to the logger context. // Err adds the field "error" with serialized err to the logger context.
func (c Context) Err(err error) Context { func (c Context) Err(err error) Context {
if c.l.stack && ErrorStackMarshaler != nil {
switch m := ErrorStackMarshaler(err).(type) {
case nil:
case LogObjectMarshaler:
c = c.Object(ErrorStackFieldName, m)
case error:
if m != nil && !isNilValue(m) {
c = c.Str(ErrorStackFieldName, m.Error())
}
case string:
c = c.Str(ErrorStackFieldName, m)
default:
c = c.Interface(ErrorStackFieldName, m)
}
}
return c.AnErr(ErrorFieldName, err) return c.AnErr(ErrorFieldName, err)
} }
@ -309,25 +325,25 @@ func (c Context) Uints64(key string, i []uint64) Context {
// Float32 adds the field key with f as a float32 to the logger context. // Float32 adds the field key with f as a float32 to the logger context.
func (c Context) Float32(key string, f float32) Context { func (c Context) Float32(key string, f float32) Context {
c.l.context = enc.AppendFloat32(enc.AppendKey(c.l.context, key), f) c.l.context = enc.AppendFloat32(enc.AppendKey(c.l.context, key), f, FloatingPointPrecision)
return c return c
} }
// Floats32 adds the field key with f as a []float32 to the logger context. // Floats32 adds the field key with f as a []float32 to the logger context.
func (c Context) Floats32(key string, f []float32) Context { func (c Context) Floats32(key string, f []float32) Context {
c.l.context = enc.AppendFloats32(enc.AppendKey(c.l.context, key), f) c.l.context = enc.AppendFloats32(enc.AppendKey(c.l.context, key), f, FloatingPointPrecision)
return c return c
} }
// Float64 adds the field key with f as a float64 to the logger context. // Float64 adds the field key with f as a float64 to the logger context.
func (c Context) Float64(key string, f float64) Context { func (c Context) Float64(key string, f float64) Context {
c.l.context = enc.AppendFloat64(enc.AppendKey(c.l.context, key), f) c.l.context = enc.AppendFloat64(enc.AppendKey(c.l.context, key), f, FloatingPointPrecision)
return c return c
} }
// Floats64 adds the field key with f as a []float64 to the logger context. // Floats64 adds the field key with f as a []float64 to the logger context.
func (c Context) Floats64(key string, f []float64) Context { func (c Context) Floats64(key string, f []float64) Context {
c.l.context = enc.AppendFloats64(enc.AppendKey(c.l.context, key), f) c.l.context = enc.AppendFloats64(enc.AppendKey(c.l.context, key), f, FloatingPointPrecision)
return c return c
} }
@ -349,13 +365,13 @@ func (c Context) Timestamp() Context {
return c return c
} }
// Time adds the field key with t formated as string using zerolog.TimeFieldFormat. // Time adds the field key with t formatted as string using zerolog.TimeFieldFormat.
func (c Context) Time(key string, t time.Time) Context { func (c Context) Time(key string, t time.Time) Context {
c.l.context = enc.AppendTime(enc.AppendKey(c.l.context, key), t, TimeFieldFormat) c.l.context = enc.AppendTime(enc.AppendKey(c.l.context, key), t, TimeFieldFormat)
return c return c
} }
// Times adds the field key with t formated as string using zerolog.TimeFieldFormat. // Times adds the field key with t formatted as string using zerolog.TimeFieldFormat.
func (c Context) Times(key string, t []time.Time) Context { func (c Context) Times(key string, t []time.Time) Context {
c.l.context = enc.AppendTimes(enc.AppendKey(c.l.context, key), t, TimeFieldFormat) c.l.context = enc.AppendTimes(enc.AppendKey(c.l.context, key), t, TimeFieldFormat)
return c return c
@ -363,27 +379,42 @@ func (c Context) Times(key string, t []time.Time) Context {
// Dur adds the fields key with d divided by unit and stored as a float. // Dur adds the fields key with d divided by unit and stored as a float.
func (c Context) Dur(key string, d time.Duration) Context { func (c Context) Dur(key string, d time.Duration) Context {
c.l.context = enc.AppendDuration(enc.AppendKey(c.l.context, key), d, DurationFieldUnit, DurationFieldInteger) c.l.context = enc.AppendDuration(enc.AppendKey(c.l.context, key), d, DurationFieldUnit, DurationFieldInteger, FloatingPointPrecision)
return c return c
} }
// Durs adds the fields key with d divided by unit and stored as a float. // Durs adds the fields key with d divided by unit and stored as a float.
func (c Context) Durs(key string, d []time.Duration) Context { func (c Context) Durs(key string, d []time.Duration) Context {
c.l.context = enc.AppendDurations(enc.AppendKey(c.l.context, key), d, DurationFieldUnit, DurationFieldInteger) c.l.context = enc.AppendDurations(enc.AppendKey(c.l.context, key), d, DurationFieldUnit, DurationFieldInteger, FloatingPointPrecision)
return c return c
} }
// Interface adds the field key with obj marshaled using reflection. // Interface adds the field key with obj marshaled using reflection.
func (c Context) Interface(key string, i interface{}) Context { func (c Context) Interface(key string, i interface{}) Context {
if obj, ok := i.(LogObjectMarshaler); ok {
return c.Object(key, obj)
}
c.l.context = enc.AppendInterface(enc.AppendKey(c.l.context, key), i) c.l.context = enc.AppendInterface(enc.AppendKey(c.l.context, key), i)
return c return c
} }
// Type adds the field key with val's type using reflection.
func (c Context) Type(key string, val interface{}) Context {
c.l.context = enc.AppendType(enc.AppendKey(c.l.context, key), val)
return c
}
// Any is a wrapper around Context.Interface. // Any is a wrapper around Context.Interface.
func (c Context) Any(key string, i interface{}) Context { func (c Context) Any(key string, i interface{}) Context {
return c.Interface(key, i) return c.Interface(key, i)
} }
// Reset removes all the context fields.
func (c Context) Reset() Context {
c.l.context = enc.AppendBeginMarker(make([]byte, 0, 500))
return c
}
type callerHook struct { type callerHook struct {
callerSkipFrameCount int callerSkipFrameCount int
} }

View file

@ -13,13 +13,13 @@ type encoder interface {
AppendBool(dst []byte, val bool) []byte AppendBool(dst []byte, val bool) []byte
AppendBools(dst []byte, vals []bool) []byte AppendBools(dst []byte, vals []bool) []byte
AppendBytes(dst, s []byte) []byte AppendBytes(dst, s []byte) []byte
AppendDuration(dst []byte, d time.Duration, unit time.Duration, useInt bool) []byte AppendDuration(dst []byte, d time.Duration, unit time.Duration, useInt bool, precision int) []byte
AppendDurations(dst []byte, vals []time.Duration, unit time.Duration, useInt bool) []byte AppendDurations(dst []byte, vals []time.Duration, unit time.Duration, useInt bool, precision int) []byte
AppendEndMarker(dst []byte) []byte AppendEndMarker(dst []byte) []byte
AppendFloat32(dst []byte, val float32) []byte AppendFloat32(dst []byte, val float32, precision int) []byte
AppendFloat64(dst []byte, val float64) []byte AppendFloat64(dst []byte, val float64, precision int) []byte
AppendFloats32(dst []byte, vals []float32) []byte AppendFloats32(dst []byte, vals []float32, precision int) []byte
AppendFloats64(dst []byte, vals []float64) []byte AppendFloats64(dst []byte, vals []float64, precision int) []byte
AppendHex(dst, s []byte) []byte AppendHex(dst, s []byte) []byte
AppendIPAddr(dst []byte, ip net.IP) []byte AppendIPAddr(dst []byte, ip net.IP) []byte
AppendIPPrefix(dst []byte, pfx net.IPNet) []byte AppendIPPrefix(dst []byte, pfx net.IPNet) []byte

View file

@ -164,7 +164,7 @@ func (e *Event) Fields(fields interface{}) *Event {
if e == nil { if e == nil {
return e return e
} }
e.buf = appendFields(e.buf, fields) e.buf = appendFields(e.buf, fields, e.stack)
return e return e
} }
@ -644,7 +644,7 @@ func (e *Event) Float32(key string, f float32) *Event {
if e == nil { if e == nil {
return e return e
} }
e.buf = enc.AppendFloat32(enc.AppendKey(e.buf, key), f) e.buf = enc.AppendFloat32(enc.AppendKey(e.buf, key), f, FloatingPointPrecision)
return e return e
} }
@ -653,7 +653,7 @@ func (e *Event) Floats32(key string, f []float32) *Event {
if e == nil { if e == nil {
return e return e
} }
e.buf = enc.AppendFloats32(enc.AppendKey(e.buf, key), f) e.buf = enc.AppendFloats32(enc.AppendKey(e.buf, key), f, FloatingPointPrecision)
return e return e
} }
@ -662,7 +662,7 @@ func (e *Event) Float64(key string, f float64) *Event {
if e == nil { if e == nil {
return e return e
} }
e.buf = enc.AppendFloat64(enc.AppendKey(e.buf, key), f) e.buf = enc.AppendFloat64(enc.AppendKey(e.buf, key), f, FloatingPointPrecision)
return e return e
} }
@ -671,7 +671,7 @@ func (e *Event) Floats64(key string, f []float64) *Event {
if e == nil { if e == nil {
return e return e
} }
e.buf = enc.AppendFloats64(enc.AppendKey(e.buf, key), f) e.buf = enc.AppendFloats64(enc.AppendKey(e.buf, key), f, FloatingPointPrecision)
return e return e
} }
@ -713,7 +713,7 @@ func (e *Event) Dur(key string, d time.Duration) *Event {
if e == nil { if e == nil {
return e return e
} }
e.buf = enc.AppendDuration(enc.AppendKey(e.buf, key), d, DurationFieldUnit, DurationFieldInteger) e.buf = enc.AppendDuration(enc.AppendKey(e.buf, key), d, DurationFieldUnit, DurationFieldInteger, FloatingPointPrecision)
return e return e
} }
@ -724,7 +724,7 @@ func (e *Event) Durs(key string, d []time.Duration) *Event {
if e == nil { if e == nil {
return e return e
} }
e.buf = enc.AppendDurations(enc.AppendKey(e.buf, key), d, DurationFieldUnit, DurationFieldInteger) e.buf = enc.AppendDurations(enc.AppendKey(e.buf, key), d, DurationFieldUnit, DurationFieldInteger, FloatingPointPrecision)
return e return e
} }
@ -739,7 +739,7 @@ func (e *Event) TimeDiff(key string, t time.Time, start time.Time) *Event {
if t.After(start) { if t.After(start) {
d = t.Sub(start) d = t.Sub(start)
} }
e.buf = enc.AppendDuration(enc.AppendKey(e.buf, key), d, DurationFieldUnit, DurationFieldInteger) e.buf = enc.AppendDuration(enc.AppendKey(e.buf, key), d, DurationFieldUnit, DurationFieldInteger, FloatingPointPrecision)
return e return e
} }

7
vendor/github.com/rs/zerolog/example.jsonl generated vendored Normal file
View file

@ -0,0 +1,7 @@
{"time":"5:41PM","level":"info","message":"Starting listener","listen":":8080","pid":37556}
{"time":"5:41PM","level":"debug","message":"Access","database":"myapp","host":"localhost:4962","pid":37556}
{"time":"5:41PM","level":"info","message":"Access","method":"GET","path":"/users","pid":37556,"resp_time":23}
{"time":"5:41PM","level":"info","message":"Access","method":"POST","path":"/posts","pid":37556,"resp_time":532}
{"time":"5:41PM","level":"warn","message":"Slow request","method":"POST","path":"/posts","pid":37556,"resp_time":532}
{"time":"5:41PM","level":"info","message":"Access","method":"GET","path":"/users","pid":37556,"resp_time":10}
{"time":"5:41PM","level":"error","message":"Database connection lost","database":"myapp","pid":37556,"error":"connection reset by peer"}

View file

@ -12,13 +12,13 @@ func isNilValue(i interface{}) bool {
return (*[2]uintptr)(unsafe.Pointer(&i))[1] == 0 return (*[2]uintptr)(unsafe.Pointer(&i))[1] == 0
} }
func appendFields(dst []byte, fields interface{}) []byte { func appendFields(dst []byte, fields interface{}, stack bool) []byte {
switch fields := fields.(type) { switch fields := fields.(type) {
case []interface{}: case []interface{}:
if n := len(fields); n&0x1 == 1 { // odd number if n := len(fields); n&0x1 == 1 { // odd number
fields = fields[:n-1] fields = fields[:n-1]
} }
dst = appendFieldList(dst, fields) dst = appendFieldList(dst, fields, stack)
case map[string]interface{}: case map[string]interface{}:
keys := make([]string, 0, len(fields)) keys := make([]string, 0, len(fields))
for key := range fields { for key := range fields {
@ -28,13 +28,13 @@ func appendFields(dst []byte, fields interface{}) []byte {
kv := make([]interface{}, 2) kv := make([]interface{}, 2)
for _, key := range keys { for _, key := range keys {
kv[0], kv[1] = key, fields[key] kv[0], kv[1] = key, fields[key]
dst = appendFieldList(dst, kv) dst = appendFieldList(dst, kv, stack)
} }
} }
return dst return dst
} }
func appendFieldList(dst []byte, kvList []interface{}) []byte { func appendFieldList(dst []byte, kvList []interface{}, stack bool) []byte {
for i, n := 0, len(kvList); i < n; i += 2 { for i, n := 0, len(kvList); i < n; i += 2 {
key, val := kvList[i], kvList[i+1] key, val := kvList[i], kvList[i+1]
if key, ok := key.(string); ok { if key, ok := key.(string); ok {
@ -74,6 +74,21 @@ func appendFieldList(dst []byte, kvList []interface{}) []byte {
default: default:
dst = enc.AppendInterface(dst, m) dst = enc.AppendInterface(dst, m)
} }
if stack && ErrorStackMarshaler != nil {
dst = enc.AppendKey(dst, ErrorStackFieldName)
switch m := ErrorStackMarshaler(val).(type) {
case nil:
case error:
if m != nil && !isNilValue(m) {
dst = enc.AppendString(dst, m.Error())
}
case string:
dst = enc.AppendString(dst, m)
default:
dst = enc.AppendInterface(dst, m)
}
}
case []error: case []error:
dst = enc.AppendArrayStart(dst) dst = enc.AppendArrayStart(dst)
for i, err := range val { for i, err := range val {
@ -124,13 +139,13 @@ func appendFieldList(dst []byte, kvList []interface{}) []byte {
case uint64: case uint64:
dst = enc.AppendUint64(dst, val) dst = enc.AppendUint64(dst, val)
case float32: case float32:
dst = enc.AppendFloat32(dst, val) dst = enc.AppendFloat32(dst, val, FloatingPointPrecision)
case float64: case float64:
dst = enc.AppendFloat64(dst, val) dst = enc.AppendFloat64(dst, val, FloatingPointPrecision)
case time.Time: case time.Time:
dst = enc.AppendTime(dst, val, TimeFieldFormat) dst = enc.AppendTime(dst, val, TimeFieldFormat)
case time.Duration: case time.Duration:
dst = enc.AppendDuration(dst, val, DurationFieldUnit, DurationFieldInteger) dst = enc.AppendDuration(dst, val, DurationFieldUnit, DurationFieldInteger, FloatingPointPrecision)
case *string: case *string:
if val != nil { if val != nil {
dst = enc.AppendString(dst, *val) dst = enc.AppendString(dst, *val)
@ -205,13 +220,13 @@ func appendFieldList(dst []byte, kvList []interface{}) []byte {
} }
case *float32: case *float32:
if val != nil { if val != nil {
dst = enc.AppendFloat32(dst, *val) dst = enc.AppendFloat32(dst, *val, FloatingPointPrecision)
} else { } else {
dst = enc.AppendNil(dst) dst = enc.AppendNil(dst)
} }
case *float64: case *float64:
if val != nil { if val != nil {
dst = enc.AppendFloat64(dst, *val) dst = enc.AppendFloat64(dst, *val, FloatingPointPrecision)
} else { } else {
dst = enc.AppendNil(dst) dst = enc.AppendNil(dst)
} }
@ -223,7 +238,7 @@ func appendFieldList(dst []byte, kvList []interface{}) []byte {
} }
case *time.Duration: case *time.Duration:
if val != nil { if val != nil {
dst = enc.AppendDuration(dst, *val, DurationFieldUnit, DurationFieldInteger) dst = enc.AppendDuration(dst, *val, DurationFieldUnit, DurationFieldInteger, FloatingPointPrecision)
} else { } else {
dst = enc.AppendNil(dst) dst = enc.AppendNil(dst)
} }
@ -252,13 +267,13 @@ func appendFieldList(dst []byte, kvList []interface{}) []byte {
case []uint64: case []uint64:
dst = enc.AppendUints64(dst, val) dst = enc.AppendUints64(dst, val)
case []float32: case []float32:
dst = enc.AppendFloats32(dst, val) dst = enc.AppendFloats32(dst, val, FloatingPointPrecision)
case []float64: case []float64:
dst = enc.AppendFloats64(dst, val) dst = enc.AppendFloats64(dst, val, FloatingPointPrecision)
case []time.Time: case []time.Time:
dst = enc.AppendTimes(dst, val, TimeFieldFormat) dst = enc.AppendTimes(dst, val, TimeFieldFormat)
case []time.Duration: case []time.Duration:
dst = enc.AppendDurations(dst, val, DurationFieldUnit, DurationFieldInteger) dst = enc.AppendDurations(dst, val, DurationFieldUnit, DurationFieldInteger, FloatingPointPrecision)
case nil: case nil:
dst = enc.AppendNil(dst) dst = enc.AppendNil(dst)
case net.IP: case net.IP:

View file

@ -1,6 +1,7 @@
package zerolog package zerolog
import ( import (
"bytes"
"encoding/json" "encoding/json"
"strconv" "strconv"
"sync/atomic" "sync/atomic"
@ -81,8 +82,22 @@ var (
} }
// InterfaceMarshalFunc allows customization of interface marshaling. // InterfaceMarshalFunc allows customization of interface marshaling.
// Default: "encoding/json.Marshal" // Default: "encoding/json.Marshal" with disabled HTML escaping
InterfaceMarshalFunc = json.Marshal InterfaceMarshalFunc = func(v interface{}) ([]byte, error) {
var buf bytes.Buffer
encoder := json.NewEncoder(&buf)
encoder.SetEscapeHTML(false)
err := encoder.Encode(v)
if err != nil {
return nil, err
}
b := buf.Bytes()
if len(b) > 0 {
// Remove trailing \n which is added by Encode.
return b[:len(b)-1], nil
}
return b, nil
}
// TimeFieldFormat defines the time format of the Time field type. If set to // TimeFieldFormat defines the time format of the Time field type. If set to
// TimeFormatUnix, TimeFormatUnixMs, TimeFormatUnixMicro or TimeFormatUnixNano, the time is formatted as a UNIX // TimeFormatUnix, TimeFormatUnixMs, TimeFormatUnixMicro or TimeFormatUnixNano, the time is formatted as a UNIX
@ -108,6 +123,39 @@ var (
// DefaultContextLogger is returned from Ctx() if there is no logger associated // DefaultContextLogger is returned from Ctx() if there is no logger associated
// with the context. // with the context.
DefaultContextLogger *Logger DefaultContextLogger *Logger
// LevelColors are used by ConsoleWriter's consoleDefaultFormatLevel to color
// log levels.
LevelColors = map[Level]int{
TraceLevel: colorBlue,
DebugLevel: 0,
InfoLevel: colorGreen,
WarnLevel: colorYellow,
ErrorLevel: colorRed,
FatalLevel: colorRed,
PanicLevel: colorRed,
}
// FormattedLevels are used by ConsoleWriter's consoleDefaultFormatLevel
// for a short level name.
FormattedLevels = map[Level]string{
TraceLevel: "TRC",
DebugLevel: "DBG",
InfoLevel: "INF",
WarnLevel: "WRN",
ErrorLevel: "ERR",
FatalLevel: "FTL",
PanicLevel: "PNC",
}
// TriggerLevelWriterBufferReuseLimit is a limit in bytes that a buffer is dropped
// from the TriggerLevelWriter buffer pool if the buffer grows above the limit.
TriggerLevelWriterBufferReuseLimit = 64 * 1024
// FloatingPointPrecision, if set to a value other than -1, controls the number
// of digits when formatting float numbers in JSON. See strconv.FormatFloat for
// more details.
FloatingPointPrecision = -1
) )
var ( var (

View file

@ -95,7 +95,7 @@ func decodeFloat(src *bufio.Reader) (float64, int) {
switch minor { switch minor {
case additionalTypeFloat16: case additionalTypeFloat16:
panic(fmt.Errorf("float16 is not suppported in decodeFloat")) panic(fmt.Errorf("float16 is not supported in decodeFloat"))
case additionalTypeFloat32: case additionalTypeFloat32:
pb := readNBytes(src, 4) pb := readNBytes(src, 4)

View file

@ -29,7 +29,7 @@ func (e Encoder) appendFloatTimestamp(dst []byte, t time.Time) []byte {
nanos := t.Nanosecond() nanos := t.Nanosecond()
var val float64 var val float64
val = float64(secs)*1.0 + float64(nanos)*1e-9 val = float64(secs)*1.0 + float64(nanos)*1e-9
return e.AppendFloat64(dst, val) return e.AppendFloat64(dst, val, -1)
} }
// AppendTime encodes and adds a timestamp to the dst byte array. // AppendTime encodes and adds a timestamp to the dst byte array.
@ -64,17 +64,17 @@ func (e Encoder) AppendTimes(dst []byte, vals []time.Time, unused string) []byte
// AppendDuration encodes and adds a duration to the dst byte array. // AppendDuration encodes and adds a duration to the dst byte array.
// useInt field indicates whether to store the duration as seconds (integer) or // useInt field indicates whether to store the duration as seconds (integer) or
// as seconds+nanoseconds (float). // as seconds+nanoseconds (float).
func (e Encoder) AppendDuration(dst []byte, d time.Duration, unit time.Duration, useInt bool) []byte { func (e Encoder) AppendDuration(dst []byte, d time.Duration, unit time.Duration, useInt bool, unused int) []byte {
if useInt { if useInt {
return e.AppendInt64(dst, int64(d/unit)) return e.AppendInt64(dst, int64(d/unit))
} }
return e.AppendFloat64(dst, float64(d)/float64(unit)) return e.AppendFloat64(dst, float64(d)/float64(unit), unused)
} }
// AppendDurations encodes and adds an array of durations to the dst byte array. // AppendDurations encodes and adds an array of durations to the dst byte array.
// useInt field indicates whether to store the duration as seconds (integer) or // useInt field indicates whether to store the duration as seconds (integer) or
// as seconds+nanoseconds (float). // as seconds+nanoseconds (float).
func (e Encoder) AppendDurations(dst []byte, vals []time.Duration, unit time.Duration, useInt bool) []byte { func (e Encoder) AppendDurations(dst []byte, vals []time.Duration, unit time.Duration, useInt bool, unused int) []byte {
major := majorTypeArray major := majorTypeArray
l := len(vals) l := len(vals)
if l == 0 { if l == 0 {
@ -87,7 +87,7 @@ func (e Encoder) AppendDurations(dst []byte, vals []time.Duration, unit time.Dur
dst = appendCborTypePrefix(dst, major, uint64(l)) dst = appendCborTypePrefix(dst, major, uint64(l))
} }
for _, d := range vals { for _, d := range vals {
dst = e.AppendDuration(dst, d, unit, useInt) dst = e.AppendDuration(dst, d, unit, useInt, unused)
} }
return dst return dst
} }

View file

@ -352,7 +352,7 @@ func (e Encoder) AppendUints64(dst []byte, vals []uint64) []byte {
} }
// AppendFloat32 encodes and inserts a single precision float value into the dst byte array. // AppendFloat32 encodes and inserts a single precision float value into the dst byte array.
func (Encoder) AppendFloat32(dst []byte, val float32) []byte { func (Encoder) AppendFloat32(dst []byte, val float32, unused int) []byte {
switch { switch {
case math.IsNaN(float64(val)): case math.IsNaN(float64(val)):
return append(dst, "\xfa\x7f\xc0\x00\x00"...) return append(dst, "\xfa\x7f\xc0\x00\x00"...)
@ -372,7 +372,7 @@ func (Encoder) AppendFloat32(dst []byte, val float32) []byte {
} }
// AppendFloats32 encodes and inserts an array of single precision float value into the dst byte array. // AppendFloats32 encodes and inserts an array of single precision float value into the dst byte array.
func (e Encoder) AppendFloats32(dst []byte, vals []float32) []byte { func (e Encoder) AppendFloats32(dst []byte, vals []float32, unused int) []byte {
major := majorTypeArray major := majorTypeArray
l := len(vals) l := len(vals)
if l == 0 { if l == 0 {
@ -385,13 +385,13 @@ func (e Encoder) AppendFloats32(dst []byte, vals []float32) []byte {
dst = appendCborTypePrefix(dst, major, uint64(l)) dst = appendCborTypePrefix(dst, major, uint64(l))
} }
for _, v := range vals { for _, v := range vals {
dst = e.AppendFloat32(dst, v) dst = e.AppendFloat32(dst, v, unused)
} }
return dst return dst
} }
// AppendFloat64 encodes and inserts a double precision float value into the dst byte array. // AppendFloat64 encodes and inserts a double precision float value into the dst byte array.
func (Encoder) AppendFloat64(dst []byte, val float64) []byte { func (Encoder) AppendFloat64(dst []byte, val float64, unused int) []byte {
switch { switch {
case math.IsNaN(val): case math.IsNaN(val):
return append(dst, "\xfb\x7f\xf8\x00\x00\x00\x00\x00\x00"...) return append(dst, "\xfb\x7f\xf8\x00\x00\x00\x00\x00\x00"...)
@ -412,7 +412,7 @@ func (Encoder) AppendFloat64(dst []byte, val float64) []byte {
} }
// AppendFloats64 encodes and inserts an array of double precision float values into the dst byte array. // AppendFloats64 encodes and inserts an array of double precision float values into the dst byte array.
func (e Encoder) AppendFloats64(dst []byte, vals []float64) []byte { func (e Encoder) AppendFloats64(dst []byte, vals []float64, unused int) []byte {
major := majorTypeArray major := majorTypeArray
l := len(vals) l := len(vals)
if l == 0 { if l == 0 {
@ -425,7 +425,7 @@ func (e Encoder) AppendFloats64(dst []byte, vals []float64) []byte {
dst = appendCborTypePrefix(dst, major, uint64(l)) dst = appendCborTypePrefix(dst, major, uint64(l))
} }
for _, v := range vals { for _, v := range vals {
dst = e.AppendFloat64(dst, v) dst = e.AppendFloat64(dst, v, unused)
} }
return dst return dst
} }

View file

@ -88,24 +88,24 @@ func appendUnixNanoTimes(dst []byte, vals []time.Time, div int64) []byte {
// AppendDuration formats the input duration with the given unit & format // AppendDuration formats the input duration with the given unit & format
// and appends the encoded string to the input byte slice. // and appends the encoded string to the input byte slice.
func (e Encoder) AppendDuration(dst []byte, d time.Duration, unit time.Duration, useInt bool) []byte { func (e Encoder) AppendDuration(dst []byte, d time.Duration, unit time.Duration, useInt bool, precision int) []byte {
if useInt { if useInt {
return strconv.AppendInt(dst, int64(d/unit), 10) return strconv.AppendInt(dst, int64(d/unit), 10)
} }
return e.AppendFloat64(dst, float64(d)/float64(unit)) return e.AppendFloat64(dst, float64(d)/float64(unit), precision)
} }
// AppendDurations formats the input durations with the given unit & format // AppendDurations formats the input durations with the given unit & format
// and appends the encoded string list to the input byte slice. // and appends the encoded string list to the input byte slice.
func (e Encoder) AppendDurations(dst []byte, vals []time.Duration, unit time.Duration, useInt bool) []byte { func (e Encoder) AppendDurations(dst []byte, vals []time.Duration, unit time.Duration, useInt bool, precision int) []byte {
if len(vals) == 0 { if len(vals) == 0 {
return append(dst, '[', ']') return append(dst, '[', ']')
} }
dst = append(dst, '[') dst = append(dst, '[')
dst = e.AppendDuration(dst, vals[0], unit, useInt) dst = e.AppendDuration(dst, vals[0], unit, useInt, precision)
if len(vals) > 1 { if len(vals) > 1 {
for _, d := range vals[1:] { for _, d := range vals[1:] {
dst = e.AppendDuration(append(dst, ','), d, unit, useInt) dst = e.AppendDuration(append(dst, ','), d, unit, useInt, precision)
} }
} }
dst = append(dst, ']') dst = append(dst, ']')

View file

@ -299,7 +299,7 @@ func (Encoder) AppendUints64(dst []byte, vals []uint64) []byte {
return dst return dst
} }
func appendFloat(dst []byte, val float64, bitSize int) []byte { func appendFloat(dst []byte, val float64, bitSize, precision int) []byte {
// JSON does not permit NaN or Infinity. A typical JSON encoder would fail // JSON does not permit NaN or Infinity. A typical JSON encoder would fail
// with an error, but a logging library wants the data to get through so we // with an error, but a logging library wants the data to get through so we
// make a tradeoff and store those types as string. // make a tradeoff and store those types as string.
@ -311,26 +311,47 @@ func appendFloat(dst []byte, val float64, bitSize int) []byte {
case math.IsInf(val, -1): case math.IsInf(val, -1):
return append(dst, `"-Inf"`...) return append(dst, `"-Inf"`...)
} }
return strconv.AppendFloat(dst, val, 'f', -1, bitSize) // convert as if by es6 number to string conversion
// see also https://cs.opensource.google/go/go/+/refs/tags/go1.20.3:src/encoding/json/encode.go;l=573
strFmt := byte('f')
// If precision is set to a value other than -1, we always just format the float using that precision.
if precision == -1 {
// Use float32 comparisons for underlying float32 value to get precise cutoffs right.
if abs := math.Abs(val); abs != 0 {
if bitSize == 64 && (abs < 1e-6 || abs >= 1e21) || bitSize == 32 && (float32(abs) < 1e-6 || float32(abs) >= 1e21) {
strFmt = 'e'
}
}
}
dst = strconv.AppendFloat(dst, val, strFmt, precision, bitSize)
if strFmt == 'e' {
// Clean up e-09 to e-9
n := len(dst)
if n >= 4 && dst[n-4] == 'e' && dst[n-3] == '-' && dst[n-2] == '0' {
dst[n-2] = dst[n-1]
dst = dst[:n-1]
}
}
return dst
} }
// AppendFloat32 converts the input float32 to a string and // AppendFloat32 converts the input float32 to a string and
// appends the encoded string to the input byte slice. // appends the encoded string to the input byte slice.
func (Encoder) AppendFloat32(dst []byte, val float32) []byte { func (Encoder) AppendFloat32(dst []byte, val float32, precision int) []byte {
return appendFloat(dst, float64(val), 32) return appendFloat(dst, float64(val), 32, precision)
} }
// AppendFloats32 encodes the input float32s to json and // AppendFloats32 encodes the input float32s to json and
// appends the encoded string list to the input byte slice. // appends the encoded string list to the input byte slice.
func (Encoder) AppendFloats32(dst []byte, vals []float32) []byte { func (Encoder) AppendFloats32(dst []byte, vals []float32, precision int) []byte {
if len(vals) == 0 { if len(vals) == 0 {
return append(dst, '[', ']') return append(dst, '[', ']')
} }
dst = append(dst, '[') dst = append(dst, '[')
dst = appendFloat(dst, float64(vals[0]), 32) dst = appendFloat(dst, float64(vals[0]), 32, precision)
if len(vals) > 1 { if len(vals) > 1 {
for _, val := range vals[1:] { for _, val := range vals[1:] {
dst = appendFloat(append(dst, ','), float64(val), 32) dst = appendFloat(append(dst, ','), float64(val), 32, precision)
} }
} }
dst = append(dst, ']') dst = append(dst, ']')
@ -339,21 +360,21 @@ func (Encoder) AppendFloats32(dst []byte, vals []float32) []byte {
// AppendFloat64 converts the input float64 to a string and // AppendFloat64 converts the input float64 to a string and
// appends the encoded string to the input byte slice. // appends the encoded string to the input byte slice.
func (Encoder) AppendFloat64(dst []byte, val float64) []byte { func (Encoder) AppendFloat64(dst []byte, val float64, precision int) []byte {
return appendFloat(dst, val, 64) return appendFloat(dst, val, 64, precision)
} }
// AppendFloats64 encodes the input float64s to json and // AppendFloats64 encodes the input float64s to json and
// appends the encoded string list to the input byte slice. // appends the encoded string list to the input byte slice.
func (Encoder) AppendFloats64(dst []byte, vals []float64) []byte { func (Encoder) AppendFloats64(dst []byte, vals []float64, precision int) []byte {
if len(vals) == 0 { if len(vals) == 0 {
return append(dst, '[', ']') return append(dst, '[', ']')
} }
dst = append(dst, '[') dst = append(dst, '[')
dst = appendFloat(dst, vals[0], 64) dst = appendFloat(dst, vals[0], 64, precision)
if len(vals) > 1 { if len(vals) > 1 {
for _, val := range vals[1:] { for _, val := range vals[1:] {
dst = appendFloat(append(dst, ','), val, 64) dst = appendFloat(append(dst, ','), val, 64, precision)
} }
} }
dst = append(dst, ']') dst = append(dst, ']')

34
vendor/github.com/rs/zerolog/log.go generated vendored
View file

@ -24,7 +24,7 @@
// //
// Sub-loggers let you chain loggers with additional context: // Sub-loggers let you chain loggers with additional context:
// //
// sublogger := log.With().Str("component": "foo").Logger() // sublogger := log.With().Str("component", "foo").Logger()
// sublogger.Info().Msg("hello world") // sublogger.Info().Msg("hello world")
// // Output: {"time":1494567715,"level":"info","message":"hello world","component":"foo"} // // Output: {"time":1494567715,"level":"info","message":"hello world","component":"foo"}
// //
@ -118,7 +118,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"os" "os"
"strconv" "strconv"
"strings" "strings"
@ -246,7 +245,7 @@ type Logger struct {
// you may consider using sync wrapper. // you may consider using sync wrapper.
func New(w io.Writer) Logger { func New(w io.Writer) Logger {
if w == nil { if w == nil {
w = ioutil.Discard w = io.Discard
} }
lw, ok := w.(LevelWriter) lw, ok := w.(LevelWriter)
if !ok { if !ok {
@ -326,10 +325,13 @@ func (l Logger) Sample(s Sampler) Logger {
} }
// Hook returns a logger with the h Hook. // Hook returns a logger with the h Hook.
func (l Logger) Hook(h Hook) Logger { func (l Logger) Hook(hooks ...Hook) Logger {
newHooks := make([]Hook, len(l.hooks), len(l.hooks)+1) if len(hooks) == 0 {
return l
}
newHooks := make([]Hook, len(l.hooks), len(l.hooks)+len(hooks))
copy(newHooks, l.hooks) copy(newHooks, l.hooks)
l.hooks = append(newHooks, h) l.hooks = append(newHooks, hooks...)
return l return l
} }
@ -385,7 +387,14 @@ func (l *Logger) Err(err error) *Event {
// //
// You must call Msg on the returned event in order to send the event. // You must call Msg on the returned event in order to send the event.
func (l *Logger) Fatal() *Event { func (l *Logger) Fatal() *Event {
return l.newEvent(FatalLevel, func(msg string) { os.Exit(1) }) return l.newEvent(FatalLevel, func(msg string) {
if closer, ok := l.w.(io.Closer); ok {
// Close the writer to flush any buffered message. Otherwise the message
// will be lost as os.Exit() terminates the program immediately.
closer.Close()
}
os.Exit(1)
})
} }
// Panic starts a new message with panic level. The panic() function // Panic starts a new message with panic level. The panic() function
@ -450,6 +459,14 @@ func (l *Logger) Printf(format string, v ...interface{}) {
} }
} }
// Println sends a log event using debug level and no extra field.
// Arguments are handled in the manner of fmt.Println.
func (l *Logger) Println(v ...interface{}) {
if e := l.Debug(); e.Enabled() {
e.CallerSkipFrame(1).Msg(fmt.Sprintln(v...))
}
}
// Write implements the io.Writer interface. This is useful to set as a writer // Write implements the io.Writer interface. This is useful to set as a writer
// for the standard library log. // for the standard library log.
func (l Logger) Write(p []byte) (n int, err error) { func (l Logger) Write(p []byte) (n int, err error) {
@ -488,6 +505,9 @@ func (l *Logger) newEvent(level Level, done func(string)) *Event {
// should returns true if the log event should be logged. // should returns true if the log event should be logged.
func (l *Logger) should(lvl Level) bool { func (l *Logger) should(lvl Level) bool {
if l.w == nil {
return false
}
if lvl < l.level || lvl < GlobalLevel() { if lvl < l.level || lvl < GlobalLevel() {
return false return false
} }

Binary file not shown.

Before

Width:  |  Height:  |  Size: 82 KiB

After

Width:  |  Height:  |  Size: 116 KiB

View file

@ -84,7 +84,7 @@ func (s *BurstSampler) Sample(lvl Level) bool {
} }
func (s *BurstSampler) inc() uint32 { func (s *BurstSampler) inc() uint32 {
now := time.Now().UnixNano() now := TimestampFunc().UnixNano()
resetAt := atomic.LoadInt64(&s.resetAt) resetAt := atomic.LoadInt64(&s.resetAt)
var c uint32 var c uint32
if now > resetAt { if now > resetAt {

View file

@ -78,3 +78,12 @@ func (sw syslogWriter) WriteLevel(level Level, p []byte) (n int, err error) {
n = len(p) n = len(p)
return return
} }
// Call the underlying writer's Close method if it is an io.Closer. Otherwise
// does nothing.
func (sw syslogWriter) Close() error {
if c, ok := sw.w.(io.Closer); ok {
return c.Close()
}
return nil
}

View file

@ -27,6 +27,15 @@ func (lw LevelWriterAdapter) WriteLevel(l Level, p []byte) (n int, err error) {
return lw.Write(p) return lw.Write(p)
} }
// Call the underlying writer's Close method if it is an io.Closer. Otherwise
// does nothing.
func (lw LevelWriterAdapter) Close() error {
if closer, ok := lw.Writer.(io.Closer); ok {
return closer.Close()
}
return nil
}
type syncWriter struct { type syncWriter struct {
mu sync.Mutex mu sync.Mutex
lw LevelWriter lw LevelWriter
@ -57,6 +66,15 @@ func (s *syncWriter) WriteLevel(l Level, p []byte) (n int, err error) {
return s.lw.WriteLevel(l, p) return s.lw.WriteLevel(l, p)
} }
func (s *syncWriter) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if closer, ok := s.lw.(io.Closer); ok {
return closer.Close()
}
return nil
}
type multiLevelWriter struct { type multiLevelWriter struct {
writers []LevelWriter writers []LevelWriter
} }
@ -89,6 +107,20 @@ func (t multiLevelWriter) WriteLevel(l Level, p []byte) (n int, err error) {
return n, err return n, err
} }
// Calls close on all the underlying writers that are io.Closers. If any of the
// Close methods return an error, the remainder of the closers are not closed
// and the error is returned.
func (t multiLevelWriter) Close() error {
for _, w := range t.writers {
if closer, ok := w.(io.Closer); ok {
if err := closer.Close(); err != nil {
return err
}
}
}
return nil
}
// MultiLevelWriter creates a writer that duplicates its writes to all the // MultiLevelWriter creates a writer that duplicates its writes to all the
// provided writers, similar to the Unix tee(1) command. If some writers // provided writers, similar to the Unix tee(1) command. If some writers
// implement LevelWriter, their WriteLevel method will be used instead of Write. // implement LevelWriter, their WriteLevel method will be used instead of Write.
@ -180,3 +212,135 @@ func (w *FilteredLevelWriter) WriteLevel(level Level, p []byte) (int, error) {
} }
return len(p), nil return len(p), nil
} }
var triggerWriterPool = &sync.Pool{
New: func() interface{} {
return bytes.NewBuffer(make([]byte, 0, 1024))
},
}
// TriggerLevelWriter buffers log lines at the ConditionalLevel or below
// until a trigger level (or higher) line is emitted. Log lines with level
// higher than ConditionalLevel are always written out to the destination
// writer. If trigger never happens, buffered log lines are never written out.
//
// It can be used to configure "log level per request".
type TriggerLevelWriter struct {
// Destination writer. If LevelWriter is provided (usually), its WriteLevel is used
// instead of Write.
io.Writer
// ConditionalLevel is the level (and below) at which lines are buffered until
// a trigger level (or higher) line is emitted. Usually this is set to DebugLevel.
ConditionalLevel Level
// TriggerLevel is the lowest level that triggers the sending of the conditional
// level lines. Usually this is set to ErrorLevel.
TriggerLevel Level
buf *bytes.Buffer
triggered bool
mu sync.Mutex
}
func (w *TriggerLevelWriter) WriteLevel(l Level, p []byte) (n int, err error) {
w.mu.Lock()
defer w.mu.Unlock()
// At first trigger level or above log line, we flush the buffer and change the
// trigger state to triggered.
if !w.triggered && l >= w.TriggerLevel {
err := w.trigger()
if err != nil {
return 0, err
}
}
// Unless triggered, we buffer everything at and below ConditionalLevel.
if !w.triggered && l <= w.ConditionalLevel {
if w.buf == nil {
w.buf = triggerWriterPool.Get().(*bytes.Buffer)
}
// We prefix each log line with a byte with the level.
// Hopefully we will never have a level value which equals a newline
// (which could interfere with reconstruction of log lines in the trigger method).
w.buf.WriteByte(byte(l))
w.buf.Write(p)
return len(p), nil
}
// Anything above ConditionalLevel is always passed through.
// Once triggered, everything is passed through.
if lw, ok := w.Writer.(LevelWriter); ok {
return lw.WriteLevel(l, p)
}
return w.Write(p)
}
// trigger expects lock to be held.
func (w *TriggerLevelWriter) trigger() error {
if w.triggered {
return nil
}
w.triggered = true
if w.buf == nil {
return nil
}
p := w.buf.Bytes()
for len(p) > 0 {
// We do not use bufio.Scanner here because we already have full buffer
// in the memory and we do not want extra copying from the buffer to
// scanner's token slice, nor we want to hit scanner's token size limit,
// and we also want to preserve newlines.
i := bytes.IndexByte(p, '\n')
line := p[0 : i+1]
p = p[i+1:]
// We prefixed each log line with a byte with the level.
level := Level(line[0])
line = line[1:]
var err error
if lw, ok := w.Writer.(LevelWriter); ok {
_, err = lw.WriteLevel(level, line)
} else {
_, err = w.Write(line)
}
if err != nil {
return err
}
}
return nil
}
// Trigger forces flushing the buffer and change the trigger state to
// triggered, if the writer has not already been triggered before.
func (w *TriggerLevelWriter) Trigger() error {
w.mu.Lock()
defer w.mu.Unlock()
return w.trigger()
}
// Close closes the writer and returns the buffer to the pool.
func (w *TriggerLevelWriter) Close() error {
w.mu.Lock()
defer w.mu.Unlock()
if w.buf == nil {
return nil
}
// We return the buffer only if it has not grown above the limit.
// This prevents accumulation of large buffers in the pool just
// because occasionally a large buffer might be needed.
if w.buf.Cap() <= TriggerLevelWriterBufferReuseLimit {
w.buf.Reset()
triggerWriterPool.Put(w.buf)
}
w.buf = nil
return nil
}

View file

@ -28,6 +28,8 @@ var (
uint32Type = reflect.TypeOf(uint32(1)) uint32Type = reflect.TypeOf(uint32(1))
uint64Type = reflect.TypeOf(uint64(1)) uint64Type = reflect.TypeOf(uint64(1))
uintptrType = reflect.TypeOf(uintptr(1))
float32Type = reflect.TypeOf(float32(1)) float32Type = reflect.TypeOf(float32(1))
float64Type = reflect.TypeOf(float64(1)) float64Type = reflect.TypeOf(float64(1))
@ -308,11 +310,11 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
case reflect.Struct: case reflect.Struct:
{ {
// All structs enter here. We're not interested in most types. // All structs enter here. We're not interested in most types.
if !canConvert(obj1Value, timeType) { if !obj1Value.CanConvert(timeType) {
break break
} }
// time.Time can compared! // time.Time can be compared!
timeObj1, ok := obj1.(time.Time) timeObj1, ok := obj1.(time.Time)
if !ok { if !ok {
timeObj1 = obj1Value.Convert(timeType).Interface().(time.Time) timeObj1 = obj1Value.Convert(timeType).Interface().(time.Time)
@ -328,7 +330,7 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
case reflect.Slice: case reflect.Slice:
{ {
// We only care about the []byte type. // We only care about the []byte type.
if !canConvert(obj1Value, bytesType) { if !obj1Value.CanConvert(bytesType) {
break break
} }
@ -345,6 +347,26 @@ func compare(obj1, obj2 interface{}, kind reflect.Kind) (CompareType, bool) {
return CompareType(bytes.Compare(bytesObj1, bytesObj2)), true return CompareType(bytes.Compare(bytesObj1, bytesObj2)), true
} }
case reflect.Uintptr:
{
uintptrObj1, ok := obj1.(uintptr)
if !ok {
uintptrObj1 = obj1Value.Convert(uintptrType).Interface().(uintptr)
}
uintptrObj2, ok := obj2.(uintptr)
if !ok {
uintptrObj2 = obj2Value.Convert(uintptrType).Interface().(uintptr)
}
if uintptrObj1 > uintptrObj2 {
return compareGreater, true
}
if uintptrObj1 == uintptrObj2 {
return compareEqual, true
}
if uintptrObj1 < uintptrObj2 {
return compareLess, true
}
}
} }
return compareEqual, false return compareEqual, false

View file

@ -1,16 +0,0 @@
//go:build go1.17
// +build go1.17
// TODO: once support for Go 1.16 is dropped, this file can be
// merged/removed with assertion_compare_go1.17_test.go and
// assertion_compare_legacy.go
package assert
import "reflect"
// Wrapper around reflect.Value.CanConvert, for compatibility
// reasons.
func canConvert(value reflect.Value, to reflect.Type) bool {
return value.CanConvert(to)
}

View file

@ -1,16 +0,0 @@
//go:build !go1.17
// +build !go1.17
// TODO: once support for Go 1.16 is dropped, this file can be
// merged/removed with assertion_compare_go1.17_test.go and
// assertion_compare_can_convert.go
package assert
import "reflect"
// Older versions of Go does not have the reflect.Value.CanConvert
// method.
func canConvert(value reflect.Value, to reflect.Type) bool {
return false
}

View file

@ -1,7 +1,4 @@
/* // Code generated with github.com/stretchr/testify/_codegen; DO NOT EDIT.
* CODE GENERATED AUTOMATICALLY WITH github.com/stretchr/testify/_codegen
* THIS FILE MUST NOT BE EDITED BY HAND
*/
package assert package assert
@ -107,7 +104,7 @@ func EqualExportedValuesf(t TestingT, expected interface{}, actual interface{},
return EqualExportedValues(t, expected, actual, append([]interface{}{msg}, args...)...) return EqualExportedValues(t, expected, actual, append([]interface{}{msg}, args...)...)
} }
// EqualValuesf asserts that two objects are equal or convertable to the same types // EqualValuesf asserts that two objects are equal or convertible to the same types
// and equal. // and equal.
// //
// assert.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted") // assert.EqualValuesf(t, uint32(123), int32(123), "error message %s", "formatted")
@ -616,6 +613,16 @@ func NotErrorIsf(t TestingT, err error, target error, msg string, args ...interf
return NotErrorIs(t, err, target, append([]interface{}{msg}, args...)...) return NotErrorIs(t, err, target, append([]interface{}{msg}, args...)...)
} }
// NotImplementsf asserts that an object does not implement the specified interface.
//
// assert.NotImplementsf(t, (*MyInterface)(nil), new(MyObject), "error message %s", "formatted")
func NotImplementsf(t TestingT, interfaceObject interface{}, object interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
return NotImplements(t, interfaceObject, object, append([]interface{}{msg}, args...)...)
}
// NotNilf asserts that the specified object is not nil. // NotNilf asserts that the specified object is not nil.
// //
// assert.NotNilf(t, err, "error message %s", "formatted") // assert.NotNilf(t, err, "error message %s", "formatted")
@ -660,10 +667,12 @@ func NotSamef(t TestingT, expected interface{}, actual interface{}, msg string,
return NotSame(t, expected, actual, append([]interface{}{msg}, args...)...) return NotSame(t, expected, actual, append([]interface{}{msg}, args...)...)
} }
// NotSubsetf asserts that the specified list(array, slice...) contains not all // NotSubsetf asserts that the specified list(array, slice...) or map does NOT
// elements given in the specified subset(array, slice...). // contain all elements given in the specified subset list(array, slice...) or
// map.
// //
// assert.NotSubsetf(t, [1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]", "error message %s", "formatted") // assert.NotSubsetf(t, [1, 3, 4], [1, 2], "error message %s", "formatted")
// assert.NotSubsetf(t, {"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted")
func NotSubsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool { func NotSubsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
@ -747,10 +756,11 @@ func Samef(t TestingT, expected interface{}, actual interface{}, msg string, arg
return Same(t, expected, actual, append([]interface{}{msg}, args...)...) return Same(t, expected, actual, append([]interface{}{msg}, args...)...)
} }
// Subsetf asserts that the specified list(array, slice...) contains all // Subsetf asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset(array, slice...). // elements given in the specified subset list(array, slice...) or map.
// //
// assert.Subsetf(t, [1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]", "error message %s", "formatted") // assert.Subsetf(t, [1, 2, 3], [1, 2], "error message %s", "formatted")
// assert.Subsetf(t, {"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted")
func Subsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool { func Subsetf(t TestingT, list interface{}, subset interface{}, msg string, args ...interface{}) bool {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()

View file

@ -1,7 +1,4 @@
/* // Code generated with github.com/stretchr/testify/_codegen; DO NOT EDIT.
* CODE GENERATED AUTOMATICALLY WITH github.com/stretchr/testify/_codegen
* THIS FILE MUST NOT BE EDITED BY HAND
*/
package assert package assert
@ -189,7 +186,7 @@ func (a *Assertions) EqualExportedValuesf(expected interface{}, actual interface
return EqualExportedValuesf(a.t, expected, actual, msg, args...) return EqualExportedValuesf(a.t, expected, actual, msg, args...)
} }
// EqualValues asserts that two objects are equal or convertable to the same types // EqualValues asserts that two objects are equal or convertible to the same types
// and equal. // and equal.
// //
// a.EqualValues(uint32(123), int32(123)) // a.EqualValues(uint32(123), int32(123))
@ -200,7 +197,7 @@ func (a *Assertions) EqualValues(expected interface{}, actual interface{}, msgAn
return EqualValues(a.t, expected, actual, msgAndArgs...) return EqualValues(a.t, expected, actual, msgAndArgs...)
} }
// EqualValuesf asserts that two objects are equal or convertable to the same types // EqualValuesf asserts that two objects are equal or convertible to the same types
// and equal. // and equal.
// //
// a.EqualValuesf(uint32(123), int32(123), "error message %s", "formatted") // a.EqualValuesf(uint32(123), int32(123), "error message %s", "formatted")
@ -1221,6 +1218,26 @@ func (a *Assertions) NotErrorIsf(err error, target error, msg string, args ...in
return NotErrorIsf(a.t, err, target, msg, args...) return NotErrorIsf(a.t, err, target, msg, args...)
} }
// NotImplements asserts that an object does not implement the specified interface.
//
// a.NotImplements((*MyInterface)(nil), new(MyObject))
func (a *Assertions) NotImplements(interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotImplements(a.t, interfaceObject, object, msgAndArgs...)
}
// NotImplementsf asserts that an object does not implement the specified interface.
//
// a.NotImplementsf((*MyInterface)(nil), new(MyObject), "error message %s", "formatted")
func (a *Assertions) NotImplementsf(interfaceObject interface{}, object interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok {
h.Helper()
}
return NotImplementsf(a.t, interfaceObject, object, msg, args...)
}
// NotNil asserts that the specified object is not nil. // NotNil asserts that the specified object is not nil.
// //
// a.NotNil(err) // a.NotNil(err)
@ -1309,10 +1326,12 @@ func (a *Assertions) NotSamef(expected interface{}, actual interface{}, msg stri
return NotSamef(a.t, expected, actual, msg, args...) return NotSamef(a.t, expected, actual, msg, args...)
} }
// NotSubset asserts that the specified list(array, slice...) contains not all // NotSubset asserts that the specified list(array, slice...) or map does NOT
// elements given in the specified subset(array, slice...). // contain all elements given in the specified subset list(array, slice...) or
// map.
// //
// a.NotSubset([1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]") // a.NotSubset([1, 3, 4], [1, 2])
// a.NotSubset({"x": 1, "y": 2}, {"z": 3})
func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -1320,10 +1339,12 @@ func (a *Assertions) NotSubset(list interface{}, subset interface{}, msgAndArgs
return NotSubset(a.t, list, subset, msgAndArgs...) return NotSubset(a.t, list, subset, msgAndArgs...)
} }
// NotSubsetf asserts that the specified list(array, slice...) contains not all // NotSubsetf asserts that the specified list(array, slice...) or map does NOT
// elements given in the specified subset(array, slice...). // contain all elements given in the specified subset list(array, slice...) or
// map.
// //
// a.NotSubsetf([1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]", "error message %s", "formatted") // a.NotSubsetf([1, 3, 4], [1, 2], "error message %s", "formatted")
// a.NotSubsetf({"x": 1, "y": 2}, {"z": 3}, "error message %s", "formatted")
func (a *Assertions) NotSubsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool { func (a *Assertions) NotSubsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -1483,10 +1504,11 @@ func (a *Assertions) Samef(expected interface{}, actual interface{}, msg string,
return Samef(a.t, expected, actual, msg, args...) return Samef(a.t, expected, actual, msg, args...)
} }
// Subset asserts that the specified list(array, slice...) contains all // Subset asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset(array, slice...). // elements given in the specified subset list(array, slice...) or map.
// //
// a.Subset([1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]") // a.Subset([1, 2, 3], [1, 2])
// a.Subset({"x": 1, "y": 2}, {"x": 1})
func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool { func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...interface{}) bool {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()
@ -1494,10 +1516,11 @@ func (a *Assertions) Subset(list interface{}, subset interface{}, msgAndArgs ...
return Subset(a.t, list, subset, msgAndArgs...) return Subset(a.t, list, subset, msgAndArgs...)
} }
// Subsetf asserts that the specified list(array, slice...) contains all // Subsetf asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset(array, slice...). // elements given in the specified subset list(array, slice...) or map.
// //
// a.Subsetf([1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]", "error message %s", "formatted") // a.Subsetf([1, 2, 3], [1, 2], "error message %s", "formatted")
// a.Subsetf({"x": 1, "y": 2}, {"x": 1}, "error message %s", "formatted")
func (a *Assertions) Subsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool { func (a *Assertions) Subsetf(list interface{}, subset interface{}, msg string, args ...interface{}) bool {
if h, ok := a.t.(tHelper); ok { if h, ok := a.t.(tHelper); ok {
h.Helper() h.Helper()

View file

@ -19,7 +19,7 @@ import (
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/pmezard/go-difflib/difflib" "github.com/pmezard/go-difflib/difflib"
yaml "gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
//go:generate sh -c "cd ../_codegen && go build && cd - && ../_codegen/_codegen -output-package=assert -template=assertion_format.go.tmpl" //go:generate sh -c "cd ../_codegen && go build && cd - && ../_codegen/_codegen -output-package=assert -template=assertion_format.go.tmpl"
@ -110,7 +110,12 @@ func copyExportedFields(expected interface{}) interface{} {
return result.Interface() return result.Interface()
case reflect.Array, reflect.Slice: case reflect.Array, reflect.Slice:
result := reflect.MakeSlice(expectedType, expectedValue.Len(), expectedValue.Len()) var result reflect.Value
if expectedKind == reflect.Array {
result = reflect.New(reflect.ArrayOf(expectedValue.Len(), expectedType.Elem())).Elem()
} else {
result = reflect.MakeSlice(expectedType, expectedValue.Len(), expectedValue.Len())
}
for i := 0; i < expectedValue.Len(); i++ { for i := 0; i < expectedValue.Len(); i++ {
index := expectedValue.Index(i) index := expectedValue.Index(i)
if isNil(index) { if isNil(index) {
@ -140,6 +145,8 @@ func copyExportedFields(expected interface{}) interface{} {
// structures. // structures.
// //
// This function does no assertion of any kind. // This function does no assertion of any kind.
//
// Deprecated: Use [EqualExportedValues] instead.
func ObjectsExportedFieldsAreEqual(expected, actual interface{}) bool { func ObjectsExportedFieldsAreEqual(expected, actual interface{}) bool {
expectedCleaned := copyExportedFields(expected) expectedCleaned := copyExportedFields(expected)
actualCleaned := copyExportedFields(actual) actualCleaned := copyExportedFields(actual)
@ -153,19 +160,42 @@ func ObjectsAreEqualValues(expected, actual interface{}) bool {
return true return true
} }
actualType := reflect.TypeOf(actual)
if actualType == nil {
return false
}
expectedValue := reflect.ValueOf(expected) expectedValue := reflect.ValueOf(expected)
if expectedValue.IsValid() && expectedValue.Type().ConvertibleTo(actualType) { actualValue := reflect.ValueOf(actual)
// Attempt comparison after type conversion if !expectedValue.IsValid() || !actualValue.IsValid() {
return reflect.DeepEqual(expectedValue.Convert(actualType).Interface(), actual) return false
} }
expectedType := expectedValue.Type()
actualType := actualValue.Type()
if !expectedType.ConvertibleTo(actualType) {
return false return false
} }
if !isNumericType(expectedType) || !isNumericType(actualType) {
// Attempt comparison after type conversion
return reflect.DeepEqual(
expectedValue.Convert(actualType).Interface(), actual,
)
}
// If BOTH values are numeric, there are chances of false positives due
// to overflow or underflow. So, we need to make sure to always convert
// the smaller type to a larger type before comparing.
if expectedType.Size() >= actualType.Size() {
return actualValue.Convert(expectedType).Interface() == expected
}
return expectedValue.Convert(actualType).Interface() == actual
}
// isNumericType returns true if the type is one of:
// int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64,
// float32, float64, complex64, complex128
func isNumericType(t reflect.Type) bool {
return t.Kind() >= reflect.Int && t.Kind() <= reflect.Complex128
}
/* CallerInfo is necessary because the assert functions use the testing object /* CallerInfo is necessary because the assert functions use the testing object
internally, causing it to print the file:line of the assert method, rather than where internally, causing it to print the file:line of the assert method, rather than where
the problem actually occurred in calling code.*/ the problem actually occurred in calling code.*/
@ -266,7 +296,7 @@ func messageFromMsgAndArgs(msgAndArgs ...interface{}) string {
// Aligns the provided message so that all lines after the first line start at the same location as the first line. // Aligns the provided message so that all lines after the first line start at the same location as the first line.
// Assumes that the first line starts at the correct location (after carriage return, tab, label, spacer and tab). // Assumes that the first line starts at the correct location (after carriage return, tab, label, spacer and tab).
// The longestLabelLen parameter specifies the length of the longest label in the output (required becaues this is the // The longestLabelLen parameter specifies the length of the longest label in the output (required because this is the
// basis on which the alignment occurs). // basis on which the alignment occurs).
func indentMessageLines(message string, longestLabelLen int) string { func indentMessageLines(message string, longestLabelLen int) string {
outBuf := new(bytes.Buffer) outBuf := new(bytes.Buffer)
@ -382,6 +412,25 @@ func Implements(t TestingT, interfaceObject interface{}, object interface{}, msg
return true return true
} }
// NotImplements asserts that an object does not implement the specified interface.
//
// assert.NotImplements(t, (*MyInterface)(nil), new(MyObject))
func NotImplements(t TestingT, interfaceObject interface{}, object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok {
h.Helper()
}
interfaceType := reflect.TypeOf(interfaceObject).Elem()
if object == nil {
return Fail(t, fmt.Sprintf("Cannot check if nil does not implement %v", interfaceType), msgAndArgs...)
}
if reflect.TypeOf(object).Implements(interfaceType) {
return Fail(t, fmt.Sprintf("%T implements %v", object, interfaceType), msgAndArgs...)
}
return true
}
// IsType asserts that the specified objects are of the same type. // IsType asserts that the specified objects are of the same type.
func IsType(t TestingT, expectedType interface{}, object interface{}, msgAndArgs ...interface{}) bool { func IsType(t TestingT, expectedType interface{}, object interface{}, msgAndArgs ...interface{}) bool {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
@ -496,7 +545,7 @@ func samePointers(first, second interface{}) bool {
// representations appropriate to be presented to the user. // representations appropriate to be presented to the user.
// //
// If the values are not of like type, the returned strings will be prefixed // If the values are not of like type, the returned strings will be prefixed
// with the type name, and the value will be enclosed in parenthesis similar // with the type name, and the value will be enclosed in parentheses similar
// to a type conversion in the Go grammar. // to a type conversion in the Go grammar.
func formatUnequalValues(expected, actual interface{}) (e string, a string) { func formatUnequalValues(expected, actual interface{}) (e string, a string) {
if reflect.TypeOf(expected) != reflect.TypeOf(actual) { if reflect.TypeOf(expected) != reflect.TypeOf(actual) {
@ -523,7 +572,7 @@ func truncatingFormat(data interface{}) string {
return value return value
} }
// EqualValues asserts that two objects are equal or convertable to the same types // EqualValues asserts that two objects are equal or convertible to the same types
// and equal. // and equal.
// //
// assert.EqualValues(t, uint32(123), int32(123)) // assert.EqualValues(t, uint32(123), int32(123))
@ -566,12 +615,19 @@ func EqualExportedValues(t TestingT, expected, actual interface{}, msgAndArgs ..
return Fail(t, fmt.Sprintf("Types expected to match exactly\n\t%v != %v", aType, bType), msgAndArgs...) return Fail(t, fmt.Sprintf("Types expected to match exactly\n\t%v != %v", aType, bType), msgAndArgs...)
} }
if aType.Kind() == reflect.Ptr {
aType = aType.Elem()
}
if bType.Kind() == reflect.Ptr {
bType = bType.Elem()
}
if aType.Kind() != reflect.Struct { if aType.Kind() != reflect.Struct {
return Fail(t, fmt.Sprintf("Types expected to both be struct \n\t%v != %v", aType.Kind(), reflect.Struct), msgAndArgs...) return Fail(t, fmt.Sprintf("Types expected to both be struct or pointer to struct \n\t%v != %v", aType.Kind(), reflect.Struct), msgAndArgs...)
} }
if bType.Kind() != reflect.Struct { if bType.Kind() != reflect.Struct {
return Fail(t, fmt.Sprintf("Types expected to both be struct \n\t%v != %v", bType.Kind(), reflect.Struct), msgAndArgs...) return Fail(t, fmt.Sprintf("Types expected to both be struct or pointer to struct \n\t%v != %v", bType.Kind(), reflect.Struct), msgAndArgs...)
} }
expected = copyExportedFields(expected) expected = copyExportedFields(expected)
@ -620,17 +676,6 @@ func NotNil(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return Fail(t, "Expected value not to be nil.", msgAndArgs...) return Fail(t, "Expected value not to be nil.", msgAndArgs...)
} }
// containsKind checks if a specified kind in the slice of kinds.
func containsKind(kinds []reflect.Kind, kind reflect.Kind) bool {
for i := 0; i < len(kinds); i++ {
if kind == kinds[i] {
return true
}
}
return false
}
// isNil checks if a specified object is nil or not, without Failing. // isNil checks if a specified object is nil or not, without Failing.
func isNil(object interface{}) bool { func isNil(object interface{}) bool {
if object == nil { if object == nil {
@ -638,16 +683,13 @@ func isNil(object interface{}) bool {
} }
value := reflect.ValueOf(object) value := reflect.ValueOf(object)
kind := value.Kind() switch value.Kind() {
isNilableKind := containsKind( case
[]reflect.Kind{
reflect.Chan, reflect.Func, reflect.Chan, reflect.Func,
reflect.Interface, reflect.Map, reflect.Interface, reflect.Map,
reflect.Ptr, reflect.Slice, reflect.UnsafePointer}, reflect.Ptr, reflect.Slice, reflect.UnsafePointer:
kind)
if isNilableKind && value.IsNil() { return value.IsNil()
return true
} }
return false return false
@ -731,16 +773,14 @@ func NotEmpty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
} }
// getLen try to get length of object. // getLen tries to get the length of an object.
// return (false, 0) if impossible. // It returns (0, false) if impossible.
func getLen(x interface{}) (ok bool, length int) { func getLen(x interface{}) (length int, ok bool) {
v := reflect.ValueOf(x) v := reflect.ValueOf(x)
defer func() { defer func() {
if e := recover(); e != nil { ok = recover() == nil
ok = false
}
}() }()
return true, v.Len() return v.Len(), true
} }
// Len asserts that the specified object has specific length. // Len asserts that the specified object has specific length.
@ -751,13 +791,13 @@ func Len(t TestingT, object interface{}, length int, msgAndArgs ...interface{})
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
} }
ok, l := getLen(object) l, ok := getLen(object)
if !ok { if !ok {
return Fail(t, fmt.Sprintf("\"%s\" could not be applied builtin len()", object), msgAndArgs...) return Fail(t, fmt.Sprintf("\"%v\" could not be applied builtin len()", object), msgAndArgs...)
} }
if l != length { if l != length {
return Fail(t, fmt.Sprintf("\"%s\" should have %d item(s), but has %d", object, length, l), msgAndArgs...) return Fail(t, fmt.Sprintf("\"%v\" should have %d item(s), but has %d", object, length, l), msgAndArgs...)
} }
return true return true
} }
@ -919,10 +959,11 @@ func NotContains(t TestingT, s, contains interface{}, msgAndArgs ...interface{})
} }
// Subset asserts that the specified list(array, slice...) contains all // Subset asserts that the specified list(array, slice...) or map contains all
// elements given in the specified subset(array, slice...). // elements given in the specified subset list(array, slice...) or map.
// //
// assert.Subset(t, [1, 2, 3], [1, 2], "But [1, 2, 3] does contain [1, 2]") // assert.Subset(t, [1, 2, 3], [1, 2])
// assert.Subset(t, {"x": 1, "y": 2}, {"x": 1})
func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) { func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
@ -975,10 +1016,12 @@ func Subset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok
return true return true
} }
// NotSubset asserts that the specified list(array, slice...) contains not all // NotSubset asserts that the specified list(array, slice...) or map does NOT
// elements given in the specified subset(array, slice...). // contain all elements given in the specified subset list(array, slice...) or
// map.
// //
// assert.NotSubset(t, [1, 3, 4], [1, 2], "But [1, 3, 4] does not contain [1, 2]") // assert.NotSubset(t, [1, 3, 4], [1, 2])
// assert.NotSubset(t, {"x": 1, "y": 2}, {"z": 3})
func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) { func NotSubset(t TestingT, list, subset interface{}, msgAndArgs ...interface{}) (ok bool) {
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
@ -1439,7 +1482,7 @@ func InEpsilon(t TestingT, expected, actual interface{}, epsilon float64, msgAnd
h.Helper() h.Helper()
} }
if math.IsNaN(epsilon) { if math.IsNaN(epsilon) {
return Fail(t, "epsilon must not be NaN") return Fail(t, "epsilon must not be NaN", msgAndArgs...)
} }
actualEpsilon, err := calcRelativeError(expected, actual) actualEpsilon, err := calcRelativeError(expected, actual)
if err != nil { if err != nil {
@ -1458,19 +1501,26 @@ func InEpsilonSlice(t TestingT, expected, actual interface{}, epsilon float64, m
if h, ok := t.(tHelper); ok { if h, ok := t.(tHelper); ok {
h.Helper() h.Helper()
} }
if expected == nil || actual == nil ||
reflect.TypeOf(actual).Kind() != reflect.Slice || if expected == nil || actual == nil {
reflect.TypeOf(expected).Kind() != reflect.Slice {
return Fail(t, "Parameters must be slice", msgAndArgs...) return Fail(t, "Parameters must be slice", msgAndArgs...)
} }
actualSlice := reflect.ValueOf(actual)
expectedSlice := reflect.ValueOf(expected) expectedSlice := reflect.ValueOf(expected)
actualSlice := reflect.ValueOf(actual)
for i := 0; i < actualSlice.Len(); i++ { if expectedSlice.Type().Kind() != reflect.Slice {
result := InEpsilon(t, actualSlice.Index(i).Interface(), expectedSlice.Index(i).Interface(), epsilon) return Fail(t, "Expected value must be slice", msgAndArgs...)
if !result { }
return result
expectedLen := expectedSlice.Len()
if !IsType(t, expected, actual) || !Len(t, actual, expectedLen) {
return false
}
for i := 0; i < expectedLen; i++ {
if !InEpsilon(t, expectedSlice.Index(i).Interface(), actualSlice.Index(i).Interface(), epsilon, "at index %d", i) {
return false
} }
} }
@ -1870,23 +1920,18 @@ func (c *CollectT) Errorf(format string, args ...interface{}) {
} }
// FailNow panics. // FailNow panics.
func (c *CollectT) FailNow() { func (*CollectT) FailNow() {
panic("Assertion failed") panic("Assertion failed")
} }
// Reset clears the collected errors. // Deprecated: That was a method for internal usage that should not have been published. Now just panics.
func (c *CollectT) Reset() { func (*CollectT) Reset() {
c.errors = nil panic("Reset() is deprecated")
} }
// Copy copies the collected errors to the supplied t. // Deprecated: That was a method for internal usage that should not have been published. Now just panics.
func (c *CollectT) Copy(t TestingT) { func (*CollectT) Copy(TestingT) {
if tt, ok := t.(tHelper); ok { panic("Copy() is deprecated")
tt.Helper()
}
for _, err := range c.errors {
t.Errorf("%v", err)
}
} }
// EventuallyWithT asserts that given condition will be met in waitFor time, // EventuallyWithT asserts that given condition will be met in waitFor time,
@ -1912,8 +1957,8 @@ func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time
h.Helper() h.Helper()
} }
collect := new(CollectT) var lastFinishedTickErrs []error
ch := make(chan bool, 1) ch := make(chan []error, 1)
timer := time.NewTimer(waitFor) timer := time.NewTimer(waitFor)
defer timer.Stop() defer timer.Stop()
@ -1924,19 +1969,25 @@ func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time
for tick := ticker.C; ; { for tick := ticker.C; ; {
select { select {
case <-timer.C: case <-timer.C:
collect.Copy(t) for _, err := range lastFinishedTickErrs {
t.Errorf("%v", err)
}
return Fail(t, "Condition never satisfied", msgAndArgs...) return Fail(t, "Condition never satisfied", msgAndArgs...)
case <-tick: case <-tick:
tick = nil tick = nil
collect.Reset()
go func() { go func() {
condition(collect) collect := new(CollectT)
ch <- len(collect.errors) == 0 defer func() {
ch <- collect.errors
}() }()
case v := <-ch: condition(collect)
if v { }()
case errs := <-ch:
if len(errs) == 0 {
return true return true
} }
// Keep the errors from the last ended condition, so that they can be copied to t if timeout is reached.
lastFinishedTickErrs = errs
tick = ticker.C tick = ticker.C
} }
} }

View file

@ -12,7 +12,7 @@ import (
// an error if building a new request fails. // an error if building a new request fails.
func httpCode(handler http.HandlerFunc, method, url string, values url.Values) (int, error) { func httpCode(handler http.HandlerFunc, method, url string, values url.Values) (int, error) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
req, err := http.NewRequest(method, url, nil) req, err := http.NewRequest(method, url, http.NoBody)
if err != nil { if err != nil {
return -1, err return -1, err
} }
@ -32,12 +32,12 @@ func HTTPSuccess(t TestingT, handler http.HandlerFunc, method, url string, value
} }
code, err := httpCode(handler, method, url, values) code, err := httpCode(handler, method, url, values)
if err != nil { if err != nil {
Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err)) Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err), msgAndArgs...)
} }
isSuccessCode := code >= http.StatusOK && code <= http.StatusPartialContent isSuccessCode := code >= http.StatusOK && code <= http.StatusPartialContent
if !isSuccessCode { if !isSuccessCode {
Fail(t, fmt.Sprintf("Expected HTTP success status code for %q but received %d", url+"?"+values.Encode(), code)) Fail(t, fmt.Sprintf("Expected HTTP success status code for %q but received %d", url+"?"+values.Encode(), code), msgAndArgs...)
} }
return isSuccessCode return isSuccessCode
@ -54,12 +54,12 @@ func HTTPRedirect(t TestingT, handler http.HandlerFunc, method, url string, valu
} }
code, err := httpCode(handler, method, url, values) code, err := httpCode(handler, method, url, values)
if err != nil { if err != nil {
Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err)) Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err), msgAndArgs...)
} }
isRedirectCode := code >= http.StatusMultipleChoices && code <= http.StatusTemporaryRedirect isRedirectCode := code >= http.StatusMultipleChoices && code <= http.StatusTemporaryRedirect
if !isRedirectCode { if !isRedirectCode {
Fail(t, fmt.Sprintf("Expected HTTP redirect status code for %q but received %d", url+"?"+values.Encode(), code)) Fail(t, fmt.Sprintf("Expected HTTP redirect status code for %q but received %d", url+"?"+values.Encode(), code), msgAndArgs...)
} }
return isRedirectCode return isRedirectCode
@ -76,12 +76,12 @@ func HTTPError(t TestingT, handler http.HandlerFunc, method, url string, values
} }
code, err := httpCode(handler, method, url, values) code, err := httpCode(handler, method, url, values)
if err != nil { if err != nil {
Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err)) Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err), msgAndArgs...)
} }
isErrorCode := code >= http.StatusBadRequest isErrorCode := code >= http.StatusBadRequest
if !isErrorCode { if !isErrorCode {
Fail(t, fmt.Sprintf("Expected HTTP error status code for %q but received %d", url+"?"+values.Encode(), code)) Fail(t, fmt.Sprintf("Expected HTTP error status code for %q but received %d", url+"?"+values.Encode(), code), msgAndArgs...)
} }
return isErrorCode return isErrorCode
@ -98,12 +98,12 @@ func HTTPStatusCode(t TestingT, handler http.HandlerFunc, method, url string, va
} }
code, err := httpCode(handler, method, url, values) code, err := httpCode(handler, method, url, values)
if err != nil { if err != nil {
Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err)) Fail(t, fmt.Sprintf("Failed to build test request, got error: %s", err), msgAndArgs...)
} }
successful := code == statuscode successful := code == statuscode
if !successful { if !successful {
Fail(t, fmt.Sprintf("Expected HTTP status code %d for %q but received %d", statuscode, url+"?"+values.Encode(), code)) Fail(t, fmt.Sprintf("Expected HTTP status code %d for %q but received %d", statuscode, url+"?"+values.Encode(), code), msgAndArgs...)
} }
return successful return successful
@ -113,7 +113,10 @@ func HTTPStatusCode(t TestingT, handler http.HandlerFunc, method, url string, va
// empty string if building a new request fails. // empty string if building a new request fails.
func HTTPBody(handler http.HandlerFunc, method, url string, values url.Values) string { func HTTPBody(handler http.HandlerFunc, method, url string, values url.Values) string {
w := httptest.NewRecorder() w := httptest.NewRecorder()
req, err := http.NewRequest(method, url+"?"+values.Encode(), nil) if len(values) > 0 {
url += "?" + values.Encode()
}
req, err := http.NewRequest(method, url, http.NoBody)
if err != nil { if err != nil {
return "" return ""
} }
@ -135,7 +138,7 @@ func HTTPBodyContains(t TestingT, handler http.HandlerFunc, method, url string,
contains := strings.Contains(body, fmt.Sprint(str)) contains := strings.Contains(body, fmt.Sprint(str))
if !contains { if !contains {
Fail(t, fmt.Sprintf("Expected response body for \"%s\" to contain \"%s\" but found \"%s\"", url+"?"+values.Encode(), str, body)) Fail(t, fmt.Sprintf("Expected response body for \"%s\" to contain \"%s\" but found \"%s\"", url+"?"+values.Encode(), str, body), msgAndArgs...)
} }
return contains return contains
@ -155,7 +158,7 @@ func HTTPBodyNotContains(t TestingT, handler http.HandlerFunc, method, url strin
contains := strings.Contains(body, fmt.Sprint(str)) contains := strings.Contains(body, fmt.Sprint(str))
if contains { if contains {
Fail(t, fmt.Sprintf("Expected response body for \"%s\" to NOT contain \"%s\" but found \"%s\"", url+"?"+values.Encode(), str, body)) Fail(t, fmt.Sprintf("Expected response body for \"%s\" to NOT contain \"%s\" but found \"%s\"", url+"?"+values.Encode(), str, body), msgAndArgs...)
} }
return !contains return !contains

View file

@ -4,6 +4,7 @@
.*envrc .*envrc
.envrc .envrc
.idea .idea
# goimports is installed here if not available
/.local/ /.local/
/site/ /site/
coverage.txt coverage.txt

4
vendor/github.com/urfave/cli/v2/.golangci.yaml generated vendored Normal file
View file

@ -0,0 +1,4 @@
# https://golangci-lint.run/usage/configuration/
linters:
enable:
- misspell

View file

@ -23,8 +23,8 @@ var (
fmt.Sprintf("See %s", appActionDeprecationURL), 2) fmt.Sprintf("See %s", appActionDeprecationURL), 2)
ignoreFlagPrefix = "test." // this is to ignore test flags when adding flags from other packages ignoreFlagPrefix = "test." // this is to ignore test flags when adding flags from other packages
SuggestFlag SuggestFlagFunc = suggestFlag SuggestFlag SuggestFlagFunc = nil // initialized in suggestions.go unless built with urfave_cli_no_suggest
SuggestCommand SuggestCommandFunc = suggestCommand SuggestCommand SuggestCommandFunc = nil // initialized in suggestions.go unless built with urfave_cli_no_suggest
SuggestDidYouMeanTemplate string = suggestDidYouMeanTemplate SuggestDidYouMeanTemplate string = suggestDidYouMeanTemplate
) )
@ -39,6 +39,8 @@ type App struct {
Usage string Usage string
// Text to override the USAGE section of help // Text to override the USAGE section of help
UsageText string UsageText string
// Whether this command supports arguments
Args bool
// Description of the program argument format. // Description of the program argument format.
ArgsUsage string ArgsUsage string
// Version of the program // Version of the program
@ -227,8 +229,6 @@ func (a *App) Setup() {
a.separator.disabled = true a.separator.disabled = true
} }
var newCommands []*Command
for _, c := range a.Commands { for _, c := range a.Commands {
cname := c.Name cname := c.Name
if c.HelpName != "" { if c.HelpName != "" {
@ -237,9 +237,7 @@ func (a *App) Setup() {
c.separator = a.separator c.separator = a.separator
c.HelpName = fmt.Sprintf("%s %s", a.HelpName, cname) c.HelpName = fmt.Sprintf("%s %s", a.HelpName, cname)
c.flagCategories = newFlagCategoriesFromFlags(c.Flags) c.flagCategories = newFlagCategoriesFromFlags(c.Flags)
newCommands = append(newCommands, c)
} }
a.Commands = newCommands
if a.Command(helpCommand.Name) == nil && !a.HideHelp { if a.Command(helpCommand.Name) == nil && !a.HideHelp {
if !a.HideHelpCommand { if !a.HideHelpCommand {
@ -329,6 +327,9 @@ func (a *App) RunContext(ctx context.Context, arguments []string) (err error) {
a.rootCommand = a.newRootCommand() a.rootCommand = a.newRootCommand()
cCtx.Command = a.rootCommand cCtx.Command = a.rootCommand
if err := checkDuplicatedCmds(a.rootCommand); err != nil {
return err
}
return a.rootCommand.Run(cCtx, arguments...) return a.rootCommand.Run(cCtx, arguments...)
} }
@ -363,6 +364,9 @@ func (a *App) suggestFlagFromError(err error, command string) (string, error) {
hideHelp = hideHelp || cmd.HideHelp hideHelp = hideHelp || cmd.HideHelp
} }
if SuggestFlag == nil {
return "", err
}
suggestion := SuggestFlag(flags, flag, hideHelp) suggestion := SuggestFlag(flags, flag, hideHelp)
if len(suggestion) == 0 { if len(suggestion) == 0 {
return "", err return "", err
@ -455,30 +459,6 @@ func (a *App) handleExitCoder(cCtx *Context, err error) {
} }
} }
func (a *App) commandNames() []string {
var cmdNames []string
for _, cmd := range a.Commands {
cmdNames = append(cmdNames, cmd.Names()...)
}
return cmdNames
}
func (a *App) validCommandName(checkCmdName string) bool {
valid := false
allCommandNames := a.commandNames()
for _, cmdName := range allCommandNames {
if checkCmdName == cmdName {
valid = true
break
}
}
return valid
}
func (a *App) argsWithDefaultCommand(oldArgs Args) Args { func (a *App) argsWithDefaultCommand(oldArgs Args) Args {
if a.DefaultCommand != "" { if a.DefaultCommand != "" {
rawArgs := append([]string{a.DefaultCommand}, oldArgs.Slice()...) rawArgs := append([]string{a.DefaultCommand}, oldArgs.Slice()...)

View file

@ -104,17 +104,17 @@ func newFlagCategoriesFromFlags(fs []Flag) FlagCategories {
var categorized bool var categorized bool
for _, fl := range fs { for _, fl := range fs {
if cf, ok := fl.(CategorizableFlag); ok { if cf, ok := fl.(CategorizableFlag); ok {
if cat := cf.GetCategory(); cat != "" { if cat := cf.GetCategory(); cat != "" && cf.IsVisible() {
fc.AddFlag(cat, cf) fc.AddFlag(cat, cf)
categorized = true categorized = true
} }
} }
} }
if categorized == true { if categorized {
for _, fl := range fs { for _, fl := range fs {
if cf, ok := fl.(CategorizableFlag); ok { if cf, ok := fl.(CategorizableFlag); ok {
if cf.GetCategory() == "" { if cf.GetCategory() == "" && cf.IsVisible() {
fc.AddFlag("", fl) fc.AddFlag("", fl)
} }
} }

View file

@ -20,6 +20,8 @@ type Command struct {
UsageText string UsageText string
// A longer explanation of how the command works // A longer explanation of how the command works
Description string Description string
// Whether this command supports arguments
Args bool
// A short description of the arguments of this command // A short description of the arguments of this command
ArgsUsage string ArgsUsage string
// The category the command is part of // The category the command is part of
@ -130,15 +132,12 @@ func (c *Command) setup(ctx *Context) {
} }
sort.Sort(c.categories.(*commandCategories)) sort.Sort(c.categories.(*commandCategories))
var newCmds []*Command
for _, scmd := range c.Subcommands { for _, scmd := range c.Subcommands {
if scmd.HelpName == "" { if scmd.HelpName == "" {
scmd.HelpName = fmt.Sprintf("%s %s", c.HelpName, scmd.Name) scmd.HelpName = fmt.Sprintf("%s %s", c.HelpName, scmd.Name)
} }
scmd.separator = c.separator scmd.separator = c.separator
newCmds = append(newCmds, scmd)
} }
c.Subcommands = newCmds
if c.BashComplete == nil { if c.BashComplete == nil {
c.BashComplete = DefaultCompleteWithFlags(c) c.BashComplete = DefaultCompleteWithFlags(c)
@ -149,6 +148,9 @@ func (c *Command) Run(cCtx *Context, arguments ...string) (err error) {
if !c.isRoot { if !c.isRoot {
c.setup(cCtx) c.setup(cCtx)
if err := checkDuplicatedCmds(c); err != nil {
return err
}
} }
a := args(arguments) a := args(arguments)
@ -404,3 +406,16 @@ func hasCommand(commands []*Command, command *Command) bool {
return false return false
} }
func checkDuplicatedCmds(parent *Command) error {
seen := make(map[string]struct{})
for _, c := range parent.Subcommands {
for _, name := range c.Names() {
if _, exists := seen[name]; exists {
return fmt.Errorf("parent command [%s] has duplicated subcommand name or alias: %s", parent.Name, name)
}
seen[name] = struct{}{}
}
}
return nil
}

View file

@ -144,7 +144,7 @@ func (cCtx *Context) Lineage() []*Context {
return lineage return lineage
} }
// Count returns the num of occurences of this flag // Count returns the num of occurrences of this flag
func (cCtx *Context) Count(name string) int { func (cCtx *Context) Count(name string) int {
if fs := cCtx.lookupFlagSet(name); fs != nil { if fs := cCtx.lookupFlagSet(name); fs != nil {
if cf, ok := fs.Lookup(name).Value.(Countable); ok { if cf, ok := fs.Lookup(name).Value.(Countable); ok {

View file

@ -84,8 +84,11 @@ func (f *BoolFlag) GetDefaultText() string {
if f.DefaultText != "" { if f.DefaultText != "" {
return f.DefaultText return f.DefaultText
} }
if f.defaultValueSet {
return fmt.Sprintf("%v", f.defaultValue) return fmt.Sprintf("%v", f.defaultValue)
} }
return fmt.Sprintf("%v", f.Value)
}
// GetEnvVars returns the env vars for this flag // GetEnvVars returns the env vars for this flag
func (f *BoolFlag) GetEnvVars() []string { func (f *BoolFlag) GetEnvVars() []string {
@ -105,6 +108,7 @@ func (f *BoolFlag) RunAction(c *Context) error {
func (f *BoolFlag) Apply(set *flag.FlagSet) error { func (f *BoolFlag) Apply(set *flag.FlagSet) error {
// set default value so that environment wont be able to overwrite it // set default value so that environment wont be able to overwrite it
f.defaultValue = f.Value f.defaultValue = f.Value
f.defaultValueSet = true
if val, source, found := flagFromEnvOrFile(f.EnvVars, f.FilePath); found { if val, source, found := flagFromEnvOrFile(f.EnvVars, f.FilePath); found {
if val != "" { if val != "" {

View file

@ -32,8 +32,11 @@ func (f *DurationFlag) GetDefaultText() string {
if f.DefaultText != "" { if f.DefaultText != "" {
return f.DefaultText return f.DefaultText
} }
if f.defaultValueSet {
return f.defaultValue.String() return f.defaultValue.String()
} }
return f.Value.String()
}
// GetEnvVars returns the env vars for this flag // GetEnvVars returns the env vars for this flag
func (f *DurationFlag) GetEnvVars() []string { func (f *DurationFlag) GetEnvVars() []string {
@ -44,6 +47,7 @@ func (f *DurationFlag) GetEnvVars() []string {
func (f *DurationFlag) Apply(set *flag.FlagSet) error { func (f *DurationFlag) Apply(set *flag.FlagSet) error {
// set default value so that environment wont be able to overwrite it // set default value so that environment wont be able to overwrite it
f.defaultValue = f.Value f.defaultValue = f.Value
f.defaultValueSet = true
if val, source, found := flagFromEnvOrFile(f.EnvVars, f.FilePath); found { if val, source, found := flagFromEnvOrFile(f.EnvVars, f.FilePath); found {
if val != "" { if val != "" {

View file

@ -32,7 +32,10 @@ func (f *Float64Flag) GetDefaultText() string {
if f.DefaultText != "" { if f.DefaultText != "" {
return f.DefaultText return f.DefaultText
} }
return f.GetValue() if f.defaultValueSet {
return fmt.Sprintf("%v", f.defaultValue)
}
return fmt.Sprintf("%v", f.Value)
} }
// GetEnvVars returns the env vars for this flag // GetEnvVars returns the env vars for this flag
@ -42,6 +45,9 @@ func (f *Float64Flag) GetEnvVars() []string {
// Apply populates the flag given the flag set and environment // Apply populates the flag given the flag set and environment
func (f *Float64Flag) Apply(set *flag.FlagSet) error { func (f *Float64Flag) Apply(set *flag.FlagSet) error {
f.defaultValue = f.Value
f.defaultValueSet = true
if val, source, found := flagFromEnvOrFile(f.EnvVars, f.FilePath); found { if val, source, found := flagFromEnvOrFile(f.EnvVars, f.FilePath); found {
if val != "" { if val != "" {
valFloat, err := strconv.ParseFloat(val, 64) valFloat, err := strconv.ParseFloat(val, 64)

View file

@ -53,9 +53,15 @@ func (f *GenericFlag) GetDefaultText() string {
if f.DefaultText != "" { if f.DefaultText != "" {
return f.DefaultText return f.DefaultText
} }
if f.defaultValue != nil { val := f.Value
return f.defaultValue.String() if f.defaultValueSet {
val = f.defaultValue
} }
if val != nil {
return val.String()
}
return "" return ""
} }
@ -70,6 +76,7 @@ func (f *GenericFlag) Apply(set *flag.FlagSet) error {
// set default value so that environment wont be able to overwrite it // set default value so that environment wont be able to overwrite it
if f.Value != nil { if f.Value != nil {
f.defaultValue = &stringGeneric{value: f.Value.String()} f.defaultValue = &stringGeneric{value: f.Value.String()}
f.defaultValueSet = true
} }
if val, source, found := flagFromEnvOrFile(f.EnvVars, f.FilePath); found { if val, source, found := flagFromEnvOrFile(f.EnvVars, f.FilePath); found {

View file

@ -32,8 +32,11 @@ func (f *IntFlag) GetDefaultText() string {
if f.DefaultText != "" { if f.DefaultText != "" {
return f.DefaultText return f.DefaultText
} }
if f.defaultValueSet {
return fmt.Sprintf("%d", f.defaultValue) return fmt.Sprintf("%d", f.defaultValue)
} }
return fmt.Sprintf("%d", f.Value)
}
// GetEnvVars returns the env vars for this flag // GetEnvVars returns the env vars for this flag
func (f *IntFlag) GetEnvVars() []string { func (f *IntFlag) GetEnvVars() []string {
@ -44,6 +47,7 @@ func (f *IntFlag) GetEnvVars() []string {
func (f *IntFlag) Apply(set *flag.FlagSet) error { func (f *IntFlag) Apply(set *flag.FlagSet) error {
// set default value so that environment wont be able to overwrite it // set default value so that environment wont be able to overwrite it
f.defaultValue = f.Value f.defaultValue = f.Value
f.defaultValueSet = true
if val, source, found := flagFromEnvOrFile(f.EnvVars, f.FilePath); found { if val, source, found := flagFromEnvOrFile(f.EnvVars, f.FilePath); found {
if val != "" { if val != "" {

View file

@ -32,8 +32,11 @@ func (f *Int64Flag) GetDefaultText() string {
if f.DefaultText != "" { if f.DefaultText != "" {
return f.DefaultText return f.DefaultText
} }
if f.defaultValueSet {
return fmt.Sprintf("%d", f.defaultValue) return fmt.Sprintf("%d", f.defaultValue)
} }
return fmt.Sprintf("%d", f.Value)
}
// GetEnvVars returns the env vars for this flag // GetEnvVars returns the env vars for this flag
func (f *Int64Flag) GetEnvVars() []string { func (f *Int64Flag) GetEnvVars() []string {
@ -44,6 +47,7 @@ func (f *Int64Flag) GetEnvVars() []string {
func (f *Int64Flag) Apply(set *flag.FlagSet) error { func (f *Int64Flag) Apply(set *flag.FlagSet) error {
// set default value so that environment wont be able to overwrite it // set default value so that environment wont be able to overwrite it
f.defaultValue = f.Value f.defaultValue = f.Value
f.defaultValueSet = true
if val, source, found := flagFromEnvOrFile(f.EnvVars, f.FilePath); found { if val, source, found := flagFromEnvOrFile(f.EnvVars, f.FilePath); found {
if val != "" { if val != "" {

View file

@ -33,10 +33,14 @@ func (f *PathFlag) GetDefaultText() string {
if f.DefaultText != "" { if f.DefaultText != "" {
return f.DefaultText return f.DefaultText
} }
if f.defaultValue == "" { val := f.Value
return f.defaultValue if f.defaultValueSet {
val = f.defaultValue
} }
return fmt.Sprintf("%q", f.defaultValue) if val == "" {
return val
}
return fmt.Sprintf("%q", val)
} }
// GetEnvVars returns the env vars for this flag // GetEnvVars returns the env vars for this flag
@ -48,6 +52,7 @@ func (f *PathFlag) GetEnvVars() []string {
func (f *PathFlag) Apply(set *flag.FlagSet) error { func (f *PathFlag) Apply(set *flag.FlagSet) error {
// set default value so that environment wont be able to overwrite it // set default value so that environment wont be able to overwrite it
f.defaultValue = f.Value f.defaultValue = f.Value
f.defaultValueSet = true
if val, _, found := flagFromEnvOrFile(f.EnvVars, f.FilePath); found { if val, _, found := flagFromEnvOrFile(f.EnvVars, f.FilePath); found {
f.Value = val f.Value = val

View file

@ -31,10 +31,15 @@ func (f *StringFlag) GetDefaultText() string {
if f.DefaultText != "" { if f.DefaultText != "" {
return f.DefaultText return f.DefaultText
} }
if f.defaultValue == "" { val := f.Value
return f.defaultValue if f.defaultValueSet {
val = f.defaultValue
} }
return fmt.Sprintf("%q", f.defaultValue)
if val == "" {
return val
}
return fmt.Sprintf("%q", val)
} }
// GetEnvVars returns the env vars for this flag // GetEnvVars returns the env vars for this flag
@ -46,6 +51,7 @@ func (f *StringFlag) GetEnvVars() []string {
func (f *StringFlag) Apply(set *flag.FlagSet) error { func (f *StringFlag) Apply(set *flag.FlagSet) error {
// set default value so that environment wont be able to overwrite it // set default value so that environment wont be able to overwrite it
f.defaultValue = f.Value f.defaultValue = f.Value
f.defaultValueSet = true
if val, _, found := flagFromEnvOrFile(f.EnvVars, f.FilePath); found { if val, _, found := flagFromEnvOrFile(f.EnvVars, f.FilePath); found {
f.Value = val f.Value = val

View file

@ -120,8 +120,13 @@ func (f *TimestampFlag) GetDefaultText() string {
if f.DefaultText != "" { if f.DefaultText != "" {
return f.DefaultText return f.DefaultText
} }
if f.defaultValue != nil && f.defaultValue.timestamp != nil { val := f.Value
return f.defaultValue.timestamp.String() if f.defaultValueSet {
val = f.defaultValue
}
if val != nil && val.timestamp != nil {
return val.timestamp.String()
} }
return "" return ""
@ -144,6 +149,7 @@ func (f *TimestampFlag) Apply(set *flag.FlagSet) error {
f.Value.SetLocation(f.Timezone) f.Value.SetLocation(f.Timezone)
f.defaultValue = f.Value.clone() f.defaultValue = f.Value.clone()
f.defaultValueSet = true
if val, source, found := flagFromEnvOrFile(f.EnvVars, f.FilePath); found { if val, source, found := flagFromEnvOrFile(f.EnvVars, f.FilePath); found {
if err := f.Value.Set(val); err != nil { if err := f.Value.Set(val); err != nil {

View file

@ -25,6 +25,7 @@ func (f *UintFlag) GetCategory() string {
func (f *UintFlag) Apply(set *flag.FlagSet) error { func (f *UintFlag) Apply(set *flag.FlagSet) error {
// set default value so that environment wont be able to overwrite it // set default value so that environment wont be able to overwrite it
f.defaultValue = f.Value f.defaultValue = f.Value
f.defaultValueSet = true
if val, source, found := flagFromEnvOrFile(f.EnvVars, f.FilePath); found { if val, source, found := flagFromEnvOrFile(f.EnvVars, f.FilePath); found {
if val != "" { if val != "" {
@ -69,8 +70,11 @@ func (f *UintFlag) GetDefaultText() string {
if f.DefaultText != "" { if f.DefaultText != "" {
return f.DefaultText return f.DefaultText
} }
if f.defaultValueSet {
return fmt.Sprintf("%d", f.defaultValue) return fmt.Sprintf("%d", f.defaultValue)
} }
return fmt.Sprintf("%d", f.Value)
}
// GetEnvVars returns the env vars for this flag // GetEnvVars returns the env vars for this flag
func (f *UintFlag) GetEnvVars() []string { func (f *UintFlag) GetEnvVars() []string {

View file

@ -25,6 +25,7 @@ func (f *Uint64Flag) GetCategory() string {
func (f *Uint64Flag) Apply(set *flag.FlagSet) error { func (f *Uint64Flag) Apply(set *flag.FlagSet) error {
// set default value so that environment wont be able to overwrite it // set default value so that environment wont be able to overwrite it
f.defaultValue = f.Value f.defaultValue = f.Value
f.defaultValueSet = true
if val, source, found := flagFromEnvOrFile(f.EnvVars, f.FilePath); found { if val, source, found := flagFromEnvOrFile(f.EnvVars, f.FilePath); found {
if val != "" { if val != "" {
@ -69,8 +70,11 @@ func (f *Uint64Flag) GetDefaultText() string {
if f.DefaultText != "" { if f.DefaultText != "" {
return f.DefaultText return f.DefaultText
} }
if f.defaultValueSet {
return fmt.Sprintf("%d", f.defaultValue) return fmt.Sprintf("%d", f.defaultValue)
} }
return fmt.Sprintf("%d", f.Value)
}
// GetEnvVars returns the env vars for this flag // GetEnvVars returns the env vars for this flag
func (f *Uint64Flag) GetEnvVars() []string { func (f *Uint64Flag) GetEnvVars() []string {

View file

@ -190,6 +190,15 @@ func (f *Uint64SliceFlag) Get(ctx *Context) []uint64 {
return ctx.Uint64Slice(f.Name) return ctx.Uint64Slice(f.Name)
} }
// RunAction executes flag action if set
func (f *Uint64SliceFlag) RunAction(c *Context) error {
if f.Action != nil {
return f.Action(c, c.Uint64Slice(f.Name))
}
return nil
}
// Uint64Slice looks up the value of a local Uint64SliceFlag, returns // Uint64Slice looks up the value of a local Uint64SliceFlag, returns
// nil if not found // nil if not found
func (cCtx *Context) Uint64Slice(name string) []uint64 { func (cCtx *Context) Uint64Slice(name string) []uint64 {

View file

@ -201,6 +201,15 @@ func (f *UintSliceFlag) Get(ctx *Context) []uint {
return ctx.UintSlice(f.Name) return ctx.UintSlice(f.Name)
} }
// RunAction executes flag action if set
func (f *UintSliceFlag) RunAction(c *Context) error {
if f.Action != nil {
return f.Action(c, c.UintSlice(f.Name))
}
return nil
}
// UintSlice looks up the value of a local UintSliceFlag, returns // UintSlice looks up the value of a local UintSliceFlag, returns
// nil if not found // nil if not found
func (cCtx *Context) UintSlice(name string) []uint { func (cCtx *Context) UintSlice(name string) []uint {

View file

@ -27,15 +27,15 @@ application:
VARIABLES VARIABLES
var ( var (
SuggestFlag SuggestFlagFunc = suggestFlag SuggestFlag SuggestFlagFunc = nil // initialized in suggestions.go unless built with urfave_cli_no_suggest
SuggestCommand SuggestCommandFunc = suggestCommand SuggestCommand SuggestCommandFunc = nil // initialized in suggestions.go unless built with urfave_cli_no_suggest
SuggestDidYouMeanTemplate string = suggestDidYouMeanTemplate SuggestDidYouMeanTemplate string = suggestDidYouMeanTemplate
) )
var AppHelpTemplate = `NAME: var AppHelpTemplate = `NAME:
{{template "helpNameTemplate" .}} {{template "helpNameTemplate" .}}
USAGE: USAGE:
{{if .UsageText}}{{wrap .UsageText 3}}{{else}}{{.HelpName}} {{if .VisibleFlags}}[global options]{{end}}{{if .Commands}} command [command options]{{end}} {{if .ArgsUsage}}{{.ArgsUsage}}{{else}}[arguments...]{{end}}{{end}}{{if .Version}}{{if not .HideVersion}} {{if .UsageText}}{{wrap .UsageText 3}}{{else}}{{.HelpName}} {{if .VisibleFlags}}[global options]{{end}}{{if .Commands}} command [command options]{{end}}{{if .ArgsUsage}} {{.ArgsUsage}}{{else}}{{if .Args}} [arguments...]{{end}}{{end}}{{end}}{{if .Version}}{{if not .HideVersion}}
VERSION: VERSION:
{{.Version}}{{end}}{{end}}{{if .Description}} {{.Version}}{{end}}{{end}}{{if .Description}}
@ -136,7 +136,7 @@ var SubcommandHelpTemplate = `NAME:
{{template "helpNameTemplate" .}} {{template "helpNameTemplate" .}}
USAGE: USAGE:
{{if .UsageText}}{{wrap .UsageText 3}}{{else}}{{.HelpName}} {{if .VisibleFlags}}command [command options]{{end}} {{if .ArgsUsage}}{{.ArgsUsage}}{{else}}[arguments...]{{end}}{{end}}{{if .Description}} {{if .UsageText}}{{wrap .UsageText 3}}{{else}}{{.HelpName}} {{if .VisibleFlags}}command [command options]{{end}}{{if .ArgsUsage}} {{.ArgsUsage}}{{else}}{{if .Args}} [arguments...]{{end}}{{end}}{{end}}{{if .Description}}
DESCRIPTION: DESCRIPTION:
{{template "descriptionTemplate" .}}{{end}}{{if .VisibleCommands}} {{template "descriptionTemplate" .}}{{end}}{{if .VisibleCommands}}
@ -253,6 +253,8 @@ type App struct {
Usage string Usage string
// Text to override the USAGE section of help // Text to override the USAGE section of help
UsageText string UsageText string
// Whether this command supports arguments
Args bool
// Description of the program argument format. // Description of the program argument format.
ArgsUsage string ArgsUsage string
// Version of the program // Version of the program
@ -523,6 +525,8 @@ type Command struct {
UsageText string UsageText string
// A longer explanation of how the command works // A longer explanation of how the command works
Description string Description string
// Whether this command supports arguments
Args bool
// A short description of the arguments of this command // A short description of the arguments of this command
ArgsUsage string ArgsUsage string
// The category the command is part of // The category the command is part of
@ -649,7 +653,7 @@ func (cCtx *Context) Bool(name string) bool
Bool looks up the value of a local BoolFlag, returns false if not found Bool looks up the value of a local BoolFlag, returns false if not found
func (cCtx *Context) Count(name string) int func (cCtx *Context) Count(name string) int
Count returns the num of occurences of this flag Count returns the num of occurrences of this flag
func (cCtx *Context) Duration(name string) time.Duration func (cCtx *Context) Duration(name string) time.Duration
Duration looks up the value of a local DurationFlag, returns 0 if not found Duration looks up the value of a local DurationFlag, returns 0 if not found
@ -2142,6 +2146,9 @@ func (f *Uint64SliceFlag) IsVisible() bool
func (f *Uint64SliceFlag) Names() []string func (f *Uint64SliceFlag) Names() []string
Names returns the names of the flag Names returns the names of the flag
func (f *Uint64SliceFlag) RunAction(c *Context) error
RunAction executes flag action if set
func (f *Uint64SliceFlag) String() string func (f *Uint64SliceFlag) String() string
String returns a readable representation of this value (for usage defaults) String returns a readable representation of this value (for usage defaults)
@ -2307,6 +2314,9 @@ func (f *UintSliceFlag) IsVisible() bool
func (f *UintSliceFlag) Names() []string func (f *UintSliceFlag) Names() []string
Names returns the names of the flag Names returns the names of the flag
func (f *UintSliceFlag) RunAction(c *Context) error
RunAction executes flag action if set
func (f *UintSliceFlag) String() string func (f *UintSliceFlag) String() string
String returns a readable representation of this value (for usage defaults) String returns a readable representation of this value (for usage defaults)

View file

@ -69,7 +69,7 @@ var helpCommand = &Command{
} }
// Case 3, 5 // Case 3, 5
if (len(cCtx.Command.Subcommands) == 1 && !cCtx.Command.HideHelp) || if (len(cCtx.Command.Subcommands) == 1 && !cCtx.Command.HideHelp && !cCtx.Command.HideHelpCommand) ||
(len(cCtx.Command.Subcommands) == 0 && cCtx.Command.HideHelp) { (len(cCtx.Command.Subcommands) == 0 && cCtx.Command.HideHelp) {
templ := cCtx.Command.CustomHelpTemplate templ := cCtx.Command.CustomHelpTemplate
if templ == "" { if templ == "" {
@ -248,7 +248,6 @@ func ShowCommandHelpAndExit(c *Context, command string, code int) {
// ShowCommandHelp prints help for the given command // ShowCommandHelp prints help for the given command
func ShowCommandHelp(ctx *Context, command string) error { func ShowCommandHelp(ctx *Context, command string) error {
commands := ctx.App.Commands commands := ctx.App.Commands
if ctx.Command.Subcommands != nil { if ctx.Command.Subcommands != nil {
commands = ctx.Command.Subcommands commands = ctx.Command.Subcommands
@ -278,7 +277,7 @@ func ShowCommandHelp(ctx *Context, command string) error {
if ctx.App.CommandNotFound == nil { if ctx.App.CommandNotFound == nil {
errMsg := fmt.Sprintf("No help topic for '%v'", command) errMsg := fmt.Sprintf("No help topic for '%v'", command)
if ctx.App.Suggest { if ctx.App.Suggest && SuggestCommand != nil {
if suggestion := SuggestCommand(ctx.Command.Subcommands, command); suggestion != "" { if suggestion := SuggestCommand(ctx.Command.Subcommands, command); suggestion != "" {
errMsg += ". " + suggestion errMsg += ". " + suggestion
} }
@ -337,7 +336,6 @@ func ShowCommandCompletions(ctx *Context, command string) {
DefaultCompleteWithFlags(c)(ctx) DefaultCompleteWithFlags(c)(ctx)
} }
} }
} }
// printHelpCustom is the default implementation of HelpPrinterCustom. // printHelpCustom is the default implementation of HelpPrinterCustom.
@ -345,7 +343,6 @@ func ShowCommandCompletions(ctx *Context, command string) {
// The customFuncs map will be combined with a default template.FuncMap to // The customFuncs map will be combined with a default template.FuncMap to
// allow using arbitrary functions in template rendering. // allow using arbitrary functions in template rendering.
func printHelpCustom(out io.Writer, templ string, data interface{}, customFuncs map[string]interface{}) { func printHelpCustom(out io.Writer, templ string, data interface{}, customFuncs map[string]interface{}) {
const maxLineLength = 10000 const maxLineLength = 10000
funcMap := template.FuncMap{ funcMap := template.FuncMap{
@ -376,17 +373,26 @@ func printHelpCustom(out io.Writer, templ string, data interface{}, customFuncs
w := tabwriter.NewWriter(out, 1, 8, 2, ' ', 0) w := tabwriter.NewWriter(out, 1, 8, 2, ' ', 0)
t := template.Must(template.New("help").Funcs(funcMap).Parse(templ)) t := template.Must(template.New("help").Funcs(funcMap).Parse(templ))
t.New("helpNameTemplate").Parse(helpNameTemplate) templates := map[string]string{
t.New("usageTemplate").Parse(usageTemplate) "helpNameTemplate": helpNameTemplate,
t.New("descriptionTemplate").Parse(descriptionTemplate) "usageTemplate": usageTemplate,
t.New("visibleCommandTemplate").Parse(visibleCommandTemplate) "descriptionTemplate": descriptionTemplate,
t.New("copyrightTemplate").Parse(copyrightTemplate) "visibleCommandTemplate": visibleCommandTemplate,
t.New("versionTemplate").Parse(versionTemplate) "copyrightTemplate": copyrightTemplate,
t.New("visibleFlagCategoryTemplate").Parse(visibleFlagCategoryTemplate) "versionTemplate": versionTemplate,
t.New("visibleFlagTemplate").Parse(visibleFlagTemplate) "visibleFlagCategoryTemplate": visibleFlagCategoryTemplate,
t.New("visibleGlobalFlagCategoryTemplate").Parse(strings.Replace(visibleFlagCategoryTemplate, "OPTIONS", "GLOBAL OPTIONS", -1)) "visibleFlagTemplate": visibleFlagTemplate,
t.New("authorsTemplate").Parse(authorsTemplate) "visibleGlobalFlagCategoryTemplate": strings.Replace(visibleFlagCategoryTemplate, "OPTIONS", "GLOBAL OPTIONS", -1),
t.New("visibleCommandCategoryTemplate").Parse(visibleCommandCategoryTemplate) "authorsTemplate": authorsTemplate,
"visibleCommandCategoryTemplate": visibleCommandCategoryTemplate,
}
for name, value := range templates {
if _, err := t.New(name).Parse(value); err != nil {
if os.Getenv("CLI_TEMPLATE_ERROR_DEBUG") != "" {
_, _ = fmt.Fprintf(ErrWriter, "CLI TEMPLATE ERROR: %#v\n", err)
}
}
}
err := t.Execute(w, data) err := t.Execute(w, data)
if err != nil { if err != nil {
@ -415,6 +421,9 @@ func checkVersion(cCtx *Context) bool {
} }
func checkHelp(cCtx *Context) bool { func checkHelp(cCtx *Context) bool {
if HelpFlag == nil {
return false
}
found := false found := false
for _, name := range HelpFlag.Names() { for _, name := range HelpFlag.Names() {
if cCtx.Bool(name) { if cCtx.Bool(name) {
@ -426,24 +435,6 @@ func checkHelp(cCtx *Context) bool {
return found return found
} }
func checkCommandHelp(c *Context, name string) bool {
if c.Bool("h") || c.Bool("help") {
_ = ShowCommandHelp(c, name)
return true
}
return false
}
func checkSubcommandHelp(cCtx *Context) bool {
if cCtx.Bool("h") || cCtx.Bool("help") {
_ = ShowSubcommandHelp(cCtx)
return true
}
return false
}
func checkShellCompleteFlag(a *App, arguments []string) (bool, []string) { func checkShellCompleteFlag(a *App, arguments []string) (bool, []string) {
if !a.EnableBashCompletion { if !a.EnableBashCompletion {
return false, arguments return false, arguments
@ -456,6 +447,15 @@ func checkShellCompleteFlag(a *App, arguments []string) (bool, []string) {
return false, arguments return false, arguments
} }
for _, arg := range arguments {
// If arguments include "--", shell completion is disabled
// because after "--" only positional arguments are accepted.
// https://unix.stackexchange.com/a/11382
if arg == "--" {
return false, arguments
}
}
return true, arguments[:pos] return true, arguments[:pos]
} }
@ -505,7 +505,6 @@ func wrap(input string, offset int, wrapAt int) string {
ss = append(ss, wrapped) ss = append(ss, wrapped)
} else { } else {
ss = append(ss, padding+wrapped) ss = append(ss, padding+wrapped)
} }
} }

View file

@ -1,7 +1,7 @@
# NOTE: the mkdocs dependencies will need to be installed out of # NOTE: the mkdocs dependencies will need to be installed out of
# band until this whole thing gets more automated: # band until this whole thing gets more automated:
# #
# pip install -r mkdocs-requirements.txt # pip install -r mkdocs-reqs.txt
# #
site_name: urfave/cli site_name: urfave/cli

View file

@ -1,3 +1,6 @@
//go:build !urfave_cli_no_suggest
// +build !urfave_cli_no_suggest
package cli package cli
import ( import (
@ -6,6 +9,11 @@ import (
"github.com/xrash/smetrics" "github.com/xrash/smetrics"
) )
func init() {
SuggestFlag = suggestFlag
SuggestCommand = suggestCommand
}
func jaroWinkler(a, b string) float64 { func jaroWinkler(a, b string) float64 {
// magic values are from https://github.com/xrash/smetrics/blob/039620a656736e6ad994090895784a7af15e0b80/jaro-winkler.go#L8 // magic values are from https://github.com/xrash/smetrics/blob/039620a656736e6ad994090895784a7af15e0b80/jaro-winkler.go#L8
const ( const (
@ -21,7 +29,7 @@ func suggestFlag(flags []Flag, provided string, hideHelp bool) string {
for _, flag := range flags { for _, flag := range flags {
flagNames := flag.Names() flagNames := flag.Names()
if !hideHelp { if !hideHelp && HelpFlag != nil {
flagNames = append(flagNames, HelpFlag.Names()...) flagNames = append(flagNames, HelpFlag.Names()...)
} }
for _, name := range flagNames { for _, name := range flagNames {

View file

@ -1,7 +1,7 @@
package cli package cli
var helpNameTemplate = `{{$v := offset .HelpName 6}}{{wrap .HelpName 3}}{{if .Usage}} - {{wrap .Usage $v}}{{end}}` var helpNameTemplate = `{{$v := offset .HelpName 6}}{{wrap .HelpName 3}}{{if .Usage}} - {{wrap .Usage $v}}{{end}}`
var usageTemplate = `{{if .UsageText}}{{wrap .UsageText 3}}{{else}}{{.HelpName}}{{if .VisibleFlags}} [command options]{{end}} {{if .ArgsUsage}}{{.ArgsUsage}}{{else}}[arguments...]{{end}}{{end}}` var usageTemplate = `{{if .UsageText}}{{wrap .UsageText 3}}{{else}}{{.HelpName}}{{if .VisibleFlags}} [command options]{{end}}{{if .ArgsUsage}} {{.ArgsUsage}}{{else}}{{if .Args}} [arguments...]{{end}}{{end}}{{end}}`
var descriptionTemplate = `{{wrap .Description 3}}` var descriptionTemplate = `{{wrap .Description 3}}`
var authorsTemplate = `{{with $length := len .Authors}}{{if ne 1 $length}}S{{end}}{{end}}: var authorsTemplate = `{{with $length := len .Authors}}{{if ne 1 $length}}S{{end}}{{end}}:
{{range $index, $author := .Authors}}{{if $index}} {{range $index, $author := .Authors}}{{if $index}}
@ -35,7 +35,7 @@ var AppHelpTemplate = `NAME:
{{template "helpNameTemplate" .}} {{template "helpNameTemplate" .}}
USAGE: USAGE:
{{if .UsageText}}{{wrap .UsageText 3}}{{else}}{{.HelpName}} {{if .VisibleFlags}}[global options]{{end}}{{if .Commands}} command [command options]{{end}} {{if .ArgsUsage}}{{.ArgsUsage}}{{else}}[arguments...]{{end}}{{end}}{{if .Version}}{{if not .HideVersion}} {{if .UsageText}}{{wrap .UsageText 3}}{{else}}{{.HelpName}} {{if .VisibleFlags}}[global options]{{end}}{{if .Commands}} command [command options]{{end}}{{if .ArgsUsage}} {{.ArgsUsage}}{{else}}{{if .Args}} [arguments...]{{end}}{{end}}{{end}}{{if .Version}}{{if not .HideVersion}}
VERSION: VERSION:
{{.Version}}{{end}}{{end}}{{if .Description}} {{.Version}}{{end}}{{end}}{{if .Description}}
@ -83,7 +83,7 @@ var SubcommandHelpTemplate = `NAME:
{{template "helpNameTemplate" .}} {{template "helpNameTemplate" .}}
USAGE: USAGE:
{{if .UsageText}}{{wrap .UsageText 3}}{{else}}{{.HelpName}} {{if .VisibleFlags}}command [command options]{{end}} {{if .ArgsUsage}}{{.ArgsUsage}}{{else}}[arguments...]{{end}}{{end}}{{if .Description}} {{if .UsageText}}{{wrap .UsageText 3}}{{else}}{{.HelpName}} {{if .VisibleFlags}}command [command options]{{end}}{{if .ArgsUsage}} {{.ArgsUsage}}{{else}}{{if .Args}} [arguments...]{{end}}{{end}}{{end}}{{if .Description}}
DESCRIPTION: DESCRIPTION:
{{template "descriptionTemplate" .}}{{end}}{{if .VisibleCommands}} {{template "descriptionTemplate" .}}{{end}}{{if .VisibleCommands}}

View file

@ -24,6 +24,7 @@ type Float64SliceFlag struct {
EnvVars []string EnvVars []string
defaultValue *Float64Slice defaultValue *Float64Slice
defaultValueSet bool
separator separatorSpec separator separatorSpec
@ -70,6 +71,7 @@ type GenericFlag struct {
EnvVars []string EnvVars []string
defaultValue Generic defaultValue Generic
defaultValueSet bool
TakesFile bool TakesFile bool
@ -121,6 +123,7 @@ type Int64SliceFlag struct {
EnvVars []string EnvVars []string
defaultValue *Int64Slice defaultValue *Int64Slice
defaultValueSet bool
separator separatorSpec separator separatorSpec
@ -167,6 +170,7 @@ type IntSliceFlag struct {
EnvVars []string EnvVars []string
defaultValue *IntSlice defaultValue *IntSlice
defaultValueSet bool
separator separatorSpec separator separatorSpec
@ -213,6 +217,7 @@ type PathFlag struct {
EnvVars []string EnvVars []string
defaultValue Path defaultValue Path
defaultValueSet bool
TakesFile bool TakesFile bool
@ -264,6 +269,7 @@ type StringSliceFlag struct {
EnvVars []string EnvVars []string
defaultValue *StringSlice defaultValue *StringSlice
defaultValueSet bool
separator separatorSpec separator separatorSpec
@ -314,6 +320,7 @@ type TimestampFlag struct {
EnvVars []string EnvVars []string
defaultValue *Timestamp defaultValue *Timestamp
defaultValueSet bool
Layout string Layout string
@ -367,6 +374,7 @@ type Uint64SliceFlag struct {
EnvVars []string EnvVars []string
defaultValue *Uint64Slice defaultValue *Uint64Slice
defaultValueSet bool
separator separatorSpec separator separatorSpec
@ -413,6 +421,7 @@ type UintSliceFlag struct {
EnvVars []string EnvVars []string
defaultValue *UintSlice defaultValue *UintSlice
defaultValueSet bool
separator separatorSpec separator separatorSpec
@ -459,6 +468,7 @@ type BoolFlag struct {
EnvVars []string EnvVars []string
defaultValue bool defaultValue bool
defaultValueSet bool
Count *int Count *int
@ -512,6 +522,7 @@ type Float64Flag struct {
EnvVars []string EnvVars []string
defaultValue float64 defaultValue float64
defaultValueSet bool
Action func(*Context, float64) error Action func(*Context, float64) error
} }
@ -561,6 +572,7 @@ type IntFlag struct {
EnvVars []string EnvVars []string
defaultValue int defaultValue int
defaultValueSet bool
Base int Base int
@ -612,6 +624,7 @@ type Int64Flag struct {
EnvVars []string EnvVars []string
defaultValue int64 defaultValue int64
defaultValueSet bool
Base int Base int
@ -663,6 +676,7 @@ type StringFlag struct {
EnvVars []string EnvVars []string
defaultValue string defaultValue string
defaultValueSet bool
TakesFile bool TakesFile bool
@ -714,6 +728,7 @@ type DurationFlag struct {
EnvVars []string EnvVars []string
defaultValue time.Duration defaultValue time.Duration
defaultValueSet bool
Action func(*Context, time.Duration) error Action func(*Context, time.Duration) error
} }
@ -763,6 +778,7 @@ type UintFlag struct {
EnvVars []string EnvVars []string
defaultValue uint defaultValue uint
defaultValueSet bool
Base int Base int
@ -814,6 +830,7 @@ type Uint64Flag struct {
EnvVars []string EnvVars []string
defaultValue uint64 defaultValue uint64
defaultValueSet bool
Base int Base int

View file

@ -6,36 +6,58 @@ import (
// The Soundex encoding. It is a phonetic algorithm that considers how the words sound in English. Soundex maps a string to a 4-byte code consisting of the first letter of the original string and three numbers. Strings that sound similar should map to the same code. // The Soundex encoding. It is a phonetic algorithm that considers how the words sound in English. Soundex maps a string to a 4-byte code consisting of the first letter of the original string and three numbers. Strings that sound similar should map to the same code.
func Soundex(s string) string { func Soundex(s string) string {
m := map[byte]string{ b := strings.Builder{}
'B': "1", 'P': "1", 'F': "1", 'V': "1", b.Grow(4)
'C': "2", 'S': "2", 'K': "2", 'G': "2", 'J': "2", 'Q': "2", 'X': "2", 'Z': "2",
'D': "3", 'T': "3",
'L': "4",
'M': "5", 'N': "5",
'R': "6",
}
s = strings.ToUpper(s)
r := string(s[0])
p := s[0] p := s[0]
for i := 1; i < len(s) && len(r) < 4; i++ { if p <= 'z' && p >= 'a' {
p -= 32 // convert to uppercase
}
b.WriteByte(p)
n := 0
for i := 1; i < len(s); i++ {
c := s[i] c := s[i]
if (c < 'A' || c > 'Z') || (c == p) { if c <= 'z' && c >= 'a' {
c -= 32 // convert to uppercase
} else if c < 'A' || c > 'Z' {
continue
}
if c == p {
continue continue
} }
p = c p = c
if n, ok := m[c]; ok { switch c {
r += n case 'B', 'P', 'F', 'V':
c = '1'
case 'C', 'S', 'K', 'G', 'J', 'Q', 'X', 'Z':
c = '2'
case 'D', 'T':
c = '3'
case 'L':
c = '4'
case 'M', 'N':
c = '5'
case 'R':
c = '6'
default:
continue
}
b.WriteByte(c)
n++
if n == 3 {
break
} }
} }
for i := len(r); i < 4; i++ { for i := n; i < 3; i++ {
r += "0" b.WriteByte('0')
} }
return r return b.String()
} }

18
vendor/modules.txt vendored
View file

@ -1,5 +1,5 @@
# github.com/aws/aws-sdk-go v1.45.25 # github.com/aws/aws-sdk-go v1.55.5
## explicit; go 1.11 ## explicit; go 1.19
github.com/aws/aws-sdk-go/aws github.com/aws/aws-sdk-go/aws
github.com/aws/aws-sdk-go/aws/arn github.com/aws/aws-sdk-go/aws/arn
github.com/aws/aws-sdk-go/aws/auth/bearer github.com/aws/aws-sdk-go/aws/auth/bearer
@ -51,7 +51,7 @@ github.com/aws/aws-sdk-go/service/sso/ssoiface
github.com/aws/aws-sdk-go/service/ssooidc github.com/aws/aws-sdk-go/service/ssooidc
github.com/aws/aws-sdk-go/service/sts github.com/aws/aws-sdk-go/service/sts
github.com/aws/aws-sdk-go/service/sts/stsiface github.com/aws/aws-sdk-go/service/sts/stsiface
# github.com/cpuguy83/go-md2man/v2 v2.0.2 # github.com/cpuguy83/go-md2man/v2 v2.0.4
## explicit; go 1.11 ## explicit; go 1.11
github.com/cpuguy83/go-md2man/v2/md2man github.com/cpuguy83/go-md2man/v2/md2man
# github.com/davecgh/go-spew v1.1.1 # github.com/davecgh/go-spew v1.1.1
@ -76,7 +76,7 @@ github.com/pkg/sftp/internal/encoding/ssh/filexfer
# github.com/pmezard/go-difflib v1.0.0 # github.com/pmezard/go-difflib v1.0.0
## explicit ## explicit
github.com/pmezard/go-difflib/difflib github.com/pmezard/go-difflib/difflib
# github.com/rs/zerolog v1.31.0 # github.com/rs/zerolog v1.33.0
## explicit; go 1.15 ## explicit; go 1.15
github.com/rs/zerolog github.com/rs/zerolog
github.com/rs/zerolog/internal/cbor github.com/rs/zerolog/internal/cbor
@ -85,14 +85,14 @@ github.com/rs/zerolog/log
# github.com/russross/blackfriday/v2 v2.1.0 # github.com/russross/blackfriday/v2 v2.1.0
## explicit ## explicit
github.com/russross/blackfriday/v2 github.com/russross/blackfriday/v2
# github.com/stretchr/testify v1.8.4 # github.com/stretchr/testify v1.9.0
## explicit; go 1.20 ## explicit; go 1.17
github.com/stretchr/testify/assert github.com/stretchr/testify/assert
# github.com/urfave/cli/v2 v2.25.7 # github.com/urfave/cli/v2 v2.27.4
## explicit; go 1.18 ## explicit; go 1.18
github.com/urfave/cli/v2 github.com/urfave/cli/v2
# github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 # github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1
## explicit ## explicit; go 1.15
github.com/xrash/smetrics github.com/xrash/smetrics
# golang.org/x/crypto v0.14.0 # golang.org/x/crypto v0.14.0
## explicit; go 1.17 ## explicit; go 1.17