Merge pull request #23425 from runcom/authz-race
pkg: authorization: lock when lazy loading
This commit is contained in:
commit
789aee497c
3 changed files with 44 additions and 27 deletions
|
@ -6,6 +6,7 @@
|
||||||
package authorization
|
package authorization
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net"
|
"net"
|
||||||
|
@ -14,17 +15,17 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
|
||||||
|
|
||||||
"bytes"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
"github.com/docker/docker/pkg/plugins"
|
"github.com/docker/docker/pkg/plugins"
|
||||||
"github.com/docker/go-connections/tlsconfig"
|
"github.com/docker/go-connections/tlsconfig"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
)
|
)
|
||||||
|
|
||||||
const pluginAddress = "authzplugin.sock"
|
const (
|
||||||
|
pluginAddress = "authz-test-plugin.sock"
|
||||||
|
)
|
||||||
|
|
||||||
func TestAuthZRequestPluginError(t *testing.T) {
|
func TestAuthZRequestPluginError(t *testing.T) {
|
||||||
server := authZPluginTestServer{t: t}
|
server := authZPluginTestServer{t: t}
|
||||||
|
@ -36,7 +37,7 @@ func TestAuthZRequestPluginError(t *testing.T) {
|
||||||
request := Request{
|
request := Request{
|
||||||
User: "user",
|
User: "user",
|
||||||
RequestBody: []byte("sample body"),
|
RequestBody: []byte("sample body"),
|
||||||
RequestURI: "www.authz.com",
|
RequestURI: "www.authz.com/auth",
|
||||||
RequestMethod: "GET",
|
RequestMethod: "GET",
|
||||||
RequestHeaders: map[string]string{"header": "value"},
|
RequestHeaders: map[string]string{"header": "value"},
|
||||||
}
|
}
|
||||||
|
@ -50,10 +51,10 @@ func TestAuthZRequestPluginError(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
|
if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
|
||||||
t.Fatalf("Response must be equal")
|
t.Fatal("Response must be equal")
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(request, server.recordedRequest) {
|
if !reflect.DeepEqual(request, server.recordedRequest) {
|
||||||
t.Fatalf("Requests must be equal")
|
t.Fatal("Requests must be equal")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -67,7 +68,7 @@ func TestAuthZRequestPlugin(t *testing.T) {
|
||||||
request := Request{
|
request := Request{
|
||||||
User: "user",
|
User: "user",
|
||||||
RequestBody: []byte("sample body"),
|
RequestBody: []byte("sample body"),
|
||||||
RequestURI: "www.authz.com",
|
RequestURI: "www.authz.com/auth",
|
||||||
RequestMethod: "GET",
|
RequestMethod: "GET",
|
||||||
RequestHeaders: map[string]string{"header": "value"},
|
RequestHeaders: map[string]string{"header": "value"},
|
||||||
}
|
}
|
||||||
|
@ -82,10 +83,10 @@ func TestAuthZRequestPlugin(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
|
if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
|
||||||
t.Fatalf("Response must be equal")
|
t.Fatal("Response must be equal")
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(request, server.recordedRequest) {
|
if !reflect.DeepEqual(request, server.recordedRequest) {
|
||||||
t.Fatalf("Requests must be equal")
|
t.Fatal("Requests must be equal")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -98,6 +99,7 @@ func TestAuthZResponsePlugin(t *testing.T) {
|
||||||
|
|
||||||
request := Request{
|
request := Request{
|
||||||
User: "user",
|
User: "user",
|
||||||
|
RequestURI: "someting.com/auth",
|
||||||
RequestBody: []byte("sample body"),
|
RequestBody: []byte("sample body"),
|
||||||
}
|
}
|
||||||
server.replayResponse = Response{
|
server.replayResponse = Response{
|
||||||
|
@ -111,10 +113,10 @@ func TestAuthZResponsePlugin(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
|
if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
|
||||||
t.Fatalf("Response must be equal")
|
t.Fatal("Response must be equal")
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(request, server.recordedRequest) {
|
if !reflect.DeepEqual(request, server.recordedRequest) {
|
||||||
t.Fatalf("Requests must be equal")
|
t.Fatal("Requests must be equal")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -158,7 +160,7 @@ func TestDrainBody(t *testing.T) {
|
||||||
t.Fatalf("Body must be copied, actual length: '%d'", len(body))
|
t.Fatalf("Body must be copied, actual length: '%d'", len(body))
|
||||||
}
|
}
|
||||||
if closer == nil {
|
if closer == nil {
|
||||||
t.Fatalf("Closer must not be nil")
|
t.Fatal("Closer must not be nil")
|
||||||
}
|
}
|
||||||
modified, err := ioutil.ReadAll(closer)
|
modified, err := ioutil.ReadAll(closer)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -229,8 +231,10 @@ type authZPluginTestServer struct {
|
||||||
// start starts the test server that implements the plugin
|
// start starts the test server that implements the plugin
|
||||||
func (t *authZPluginTestServer) start() {
|
func (t *authZPluginTestServer) start() {
|
||||||
r := mux.NewRouter()
|
r := mux.NewRouter()
|
||||||
os.Remove(pluginAddress)
|
l, err := net.Listen("unix", pluginAddress)
|
||||||
l, _ := net.ListenUnix("unix", &net.UnixAddr{Name: pluginAddress, Net: "unix"})
|
if err != nil {
|
||||||
|
t.t.Fatal(err)
|
||||||
|
}
|
||||||
t.listener = l
|
t.listener = l
|
||||||
r.HandleFunc("/Plugin.Activate", t.activate)
|
r.HandleFunc("/Plugin.Activate", t.activate)
|
||||||
r.HandleFunc("/"+AuthZApiRequest, t.auth)
|
r.HandleFunc("/"+AuthZApiRequest, t.auth)
|
||||||
|
@ -257,14 +261,23 @@ func (t *authZPluginTestServer) stop() {
|
||||||
// auth is a used to record/replay the authentication api messages
|
// auth is a used to record/replay the authentication api messages
|
||||||
func (t *authZPluginTestServer) auth(w http.ResponseWriter, r *http.Request) {
|
func (t *authZPluginTestServer) auth(w http.ResponseWriter, r *http.Request) {
|
||||||
t.recordedRequest = Request{}
|
t.recordedRequest = Request{}
|
||||||
body, _ := ioutil.ReadAll(r.Body)
|
body, err := ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.t.Fatal(err)
|
||||||
|
}
|
||||||
r.Body.Close()
|
r.Body.Close()
|
||||||
json.Unmarshal(body, &t.recordedRequest)
|
json.Unmarshal(body, &t.recordedRequest)
|
||||||
b, _ := json.Marshal(t.replayResponse)
|
b, err := json.Marshal(t.replayResponse)
|
||||||
|
if err != nil {
|
||||||
|
t.t.Fatal(err)
|
||||||
|
}
|
||||||
w.Write(b)
|
w.Write(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *authZPluginTestServer) activate(w http.ResponseWriter, r *http.Request) {
|
func (t *authZPluginTestServer) activate(w http.ResponseWriter, r *http.Request) {
|
||||||
b, _ := json.Marshal(plugins.Manifest{Implements: []string{AuthZApiImplements}})
|
b, err := json.Marshal(plugins.Manifest{Implements: []string{AuthZApiImplements}})
|
||||||
|
if err != nil {
|
||||||
|
t.t.Fatal(err)
|
||||||
|
}
|
||||||
w.Write(b)
|
w.Write(b)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
package authorization
|
package authorization
|
||||||
|
|
||||||
import "github.com/docker/docker/pkg/plugins"
|
import (
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/docker/docker/pkg/plugins"
|
||||||
|
)
|
||||||
|
|
||||||
// Plugin allows third party plugins to authorize requests and responses
|
// Plugin allows third party plugins to authorize requests and responses
|
||||||
// in the context of docker API
|
// in the context of docker API
|
||||||
|
@ -33,6 +37,7 @@ func NewPlugins(names []string) []Plugin {
|
||||||
type authorizationPlugin struct {
|
type authorizationPlugin struct {
|
||||||
plugin *plugins.Plugin
|
plugin *plugins.Plugin
|
||||||
name string
|
name string
|
||||||
|
once sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAuthorizationPlugin(name string) Plugin {
|
func newAuthorizationPlugin(name string) Plugin {
|
||||||
|
@ -72,12 +77,11 @@ func (a *authorizationPlugin) AuthZResponse(authReq *Request) (*Response, error)
|
||||||
// initPlugin initializes the authorization plugin if needed
|
// initPlugin initializes the authorization plugin if needed
|
||||||
func (a *authorizationPlugin) initPlugin() error {
|
func (a *authorizationPlugin) initPlugin() error {
|
||||||
// Lazy loading of plugins
|
// Lazy loading of plugins
|
||||||
if a.plugin == nil {
|
|
||||||
var err error
|
var err error
|
||||||
|
a.once.Do(func() {
|
||||||
|
if a.plugin == nil {
|
||||||
a.plugin, err = plugins.Get(a.name, AuthZApiImplements)
|
a.plugin, err = plugins.Get(a.name, AuthZApiImplements)
|
||||||
if err != nil {
|
}
|
||||||
|
})
|
||||||
return err
|
return err
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -130,7 +130,7 @@ func (c *Client) callWithRetry(serviceMethod string, data io.Reader, retry bool)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
retries++
|
retries++
|
||||||
logrus.Warnf("Unable to connect to plugin: %s:%s, retrying in %v", req.URL.Host, req.URL.Path, timeOff)
|
logrus.Warnf("Unable to connect to plugin: %s%s: %v, retrying in %v", req.URL.Host, req.URL.Path, err, timeOff)
|
||||||
time.Sleep(timeOff)
|
time.Sleep(timeOff)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue