2016-02-11 20:42:12 -05:00
|
|
|
// +build !windows
|
|
|
|
|
|
|
|
// TODO Windows: This uses a Unix socket for testing. This might be possible
|
|
|
|
// to port to Windows using a named pipe instead.
|
|
|
|
|
2015-11-12 06:06:47 -05:00
|
|
|
package authorization
|
|
|
|
|
|
|
|
import (
|
2016-06-12 11:23:19 -04:00
|
|
|
"bytes"
|
2015-11-12 06:06:47 -05:00
|
|
|
"encoding/json"
|
|
|
|
"io/ioutil"
|
|
|
|
"net"
|
|
|
|
"net/http"
|
|
|
|
"net/http/httptest"
|
|
|
|
"os"
|
|
|
|
"path"
|
|
|
|
"reflect"
|
2016-03-01 20:16:40 -05:00
|
|
|
"strings"
|
2016-06-12 11:23:19 -04:00
|
|
|
"testing"
|
2016-03-01 20:16:40 -05:00
|
|
|
|
2015-12-16 06:01:04 -05:00
|
|
|
"github.com/docker/docker/pkg/plugins"
|
2015-12-29 19:27:12 -05:00
|
|
|
"github.com/docker/go-connections/tlsconfig"
|
2015-12-16 06:01:04 -05:00
|
|
|
"github.com/gorilla/mux"
|
2015-11-12 06:06:47 -05:00
|
|
|
)
|
|
|
|
|
2016-06-12 11:23:19 -04:00
|
|
|
const (
|
|
|
|
pluginAddress = "authz-test-plugin.sock"
|
|
|
|
)
|
2015-11-12 06:06:47 -05:00
|
|
|
|
2015-12-15 03:49:18 -05:00
|
|
|
func TestAuthZRequestPluginError(t *testing.T) {
|
|
|
|
server := authZPluginTestServer{t: t}
|
2016-05-31 21:34:35 -04:00
|
|
|
server.start()
|
2015-12-15 03:49:18 -05:00
|
|
|
defer server.stop()
|
|
|
|
|
|
|
|
authZPlugin := createTestPlugin(t)
|
|
|
|
|
|
|
|
request := Request{
|
|
|
|
User: "user",
|
|
|
|
RequestBody: []byte("sample body"),
|
2016-06-12 11:23:19 -04:00
|
|
|
RequestURI: "www.authz.com/auth",
|
2015-12-15 03:49:18 -05:00
|
|
|
RequestMethod: "GET",
|
|
|
|
RequestHeaders: map[string]string{"header": "value"},
|
|
|
|
}
|
|
|
|
server.replayResponse = Response{
|
|
|
|
Err: "an error",
|
|
|
|
}
|
|
|
|
|
|
|
|
actualResponse, err := authZPlugin.AuthZRequest(&request)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Failed to authorize request %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
|
2016-06-12 11:23:19 -04:00
|
|
|
t.Fatal("Response must be equal")
|
2015-12-15 03:49:18 -05:00
|
|
|
}
|
|
|
|
if !reflect.DeepEqual(request, server.recordedRequest) {
|
2016-06-12 11:23:19 -04:00
|
|
|
t.Fatal("Requests must be equal")
|
2015-12-15 03:49:18 -05:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2015-11-12 06:06:47 -05:00
|
|
|
func TestAuthZRequestPlugin(t *testing.T) {
|
|
|
|
server := authZPluginTestServer{t: t}
|
2016-05-31 21:34:35 -04:00
|
|
|
server.start()
|
2015-11-12 06:06:47 -05:00
|
|
|
defer server.stop()
|
|
|
|
|
|
|
|
authZPlugin := createTestPlugin(t)
|
|
|
|
|
|
|
|
request := Request{
|
|
|
|
User: "user",
|
|
|
|
RequestBody: []byte("sample body"),
|
2016-06-12 11:23:19 -04:00
|
|
|
RequestURI: "www.authz.com/auth",
|
2015-11-12 06:06:47 -05:00
|
|
|
RequestMethod: "GET",
|
|
|
|
RequestHeaders: map[string]string{"header": "value"},
|
|
|
|
}
|
|
|
|
server.replayResponse = Response{
|
|
|
|
Allow: true,
|
|
|
|
Msg: "Sample message",
|
|
|
|
}
|
|
|
|
|
|
|
|
actualResponse, err := authZPlugin.AuthZRequest(&request)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Failed to authorize request %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
|
2016-06-12 11:23:19 -04:00
|
|
|
t.Fatal("Response must be equal")
|
2015-11-12 06:06:47 -05:00
|
|
|
}
|
|
|
|
if !reflect.DeepEqual(request, server.recordedRequest) {
|
2016-06-12 11:23:19 -04:00
|
|
|
t.Fatal("Requests must be equal")
|
2015-11-12 06:06:47 -05:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestAuthZResponsePlugin(t *testing.T) {
|
|
|
|
server := authZPluginTestServer{t: t}
|
2016-05-31 21:34:35 -04:00
|
|
|
server.start()
|
2015-11-12 06:06:47 -05:00
|
|
|
defer server.stop()
|
|
|
|
|
|
|
|
authZPlugin := createTestPlugin(t)
|
|
|
|
|
|
|
|
request := Request{
|
|
|
|
User: "user",
|
2016-06-12 11:23:19 -04:00
|
|
|
RequestURI: "someting.com/auth",
|
2015-11-12 06:06:47 -05:00
|
|
|
RequestBody: []byte("sample body"),
|
|
|
|
}
|
|
|
|
server.replayResponse = Response{
|
|
|
|
Allow: true,
|
|
|
|
Msg: "Sample message",
|
|
|
|
}
|
|
|
|
|
|
|
|
actualResponse, err := authZPlugin.AuthZResponse(&request)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Failed to authorize request %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
|
2016-06-12 11:23:19 -04:00
|
|
|
t.Fatal("Response must be equal")
|
2015-11-12 06:06:47 -05:00
|
|
|
}
|
|
|
|
if !reflect.DeepEqual(request, server.recordedRequest) {
|
2016-06-12 11:23:19 -04:00
|
|
|
t.Fatal("Requests must be equal")
|
2015-11-12 06:06:47 -05:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestResponseModifier(t *testing.T) {
|
|
|
|
r := httptest.NewRecorder()
|
|
|
|
m := NewResponseModifier(r)
|
|
|
|
m.Header().Set("h1", "v1")
|
|
|
|
m.Write([]byte("body"))
|
2016-07-19 03:40:20 -04:00
|
|
|
m.WriteHeader(http.StatusInternalServerError)
|
2015-11-12 06:06:47 -05:00
|
|
|
|
2016-02-04 10:41:41 -05:00
|
|
|
m.FlushAll()
|
2015-11-12 06:06:47 -05:00
|
|
|
if r.Header().Get("h1") != "v1" {
|
|
|
|
t.Fatalf("Header value must exists %s", r.Header().Get("h1"))
|
|
|
|
}
|
|
|
|
if !reflect.DeepEqual(r.Body.Bytes(), []byte("body")) {
|
|
|
|
t.Fatalf("Body value must exists %s", r.Body.Bytes())
|
|
|
|
}
|
2016-07-19 03:40:20 -04:00
|
|
|
if r.Code != http.StatusInternalServerError {
|
2015-11-12 06:06:47 -05:00
|
|
|
t.Fatalf("Status code must be correct %d", r.Code)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2016-02-23 05:04:16 -05:00
|
|
|
func TestDrainBody(t *testing.T) {
|
|
|
|
tests := []struct {
|
|
|
|
length int // length is the message length send to drainBody
|
|
|
|
expectedBodyLength int // expectedBodyLength is the expected body length after drainBody is called
|
|
|
|
}{
|
|
|
|
{10, 10}, // Small message size
|
|
|
|
{maxBodySize - 1, maxBodySize - 1}, // Max message size
|
|
|
|
{maxBodySize * 2, 0}, // Large message size (skip copying body)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, test := range tests {
|
|
|
|
msg := strings.Repeat("a", test.length)
|
|
|
|
body, closer, err := drainBody(ioutil.NopCloser(bytes.NewReader([]byte(msg))))
|
2016-05-08 08:01:23 -04:00
|
|
|
if err != nil {
|
|
|
|
t.Fatal(err)
|
|
|
|
}
|
2016-02-23 05:04:16 -05:00
|
|
|
if len(body) != test.expectedBodyLength {
|
|
|
|
t.Fatalf("Body must be copied, actual length: '%d'", len(body))
|
|
|
|
}
|
|
|
|
if closer == nil {
|
2016-06-12 11:23:19 -04:00
|
|
|
t.Fatal("Closer must not be nil")
|
2016-02-23 05:04:16 -05:00
|
|
|
}
|
|
|
|
modified, err := ioutil.ReadAll(closer)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Error must not be nil: '%v'", err)
|
|
|
|
}
|
|
|
|
if len(modified) != len(msg) {
|
|
|
|
t.Fatalf("Result should not be truncated. Original length: '%d', new length: '%d'", len(msg), len(modified))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2015-11-12 06:06:47 -05:00
|
|
|
func TestResponseModifierOverride(t *testing.T) {
|
|
|
|
r := httptest.NewRecorder()
|
|
|
|
m := NewResponseModifier(r)
|
|
|
|
m.Header().Set("h1", "v1")
|
|
|
|
m.Write([]byte("body"))
|
2016-07-19 03:40:20 -04:00
|
|
|
m.WriteHeader(http.StatusInternalServerError)
|
2015-11-12 06:06:47 -05:00
|
|
|
|
|
|
|
overrideHeader := make(http.Header)
|
|
|
|
overrideHeader.Add("h1", "v2")
|
|
|
|
overrideHeaderBytes, err := json.Marshal(overrideHeader)
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("override header failed %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
m.OverrideHeader(overrideHeaderBytes)
|
|
|
|
m.OverrideBody([]byte("override body"))
|
2016-07-19 03:40:20 -04:00
|
|
|
m.OverrideStatusCode(http.StatusNotFound)
|
2016-02-04 10:41:41 -05:00
|
|
|
m.FlushAll()
|
2015-11-12 06:06:47 -05:00
|
|
|
if r.Header().Get("h1") != "v2" {
|
|
|
|
t.Fatalf("Header value must exists %s", r.Header().Get("h1"))
|
|
|
|
}
|
|
|
|
if !reflect.DeepEqual(r.Body.Bytes(), []byte("override body")) {
|
|
|
|
t.Fatalf("Body value must exists %s", r.Body.Bytes())
|
|
|
|
}
|
2016-07-19 03:40:20 -04:00
|
|
|
if r.Code != http.StatusNotFound {
|
2015-11-12 06:06:47 -05:00
|
|
|
t.Fatalf("Status code must be correct %d", r.Code)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// createTestPlugin creates a new sample authorization plugin
|
|
|
|
func createTestPlugin(t *testing.T) *authorizationPlugin {
|
|
|
|
pwd, err := os.Getwd()
|
|
|
|
if err != nil {
|
2016-05-08 08:01:23 -04:00
|
|
|
t.Fatal(err)
|
2015-11-12 06:06:47 -05:00
|
|
|
}
|
|
|
|
|
2016-05-16 11:50:55 -04:00
|
|
|
client, err := plugins.NewClient("unix:///"+path.Join(pwd, pluginAddress), &tlsconfig.Options{InsecureSkipVerify: true})
|
2015-11-12 06:06:47 -05:00
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("Failed to create client %v", err)
|
|
|
|
}
|
|
|
|
|
2016-05-16 11:50:55 -04:00
|
|
|
return &authorizationPlugin{name: "plugin", plugin: client}
|
2015-11-12 06:06:47 -05:00
|
|
|
}
|
|
|
|
|
|
|
|
// AuthZPluginTestServer is a simple server that implements the authZ plugin interface
|
|
|
|
type authZPluginTestServer struct {
|
|
|
|
listener net.Listener
|
|
|
|
t *testing.T
|
|
|
|
// request stores the request sent from the daemon to the plugin
|
|
|
|
recordedRequest Request
|
|
|
|
// response stores the response sent from the plugin to the daemon
|
|
|
|
replayResponse Response
|
2016-05-31 21:34:35 -04:00
|
|
|
server *httptest.Server
|
2015-11-12 06:06:47 -05:00
|
|
|
}
|
|
|
|
|
|
|
|
// start starts the test server that implements the plugin
|
|
|
|
func (t *authZPluginTestServer) start() {
|
|
|
|
r := mux.NewRouter()
|
2016-06-12 11:23:19 -04:00
|
|
|
l, err := net.Listen("unix", pluginAddress)
|
|
|
|
if err != nil {
|
|
|
|
t.t.Fatal(err)
|
|
|
|
}
|
2015-11-12 06:06:47 -05:00
|
|
|
t.listener = l
|
|
|
|
r.HandleFunc("/Plugin.Activate", t.activate)
|
|
|
|
r.HandleFunc("/"+AuthZApiRequest, t.auth)
|
|
|
|
r.HandleFunc("/"+AuthZApiResponse, t.auth)
|
2016-05-31 21:34:35 -04:00
|
|
|
t.server = &httptest.Server{
|
|
|
|
Listener: l,
|
|
|
|
Config: &http.Server{
|
|
|
|
Handler: r,
|
|
|
|
Addr: pluginAddress,
|
|
|
|
},
|
|
|
|
}
|
|
|
|
t.server.Start()
|
2015-11-12 06:06:47 -05:00
|
|
|
}
|
|
|
|
|
|
|
|
// stop stops the test server that implements the plugin
|
|
|
|
func (t *authZPluginTestServer) stop() {
|
2016-05-31 21:34:35 -04:00
|
|
|
t.server.Close()
|
2015-11-12 06:06:47 -05:00
|
|
|
os.Remove(pluginAddress)
|
|
|
|
if t.listener != nil {
|
|
|
|
t.listener.Close()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// auth is a used to record/replay the authentication api messages
|
|
|
|
func (t *authZPluginTestServer) auth(w http.ResponseWriter, r *http.Request) {
|
|
|
|
t.recordedRequest = Request{}
|
2016-06-12 11:23:19 -04:00
|
|
|
body, err := ioutil.ReadAll(r.Body)
|
|
|
|
if err != nil {
|
|
|
|
t.t.Fatal(err)
|
|
|
|
}
|
2016-05-08 08:01:23 -04:00
|
|
|
r.Body.Close()
|
2015-11-12 06:06:47 -05:00
|
|
|
json.Unmarshal(body, &t.recordedRequest)
|
2016-06-12 11:23:19 -04:00
|
|
|
b, err := json.Marshal(t.replayResponse)
|
|
|
|
if err != nil {
|
|
|
|
t.t.Fatal(err)
|
|
|
|
}
|
2015-11-12 06:06:47 -05:00
|
|
|
w.Write(b)
|
|
|
|
}
|
|
|
|
|
|
|
|
func (t *authZPluginTestServer) activate(w http.ResponseWriter, r *http.Request) {
|
2016-06-12 11:23:19 -04:00
|
|
|
b, err := json.Marshal(plugins.Manifest{Implements: []string{AuthZApiImplements}})
|
|
|
|
if err != nil {
|
|
|
|
t.t.Fatal(err)
|
|
|
|
}
|
2015-11-12 06:06:47 -05:00
|
|
|
w.Write(b)
|
|
|
|
}
|