diff --git a/pkg/truncindex/truncindex.go b/pkg/truncindex/truncindex.go index 54a0bf353e..74776e65e6 100644 --- a/pkg/truncindex/truncindex.go +++ b/pkg/truncindex/truncindex.go @@ -125,8 +125,13 @@ func (idx *TruncIndex) Get(s string) (string, error) { return "", ErrNotExist } -// Iterate iterates over all stored IDs, and passes each of them to the given handler. +// Iterate iterates over all stored IDs and passes each of them to the given +// handler. Take care that the handler method does not call any public +// method on truncindex as the internal locking is not reentrant/recursive +// and will result in deadlock. func (idx *TruncIndex) Iterate(handler func(id string)) { + idx.Lock() + defer idx.Unlock() idx.trie.Visit(func(prefix patricia.Prefix, item patricia.Item) error { handler(string(prefix)) return nil diff --git a/pkg/truncindex/truncindex_test.go b/pkg/truncindex/truncindex_test.go index 8197baf7d4..89658cabb9 100644 --- a/pkg/truncindex/truncindex_test.go +++ b/pkg/truncindex/truncindex_test.go @@ -3,6 +3,7 @@ package truncindex import ( "math/rand" "testing" + "time" "github.com/docker/docker/pkg/stringid" ) @@ -98,6 +99,7 @@ func TestTruncIndex(t *testing.T) { assertIndexGet(t, index, id, id, false) assertIndexIterate(t) + assertIndexIterateDoNotPanic(t) } func assertIndexIterate(t *testing.T) { @@ -121,6 +123,28 @@ func assertIndexIterate(t *testing.T) { }) } +func assertIndexIterateDoNotPanic(t *testing.T) { + ids := []string{ + "19b36c2c326ccc11e726eee6ee78a0baf166ef96", + "28b36c2c326ccc11e726eee6ee78a0baf166ef96", + } + + index := NewTruncIndex(ids) + iterationStarted := make(chan bool, 1) + + go func() { + <-iterationStarted + index.Delete("19b36c2c326ccc11e726eee6ee78a0baf166ef96") + }() + + index.Iterate(func(targetId string) { + if targetId == "19b36c2c326ccc11e726eee6ee78a0baf166ef96" { + iterationStarted <- true + time.Sleep(100 * time.Millisecond) + } + }) +} + func assertIndexGet(t *testing.T, index *TruncIndex, input, expectedResult string, expectError bool) { if result, err := index.Get(input); err != nil && !expectError { t.Fatalf("Unexpected error getting '%s': %s", input, err)