diff --git a/pkg/authorization/api_test.go b/pkg/authorization/api_test.go new file mode 100644 index 0000000000..1031949069 --- /dev/null +++ b/pkg/authorization/api_test.go @@ -0,0 +1,75 @@ +package authorization + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "math/big" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestPeerCertificateMarshalJSON(t *testing.T) { + template := &x509.Certificate{ + IsCA: true, + BasicConstraintsValid: true, + SubjectKeyId: []byte{1, 2, 3}, + SerialNumber: big.NewInt(1234), + Subject: pkix.Name{ + Country: []string{"Earth"}, + Organization: []string{"Mother Nature"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(5, 5, 5), + + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + } + // generate private key + privatekey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + publickey := &privatekey.PublicKey + + // create a self-signed certificate. template = parent + var parent = template + raw, err := x509.CreateCertificate(rand.Reader, template, parent, publickey, privatekey) + require.NoError(t, err) + + cert, err := x509.ParseCertificate(raw) + require.NoError(t, err) + + var certs = []*x509.Certificate{cert} + addr := "www.authz.com/auth" + req, err := http.NewRequest("GET", addr, nil) + require.NoError(t, err) + + req.RequestURI = addr + req.TLS = &tls.ConnectionState{} + req.TLS.PeerCertificates = certs + req.Header.Add("header", "value") + + for _, c := range req.TLS.PeerCertificates { + pcObj := PeerCertificate(*c) + + t.Run("Marshalling :", func(t *testing.T) { + raw, err = pcObj.MarshalJSON() + require.NotNil(t, raw) + require.Nil(t, err) + }) + + t.Run("UnMarshalling :", func(t *testing.T) { + err := pcObj.UnmarshalJSON(raw) + require.Nil(t, err) + require.Equal(t, "Earth", pcObj.Subject.Country[0]) + require.Equal(t, true, pcObj.IsCA) + + }) + + } + +} diff --git a/pkg/authorization/middleware_test.go b/pkg/authorization/middleware_test.go new file mode 100644 index 0000000000..fc7401135c --- /dev/null +++ b/pkg/authorization/middleware_test.go @@ -0,0 +1,53 @@ +package authorization + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/docker/docker/pkg/plugingetter" + "github.com/stretchr/testify/require" +) + +func TestMiddleware(t *testing.T) { + pluginNames := []string{"testPlugin1", "testPlugin2"} + var pluginGetter plugingetter.PluginGetter + m := NewMiddleware(pluginNames, pluginGetter) + authPlugins := m.getAuthzPlugins() + require.Equal(t, 2, len(authPlugins)) + require.EqualValues(t, pluginNames[0], authPlugins[0].Name()) + require.EqualValues(t, pluginNames[1], authPlugins[1].Name()) +} + +func TestNewResponseModifier(t *testing.T) { + recorder := httptest.NewRecorder() + modifier := NewResponseModifier(recorder) + modifier.Header().Set("H1", "V1") + modifier.Write([]byte("body")) + require.False(t, modifier.Hijacked()) + modifier.WriteHeader(http.StatusInternalServerError) + require.NotNil(t, modifier.RawBody()) + + raw, err := modifier.RawHeaders() + require.NotNil(t, raw) + require.Nil(t, err) + + headerData := strings.Split(strings.TrimSpace(string(raw)), ":") + require.EqualValues(t, "H1", strings.TrimSpace(headerData[0])) + require.EqualValues(t, "V1", strings.TrimSpace(headerData[1])) + + modifier.Flush() + modifier.FlushAll() + + if recorder.Header().Get("H1") != "V1" { + t.Fatalf("Header value must exists %s", recorder.Header().Get("H1")) + } + +} + +func setAuthzPlugins(m *Middleware, plugins []Plugin) { + m.mu.Lock() + m.plugins = plugins + m.mu.Unlock() +} diff --git a/pkg/authorization/middleware_unix_test.go b/pkg/authorization/middleware_unix_test.go new file mode 100644 index 0000000000..fd684f1208 --- /dev/null +++ b/pkg/authorization/middleware_unix_test.go @@ -0,0 +1,65 @@ +// +build !windows + +package authorization + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/docker/docker/pkg/plugingetter" + "github.com/stretchr/testify/require" + "golang.org/x/net/context" +) + +func TestMiddlewareWrapHandler(t *testing.T) { + server := authZPluginTestServer{t: t} + server.start() + defer server.stop() + + authZPlugin := createTestPlugin(t) + pluginNames := []string{authZPlugin.name} + + var pluginGetter plugingetter.PluginGetter + middleWare := NewMiddleware(pluginNames, pluginGetter) + handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { + return nil + } + + authList := []Plugin{authZPlugin} + middleWare.SetPlugins([]string{"My Test Plugin"}) + setAuthzPlugins(middleWare, authList) + mdHandler := middleWare.WrapHandler(handler) + require.NotNil(t, mdHandler) + + addr := "www.example.com/auth" + req, _ := http.NewRequest("GET", addr, nil) + req.RequestURI = addr + req.Header.Add("header", "value") + + resp := httptest.NewRecorder() + ctx := context.Background() + + t.Run("Error Test Case :", func(t *testing.T) { + server.replayResponse = Response{ + Allow: false, + Msg: "Server Auth Not Allowed", + } + if err := mdHandler(ctx, resp, req, map[string]string{}); err == nil { + require.Error(t, err) + } + + }) + + t.Run("Positive Test Case :", func(t *testing.T) { + server.replayResponse = Response{ + Allow: true, + Msg: "Server Auth Allowed", + } + if err := mdHandler(ctx, resp, req, map[string]string{}); err != nil { + require.NoError(t, err) + } + + }) + +}