diff --git a/daemon/logger/awslogs/cloudwatchlogs.go b/daemon/logger/awslogs/cloudwatchlogs.go index 3d6466f09d..0f989af379 100644 --- a/daemon/logger/awslogs/cloudwatchlogs.go +++ b/daemon/logger/awslogs/cloudwatchlogs.go @@ -29,6 +29,7 @@ import ( const ( name = "awslogs" regionKey = "awslogs-region" + endpointKey = "awslogs-endpoint" regionEnvKey = "AWS_REGION" logGroupKey = "awslogs-group" logStreamKey = "awslogs-stream" @@ -111,11 +112,11 @@ type eventBatch struct { // New creates an awslogs logger using the configuration passed in on the // context. Supported context configuration variables are awslogs-region, -// awslogs-group, awslogs-stream, awslogs-create-group, awslogs-multiline-pattern -// and awslogs-datetime-format. When available, configuration is -// also taken from environment variables AWS_REGION, AWS_ACCESS_KEY_ID, -// AWS_SECRET_ACCESS_KEY, the shared credentials file (~/.aws/credentials), and -// the EC2 Instance Metadata Service. +// awslogs-endpoint, awslogs-group, awslogs-stream, awslogs-create-group, +// awslogs-multiline-pattern and awslogs-datetime-format. +// When available, configuration is also taken from environment variables +// AWS_REGION, AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, the shared credentials +// file (~/.aws/credentials), and the EC2 Instance Metadata Service. func New(info logger.Info) (logger.Logger, error) { logGroupName := info.Config[logGroupKey] logStreamName, err := loggerutils.ParseLogTag(info, "{{.FullID}}") @@ -262,13 +263,16 @@ var newSDKEndpoint = credentialsEndpoint // User-Agent string and automatic region detection using the EC2 Instance // Metadata Service when region is otherwise unspecified. func newAWSLogsClient(info logger.Info) (api, error) { - var region *string + var region, endpoint *string if os.Getenv(regionEnvKey) != "" { region = aws.String(os.Getenv(regionEnvKey)) } if info.Config[regionKey] != "" { region = aws.String(info.Config[regionKey]) } + if info.Config[endpointKey] != "" { + endpoint = aws.String(info.Config[endpointKey]) + } if region == nil || *region == "" { logrus.Info("Trying to get region from EC2 Metadata") ec2MetadataClient := newRegionFinder() @@ -290,6 +294,11 @@ func newAWSLogsClient(info logger.Info) (api, error) { // attach region to cloudwatchlogs config sess.Config.Region = region + // attach endpoint to cloudwatchlogs config + if endpoint != nil { + sess.Config.Endpoint = endpoint + } + if uri, ok := info.Config[credentialsEndpointKey]; ok { logrus.Debugf("Trying to get credentials from awslogs-credentials-endpoint") @@ -606,7 +615,7 @@ func (l *logStream) putLogEvents(events []*cloudwatchlogs.InputLogEvent, sequenc return resp.NextSequenceToken, nil } -// ValidateLogOpt looks for awslogs-specific log options awslogs-region, +// ValidateLogOpt looks for awslogs-specific log options awslogs-region, awslogs-endpoint // awslogs-group, awslogs-stream, awslogs-create-group, awslogs-datetime-format, // awslogs-multiline-pattern func ValidateLogOpt(cfg map[string]string) error { @@ -616,6 +625,7 @@ func ValidateLogOpt(cfg map[string]string) error { case logStreamKey: case logCreateGroupKey: case regionKey: + case endpointKey: case tagKey: case datetimeFormatKey: case multilinePatternKey: diff --git a/daemon/logger/awslogs/cloudwatchlogs_test.go b/daemon/logger/awslogs/cloudwatchlogs_test.go index 6955d910c3..172be51a30 100644 --- a/daemon/logger/awslogs/cloudwatchlogs_test.go +++ b/daemon/logger/awslogs/cloudwatchlogs_test.go @@ -67,13 +67,11 @@ func TestNewAWSLogsClientUserAgentHandler(t *testing.T) { } client, err := newAWSLogsClient(info) - if err != nil { - t.Fatal(err) - } + assert.NilError(t, err) + realClient, ok := client.(*cloudwatchlogs.CloudWatchLogs) - if !ok { - t.Fatal("Could not cast client to cloudwatchlogs.CloudWatchLogs") - } + assert.Check(t, ok, "Could not cast client to cloudwatchlogs.CloudWatchLogs") + buildHandlerList := realClient.Handlers.Build request := &request.Request{ HTTPRequest: &http.Request{ @@ -90,6 +88,26 @@ func TestNewAWSLogsClientUserAgentHandler(t *testing.T) { } } +func TestNewAWSLogsClientAWSLogsEndpoint(t *testing.T) { + endpoint := "mock-endpoint" + info := logger.Info{ + Config: map[string]string{ + regionKey: "us-east-1", + endpointKey: endpoint, + }, + } + + client, err := newAWSLogsClient(info) + assert.NilError(t, err) + + realClient, ok := client.(*cloudwatchlogs.CloudWatchLogs) + assert.Check(t, ok, "Could not cast client to cloudwatchlogs.CloudWatchLogs") + + endpointWithScheme := realClient.Endpoint + expectedEndpointWithScheme := "https://" + endpoint + assert.Equal(t, endpointWithScheme, expectedEndpointWithScheme, "Wrong endpoint") +} + func TestNewAWSLogsClientRegionDetect(t *testing.T) { info := logger.Info{ Config: map[string]string{}, @@ -104,9 +122,7 @@ func TestNewAWSLogsClientRegionDetect(t *testing.T) { } _, err := newAWSLogsClient(info) - if err != nil { - t.Fatal(err) - } + assert.NilError(t, err) } func TestCreateSuccess(t *testing.T) { @@ -196,9 +212,7 @@ func TestCreateAlreadyExists(t *testing.T) { err := stream.create() - if err != nil { - t.Fatal("Expected nil err") - } + assert.NilError(t, err) } func TestLogClosed(t *testing.T) { @@ -242,9 +256,8 @@ func TestLogBlocking(t *testing.T) { } select { case err := <-errorCh: - if err != nil { - t.Fatal(err) - } + assert.NilError(t, err) + case <-time.After(30 * time.Second): t.Fatal("timed out waiting for read") } @@ -258,9 +271,7 @@ func TestLogNonBlockingBufferEmpty(t *testing.T) { logNonBlocking: true, } err := stream.Log(&logger.Message{}) - if err != nil { - t.Fatal(err) - } + assert.NilError(t, err) } func TestLogNonBlockingBufferFull(t *testing.T) { @@ -1246,9 +1257,7 @@ func TestCreateTagSuccess(t *testing.T) { err := stream.create() - if err != nil { - t.Errorf("Received unexpected err: %v\n", err) - } + assert.NilError(t, err) argument := <-mockClient.createLogStreamArgument if *argument.LogStreamName != "test-container/container-abcdefghijklmnopqrstuvwxyz01234567890" { @@ -1340,7 +1349,6 @@ func TestNewAWSLogsClientCredentialEnvironmentVariable(t *testing.T) { assert.Check(t, is.Equal(expectedAccessKeyID, creds.AccessKeyID)) assert.Check(t, is.Equal(expectedSecretAccessKey, creds.SecretAccessKey)) - } func TestNewAWSLogsClientCredentialSharedFile(t *testing.T) {