diff --git a/libnetwork/networkdb/networkdb_test.go b/libnetwork/networkdb/networkdb_test.go index c490a47f29..16bb293785 100644 --- a/libnetwork/networkdb/networkdb_test.go +++ b/libnetwork/networkdb/networkdb_test.go @@ -13,7 +13,6 @@ import ( "github.com/docker/docker/pkg/stringid" "github.com/docker/go-events" "github.com/hashicorp/memberlist" - "github.com/hashicorp/serf/serf" "github.com/sirupsen/logrus" "gotest.tools/v3/assert" is "gotest.tools/v3/assert/cmp" @@ -480,47 +479,43 @@ func TestNetworkDBCRUDMediumCluster(t *testing.T) { func TestNetworkDBNodeJoinLeaveIteration(t *testing.T) { dbs := createNetworkDBInstances(t, 2, "node", DefaultConfig()) - var ( - dbIndex int32 - staleNetworkTime [2]serf.LamportTime - expectNodeCount int - network = "network1" - ) - dbChangeWitness := func(t poll.LogT) poll.Result { - db := dbs[dbIndex] - networkTime := db.networkClock.Time() - if networkTime <= staleNetworkTime[dbIndex] { - return poll.Continue("network time is stale, no change registered yet.") + dbChangeWitness := func(db *NetworkDB) func(network string, expectNodeCount int) { + staleNetworkTime := db.networkClock.Time() + return func(network string, expectNodeCount int) { + check := func(t poll.LogT) poll.Result { + networkTime := db.networkClock.Time() + if networkTime <= staleNetworkTime { + return poll.Continue("network time is stale, no change registered yet.") + } + count := -1 + db.Lock() + if nodes, ok := db.networkNodes[network]; ok { + count = len(nodes) + } + db.Unlock() + if count != expectNodeCount { + return poll.Continue("current number of nodes is %d, expect %d.", count, expectNodeCount) + } + return poll.Success() + } + t.Helper() + poll.WaitOn(t, check, poll.WithTimeout(3*time.Second), poll.WithDelay(5*time.Millisecond)) } - count := -1 - db.Lock() - if nodes, ok := db.networkNodes[network]; ok { - count = len(nodes) - } - db.Unlock() - if count != expectNodeCount { - return poll.Continue("current number of nodes is %d, expect %d.", count, expectNodeCount) - } - return poll.Success() } // Single node Join/Leave - staleNetworkTime[0], staleNetworkTime[1] = dbs[0].networkClock.Time(), dbs[1].networkClock.Time() + witness0 := dbChangeWitness(dbs[0]) err := dbs[0].JoinNetwork("network1") assert.NilError(t, err) + witness0("network1", 1) - dbIndex, expectNodeCount = 0, 1 - poll.WaitOn(t, dbChangeWitness, poll.WithTimeout(3*time.Second), poll.WithDelay(5*time.Millisecond)) - - staleNetworkTime[0], staleNetworkTime[1] = dbs[0].networkClock.Time(), dbs[1].networkClock.Time() + witness0 = dbChangeWitness(dbs[0]) err = dbs[0].LeaveNetwork("network1") assert.NilError(t, err) - - dbIndex, expectNodeCount = 0, 0 - poll.WaitOn(t, dbChangeWitness, poll.WithTimeout(3*time.Second), poll.WithDelay(5*time.Millisecond)) + witness0("network1", 0) // Multiple nodes Join/Leave - staleNetworkTime[0], staleNetworkTime[1] = dbs[0].networkClock.Time(), dbs[1].networkClock.Time() + witness0, witness1 := dbChangeWitness(dbs[0]), dbChangeWitness(dbs[1]) err = dbs[0].JoinNetwork("network1") assert.NilError(t, err) @@ -529,34 +524,30 @@ func TestNetworkDBNodeJoinLeaveIteration(t *testing.T) { // Wait for the propagation on db[0] dbs[0].verifyNetworkExistence(t, dbs[1].config.NodeID, "network1", true) - dbIndex, expectNodeCount = 0, 2 - poll.WaitOn(t, dbChangeWitness, poll.WithTimeout(3*time.Second), poll.WithDelay(5*time.Millisecond)) + witness0("network1", 2) if n, ok := dbs[0].networks[dbs[0].config.NodeID]["network1"]; !ok || n.leaving { t.Fatalf("The network should not be marked as leaving:%t", n.leaving) } // Wait for the propagation on db[1] dbs[1].verifyNetworkExistence(t, dbs[0].config.NodeID, "network1", true) - dbIndex, expectNodeCount = 1, 2 - poll.WaitOn(t, dbChangeWitness, poll.WithTimeout(3*time.Second), poll.WithDelay(5*time.Millisecond)) + witness1("network1", 2) if n, ok := dbs[1].networks[dbs[1].config.NodeID]["network1"]; !ok || n.leaving { t.Fatalf("The network should not be marked as leaving:%t", n.leaving) } // Try a quick leave/join - staleNetworkTime[0], staleNetworkTime[1] = dbs[0].networkClock.Time(), dbs[1].networkClock.Time() + witness0, witness1 = dbChangeWitness(dbs[0]), dbChangeWitness(dbs[1]) err = dbs[0].LeaveNetwork("network1") assert.NilError(t, err) err = dbs[0].JoinNetwork("network1") assert.NilError(t, err) dbs[0].verifyNetworkExistence(t, dbs[1].config.NodeID, "network1", true) - dbIndex, expectNodeCount = 0, 2 - poll.WaitOn(t, dbChangeWitness, poll.WithTimeout(3*time.Second), poll.WithDelay(5*time.Millisecond)) + witness0("network1", 2) dbs[1].verifyNetworkExistence(t, dbs[0].config.NodeID, "network1", true) - dbIndex, expectNodeCount = 1, 2 - poll.WaitOn(t, dbChangeWitness, poll.WithTimeout(3*time.Second), poll.WithDelay(5*time.Millisecond)) + witness1("network1", 2) closeNetworkDBInstances(t, dbs) }