Merge pull request #35055 from adnxn/creds-endpoint

Add credentials endpoint option for awslogs driver
This commit is contained in:
Michael Crosby 2017-10-24 14:45:14 -04:00 committed by GitHub
commit 158c072bde
2 changed files with 163 additions and 11 deletions

View File

@ -14,6 +14,7 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/credentials/endpointcreds"
"github.com/aws/aws-sdk-go/aws/ec2metadata"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
@ -26,16 +27,17 @@ import (
)
const (
name = "awslogs"
regionKey = "awslogs-region"
regionEnvKey = "AWS_REGION"
logGroupKey = "awslogs-group"
logStreamKey = "awslogs-stream"
logCreateGroupKey = "awslogs-create-group"
tagKey = "tag"
datetimeFormatKey = "awslogs-datetime-format"
multilinePatternKey = "awslogs-multiline-pattern"
batchPublishFrequency = 5 * time.Second
name = "awslogs"
regionKey = "awslogs-region"
regionEnvKey = "AWS_REGION"
logGroupKey = "awslogs-group"
logStreamKey = "awslogs-stream"
logCreateGroupKey = "awslogs-create-group"
tagKey = "tag"
datetimeFormatKey = "awslogs-datetime-format"
multilinePatternKey = "awslogs-multiline-pattern"
credentialsEndpointKey = "awslogs-credentials-endpoint"
batchPublishFrequency = 5 * time.Second
// See: http://docs.aws.amazon.com/AmazonCloudWatchLogs/latest/APIReference/API_PutLogEvents.html
perEventBytes = 26
@ -50,6 +52,8 @@ const (
invalidSequenceTokenCode = "InvalidSequenceTokenException"
resourceNotFoundCode = "ResourceNotFoundException"
credentialsEndpoint = "http://169.254.170.2"
userAgentHeader = "User-Agent"
)
@ -198,6 +202,10 @@ var newRegionFinder = func() regionFinder {
return ec2metadata.New(session.New())
}
// newSDKEndpoint is a variable such that the implementation
// can be swapped out for unit tests.
var newSDKEndpoint = credentialsEndpoint
// newAWSLogsClient creates the service client for Amazon CloudWatch Logs.
// Customizations to the default client from the SDK include a Docker-specific
// User-Agent string and automatic region detection using the EC2 Instance
@ -222,11 +230,33 @@ func newAWSLogsClient(info logger.Info) (api, error) {
}
region = &r
}
sess, err := session.NewSession()
if err != nil {
return nil, errors.New("Failed to create a service client session for for awslogs driver")
}
// attach region to cloudwatchlogs config
sess.Config.Region = region
if uri, ok := info.Config[credentialsEndpointKey]; ok {
logrus.Debugf("Trying to get credentials from awslogs-credentials-endpoint")
endpoint := fmt.Sprintf("%s%s", newSDKEndpoint, uri)
creds := endpointcreds.NewCredentialsClient(*sess.Config, sess.Handlers, endpoint,
func(p *endpointcreds.Provider) {
p.ExpiryWindow = 5 * time.Minute
})
// attach credentials to cloudwatchlogs config
sess.Config.Credentials = creds
}
logrus.WithFields(logrus.Fields{
"region": *region,
}).Debug("Created awslogs client")
client := cloudwatchlogs.New(session.New(), aws.NewConfig().WithRegion(*region))
client := cloudwatchlogs.New(sess)
client.Handlers.Build.PushBackNamed(request.NamedHandler{
Name: "DockerUserAgentHandler",
@ -525,6 +555,7 @@ func ValidateLogOpt(cfg map[string]string) error {
case tagKey:
case datetimeFormatKey:
case multilinePatternKey:
case credentialsEndpointKey:
default:
return fmt.Errorf("unknown log opt '%s' for %s log driver", key, name)
}

View File

@ -3,7 +3,10 @@ package awslogs
import (
"errors"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"os"
"reflect"
"regexp"
"runtime"
@ -1065,3 +1068,121 @@ func BenchmarkUnwrapEvents(b *testing.B) {
as.Len(res, maximumLogEventsPerPut)
}
}
func TestNewAWSLogsClientCredentialEndpointDetect(t *testing.T) {
// required for the cloudwatchlogs client
os.Setenv("AWS_REGION", "us-west-2")
defer os.Unsetenv("AWS_REGION")
credsResp := `{
"AccessKeyId" : "test-access-key-id",
"SecretAccessKey": "test-secret-access-key"
}`
expectedAccessKeyID := "test-access-key-id"
expectedSecretAccessKey := "test-secret-access-key"
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
fmt.Fprintln(w, credsResp)
}))
defer testServer.Close()
// set the SDKEndpoint in the driver
newSDKEndpoint = testServer.URL
info := logger.Info{
Config: map[string]string{},
}
info.Config["awslogs-credentials-endpoint"] = "/creds"
c, err := newAWSLogsClient(info)
assert.NoError(t, err)
client := c.(*cloudwatchlogs.CloudWatchLogs)
creds, err := client.Config.Credentials.Get()
assert.NoError(t, err)
assert.Equal(t, expectedAccessKeyID, creds.AccessKeyID)
assert.Equal(t, expectedSecretAccessKey, creds.SecretAccessKey)
}
func TestNewAWSLogsClientCredentialEnvironmentVariable(t *testing.T) {
// required for the cloudwatchlogs client
os.Setenv("AWS_REGION", "us-west-2")
defer os.Unsetenv("AWS_REGION")
expectedAccessKeyID := "test-access-key-id"
expectedSecretAccessKey := "test-secret-access-key"
os.Setenv("AWS_ACCESS_KEY_ID", expectedAccessKeyID)
defer os.Unsetenv("AWS_ACCESS_KEY_ID")
os.Setenv("AWS_SECRET_ACCESS_KEY", expectedSecretAccessKey)
defer os.Unsetenv("AWS_SECRET_ACCESS_KEY")
info := logger.Info{
Config: map[string]string{},
}
c, err := newAWSLogsClient(info)
assert.NoError(t, err)
client := c.(*cloudwatchlogs.CloudWatchLogs)
creds, err := client.Config.Credentials.Get()
assert.NoError(t, err)
assert.Equal(t, expectedAccessKeyID, creds.AccessKeyID)
assert.Equal(t, expectedSecretAccessKey, creds.SecretAccessKey)
}
func TestNewAWSLogsClientCredentialSharedFile(t *testing.T) {
// required for the cloudwatchlogs client
os.Setenv("AWS_REGION", "us-west-2")
defer os.Unsetenv("AWS_REGION")
expectedAccessKeyID := "test-access-key-id"
expectedSecretAccessKey := "test-secret-access-key"
contentStr := `
[default]
aws_access_key_id = "test-access-key-id"
aws_secret_access_key = "test-secret-access-key"
`
content := []byte(contentStr)
tmpfile, err := ioutil.TempFile("", "example")
defer os.Remove(tmpfile.Name()) // clean up
assert.NoError(t, err)
_, err = tmpfile.Write(content)
assert.NoError(t, err)
err = tmpfile.Close()
assert.NoError(t, err)
os.Unsetenv("AWS_ACCESS_KEY_ID")
os.Unsetenv("AWS_SECRET_ACCESS_KEY")
os.Setenv("AWS_SHARED_CREDENTIALS_FILE", tmpfile.Name())
defer os.Unsetenv("AWS_SHARED_CREDENTIALS_FILE")
info := logger.Info{
Config: map[string]string{},
}
c, err := newAWSLogsClient(info)
assert.NoError(t, err)
client := c.(*cloudwatchlogs.CloudWatchLogs)
creds, err := client.Config.Credentials.Get()
assert.NoError(t, err)
assert.Equal(t, expectedAccessKeyID, creds.AccessKeyID)
assert.Equal(t, expectedSecretAccessKey, creds.SecretAccessKey)
}