1
0
Fork 0
mirror of https://github.com/moby/moby.git synced 2022-11-09 12:21:53 -05:00
moby--moby/vendor/github.com/aws/aws-sdk-go/private/protocol/rest/unmarshal.go
Samuel Karp 44a8e10bfc
awslogs: Update aws-sdk-go to support IMDSv2
AWS recently launched a new version of the EC2 Instance Metadata
Service, which is used to provide credentials to the awslogs driver when
running on Amazon EC2.  This new version of the IMDS adds
defense-in-depth mechanisms against open firewalls, reverse proxies, and
SSRF vulnerabilities and is generally an improvement over the previous
version.  An updated version of the AWS SDK is able to handle the both
the previous version and the new version of the IMDS and functions when
either is enabled.

More information about IMDSv2 is available at the following links:

* https://aws.amazon.com/blogs/security/defense-in-depth-open-firewalls-reverse-proxies-ssrf-vulnerabilities-ec2-instance-metadata-service/
* https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html

Closes https://github.com/moby/moby/issues/40422

Signed-off-by: Samuel Karp <skarp@amazon.com>
2020-02-06 10:56:05 -08:00

257 lines
6.7 KiB
Go

package rest
import (
"bytes"
"encoding/base64"
"fmt"
"io"
"io/ioutil"
"net/http"
"reflect"
"strconv"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/request"
awsStrings "github.com/aws/aws-sdk-go/internal/strings"
"github.com/aws/aws-sdk-go/private/protocol"
)
// UnmarshalHandler is a named request handler for unmarshaling rest protocol requests
var UnmarshalHandler = request.NamedHandler{Name: "awssdk.rest.Unmarshal", Fn: Unmarshal}
// UnmarshalMetaHandler is a named request handler for unmarshaling rest protocol request metadata
var UnmarshalMetaHandler = request.NamedHandler{Name: "awssdk.rest.UnmarshalMeta", Fn: UnmarshalMeta}
// Unmarshal unmarshals the REST component of a response in a REST service.
func Unmarshal(r *request.Request) {
if r.DataFilled() {
v := reflect.Indirect(reflect.ValueOf(r.Data))
if err := unmarshalBody(r, v); err != nil {
r.Error = err
}
}
}
// UnmarshalMeta unmarshals the REST metadata of a response in a REST service
func UnmarshalMeta(r *request.Request) {
r.RequestID = r.HTTPResponse.Header.Get("X-Amzn-Requestid")
if r.RequestID == "" {
// Alternative version of request id in the header
r.RequestID = r.HTTPResponse.Header.Get("X-Amz-Request-Id")
}
if r.DataFilled() {
if err := UnmarshalResponse(r.HTTPResponse, r.Data, aws.BoolValue(r.Config.LowerCaseHeaderMaps)); err != nil {
r.Error = err
}
}
}
// UnmarshalResponse attempts to unmarshal the REST response headers to
// the data type passed in. The type must be a pointer. An error is returned
// with any error unmarshaling the response into the target datatype.
func UnmarshalResponse(resp *http.Response, data interface{}, lowerCaseHeaderMaps bool) error {
v := reflect.Indirect(reflect.ValueOf(data))
return unmarshalLocationElements(resp, v, lowerCaseHeaderMaps)
}
func unmarshalBody(r *request.Request, v reflect.Value) error {
if field, ok := v.Type().FieldByName("_"); ok {
if payloadName := field.Tag.Get("payload"); payloadName != "" {
pfield, _ := v.Type().FieldByName(payloadName)
if ptag := pfield.Tag.Get("type"); ptag != "" && ptag != "structure" {
payload := v.FieldByName(payloadName)
if payload.IsValid() {
switch payload.Interface().(type) {
case []byte:
defer r.HTTPResponse.Body.Close()
b, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil {
return awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err)
}
payload.Set(reflect.ValueOf(b))
case *string:
defer r.HTTPResponse.Body.Close()
b, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil {
return awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err)
}
str := string(b)
payload.Set(reflect.ValueOf(&str))
default:
switch payload.Type().String() {
case "io.ReadCloser":
payload.Set(reflect.ValueOf(r.HTTPResponse.Body))
case "io.ReadSeeker":
b, err := ioutil.ReadAll(r.HTTPResponse.Body)
if err != nil {
return awserr.New(request.ErrCodeSerialization,
"failed to read response body", err)
}
payload.Set(reflect.ValueOf(ioutil.NopCloser(bytes.NewReader(b))))
default:
io.Copy(ioutil.Discard, r.HTTPResponse.Body)
r.HTTPResponse.Body.Close()
return awserr.New(request.ErrCodeSerialization,
"failed to decode REST response",
fmt.Errorf("unknown payload type %s", payload.Type()))
}
}
}
}
}
}
return nil
}
func unmarshalLocationElements(resp *http.Response, v reflect.Value, lowerCaseHeaderMaps bool) error {
for i := 0; i < v.NumField(); i++ {
m, field := v.Field(i), v.Type().Field(i)
if n := field.Name; n[0:1] == strings.ToLower(n[0:1]) {
continue
}
if m.IsValid() {
name := field.Tag.Get("locationName")
if name == "" {
name = field.Name
}
switch field.Tag.Get("location") {
case "statusCode":
unmarshalStatusCode(m, resp.StatusCode)
case "header":
err := unmarshalHeader(m, resp.Header.Get(name), field.Tag)
if err != nil {
return awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err)
}
case "headers":
prefix := field.Tag.Get("locationName")
err := unmarshalHeaderMap(m, resp.Header, prefix, lowerCaseHeaderMaps)
if err != nil {
awserr.New(request.ErrCodeSerialization, "failed to decode REST response", err)
}
}
}
}
return nil
}
func unmarshalStatusCode(v reflect.Value, statusCode int) {
if !v.IsValid() {
return
}
switch v.Interface().(type) {
case *int64:
s := int64(statusCode)
v.Set(reflect.ValueOf(&s))
}
}
func unmarshalHeaderMap(r reflect.Value, headers http.Header, prefix string, normalize bool) error {
if len(headers) == 0 {
return nil
}
switch r.Interface().(type) {
case map[string]*string: // we only support string map value types
out := map[string]*string{}
for k, v := range headers {
if awsStrings.HasPrefixFold(k, prefix) {
if normalize == true {
k = strings.ToLower(k)
} else {
k = http.CanonicalHeaderKey(k)
}
out[k[len(prefix):]] = &v[0]
}
}
if len(out) != 0 {
r.Set(reflect.ValueOf(out))
}
}
return nil
}
func unmarshalHeader(v reflect.Value, header string, tag reflect.StructTag) error {
switch tag.Get("type") {
case "jsonvalue":
if len(header) == 0 {
return nil
}
case "blob":
if len(header) == 0 {
return nil
}
default:
if !v.IsValid() || (header == "" && v.Elem().Kind() != reflect.String) {
return nil
}
}
switch v.Interface().(type) {
case *string:
v.Set(reflect.ValueOf(&header))
case []byte:
b, err := base64.StdEncoding.DecodeString(header)
if err != nil {
return err
}
v.Set(reflect.ValueOf(b))
case *bool:
b, err := strconv.ParseBool(header)
if err != nil {
return err
}
v.Set(reflect.ValueOf(&b))
case *int64:
i, err := strconv.ParseInt(header, 10, 64)
if err != nil {
return err
}
v.Set(reflect.ValueOf(&i))
case *float64:
f, err := strconv.ParseFloat(header, 64)
if err != nil {
return err
}
v.Set(reflect.ValueOf(&f))
case *time.Time:
format := tag.Get("timestampFormat")
if len(format) == 0 {
format = protocol.RFC822TimeFormatName
}
t, err := protocol.ParseTime(format, header)
if err != nil {
return err
}
v.Set(reflect.ValueOf(&t))
case aws.JSONValue:
escaping := protocol.NoEscape
if tag.Get("location") == "header" {
escaping = protocol.Base64Escape
}
m, err := protocol.DecodeJSONValue(header, escaping)
if err != nil {
return err
}
v.Set(reflect.ValueOf(m))
default:
err := fmt.Errorf("Unsupported value for param %v (%s)", v.Interface(), v.Type())
return err
}
return nil
}