From a616cf3b88ee264b9a565dc0c25e583444ba64e2 Mon Sep 17 00:00:00 2001 From: Alexander Morozov Date: Tue, 11 Oct 2016 16:31:45 -0700 Subject: [PATCH] pkg/authorization: make it goroutine-safe It was racy on config reload Signed-off-by: Alexander Morozov --- pkg/authorization/middleware.go | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/pkg/authorization/middleware.go b/pkg/authorization/middleware.go index 58734ec496..c8de8db839 100644 --- a/pkg/authorization/middleware.go +++ b/pkg/authorization/middleware.go @@ -2,6 +2,7 @@ package authorization import ( "net/http" + "sync" "github.com/Sirupsen/logrus" "golang.org/x/net/context" @@ -10,6 +11,7 @@ import ( // Middleware uses a list of plugins to // handle authorization in the API requests. type Middleware struct { + mu sync.Mutex plugins []Plugin } @@ -23,14 +25,19 @@ func NewMiddleware(names []string) *Middleware { // SetPlugins sets the plugin used for authorization func (m *Middleware) SetPlugins(names []string) { + m.mu.Lock() m.plugins = newPlugins(names) + m.mu.Unlock() } // WrapHandler returns a new handler function wrapping the previous one in the request chain. func (m *Middleware) WrapHandler(handler func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error) func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { return func(ctx context.Context, w http.ResponseWriter, r *http.Request, vars map[string]string) error { - if len(m.plugins) == 0 { + m.mu.Lock() + plugins := m.plugins + m.mu.Unlock() + if len(plugins) == 0 { return handler(ctx, w, r, vars) } @@ -46,7 +53,7 @@ func (m *Middleware) WrapHandler(handler func(ctx context.Context, w http.Respon userAuthNMethod = "TLS" } - authCtx := NewCtx(m.plugins, user, userAuthNMethod, r.Method, r.RequestURI) + authCtx := NewCtx(plugins, user, userAuthNMethod, r.Method, r.RequestURI) if err := authCtx.AuthZRequest(w, r); err != nil { logrus.Errorf("AuthZRequest for %s %s returned error: %s", r.Method, r.RequestURI, err)