diff --git a/client/ping.go b/client/ping.go index 631ed02005..a4c2e2c4dd 100644 --- a/client/ping.go +++ b/client/ping.go @@ -18,13 +18,15 @@ func (cli *Client) Ping(ctx context.Context) (types.Ping, error) { } defer ensureReaderClosed(serverResp) - ping.APIVersion = serverResp.header.Get("API-Version") + if serverResp.header != nil { + ping.APIVersion = serverResp.header.Get("API-Version") - if serverResp.header.Get("Docker-Experimental") == "true" { - ping.Experimental = true + if serverResp.header.Get("Docker-Experimental") == "true" { + ping.Experimental = true + } + ping.OSType = serverResp.header.Get("OSType") } - ping.OSType = serverResp.header.Get("OSType") - - return ping, nil + err = cli.checkResponseErr(serverResp) + return ping, err } diff --git a/client/ping_test.go b/client/ping_test.go new file mode 100644 index 0000000000..7a4a1a9024 --- /dev/null +++ b/client/ping_test.go @@ -0,0 +1,82 @@ +package client + +import ( + "errors" + "io/ioutil" + "net/http" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/net/context" +) + +// TestPingFail tests that when a server sends a non-successful response that we +// can still grab API details, when set. +// Some of this is just excercising the code paths to make sure there are no +// panics. +func TestPingFail(t *testing.T) { + var withHeader bool + client := &Client{ + client: newMockClient(func(req *http.Request) (*http.Response, error) { + resp := &http.Response{StatusCode: http.StatusInternalServerError} + if withHeader { + resp.Header = http.Header{} + resp.Header.Set("API-Version", "awesome") + resp.Header.Set("Docker-Experimental", "true") + } + resp.Body = ioutil.NopCloser(strings.NewReader("some error with the server")) + return resp, nil + }), + } + + ping, err := client.Ping(context.Background()) + assert.Error(t, err) + assert.Equal(t, false, ping.Experimental) + assert.Equal(t, "", ping.APIVersion) + + withHeader = true + ping2, err := client.Ping(context.Background()) + assert.Error(t, err) + assert.Equal(t, true, ping2.Experimental) + assert.Equal(t, "awesome", ping2.APIVersion) +} + +// TestPingWithError tests the case where there is a protocol error in the ping. +// This test is mostly just testing that there are no panics in this code path. +func TestPingWithError(t *testing.T) { + client := &Client{ + client: newMockClient(func(req *http.Request) (*http.Response, error) { + resp := &http.Response{StatusCode: http.StatusInternalServerError} + resp.Header = http.Header{} + resp.Header.Set("API-Version", "awesome") + resp.Header.Set("Docker-Experimental", "true") + resp.Body = ioutil.NopCloser(strings.NewReader("some error with the server")) + return resp, errors.New("some error") + }), + } + + ping, err := client.Ping(context.Background()) + assert.Error(t, err) + assert.Equal(t, false, ping.Experimental) + assert.Equal(t, "", ping.APIVersion) +} + +// TestPingSuccess tests that we are able to get the expected API headers/ping +// details on success. +func TestPingSuccess(t *testing.T) { + client := &Client{ + client: newMockClient(func(req *http.Request) (*http.Response, error) { + resp := &http.Response{StatusCode: http.StatusInternalServerError} + resp.Header = http.Header{} + resp.Header.Set("API-Version", "awesome") + resp.Header.Set("Docker-Experimental", "true") + resp.Body = ioutil.NopCloser(strings.NewReader("some error with the server")) + return resp, nil + }), + } + ping, err := client.Ping(context.Background()) + assert.Error(t, err) + assert.Equal(t, true, ping.Experimental) + assert.Equal(t, "awesome", ping.APIVersion) +} diff --git a/client/request.go b/client/request.go index 6457b316a3..3e7d43feac 100644 --- a/client/request.go +++ b/client/request.go @@ -24,6 +24,7 @@ type serverResponse struct { body io.ReadCloser header http.Header statusCode int + reqURL *url.URL } // head sends an http request to the docker API using the method HEAD. @@ -118,11 +119,18 @@ func (cli *Client) sendRequest(ctx context.Context, method, path string, query u if err != nil { return serverResponse{}, err } - return cli.doRequest(ctx, req) + resp, err := cli.doRequest(ctx, req) + if err != nil { + return resp, err + } + if err := cli.checkResponseErr(resp); err != nil { + return resp, err + } + return resp, nil } func (cli *Client) doRequest(ctx context.Context, req *http.Request) (serverResponse, error) { - serverResp := serverResponse{statusCode: -1} + serverResp := serverResponse{statusCode: -1, reqURL: req.URL} resp, err := ctxhttp.Do(ctx, cli.client, req) if err != nil { @@ -179,37 +187,44 @@ func (cli *Client) doRequest(ctx context.Context, req *http.Request) (serverResp if resp != nil { serverResp.statusCode = resp.StatusCode + serverResp.body = resp.Body + serverResp.header = resp.Header } - - if serverResp.statusCode < 200 || serverResp.statusCode >= 400 { - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return serverResp, err - } - if len(body) == 0 { - return serverResp, fmt.Errorf("Error: request returned %s for API route and version %s, check if the server supports the requested API version", http.StatusText(serverResp.statusCode), req.URL) - } - - var errorMessage string - if (cli.version == "" || versions.GreaterThan(cli.version, "1.23")) && - resp.Header.Get("Content-Type") == "application/json" { - var errorResponse types.ErrorResponse - if err := json.Unmarshal(body, &errorResponse); err != nil { - return serverResp, fmt.Errorf("Error reading JSON: %v", err) - } - errorMessage = errorResponse.Message - } else { - errorMessage = string(body) - } - - return serverResp, fmt.Errorf("Error response from daemon: %s", strings.TrimSpace(errorMessage)) - } - - serverResp.body = resp.Body - serverResp.header = resp.Header return serverResp, nil } +func (cli *Client) checkResponseErr(serverResp serverResponse) error { + if serverResp.statusCode >= 200 && serverResp.statusCode < 400 { + return nil + } + + body, err := ioutil.ReadAll(serverResp.body) + if err != nil { + return err + } + if len(body) == 0 { + return fmt.Errorf("Error: request returned %s for API route and version %s, check if the server supports the requested API version", http.StatusText(serverResp.statusCode), serverResp.reqURL) + } + + var ct string + if serverResp.header != nil { + ct = serverResp.header.Get("Content-Type") + } + + var errorMessage string + if (cli.version == "" || versions.GreaterThan(cli.version, "1.23")) && ct == "application/json" { + var errorResponse types.ErrorResponse + if err := json.Unmarshal(body, &errorResponse); err != nil { + return fmt.Errorf("Error reading JSON: %v", err) + } + errorMessage = errorResponse.Message + } else { + errorMessage = string(body) + } + + return fmt.Errorf("Error response from daemon: %s", strings.TrimSpace(errorMessage)) +} + func (cli *Client) addHeaders(req *http.Request, headers headers) *http.Request { // Add CLI Config's HTTP Headers BEFORE we set the Docker headers // then the user can't change OUR headers @@ -239,9 +254,9 @@ func encodeData(data interface{}) (*bytes.Buffer, error) { } func ensureReaderClosed(response serverResponse) { - if body := response.body; body != nil { + if response.body != nil { // Drain up to 512 bytes and close the body to let the Transport reuse the connection - io.CopyN(ioutil.Discard, body, 512) + io.CopyN(ioutil.Discard, response.body, 512) response.body.Close() } }