Skip to content

Commit 019ca96

Browse files
committed
fix: Fix channel clamping logic in ChannelPools to respect channel bounds
1 parent f62aa0c commit 019ca96

3 files changed

Lines changed: 104 additions & 78 deletions

File tree

sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPool.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,10 @@ void resize() {
310310
if (minChannels < settings.getMinChannelCount()) {
311311
minChannels = settings.getMinChannelCount();
312312
}
313+
// Limit in case the calculated min channel count exceeds the configured max channel count
314+
if (minChannels > settings.getMaxChannelCount()) {
315+
minChannels = settings.getMaxChannelCount();
316+
}
313317

314318
// Number of channels if each channel operated at minimum capacity
315319
// Note: getMinRpcsPerChannel() can return 0, but division by 0 shouldn't cause a problem.
@@ -319,8 +323,9 @@ void resize() {
319323
if (maxChannels > settings.getMaxChannelCount()) {
320324
maxChannels = settings.getMaxChannelCount();
321325
}
322-
if (maxChannels < minChannels) {
323-
maxChannels = minChannels;
326+
// Limit in case the calculated max channel count falls below the configured min channel count
327+
if (maxChannels < settings.getMinChannelCount()) {
328+
maxChannels = settings.getMinChannelCount();
324329
}
325330

326331
// If the pool were to be resized, try to aim for the middle of the bound, but limit rate of

sdk-platform-java/gax-java/gax-grpc/src/main/java/com/google/api/gax/grpc/ChannelPoolSettings.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ public ChannelPoolSettings build() {
182182
Preconditions.checkState(
183183
s.getMinChannelCount() > 0, "Minimum channel count must be at least 1");
184184
Preconditions.checkState(
185-
s.getMinChannelCount() <= s.getMaxRpcsPerChannel(), "absolute channel range is invalid");
185+
s.getMinChannelCount() <= s.getMaxChannelCount(), "absolute channel range is invalid");
186186
Preconditions.checkState(
187187
s.getMinChannelCount() <= s.getInitialChannelCount(),
188188
"initial channel count be at least minChannelCount");

sdk-platform-java/gax-java/gax-grpc/src/test/java/com/google/api/gax/grpc/ChannelPoolTest.java

Lines changed: 96 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,26 @@ private void verifyTargetChannel(
144144
}
145145
}
146146

147+
private static ChannelFactory createMockChannelFactory(
148+
List<ManagedChannel> channels, List<ClientCall<Object, Object>> startedCalls) {
149+
return () -> {
150+
ManagedChannel channel = Mockito.mock(ManagedChannel.class);
151+
Mockito.when(channel.newCall(Mockito.any(), Mockito.any()))
152+
.thenAnswer(
153+
invocation -> {
154+
@SuppressWarnings("unchecked")
155+
ClientCall<Object, Object> clientCall = Mockito.mock(ClientCall.class);
156+
if (startedCalls != null) {
157+
startedCalls.add(clientCall);
158+
}
159+
return clientCall;
160+
});
161+
162+
channels.add(channel);
163+
return channel;
164+
};
165+
}
166+
147167
@Test
148168
void ensureEvenDistribution() throws InterruptedException, IOException {
149169
int numChannels = 10;
@@ -451,21 +471,7 @@ void channelCountShouldNotChangeWhenOutstandingRpcsAreWithinLimits() throws Exce
451471
List<ManagedChannel> channels = new ArrayList<>();
452472
List<ClientCall<Object, Object>> startedCalls = new ArrayList<>();
453473

454-
ChannelFactory channelFactory =
455-
() -> {
456-
ManagedChannel channel = Mockito.mock(ManagedChannel.class);
457-
Mockito.when(channel.newCall(Mockito.any(), Mockito.any()))
458-
.thenAnswer(
459-
invocation -> {
460-
@SuppressWarnings("unchecked")
461-
ClientCall<Object, Object> clientCall = Mockito.mock(ClientCall.class);
462-
startedCalls.add(clientCall);
463-
return clientCall;
464-
});
465-
466-
channels.add(channel);
467-
return channel;
468-
};
474+
ChannelFactory channelFactory = createMockChannelFactory(channels, startedCalls);
469475

470476
pool =
471477
new ChannelPool(
@@ -531,21 +537,7 @@ void customResizeDeltaIsRespected() throws Exception {
531537
List<ManagedChannel> channels = new ArrayList<>();
532538
List<ClientCall<Object, Object>> startedCalls = new ArrayList<>();
533539

534-
ChannelFactory channelFactory =
535-
() -> {
536-
ManagedChannel channel = Mockito.mock(ManagedChannel.class);
537-
Mockito.when(channel.newCall(Mockito.any(), Mockito.any()))
538-
.thenAnswer(
539-
invocation -> {
540-
@SuppressWarnings("unchecked")
541-
ClientCall<Object, Object> clientCall = Mockito.mock(ClientCall.class);
542-
startedCalls.add(clientCall);
543-
return clientCall;
544-
});
545-
546-
channels.add(channel);
547-
return channel;
548-
};
540+
ChannelFactory channelFactory = createMockChannelFactory(channels, startedCalls);
549541

550542
pool =
551543
new ChannelPool(
@@ -578,21 +570,7 @@ void removedIdleChannelsAreShutdown() throws Exception {
578570
List<ManagedChannel> channels = new ArrayList<>();
579571
List<ClientCall<Object, Object>> startedCalls = new ArrayList<>();
580572

581-
ChannelFactory channelFactory =
582-
() -> {
583-
ManagedChannel channel = Mockito.mock(ManagedChannel.class);
584-
Mockito.when(channel.newCall(Mockito.any(), Mockito.any()))
585-
.thenAnswer(
586-
invocation -> {
587-
@SuppressWarnings("unchecked")
588-
ClientCall<Object, Object> clientCall = Mockito.mock(ClientCall.class);
589-
startedCalls.add(clientCall);
590-
return clientCall;
591-
});
592-
593-
channels.add(channel);
594-
return channel;
595-
};
573+
ChannelFactory channelFactory = createMockChannelFactory(channels, startedCalls);
596574

597575
pool =
598576
new ChannelPool(
@@ -619,21 +597,7 @@ void removedActiveChannelsAreShutdown() throws Exception {
619597
List<ManagedChannel> channels = new ArrayList<>();
620598
List<ClientCall<Object, Object>> startedCalls = new ArrayList<>();
621599

622-
ChannelFactory channelFactory =
623-
() -> {
624-
ManagedChannel channel = Mockito.mock(ManagedChannel.class);
625-
Mockito.when(channel.newCall(Mockito.any(), Mockito.any()))
626-
.thenAnswer(
627-
invocation -> {
628-
@SuppressWarnings("unchecked")
629-
ClientCall<Object, Object> clientCall = Mockito.mock(ClientCall.class);
630-
startedCalls.add(clientCall);
631-
return clientCall;
632-
});
633-
634-
channels.add(channel);
635-
return channel;
636-
};
600+
ChannelFactory channelFactory = createMockChannelFactory(channels, startedCalls);
637601

638602
pool =
639603
new ChannelPool(
@@ -734,21 +698,7 @@ void repeatedResizingLogsWarningOnExpand() throws Exception {
734698
List<ManagedChannel> channels = new ArrayList<>();
735699
List<ClientCall<Object, Object>> startedCalls = new ArrayList<>();
736700

737-
ChannelFactory channelFactory =
738-
() -> {
739-
ManagedChannel channel = Mockito.mock(ManagedChannel.class);
740-
Mockito.when(channel.newCall(Mockito.any(), Mockito.any()))
741-
.thenAnswer(
742-
invocation -> {
743-
@SuppressWarnings("unchecked")
744-
ClientCall<Object, Object> clientCall = Mockito.mock(ClientCall.class);
745-
startedCalls.add(clientCall);
746-
return clientCall;
747-
});
748-
749-
channels.add(channel);
750-
return channel;
751-
};
701+
ChannelFactory channelFactory = createMockChannelFactory(channels, startedCalls);
752702

753703
pool =
754704
new ChannelPool(
@@ -899,4 +849,75 @@ void testDoubleRelease() throws Exception {
899849
ChannelPool.LOG.removeHandler(logHandler);
900850
}
901851
}
852+
853+
@Test
854+
void settingsValidationFailsWhenMinChannelsExceedsMaxChannels() {
855+
Assertions.assertThrows(
856+
IllegalStateException.class,
857+
() -> ChannelPoolSettings.builder().setMinChannelCount(2).setMaxChannelCount(1).build());
858+
}
859+
860+
@Test
861+
void minChannelsClampedToMaxChannelCountUnderHighLoad() throws Exception {
862+
ScheduledExecutorService executor = Mockito.mock(ScheduledExecutorService.class);
863+
FixedExecutorProvider provider = FixedExecutorProvider.create(executor);
864+
865+
List<ManagedChannel> channels = new ArrayList<>();
866+
ChannelFactory channelFactory = createMockChannelFactory(channels, null);
867+
868+
pool =
869+
new ChannelPool(
870+
ChannelPoolSettings.builder()
871+
.setInitialChannelCount(1)
872+
.setMinRpcsPerChannel(1)
873+
.setMaxRpcsPerChannel(2)
874+
.setMaxResizeDelta(10) // Allow large growth
875+
.setMinChannelCount(1)
876+
.setMaxChannelCount(5)
877+
.build(),
878+
channelFactory,
879+
provider);
880+
assertThat(pool.entries.get()).hasSize(1);
881+
882+
// Add 20 RPCs, which would require 10 channels (20/2)
883+
// But max is 5
884+
for (int i = 0; i < 20; i++) {
885+
ClientCalls.futureUnaryCall(
886+
pool.newCall(METHOD_RECOGNIZE, CallOptions.DEFAULT), Color.getDefaultInstance());
887+
}
888+
889+
pool.resize();
890+
891+
// Should be clamped to maxChannelCount = 5
892+
assertThat(pool.entries.get()).hasSize(5);
893+
}
894+
895+
@Test
896+
void maxChannelsClampedToMinChannelCountUnderLowLoad() throws Exception {
897+
ScheduledExecutorService executor = Mockito.mock(ScheduledExecutorService.class);
898+
FixedExecutorProvider provider = FixedExecutorProvider.create(executor);
899+
900+
List<ManagedChannel> channels = new ArrayList<>();
901+
ChannelFactory channelFactory = createMockChannelFactory(channels, null);
902+
903+
pool =
904+
new ChannelPool(
905+
ChannelPoolSettings.builder()
906+
.setInitialChannelCount(5)
907+
.setMinRpcsPerChannel(1)
908+
.setMaxRpcsPerChannel(2)
909+
.setMinChannelCount(3)
910+
.setMaxChannelCount(10)
911+
.build(),
912+
channelFactory,
913+
provider);
914+
assertThat(pool.entries.get()).hasSize(5);
915+
916+
// With no outstanding RPCs, the pool should want to shrink to 0
917+
// But min is 3
918+
pool.resize();
919+
920+
// Should be clamped to minChannelCount = 3
921+
assertThat(pool.entries.get()).hasSize(3);
922+
}
902923
}

0 commit comments

Comments
 (0)