diff --git a/daemon/logger/awslogs/cloudwatchlogs.go b/daemon/logger/awslogs/cloudwatchlogs.go index 3a7f2f631d..97882538b4 100644 --- a/daemon/logger/awslogs/cloudwatchlogs.go +++ b/daemon/logger/awslogs/cloudwatchlogs.go @@ -14,6 +14,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/credentials/endpointcreds" "github.com/aws/aws-sdk-go/aws/ec2metadata" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/session" @@ -26,16 +27,17 @@ import ( ) const ( - name = "awslogs" - regionKey = "awslogs-region" - regionEnvKey = "AWS_REGION" - logGroupKey = "awslogs-group" - logStreamKey = "awslogs-stream" - logCreateGroupKey = "awslogs-create-group" - tagKey = "tag" - datetimeFormatKey = "awslogs-datetime-format" - multilinePatternKey = "awslogs-multiline-pattern" - batchPublishFrequency = 5 * time.Second + name = "awslogs" + regionKey = "awslogs-region" + regionEnvKey = "AWS_REGION" + logGroupKey = "awslogs-group" + logStreamKey = "awslogs-stream" + logCreateGroupKey = "awslogs-create-group" + tagKey = "tag" + datetimeFormatKey = "awslogs-datetime-format" + multilinePatternKey = "awslogs-multiline-pattern" + credentialsEndpointKey = "awslogs-credentials-endpoint" + batchPublishFrequency = 5 * time.Second // See: http://docs.aws.amazon.com/AmazonCloudWatchLogs/latest/APIReference/API_PutLogEvents.html perEventBytes = 26 @@ -50,6 +52,8 @@ const ( invalidSequenceTokenCode = "InvalidSequenceTokenException" resourceNotFoundCode = "ResourceNotFoundException" + credentialsEndpoint = "http://169.254.170.2" + userAgentHeader = "User-Agent" ) @@ -198,6 +202,10 @@ var newRegionFinder = func() regionFinder { return ec2metadata.New(session.New()) } +// newSDKEndpoint is a variable such that the implementation +// can be swapped out for unit tests. +var newSDKEndpoint = credentialsEndpoint + // newAWSLogsClient creates the service client for Amazon CloudWatch Logs. // Customizations to the default client from the SDK include a Docker-specific // User-Agent string and automatic region detection using the EC2 Instance @@ -222,11 +230,33 @@ func newAWSLogsClient(info logger.Info) (api, error) { } region = &r } + + sess, err := session.NewSession() + if err != nil { + return nil, errors.New("Failed to create a service client session for for awslogs driver") + } + + // attach region to cloudwatchlogs config + sess.Config.Region = region + + if uri, ok := info.Config[credentialsEndpointKey]; ok { + logrus.Debugf("Trying to get credentials from awslogs-credentials-endpoint") + + endpoint := fmt.Sprintf("%s%s", newSDKEndpoint, uri) + creds := endpointcreds.NewCredentialsClient(*sess.Config, sess.Handlers, endpoint, + func(p *endpointcreds.Provider) { + p.ExpiryWindow = 5 * time.Minute + }) + + // attach credentials to cloudwatchlogs config + sess.Config.Credentials = creds + } + logrus.WithFields(logrus.Fields{ "region": *region, }).Debug("Created awslogs client") - client := cloudwatchlogs.New(session.New(), aws.NewConfig().WithRegion(*region)) + client := cloudwatchlogs.New(sess) client.Handlers.Build.PushBackNamed(request.NamedHandler{ Name: "DockerUserAgentHandler", @@ -525,6 +555,7 @@ func ValidateLogOpt(cfg map[string]string) error { case tagKey: case datetimeFormatKey: case multilinePatternKey: + case credentialsEndpointKey: default: return fmt.Errorf("unknown log opt '%s' for %s log driver", key, name) } diff --git a/daemon/logger/awslogs/cloudwatchlogs_test.go b/daemon/logger/awslogs/cloudwatchlogs_test.go index 989eb6f52c..7d482d8196 100644 --- a/daemon/logger/awslogs/cloudwatchlogs_test.go +++ b/daemon/logger/awslogs/cloudwatchlogs_test.go @@ -3,7 +3,10 @@ package awslogs import ( "errors" "fmt" + "io/ioutil" "net/http" + "net/http/httptest" + "os" "reflect" "regexp" "runtime" @@ -1065,3 +1068,121 @@ func BenchmarkUnwrapEvents(b *testing.B) { as.Len(res, maximumLogEventsPerPut) } } + +func TestNewAWSLogsClientCredentialEndpointDetect(t *testing.T) { + // required for the cloudwatchlogs client + os.Setenv("AWS_REGION", "us-west-2") + defer os.Unsetenv("AWS_REGION") + + credsResp := `{ + "AccessKeyId" : "test-access-key-id", + "SecretAccessKey": "test-secret-access-key" + }` + + expectedAccessKeyID := "test-access-key-id" + expectedSecretAccessKey := "test-secret-access-key" + + testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprintln(w, credsResp) + })) + defer testServer.Close() + + // set the SDKEndpoint in the driver + newSDKEndpoint = testServer.URL + + info := logger.Info{ + Config: map[string]string{}, + } + + info.Config["awslogs-credentials-endpoint"] = "/creds" + + c, err := newAWSLogsClient(info) + assert.NoError(t, err) + + client := c.(*cloudwatchlogs.CloudWatchLogs) + + creds, err := client.Config.Credentials.Get() + assert.NoError(t, err) + + assert.Equal(t, expectedAccessKeyID, creds.AccessKeyID) + assert.Equal(t, expectedSecretAccessKey, creds.SecretAccessKey) +} + +func TestNewAWSLogsClientCredentialEnvironmentVariable(t *testing.T) { + // required for the cloudwatchlogs client + os.Setenv("AWS_REGION", "us-west-2") + defer os.Unsetenv("AWS_REGION") + + expectedAccessKeyID := "test-access-key-id" + expectedSecretAccessKey := "test-secret-access-key" + + os.Setenv("AWS_ACCESS_KEY_ID", expectedAccessKeyID) + defer os.Unsetenv("AWS_ACCESS_KEY_ID") + + os.Setenv("AWS_SECRET_ACCESS_KEY", expectedSecretAccessKey) + defer os.Unsetenv("AWS_SECRET_ACCESS_KEY") + + info := logger.Info{ + Config: map[string]string{}, + } + + c, err := newAWSLogsClient(info) + assert.NoError(t, err) + + client := c.(*cloudwatchlogs.CloudWatchLogs) + + creds, err := client.Config.Credentials.Get() + assert.NoError(t, err) + + assert.Equal(t, expectedAccessKeyID, creds.AccessKeyID) + assert.Equal(t, expectedSecretAccessKey, creds.SecretAccessKey) + +} + +func TestNewAWSLogsClientCredentialSharedFile(t *testing.T) { + // required for the cloudwatchlogs client + os.Setenv("AWS_REGION", "us-west-2") + defer os.Unsetenv("AWS_REGION") + + expectedAccessKeyID := "test-access-key-id" + expectedSecretAccessKey := "test-secret-access-key" + + contentStr := ` + [default] + aws_access_key_id = "test-access-key-id" + aws_secret_access_key = "test-secret-access-key" + ` + content := []byte(contentStr) + + tmpfile, err := ioutil.TempFile("", "example") + defer os.Remove(tmpfile.Name()) // clean up + assert.NoError(t, err) + + _, err = tmpfile.Write(content) + assert.NoError(t, err) + + err = tmpfile.Close() + assert.NoError(t, err) + + os.Unsetenv("AWS_ACCESS_KEY_ID") + os.Unsetenv("AWS_SECRET_ACCESS_KEY") + + os.Setenv("AWS_SHARED_CREDENTIALS_FILE", tmpfile.Name()) + defer os.Unsetenv("AWS_SHARED_CREDENTIALS_FILE") + + info := logger.Info{ + Config: map[string]string{}, + } + + c, err := newAWSLogsClient(info) + assert.NoError(t, err) + + client := c.(*cloudwatchlogs.CloudWatchLogs) + + creds, err := client.Config.Credentials.Get() + assert.NoError(t, err) + + assert.Equal(t, expectedAccessKeyID, creds.AccessKeyID) + assert.Equal(t, expectedSecretAccessKey, creds.SecretAccessKey) +}