diff --git a/xds/internal/balancer/ringhash/e2e/ringhash_balancer_test.go b/xds/internal/balancer/ringhash/e2e/ringhash_balancer_test.go index c67f673f5b2b..4e1dd5513a41 100644 --- a/xds/internal/balancer/ringhash/e2e/ringhash_balancer_test.go +++ b/xds/internal/balancer/ringhash/e2e/ringhash_balancer_test.go @@ -162,8 +162,8 @@ func (s) TestRingHash_ReconnectToMoveOutOfTransientFailure(t *testing.T) { } } -// startTestServiceBackends starts num stub servers. It returns their addresses. -// Servers are closed when the test is stopped. +// startTestServiceBackends starts num stub servers. It returns the list of +// stubservers. Servers are closed when the test is stopped. func startTestServiceBackends(t *testing.T, num int) []*stubserver.StubServer { t.Helper() @@ -176,6 +176,7 @@ func startTestServiceBackends(t *testing.T, num int) []*stubserver.StubServer { return servers } +// backendAddrs returns a list of address strings for the given stubservers. func backendAddrs(servers []*stubserver.StubServer) []string { addrs := make([]string, 0, len(servers)) for _, s := range servers { @@ -2185,24 +2186,18 @@ func (s) TestRingHash_FallBackWithinEndpoint(t *testing.T) { t.Errorf("Got %v RPCs routed to a backend, want %v", got, numRPCs) } - var otherEndpointAddr string + // Due to the channel ID hashing policy, the request could go to the first + // address of either endpoint. var backendIdx int switch initialAddr { case backendAddrs[0]: - otherEndpointAddr = backendAddrs[1] backendIdx = 0 - case backendAddrs[1]: - otherEndpointAddr = backendAddrs[0] - backendIdx = 1 case backendAddrs[2]: - otherEndpointAddr = backendAddrs[3] backendIdx = 2 - case backendAddrs[3]: - otherEndpointAddr = backendAddrs[2] - backendIdx = 3 default: - t.Fatalf("Request sent to unknown backend: %q", initialAddr) + t.Fatalf("Request sent to unexpected backend: %q", initialAddr) } + otherEndpointAddr := backendAddrs[backendIdx+1] // Shut down the previously used backend. backends[backendIdx].Stop() diff --git a/xds/internal/balancer/ringhash/picker.go b/xds/internal/balancer/ringhash/picker.go index 983a6b8f568c..f9e5139da7b4 100644 --- a/xds/internal/balancer/ringhash/picker.go +++ b/xds/internal/balancer/ringhash/picker.go @@ -19,6 +19,8 @@ package ringhash import ( + "fmt" + "google.golang.org/grpc/balancer" "google.golang.org/grpc/connectivity" "google.golang.org/grpc/internal/grpclog" @@ -49,7 +51,7 @@ func (p *picker) Pick(info balancer.PickInfo) (balancer.PickResult, error) { return balState.Picker.Pick(info) case connectivity.TransientFailure: default: - p.logger.Errorf("Found child balancer in unknown state: %v", balState.ConnectivityState) + panic(fmt.Sprintf("Found child balancer in unknown state: %v", balState.ConnectivityState)) } } // All children are in transient failure. Return the first failure. diff --git a/xds/internal/balancer/ringhash/ringhash.go b/xds/internal/balancer/ringhash/ringhash.go index da559a45e58d..422760bd49da 100644 --- a/xds/internal/balancer/ringhash/ringhash.go +++ b/xds/internal/balancer/ringhash/ringhash.go @@ -113,36 +113,39 @@ func (b *ringhashBalancer) UpdateState(state balancer.State) { endpoint := childState.Endpoint endpointsSet.Set(endpoint, true) newWeight := getWeightAttribute(endpoint) - var es *endpointState if val, ok := b.endpointStates.Get(endpoint); !ok { - es = &endpointState{} + es := &endpointState{ + balancer: childState.Balancer, + weight: newWeight, + firstAddr: endpoint.Addresses[0].Addr, + state: childState.State, + } b.endpointStates.Set(endpoint, es) b.shouldRegenerateRing = true - es.balancer = childState.Balancer } else { // We have seen this endpoint before and created a `endpointState` - // object for it. If the weight associated with the endpoint has - // changed, update the endpoint state map with the new weight. + // object for it. If the weight or the first address of the endpoint + // has changed, update the endpoint state map with the new weight. // This will be used when a new ring is created. - es = val.(*endpointState) + es := val.(*endpointState) if oldWeight := es.weight; oldWeight != newWeight { b.shouldRegenerateRing = true + es.weight = newWeight } if es.firstAddr != endpoint.Addresses[0].Addr { // If the order of the addresses for a given endpoint change, // that will change the position of the endpoint in the ring. // -A61 b.shouldRegenerateRing = true + es.firstAddr = endpoint.Addresses[0].Addr } oldState := es.state.ConnectivityState newState := childState.State.ConnectivityState if oldState != newState && newState == connectivity.TransientFailure { endpointEnteredTF = true } + es.state = childState.State } - es.weight = newWeight - es.state = childState.State - es.firstAddr = endpoint.Addresses[0].Addr } for _, endpoint := range b.endpointStates.Keys() { diff --git a/xds/internal/balancer/ringhash/ringhash_test.go b/xds/internal/balancer/ringhash/ringhash_test.go index cb750d6c7602..73c3fd0dc564 100644 --- a/xds/internal/balancer/ringhash/ringhash_test.go +++ b/xds/internal/balancer/ringhash/ringhash_test.go @@ -599,7 +599,7 @@ func (s) TestAutoConnectEndpointOnTransientFailure(t *testing.T) { // will be CONNECTING. cc.WaitForConnectivityState(ctx, connectivity.Connecting) p1 := <-cc.NewPickerCh - p1.Pick(balancer.PickInfo{Ctx: context.Background()}) + p1.Pick(balancer.PickInfo{Ctx: ctx}) select { case <-sc0.ConnectCh: case <-time.After(defaultTestTimeout):