From 79aa65c1faa5ddd924ccebf377db64f049459afc Mon Sep 17 00:00:00 2001 From: Sebastiaan van Stijn Date: Sat, 26 Feb 2022 13:45:12 +0100 Subject: [PATCH] registry: return "errdefs" compatible error types Adding some small utility functions to make generating them easier. Signed-off-by: Sebastiaan van Stijn --- registry/config.go | 26 ++++++++++++-------------- registry/config_test.go | 8 +++++--- registry/endpoint_v1.go | 15 +++++++-------- registry/errors.go | 13 +++++++++++++ registry/registry.go | 9 ++++----- registry/service.go | 7 +++---- registry/service_v2.go | 2 +- registry/session.go | 6 +++--- 8 files changed, 48 insertions(+), 38 deletions(-) diff --git a/registry/config.go b/registry/config.go index fae73bb4e3..61e8447ddf 100644 --- a/registry/config.go +++ b/registry/config.go @@ -1,7 +1,6 @@ package registry // import "github.com/docker/docker/registry" import ( - "fmt" "net" "net/url" "regexp" @@ -97,17 +96,17 @@ func (config *serviceConfig) loadAllowNondistributableArtifacts(registries []str return err } if hasScheme(r) { - return fmt.Errorf("allow-nondistributable-artifacts registry %s should not contain '://'", r) + return invalidParamf("allow-nondistributable-artifacts registry %s should not contain '://'", r) } if _, ipnet, err := net.ParseCIDR(r); err == nil { // Valid CIDR. cidrs[ipnet.String()] = (*registrytypes.NetIPNet)(ipnet) - } else if err := validateHostPort(r); err == nil { + } else if err = validateHostPort(r); err == nil { // Must be `host:port` if not CIDR. hostnames[r] = true } else { - return fmt.Errorf("allow-nondistributable-artifacts registry %s is not valid: %v", r, err) + return invalidParamWrapf(err, "allow-nondistributable-artifacts registry %s is not valid", r) } } @@ -188,7 +187,7 @@ skip: // before returning err, roll back to original data config.ServiceConfig.InsecureRegistryCIDRs = originalCIDRs config.ServiceConfig.IndexConfigs = originalIndexInfos - return fmt.Errorf("insecure registry %s should not contain '://'", r) + return invalidParamf("insecure registry %s should not contain '://'", r) } // Check if CIDR was passed to --insecure-registry _, ipnet, err := net.ParseCIDR(r) @@ -207,8 +206,7 @@ skip: if err := validateHostPort(r); err != nil { config.ServiceConfig.InsecureRegistryCIDRs = originalCIDRs config.ServiceConfig.IndexConfigs = originalIndexInfos - return fmt.Errorf("insecure registry %s is not valid: %v", r, err) - + return invalidParamWrapf(err, "insecure registry %s is not valid", r) } // Assume `host:port` if not CIDR. config.IndexConfigs[r] = ®istrytypes.IndexInfo{ @@ -310,18 +308,18 @@ func isCIDRMatch(cidrs []*registrytypes.NetIPNet, URLHost string) bool { func ValidateMirror(val string) (string, error) { uri, err := url.Parse(val) if err != nil { - return "", fmt.Errorf("invalid mirror: %q is not a valid URI", val) + return "", invalidParamWrapf(err, "invalid mirror: %q is not a valid URI", val) } if uri.Scheme != "http" && uri.Scheme != "https" { - return "", fmt.Errorf("invalid mirror: unsupported scheme %q in %q", uri.Scheme, uri) + return "", invalidParamf("invalid mirror: unsupported scheme %q in %q", uri.Scheme, uri) } if (uri.Path != "" && uri.Path != "/") || uri.RawQuery != "" || uri.Fragment != "" { - return "", fmt.Errorf("invalid mirror: path, query, or fragment at end of the URI %q", uri) + return "", invalidParamf("invalid mirror: path, query, or fragment at end of the URI %q", uri) } if uri.User != nil { // strip password from output uri.User = url.UserPassword(uri.User.Username(), "xxxxx") - return "", fmt.Errorf("invalid mirror: username/password not allowed in URI %q", uri) + return "", invalidParamf("invalid mirror: username/password not allowed in URI %q", uri) } return strings.TrimSuffix(val, "/") + "/", nil } @@ -333,7 +331,7 @@ func ValidateIndexName(val string) (string, error) { val = "docker.io" } if strings.HasPrefix(val, "-") || strings.HasSuffix(val, "-") { - return "", fmt.Errorf("invalid index name (%s). Cannot begin or end with a hyphen", val) + return "", invalidParamf("invalid index name (%s). Cannot begin or end with a hyphen", val) } return val, nil } @@ -352,7 +350,7 @@ func validateHostPort(s string) error { // If match against the `host:port` pattern fails, // it might be `IPv6:port`, which will be captured by net.ParseIP(host) if !validHostPortRegex.MatchString(s) && net.ParseIP(host) == nil { - return fmt.Errorf("invalid host %q", host) + return invalidParamf("invalid host %q", host) } if port != "" { v, err := strconv.Atoi(port) @@ -360,7 +358,7 @@ func validateHostPort(s string) error { return err } if v < 0 || v > 65535 { - return fmt.Errorf("invalid port %q", port) + return invalidParamf("invalid port %q", port) } } return nil diff --git a/registry/config_test.go b/registry/config_test.go index 2968058d23..123cedee0f 100644 --- a/registry/config_test.go +++ b/registry/config_test.go @@ -6,6 +6,7 @@ import ( "strings" "testing" + "github.com/docker/docker/errdefs" "gotest.tools/v3/assert" is "gotest.tools/v3/assert/cmp" ) @@ -255,9 +256,8 @@ func TestLoadInsecureRegistries(t *testing.T) { if err == nil { t.Fatalf("expect error '%s', got no error", testCase.err) } - if !strings.Contains(err.Error(), testCase.err) { - t.Fatalf("expect error '%s', got '%s'", testCase.err, err) - } + assert.ErrorContains(t, err, testCase.err) + assert.Check(t, errdefs.IsInvalidParameter(err)) } } } @@ -313,6 +313,7 @@ func TestNewServiceConfig(t *testing.T) { _, err := newServiceConfig(testCase.opts) if testCase.errStr != "" { assert.Check(t, is.Error(err, testCase.errStr)) + assert.Check(t, errdefs.IsInvalidParameter(err)) } else { assert.Check(t, err) } @@ -377,5 +378,6 @@ func TestValidateIndexNameWithError(t *testing.T) { for _, testCase := range invalid { _, err := ValidateIndexName(testCase.index) assert.Check(t, is.Error(err, testCase.err)) + assert.Check(t, errdefs.IsInvalidParameter(err)) } } diff --git a/registry/endpoint_v1.go b/registry/endpoint_v1.go index 0a8b2cdacc..a61e94cf8c 100644 --- a/registry/endpoint_v1.go +++ b/registry/endpoint_v1.go @@ -3,7 +3,6 @@ package registry // import "github.com/docker/docker/registry" import ( "crypto/tls" "encoding/json" - "fmt" "io" "net/http" "net/url" @@ -64,7 +63,7 @@ func validateEndpoint(endpoint *v1Endpoint) error { if endpoint.IsSecure { // If registry is secure and HTTPS failed, show user the error and tell them about `--insecure-registry` // in case that's what they need. DO NOT accept unknown CA certificates, and DO NOT fallback to HTTP. - return fmt.Errorf("invalid registry endpoint %s: %v. If this private registry supports only HTTP or HTTPS with an unknown CA certificate, please add `--insecure-registry %s` to the daemon's arguments. In the case of HTTPS, if you have access to the registry's CA certificate, no need for the flag; simply place the CA certificate at /etc/docker/certs.d/%s/ca.crt", endpoint, err, endpoint.URL.Host, endpoint.URL.Host) + return invalidParamf("invalid registry endpoint %s: %v. If this private registry supports only HTTP or HTTPS with an unknown CA certificate, please add `--insecure-registry %s` to the daemon's arguments. In the case of HTTPS, if you have access to the registry's CA certificate, no need for the flag; simply place the CA certificate at /etc/docker/certs.d/%s/ca.crt", endpoint, err, endpoint.URL.Host, endpoint.URL.Host) } // If registry is insecure and HTTPS failed, fallback to HTTP. @@ -76,7 +75,7 @@ func validateEndpoint(endpoint *v1Endpoint) error { return nil } - return fmt.Errorf("invalid registry endpoint %q. HTTPS attempt: %v. HTTP attempt: %v", endpoint, err, err2) + return invalidParamf("invalid registry endpoint %q. HTTPS attempt: %v. HTTP attempt: %v", endpoint, err, err2) } return nil @@ -99,7 +98,7 @@ func trimV1Address(address string) (string, error) { for k, v := range apiVersions { if k != APIVersion1 && apiVersionStr == v { - return "", fmt.Errorf("unsupported V1 version path %s", apiVersionStr) + return "", invalidParamf("unsupported V1 version path %s", apiVersionStr) } } @@ -118,7 +117,7 @@ func newV1EndpointFromStr(address string, tlsConfig *tls.Config, userAgent strin uri, err := url.Parse(address) if err != nil { - return nil, err + return nil, invalidParam(err) } // TODO(tiborvass): make sure a ConnectTimeout transport is used @@ -148,19 +147,19 @@ func (e *v1Endpoint) ping() (v1PingResult, error) { pingURL := e.String() + "_ping" req, err := http.NewRequest(http.MethodGet, pingURL, nil) if err != nil { - return v1PingResult{}, err + return v1PingResult{}, invalidParam(err) } resp, err := e.client.Do(req) if err != nil { - return v1PingResult{}, err + return v1PingResult{}, invalidParam(err) } defer resp.Body.Close() jsonString, err := io.ReadAll(resp.Body) if err != nil { - return v1PingResult{}, fmt.Errorf("error while reading the http response: %s", err) + return v1PingResult{}, invalidParamWrapf(err, "error while reading response from %s", pingURL) } // If the header is absent, we assume true for compatibility with earlier diff --git a/registry/errors.go b/registry/errors.go index 4906303efc..7dc20ad8ff 100644 --- a/registry/errors.go +++ b/registry/errors.go @@ -5,6 +5,7 @@ import ( "github.com/docker/distribution/registry/api/errcode" "github.com/docker/docker/errdefs" + "github.com/pkg/errors" ) func translateV2AuthError(err error) error { @@ -21,3 +22,15 @@ func translateV2AuthError(err error) error { return err } + +func invalidParam(err error) error { + return errdefs.InvalidParameter(err) +} + +func invalidParamf(format string, args ...interface{}) error { + return errdefs.InvalidParameter(errors.Errorf(format, args...)) +} + +func invalidParamWrapf(err error, format string, args ...interface{}) error { + return errdefs.InvalidParameter(errors.Wrapf(err, format, args...)) +} diff --git a/registry/registry.go b/registry/registry.go index 4241075c3f..983e4243bf 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -3,7 +3,6 @@ package registry // import "github.com/docker/docker/registry" import ( "crypto/tls" - "fmt" "net" "net/http" "os" @@ -53,7 +52,7 @@ func hasFile(files []os.DirEntry, name string) bool { func ReadCertsDirectory(tlsConfig *tls.Config, directory string) error { fs, err := os.ReadDir(directory) if err != nil && !os.IsNotExist(err) { - return err + return invalidParam(err) } for _, f := range fs { @@ -61,7 +60,7 @@ func ReadCertsDirectory(tlsConfig *tls.Config, directory string) error { if tlsConfig.RootCAs == nil { systemPool, err := tlsconfig.SystemCertPool() if err != nil { - return fmt.Errorf("unable to get system cert pool: %v", err) + return invalidParamWrapf(err, "unable to get system cert pool") } tlsConfig.RootCAs = systemPool } @@ -77,7 +76,7 @@ func ReadCertsDirectory(tlsConfig *tls.Config, directory string) error { keyName := certName[:len(certName)-5] + ".key" logrus.Debugf("cert: %s", filepath.Join(directory, f.Name())) if !hasFile(fs, keyName) { - return fmt.Errorf("missing key %s for client certificate %s. Note that CA certificates should use the extension .crt", keyName, certName) + return invalidParamf("missing key %s for client certificate %s. CA certificates must use the extension .crt", keyName, certName) } cert, err := tls.LoadX509KeyPair(filepath.Join(directory, certName), filepath.Join(directory, keyName)) if err != nil { @@ -90,7 +89,7 @@ func ReadCertsDirectory(tlsConfig *tls.Config, directory string) error { certName := keyName[:len(keyName)-4] + ".cert" logrus.Debugf("key: %s", filepath.Join(directory, f.Name())) if !hasFile(fs, certName) { - return fmt.Errorf("Missing client certificate %s for key %s", certName, keyName) + return invalidParamf("missing client certificate %s for key %s", certName, keyName) } } } diff --git a/registry/service.go b/registry/service.go index ea2fb35c14..b5e24ebac8 100644 --- a/registry/service.go +++ b/registry/service.go @@ -13,7 +13,6 @@ import ( "github.com/docker/docker/api/types" registrytypes "github.com/docker/docker/api/types/registry" "github.com/docker/docker/errdefs" - "github.com/pkg/errors" "github.com/sirupsen/logrus" ) @@ -117,7 +116,7 @@ func (s *defaultService) Auth(ctx context.Context, authConfig *types.AuthConfig, } u, err := url.Parse(serverAddress) if err != nil { - return "", "", errdefs.InvalidParameter(errors.Errorf("unable to parse server address: %v", err)) + return "", "", invalidParamWrapf(err, "unable to parse server address") } registryHostName = u.Host } @@ -127,7 +126,7 @@ func (s *defaultService) Auth(ctx context.Context, authConfig *types.AuthConfig, // to a mirror. endpoints, err := s.LookupPushEndpoints(registryHostName) if err != nil { - return "", "", errdefs.InvalidParameter(err) + return "", "", invalidParam(err) } for _, endpoint := range endpoints { @@ -162,7 +161,7 @@ func splitReposSearchTerm(reposName string) (string, string) { func (s *defaultService) Search(ctx context.Context, term string, limit int, authConfig *types.AuthConfig, userAgent string, headers map[string][]string) (*registrytypes.SearchResults, error) { // TODO Use ctx when searching for repositories if hasScheme(term) { - return nil, errors.New(`invalid repository name (ex: "registry.domain.tld/myrepos")`) + return nil, invalidParamf("invalid repository name: repository name (%s) should not have a scheme", term) } indexName, remoteName := splitReposSearchTerm(term) diff --git a/registry/service_v2.go b/registry/service_v2.go index 8736748833..11faf239e0 100644 --- a/registry/service_v2.go +++ b/registry/service_v2.go @@ -16,7 +16,7 @@ func (s *defaultService) lookupV2Endpoints(hostname string) (endpoints []APIEndp } mirrorURL, err := url.Parse(mirror) if err != nil { - return nil, err + return nil, invalidParam(err) } mirrorTLSConfig, err := s.tlsConfig(mirrorURL.Host) if err != nil { diff --git a/registry/session.go b/registry/session.go index 6316b75c43..ed3813505e 100644 --- a/registry/session.go +++ b/registry/session.go @@ -169,7 +169,7 @@ func authorizeClient(client *http.Client, authConfig *types.AuthConfig, endpoint jar, err := cookiejar.New(nil) if err != nil { - return errors.New("cookiejar.New is not supposed to return an error") + return errdefs.System(errors.New("cookiejar.New is not supposed to return an error")) } client.Jar = jar @@ -187,14 +187,14 @@ func newSession(client *http.Client, endpoint *v1Endpoint) *session { // searchRepositories performs a search against the remote repository func (r *session) searchRepositories(term string, limit int) (*registrytypes.SearchResults, error) { if limit < 1 || limit > 100 { - return nil, errdefs.InvalidParameter(errors.Errorf("Limit %d is outside the range of [1, 100]", limit)) + return nil, invalidParamf("limit %d is outside the range of [1, 100]", limit) } logrus.Debugf("Index server: %s", r.indexEndpoint) u := r.indexEndpoint.String() + "search?q=" + url.QueryEscape(term) + "&n=" + url.QueryEscape(fmt.Sprintf("%d", limit)) req, err := http.NewRequest(http.MethodGet, u, nil) if err != nil { - return nil, errors.Wrap(errdefs.InvalidParameter(err), "Error building request") + return nil, invalidParamWrapf(err, "error building request") } // Have the AuthTransport send authentication, when logged in. req.Header.Set("X-Docker-Token", "true")