@VisibleForTesting static void prepareHeaders( Metadata headers, CallOptions callOptions, String userAgent, DecompressorRegistry decompressorRegistry, Compressor compressor) { // Fill out the User-Agent header. headers.removeAll(USER_AGENT_KEY); if (userAgent != null) { headers.put(USER_AGENT_KEY, userAgent); } headers.removeAll(MESSAGE_ENCODING_KEY); if (compressor != Codec.Identity.NONE) { headers.put(MESSAGE_ENCODING_KEY, compressor.getMessageEncoding()); } headers.removeAll(MESSAGE_ACCEPT_ENCODING_KEY); if (!decompressorRegistry.getAdvertisedMessageEncodings().isEmpty()) { String acceptEncoding = ACCEPT_ENCODING_JOINER.join(decompressorRegistry.getAdvertisedMessageEncodings()); headers.put(MESSAGE_ACCEPT_ENCODING_KEY, acceptEncoding); } }
@Test public void prepareHeaders_acceptedEncodingsAdded() { Metadata m = new Metadata(); DecompressorRegistry customRegistry = DecompressorRegistry.newEmptyInstance(); customRegistry.register( new Decompressor() { @Override public String getMessageEncoding() { return "a"; } @Override public InputStream decompress(InputStream is) throws IOException { return null; } }, true); customRegistry.register( new Decompressor() { @Override public String getMessageEncoding() { return "b"; } @Override public InputStream decompress(InputStream is) throws IOException { return null; } }, true); customRegistry.register( new Decompressor() { @Override public String getMessageEncoding() { return "c"; } @Override public InputStream decompress(InputStream is) throws IOException { return null; } }, false); // not advertised ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, "user agent", customRegistry); Iterable<String> acceptedEncodings = Splitter.on(',').split(m.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY)); // Order may be different, since decoder priorities have not yet been implemented. assertEquals(ImmutableSet.of("b", "a"), ImmutableSet.copyOf(acceptedEncodings)); }
@Test public void prepareHeaders_userAgentAdded() { Metadata m = new Metadata(); ClientCallImpl.prepareHeaders( m, CallOptions.DEFAULT, "user agent", DecompressorRegistry.getDefaultInstance()); assertEquals(m.get(GrpcUtil.USER_AGENT_KEY), "user agent"); }
@Test public void prepareHeaders_ignoreIdentityEncoding() { Metadata m = new Metadata(); CallOptions callOptions = CallOptions.DEFAULT.withCompressor(Codec.Identity.NONE); ClientCallImpl.prepareHeaders( m, callOptions, "user agent", DecompressorRegistry.getDefaultInstance()); assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY)); }
@Test public void prepareHeaders_messageEncodingAdded() { Metadata m = new Metadata(); CallOptions callOptions = CallOptions.DEFAULT.withCompressor(new Codec.Gzip()); ClientCallImpl.prepareHeaders( m, callOptions, "user agent", DecompressorRegistry.getDefaultInstance()); assertEquals(m.get(GrpcUtil.MESSAGE_ENCODING_KEY), new Codec.Gzip().getMessageEncoding()); }
@Test public void prepareHeaders_authorityAdded() { Metadata m = new Metadata(); CallOptions callOptions = CallOptions.DEFAULT.withAuthority("auth"); ClientCallImpl.prepareHeaders( m, callOptions, "user agent", DecompressorRegistry.getDefaultInstance()); assertEquals(m.get(GrpcUtil.AUTHORITY_KEY), "auth"); }
@Test public void prepareHeaders_removeReservedHeaders() { Metadata m = new Metadata(); m.put(GrpcUtil.MESSAGE_ENCODING_KEY, "gzip"); m.put(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY, "gzip"); ClientCallImpl.prepareHeaders(m, DecompressorRegistry.emptyInstance(), Codec.Identity.NONE); assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY)); assertNull(m.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY)); }
@Test public void prepareHeaders_removeReservedHeaders() { Metadata m = new Metadata(); m.put(GrpcUtil.AUTHORITY_KEY, "auth"); m.put(GrpcUtil.USER_AGENT_KEY, "user agent"); m.put(GrpcUtil.MESSAGE_ENCODING_KEY, "gzip"); m.put(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY, "gzip"); ClientCallImpl.prepareHeaders( m, CallOptions.DEFAULT, null, DecompressorRegistry.newEmptyInstance()); assertNull(m.get(GrpcUtil.AUTHORITY_KEY)); assertNull(m.get(GrpcUtil.USER_AGENT_KEY)); assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY)); assertNull(m.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY)); }
@Test public void advertisedEncodingsAreSent() { ClientCallImpl<Void, Void> call = new ClientCallImpl<Void, Void>( method, MoreExecutors.directExecutor(), CallOptions.DEFAULT, provider, deadlineCancellationExecutor) .setDecompressorRegistry(decompressorRegistry); call.start(callListener, new Metadata()); ArgumentCaptor<Metadata> metadataCaptor = ArgumentCaptor.forClass(Metadata.class); verify(transport).newStream(eq(method), metadataCaptor.capture(), same(CallOptions.DEFAULT)); Metadata actual = metadataCaptor.getValue(); Set<String> acceptedEncodings = ImmutableSet.copyOf(actual.getAll(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY)); assertEquals(decompressorRegistry.getAdvertisedMessageEncodings(), acceptedEncodings); }
@Test public void advertisedEncodingsAreSent() { MethodDescriptor<Void, Void> descriptor = MethodDescriptor.create( MethodType.UNARY, "service/method", new TestMarshaller<Void>(), new TestMarshaller<Void>()); final ClientTransport transport = mock(ClientTransport.class); final ClientStream stream = mock(ClientStream.class); ClientTransportProvider provider = new ClientTransportProvider() { @Override public ListenableFuture<ClientTransport> get(CallOptions callOptions) { return Futures.immediateFuture(transport); } }; when(transport.newStream( any(MethodDescriptor.class), any(Metadata.class), any(ClientStreamListener.class))) .thenReturn(stream); ClientCallImpl<Void, Void> call = new ClientCallImpl<Void, Void>( descriptor, executor, CallOptions.DEFAULT, provider, deadlineCancellationExecutor) .setDecompressorRegistry(decompressorRegistry); call.start(new TestClientCallListener<Void>(), new Metadata()); ArgumentCaptor<Metadata> metadataCaptor = ArgumentCaptor.forClass(Metadata.class); verify(transport) .newStream(eq(descriptor), metadataCaptor.capture(), isA(ClientStreamListener.class)); Metadata actual = metadataCaptor.getValue(); Set<String> acceptedEncodings = ImmutableSet.copyOf(actual.getAll(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY)); assertEquals(decompressorRegistry.getAdvertisedMessageEncodings(), acceptedEncodings); }
/** * Constructs a new {@link GrpcService} that can be bound to {@link * com.linecorp.armeria.server.ServerBuilder}. As GRPC services themselves are mounted at a path * that corresponds to their protobuf package, you will almost always want to bind to a prefix, * e.g. by using {@link com.linecorp.armeria.server.ServerBuilder#serviceUnder(String, Service)}. */ public GrpcService build() { return new GrpcService( registryBuilder.build(), firstNonNull(decompressorRegistry, DecompressorRegistry.getDefaultInstance()), firstNonNull(compressorRegistry, CompressorRegistry.getDefaultInstance())); }
@Before public void setUp() { decompressorRegistry.register(new Codec.Gzip(), true); }
/** Test for {@link ClientCallImpl}. */ @RunWith(JUnit4.class) public class ClientCallImplTest { private final SerializingExecutor executor = new SerializingExecutor(MoreExecutors.directExecutor()); private final ScheduledExecutorService deadlineCancellationExecutor = Executors.newScheduledThreadPool(0); private final DecompressorRegistry decompressorRegistry = DecompressorRegistry.getDefaultInstance(); @Before public void setUp() { decompressorRegistry.register(new Codec.Gzip(), true); } @Test public void advertisedEncodingsAreSent() { MethodDescriptor<Void, Void> descriptor = MethodDescriptor.create( MethodType.UNARY, "service/method", new TestMarshaller<Void>(), new TestMarshaller<Void>()); final ClientTransport transport = mock(ClientTransport.class); final ClientStream stream = mock(ClientStream.class); ClientTransportProvider provider = new ClientTransportProvider() { @Override public ListenableFuture<ClientTransport> get(CallOptions callOptions) { return Futures.immediateFuture(transport); } }; when(transport.newStream( any(MethodDescriptor.class), any(Metadata.class), any(ClientStreamListener.class))) .thenReturn(stream); ClientCallImpl<Void, Void> call = new ClientCallImpl<Void, Void>( descriptor, executor, CallOptions.DEFAULT, provider, deadlineCancellationExecutor) .setDecompressorRegistry(decompressorRegistry); call.start(new TestClientCallListener<Void>(), new Metadata()); ArgumentCaptor<Metadata> metadataCaptor = ArgumentCaptor.forClass(Metadata.class); verify(transport) .newStream(eq(descriptor), metadataCaptor.capture(), isA(ClientStreamListener.class)); Metadata actual = metadataCaptor.getValue(); Set<String> acceptedEncodings = ImmutableSet.copyOf(actual.getAll(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY)); assertEquals(decompressorRegistry.getAdvertisedMessageEncodings(), acceptedEncodings); } @Test public void prepareHeaders_authorityAdded() { Metadata m = new Metadata(); CallOptions callOptions = CallOptions.DEFAULT.withAuthority("auth"); ClientCallImpl.prepareHeaders( m, callOptions, "user agent", DecompressorRegistry.getDefaultInstance()); assertEquals(m.get(GrpcUtil.AUTHORITY_KEY), "auth"); } @Test public void prepareHeaders_userAgentAdded() { Metadata m = new Metadata(); ClientCallImpl.prepareHeaders( m, CallOptions.DEFAULT, "user agent", DecompressorRegistry.getDefaultInstance()); assertEquals(m.get(GrpcUtil.USER_AGENT_KEY), "user agent"); } @Test public void prepareHeaders_messageEncodingAdded() { Metadata m = new Metadata(); CallOptions callOptions = CallOptions.DEFAULT.withCompressor(new Codec.Gzip()); ClientCallImpl.prepareHeaders( m, callOptions, "user agent", DecompressorRegistry.getDefaultInstance()); assertEquals(m.get(GrpcUtil.MESSAGE_ENCODING_KEY), new Codec.Gzip().getMessageEncoding()); } @Test public void prepareHeaders_ignoreIdentityEncoding() { Metadata m = new Metadata(); CallOptions callOptions = CallOptions.DEFAULT.withCompressor(Codec.Identity.NONE); ClientCallImpl.prepareHeaders( m, callOptions, "user agent", DecompressorRegistry.getDefaultInstance()); assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY)); } @Test public void prepareHeaders_acceptedEncodingsAdded() { Metadata m = new Metadata(); DecompressorRegistry customRegistry = DecompressorRegistry.newEmptyInstance(); customRegistry.register( new Decompressor() { @Override public String getMessageEncoding() { return "a"; } @Override public InputStream decompress(InputStream is) throws IOException { return null; } }, true); customRegistry.register( new Decompressor() { @Override public String getMessageEncoding() { return "b"; } @Override public InputStream decompress(InputStream is) throws IOException { return null; } }, true); customRegistry.register( new Decompressor() { @Override public String getMessageEncoding() { return "c"; } @Override public InputStream decompress(InputStream is) throws IOException { return null; } }, false); // not advertised ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, "user agent", customRegistry); Iterable<String> acceptedEncodings = Splitter.on(',').split(m.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY)); // Order may be different, since decoder priorities have not yet been implemented. assertEquals(ImmutableSet.of("b", "a"), ImmutableSet.copyOf(acceptedEncodings)); } @Test public void prepareHeaders_removeReservedHeaders() { Metadata m = new Metadata(); m.put(GrpcUtil.AUTHORITY_KEY, "auth"); m.put(GrpcUtil.USER_AGENT_KEY, "user agent"); m.put(GrpcUtil.MESSAGE_ENCODING_KEY, "gzip"); m.put(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY, "gzip"); ClientCallImpl.prepareHeaders( m, CallOptions.DEFAULT, null, DecompressorRegistry.newEmptyInstance()); assertNull(m.get(GrpcUtil.AUTHORITY_KEY)); assertNull(m.get(GrpcUtil.USER_AGENT_KEY)); assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY)); assertNull(m.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY)); } private static class TestMarshaller<T> implements Marshaller<T> { @Override public InputStream stream(T value) { return null; } @Override public T parse(InputStream stream) { return null; } } private static class TestClientCallListener<T> extends ClientCall.Listener<T> {} }
/** Test for {@link ClientCallImpl}. */ @RunWith(JUnit4.class) public class ClientCallImplTest { private static final MethodDescriptor<Void, Void> DESCRIPTOR = MethodDescriptor.create( MethodType.UNARY, "service/method", new TestMarshaller<Void>(), new TestMarshaller<Void>()); private final FakeClock fakeClock = new FakeClock(); private final ScheduledExecutorService deadlineCancellationExecutor = fakeClock.scheduledExecutorService; private final DecompressorRegistry decompressorRegistry = DecompressorRegistry.getDefaultInstance().with(new Codec.Gzip(), true); private final MethodDescriptor<Void, Void> method = MethodDescriptor.create( MethodType.UNARY, "service/method", new TestMarshaller<Void>(), new TestMarshaller<Void>()); @Mock private ClientStreamListener streamListener; @Mock private ClientTransport clientTransport; @Captor private ArgumentCaptor<Status> statusCaptor; @Mock private ClientTransport transport; @Mock private ClientTransportProvider provider; @Mock private ClientStream stream; @Mock private ClientCall.Listener<Void> callListener; @Captor private ArgumentCaptor<ClientStreamListener> listenerArgumentCaptor; @Captor private ArgumentCaptor<Status> statusArgumentCaptor; @Before public void setUp() { MockitoAnnotations.initMocks(this); when(provider.get(any(CallOptions.class))).thenReturn(transport); when(transport.newStream( any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class))) .thenReturn(stream); } @After public void tearDown() { Context.ROOT.attach(); } @Test public void advertisedEncodingsAreSent() { ClientCallImpl<Void, Void> call = new ClientCallImpl<Void, Void>( method, MoreExecutors.directExecutor(), CallOptions.DEFAULT, provider, deadlineCancellationExecutor) .setDecompressorRegistry(decompressorRegistry); call.start(callListener, new Metadata()); ArgumentCaptor<Metadata> metadataCaptor = ArgumentCaptor.forClass(Metadata.class); verify(transport).newStream(eq(method), metadataCaptor.capture(), same(CallOptions.DEFAULT)); Metadata actual = metadataCaptor.getValue(); Set<String> acceptedEncodings = ImmutableSet.copyOf(actual.getAll(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY)); assertEquals(decompressorRegistry.getAdvertisedMessageEncodings(), acceptedEncodings); } @Test public void authorityPropagatedToStream() { ClientCallImpl<Void, Void> call = new ClientCallImpl<Void, Void>( method, MoreExecutors.directExecutor(), CallOptions.DEFAULT.withAuthority("overridden-authority"), provider, deadlineCancellationExecutor) .setDecompressorRegistry(decompressorRegistry); call.start(callListener, new Metadata()); verify(stream).setAuthority("overridden-authority"); } @Test public void callOptionsPropagatedToTransport() { final CallOptions callOptions = CallOptions.DEFAULT.withAuthority("dummy_value"); final ClientCallImpl<Void, Void> call = new ClientCallImpl<Void, Void>( method, MoreExecutors.directExecutor(), callOptions, provider, deadlineCancellationExecutor) .setDecompressorRegistry(decompressorRegistry); final Metadata metadata = new Metadata(); call.start(callListener, metadata); verify(transport).newStream(same(method), same(metadata), same(callOptions)); } @Test public void authorityNotPropagatedToStream() { ClientCallImpl<Void, Void> call = new ClientCallImpl<Void, Void>( method, MoreExecutors.directExecutor(), // Don't provide an authority CallOptions.DEFAULT, provider, deadlineCancellationExecutor) .setDecompressorRegistry(decompressorRegistry); call.start(callListener, new Metadata()); verify(stream, never()).setAuthority(any(String.class)); } @Test public void prepareHeaders_userAgentIgnored() { Metadata m = new Metadata(); m.put(GrpcUtil.USER_AGENT_KEY, "batmobile"); ClientCallImpl.prepareHeaders(m, decompressorRegistry, Codec.Identity.NONE); // User Agent is removed and set by the transport assertThat(m.get(GrpcUtil.USER_AGENT_KEY)).isNotNull(); } @Test public void prepareHeaders_ignoreIdentityEncoding() { Metadata m = new Metadata(); ClientCallImpl.prepareHeaders(m, decompressorRegistry, Codec.Identity.NONE); assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY)); } @Test public void prepareHeaders_acceptedEncodingsAdded() { Metadata m = new Metadata(); DecompressorRegistry customRegistry = DecompressorRegistry.emptyInstance() .with( new Decompressor() { @Override public String getMessageEncoding() { return "a"; } @Override public InputStream decompress(InputStream is) throws IOException { return null; } }, true) .with( new Decompressor() { @Override public String getMessageEncoding() { return "b"; } @Override public InputStream decompress(InputStream is) throws IOException { return null; } }, true) .with( new Decompressor() { @Override public String getMessageEncoding() { return "c"; } @Override public InputStream decompress(InputStream is) throws IOException { return null; } }, false); // not advertised ClientCallImpl.prepareHeaders(m, customRegistry, Codec.Identity.NONE); Iterable<String> acceptedEncodings = ACCEPT_ENCODING_SPLITER.split(m.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY)); // Order may be different, since decoder priorities have not yet been implemented. assertEquals(ImmutableSet.of("b", "a"), ImmutableSet.copyOf(acceptedEncodings)); } @Test public void prepareHeaders_removeReservedHeaders() { Metadata m = new Metadata(); m.put(GrpcUtil.MESSAGE_ENCODING_KEY, "gzip"); m.put(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY, "gzip"); ClientCallImpl.prepareHeaders(m, DecompressorRegistry.emptyInstance(), Codec.Identity.NONE); assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY)); assertNull(m.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY)); } @Test public void callerContextPropagatedToListener() throws Exception { // Attach the context which is recorded when the call is created final Context.Key<String> testKey = Context.key("testing"); Context.current().withValue(testKey, "testValue").attach(); ClientCallImpl<Void, Void> call = new ClientCallImpl<Void, Void>( DESCRIPTOR, new SerializingExecutor(Executors.newSingleThreadExecutor()), CallOptions.DEFAULT, provider, deadlineCancellationExecutor) .setDecompressorRegistry(decompressorRegistry); Context.ROOT.attach(); // Override the value after creating the call, this should not be seen by callbacks Context.current().withValue(testKey, "badValue").attach(); final AtomicBoolean onHeadersCalled = new AtomicBoolean(); final AtomicBoolean onMessageCalled = new AtomicBoolean(); final AtomicBoolean onReadyCalled = new AtomicBoolean(); final AtomicBoolean observedIncorrectContext = new AtomicBoolean(); final CountDownLatch latch = new CountDownLatch(1); call.start( new ClientCall.Listener<Void>() { @Override public void onHeaders(Metadata headers) { onHeadersCalled.set(true); checkContext(); } @Override public void onMessage(Void message) { onMessageCalled.set(true); checkContext(); } @Override public void onClose(Status status, Metadata trailers) { checkContext(); latch.countDown(); } @Override public void onReady() { onReadyCalled.set(true); checkContext(); } private void checkContext() { if (!"testValue".equals(testKey.get())) { observedIncorrectContext.set(true); } } }, new Metadata()); verify(stream).start(listenerArgumentCaptor.capture()); ClientStreamListener listener = listenerArgumentCaptor.getValue(); listener.onReady(); listener.headersRead(new Metadata()); listener.messageRead(new ByteArrayInputStream(new byte[0])); listener.messageRead(new ByteArrayInputStream(new byte[0])); listener.closed(Status.OK, new Metadata()); assertTrue(latch.await(5, TimeUnit.SECONDS)); assertTrue(onHeadersCalled.get()); assertTrue(onMessageCalled.get()); assertTrue(onReadyCalled.get()); assertFalse(observedIncorrectContext.get()); } @Test public void contextCancellationCancelsStream() throws Exception { // Attach the context which is recorded when the call is created Context.CancellableContext cancellableContext = Context.current().withCancellation(); Context previous = cancellableContext.attach(); ClientCallImpl<Void, Void> call = new ClientCallImpl<Void, Void>( DESCRIPTOR, new SerializingExecutor(Executors.newSingleThreadExecutor()), CallOptions.DEFAULT, provider, deadlineCancellationExecutor) .setDecompressorRegistry(decompressorRegistry); previous.attach(); call.start(callListener, new Metadata()); Throwable t = new Throwable(); cancellableContext.cancel(t); verify(stream, times(1)).cancel(statusArgumentCaptor.capture()); verify(stream, times(1)).cancel(statusCaptor.capture()); assertEquals(Status.Code.CANCELLED, statusCaptor.getValue().getCode()); } @Test public void contextAlreadyCancelledNotifiesImmediately() throws Exception { // Attach the context which is recorded when the call is created Context.CancellableContext cancellableContext = Context.current().withCancellation(); Throwable cause = new Throwable(); cancellableContext.cancel(cause); Context previous = cancellableContext.attach(); ClientCallImpl<Void, Void> call = new ClientCallImpl<Void, Void>( DESCRIPTOR, new SerializingExecutor(Executors.newSingleThreadExecutor()), CallOptions.DEFAULT, provider, deadlineCancellationExecutor) .setDecompressorRegistry(decompressorRegistry); previous.attach(); final SettableFuture<Status> statusFuture = SettableFuture.create(); call.start( new ClientCall.Listener<Void>() { @Override public void onClose(Status status, Metadata trailers) { statusFuture.set(status); } }, new Metadata()); // Caller should receive onClose callback. Status status = statusFuture.get(5, TimeUnit.SECONDS); assertEquals(Status.Code.CANCELLED, status.getCode()); assertSame(cause, status.getCause()); // Following operations should be no-op. call.request(1); call.sendMessage(null); call.halfClose(); // Stream should never be created. verifyZeroInteractions(transport); try { call.sendMessage(null); fail("Call has been cancelled"); } catch (IllegalStateException ise) { // expected } } @Test public void deadlineExceededBeforeCallStarted() { CallOptions callOptions = CallOptions.DEFAULT.withDeadlineAfter(0, TimeUnit.SECONDS); fakeClock.forwardTime(System.nanoTime(), TimeUnit.NANOSECONDS); ClientCallImpl<Void, Void> call = new ClientCallImpl<Void, Void>( DESCRIPTOR, new SerializingExecutor(Executors.newSingleThreadExecutor()), callOptions, provider, deadlineCancellationExecutor) .setDecompressorRegistry(decompressorRegistry); call.start(callListener, new Metadata()); verify(transport, times(0)).newStream(any(MethodDescriptor.class), any(Metadata.class)); verify(callListener, timeout(1000)).onClose(statusCaptor.capture(), any(Metadata.class)); assertEquals(Status.Code.DEADLINE_EXCEEDED, statusCaptor.getValue().getCode()); verifyZeroInteractions(provider); } @Test public void contextDeadlineShouldBePropagatedInMetadata() { long deadlineNanos = TimeUnit.SECONDS.toNanos(1); Context context = Context.current() .withDeadlineAfter(deadlineNanos, TimeUnit.NANOSECONDS, deadlineCancellationExecutor); context.attach(); ClientCallImpl<Void, Void> call = new ClientCallImpl<Void, Void>( DESCRIPTOR, MoreExecutors.directExecutor(), CallOptions.DEFAULT, provider, deadlineCancellationExecutor); Metadata headers = new Metadata(); call.start(callListener, headers); assertTrue(headers.containsKey(GrpcUtil.TIMEOUT_KEY)); Long timeout = headers.get(GrpcUtil.TIMEOUT_KEY); assertNotNull(timeout); long deltaNanos = TimeUnit.MILLISECONDS.toNanos(400); assertTimeoutBetween(timeout, deadlineNanos - deltaNanos, deadlineNanos); } @Test public void contextDeadlineShouldOverrideLargerMetadataTimeout() { long deadlineNanos = TimeUnit.SECONDS.toNanos(1); Context context = Context.current() .withDeadlineAfter(deadlineNanos, TimeUnit.NANOSECONDS, deadlineCancellationExecutor); context.attach(); CallOptions callOpts = CallOptions.DEFAULT.withDeadlineAfter(2, TimeUnit.SECONDS); ClientCallImpl<Void, Void> call = new ClientCallImpl<Void, Void>( DESCRIPTOR, MoreExecutors.directExecutor(), callOpts, provider, deadlineCancellationExecutor); Metadata headers = new Metadata(); call.start(callListener, headers); assertTrue(headers.containsKey(GrpcUtil.TIMEOUT_KEY)); Long timeout = headers.get(GrpcUtil.TIMEOUT_KEY); assertNotNull(timeout); long deltaNanos = TimeUnit.MILLISECONDS.toNanos(400); assertTimeoutBetween(timeout, deadlineNanos - deltaNanos, deadlineNanos); } @Test public void contextDeadlineShouldNotOverrideSmallerMetadataTimeout() { long deadlineNanos = TimeUnit.SECONDS.toNanos(2); Context context = Context.current() .withDeadlineAfter(deadlineNanos, TimeUnit.NANOSECONDS, deadlineCancellationExecutor); context.attach(); CallOptions callOpts = CallOptions.DEFAULT.withDeadlineAfter(1, TimeUnit.SECONDS); ClientCallImpl<Void, Void> call = new ClientCallImpl<Void, Void>( DESCRIPTOR, MoreExecutors.directExecutor(), callOpts, provider, deadlineCancellationExecutor); Metadata headers = new Metadata(); call.start(callListener, headers); assertTrue(headers.containsKey(GrpcUtil.TIMEOUT_KEY)); Long timeout = headers.get(GrpcUtil.TIMEOUT_KEY); assertNotNull(timeout); long callOptsNanos = TimeUnit.SECONDS.toNanos(1); long deltaNanos = TimeUnit.MILLISECONDS.toNanos(400); assertTimeoutBetween(timeout, callOptsNanos - deltaNanos, callOptsNanos); } @Test public void expiredDeadlineCancelsStream_CallOptions() { fakeClock.forwardTime(System.nanoTime(), TimeUnit.NANOSECONDS); ClientCallImpl<Void, Void> call = new ClientCallImpl<Void, Void>( DESCRIPTOR, MoreExecutors.directExecutor(), CallOptions.DEFAULT.withDeadline(Deadline.after(1000, TimeUnit.MILLISECONDS)), provider, deadlineCancellationExecutor); call.start(callListener, new Metadata()); fakeClock.forwardMillis(1001); verify(stream, times(1)).cancel(statusCaptor.capture()); assertEquals(Status.Code.DEADLINE_EXCEEDED, statusCaptor.getValue().getCode()); } @Test public void expiredDeadlineCancelsStream_Context() { fakeClock.forwardTime(System.nanoTime(), TimeUnit.NANOSECONDS); Context.current() .withDeadlineAfter(1000, TimeUnit.MILLISECONDS, deadlineCancellationExecutor) .attach(); ClientCallImpl<Void, Void> call = new ClientCallImpl<Void, Void>( DESCRIPTOR, MoreExecutors.directExecutor(), CallOptions.DEFAULT, provider, deadlineCancellationExecutor); call.start(callListener, new Metadata()); fakeClock.forwardMillis(TimeUnit.SECONDS.toMillis(1001)); verify(stream, times(1)).cancel(statusCaptor.capture()); assertEquals(Status.Code.DEADLINE_EXCEEDED, statusCaptor.getValue().getCode()); } @Test public void streamCancelAbortsDeadlineTimer() { fakeClock.forwardTime(System.nanoTime(), TimeUnit.NANOSECONDS); ClientCallImpl<Void, Void> call = new ClientCallImpl<Void, Void>( DESCRIPTOR, MoreExecutors.directExecutor(), CallOptions.DEFAULT.withDeadline(Deadline.after(1000, TimeUnit.MILLISECONDS)), provider, deadlineCancellationExecutor); call.start(callListener, new Metadata()); call.cancel("canceled", null); // Run the deadline timer, which should have been cancelled by the previous call to cancel() fakeClock.forwardMillis(1001); verify(stream, times(1)).cancel(statusCaptor.capture()); assertEquals(Status.CANCELLED.getCode(), statusCaptor.getValue().getCode()); } /** Without a context or call options deadline, a timeout should not be set in metadata. */ @Test public void timeoutShouldNotBeSet() { ClientCallImpl<Void, Void> call = new ClientCallImpl<Void, Void>( DESCRIPTOR, MoreExecutors.directExecutor(), CallOptions.DEFAULT, provider, deadlineCancellationExecutor); Metadata headers = new Metadata(); call.start(callListener, headers); assertFalse(headers.containsKey(GrpcUtil.TIMEOUT_KEY)); } @Test public void cancelInOnMessageShouldInvokeStreamCancel() throws Exception { final ClientCallImpl<Void, Void> call = new ClientCallImpl<Void, Void>( DESCRIPTOR, MoreExecutors.directExecutor(), CallOptions.DEFAULT, provider, deadlineCancellationExecutor); final Exception cause = new Exception(); ClientCall.Listener<Void> callListener = new ClientCall.Listener<Void>() { @Override public void onMessage(Void message) { call.cancel("foo", cause); } }; call.start(callListener, new Metadata()); call.halfClose(); call.request(1); verify(stream).start(listenerArgumentCaptor.capture()); ClientStreamListener streamListener = listenerArgumentCaptor.getValue(); streamListener.onReady(); streamListener.headersRead(new Metadata()); streamListener.messageRead(new ByteArrayInputStream(new byte[0])); verify(stream).cancel(statusCaptor.capture()); Status status = statusCaptor.getValue(); assertEquals(Status.CANCELLED.getCode(), status.getCode()); assertEquals("foo", status.getDescription()); assertSame(cause, status.getCause()); } private static class TestMarshaller<T> implements Marshaller<T> { @Override public InputStream stream(T value) { return null; } @Override public T parse(InputStream stream) { return null; } } private static void assertTimeoutBetween(long timeout, long from, long to) { assertTrue("timeout: " + timeout + " ns", timeout <= to); assertTrue("timeout: " + timeout + " ns", timeout >= from); } }
/** Implementation of {@link ClientCall}. */ final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT> implements Context.CancellationListener { private static final Logger log = Logger.getLogger(ClientCallImpl.class.getName()); private final MethodDescriptor<ReqT, RespT> method; private final Executor callExecutor; private final Context parentContext; private volatile Context context; private final boolean unaryRequest; private final CallOptions callOptions; private ClientStream stream; private volatile boolean contextListenerShouldBeRemoved; private boolean cancelCalled; private boolean halfCloseCalled; private final ClientTransportProvider clientTransportProvider; private String userAgent; private ScheduledExecutorService deadlineCancellationExecutor; private DecompressorRegistry decompressorRegistry = DecompressorRegistry.getDefaultInstance(); private CompressorRegistry compressorRegistry = CompressorRegistry.getDefaultInstance(); ClientCallImpl( MethodDescriptor<ReqT, RespT> method, Executor executor, CallOptions callOptions, ClientTransportProvider clientTransportProvider, ScheduledExecutorService deadlineCancellationExecutor) { this.method = method; // If we know that the executor is a direct executor, we don't need to wrap it with a // SerializingExecutor. This is purely for performance reasons. // See https://github.com/grpc/grpc-java/issues/368 this.callExecutor = executor == directExecutor() ? new SerializeReentrantCallsDirectExecutor() : new SerializingExecutor(executor); // Propagate the context from the thread which initiated the call to all callbacks. this.parentContext = Context.current(); this.unaryRequest = method.getType() == MethodType.UNARY || method.getType() == MethodType.SERVER_STREAMING; this.callOptions = callOptions; this.clientTransportProvider = clientTransportProvider; this.deadlineCancellationExecutor = deadlineCancellationExecutor; } @Override public void cancelled(Context context) { stream.cancel(statusFromCancelled(context)); } /** Provider of {@link ClientTransport}s. */ interface ClientTransportProvider { /** Returns a transport for a new call. */ ClientTransport get(CallOptions callOptions); } ClientCallImpl<ReqT, RespT> setUserAgent(String userAgent) { this.userAgent = userAgent; return this; } ClientCallImpl<ReqT, RespT> setDecompressorRegistry(DecompressorRegistry decompressorRegistry) { this.decompressorRegistry = decompressorRegistry; return this; } ClientCallImpl<ReqT, RespT> setCompressorRegistry(CompressorRegistry compressorRegistry) { this.compressorRegistry = compressorRegistry; return this; } @VisibleForTesting static void prepareHeaders( Metadata headers, CallOptions callOptions, String userAgent, DecompressorRegistry decompressorRegistry, Compressor compressor) { // Fill out the User-Agent header. headers.removeAll(USER_AGENT_KEY); if (userAgent != null) { headers.put(USER_AGENT_KEY, userAgent); } headers.removeAll(MESSAGE_ENCODING_KEY); if (compressor != Codec.Identity.NONE) { headers.put(MESSAGE_ENCODING_KEY, compressor.getMessageEncoding()); } headers.removeAll(MESSAGE_ACCEPT_ENCODING_KEY); if (!decompressorRegistry.getAdvertisedMessageEncodings().isEmpty()) { String acceptEncoding = ACCEPT_ENCODING_JOINER.join(decompressorRegistry.getAdvertisedMessageEncodings()); headers.put(MESSAGE_ACCEPT_ENCODING_KEY, acceptEncoding); } } @Override public void start(final Listener<RespT> observer, Metadata headers) { checkState(stream == null, "Already started"); checkNotNull(observer, "observer"); checkNotNull(headers, "headers"); // Create the context final Deadline effectiveDeadline = min(callOptions.getDeadline(), parentContext.getDeadline()); if (effectiveDeadline != parentContext.getDeadline()) { context = parentContext.withDeadline(effectiveDeadline, deadlineCancellationExecutor); } else { context = parentContext.withCancellation(); } if (context.isCancelled()) { // Context is already cancelled so no need to create a real stream, just notify the observer // of cancellation via callback on the executor stream = NoopClientStream.INSTANCE; callExecutor.execute( new ContextRunnable(context) { @Override public void runInContext() { observer.onClose(statusFromCancelled(context), new Metadata()); } }); return; } final String compressorName = callOptions.getCompressor(); Compressor compressor = null; if (compressorName != null) { compressor = compressorRegistry.lookupCompressor(compressorName); if (compressor == null) { stream = NoopClientStream.INSTANCE; callExecutor.execute( new ContextRunnable(context) { @Override public void runInContext() { observer.onClose( Status.INTERNAL.withDescription( String.format("Unable to find compressor by name %s", compressorName)), new Metadata()); } }); return; } } else { compressor = Codec.Identity.NONE; } prepareHeaders(headers, callOptions, userAgent, decompressorRegistry, compressor); final boolean deadlineExceeded = effectiveDeadline != null && effectiveDeadline.isExpired(); if (!deadlineExceeded) { updateTimeoutHeaders( effectiveDeadline, callOptions.getDeadline(), parentContext.getDeadline(), headers); ClientTransport transport = clientTransportProvider.get(callOptions); stream = transport.newStream(method, headers); } else { stream = new FailingClientStream(DEADLINE_EXCEEDED); } if (callOptions.getAuthority() != null) { stream.setAuthority(callOptions.getAuthority()); } stream.setCompressor(compressor); stream.start(new ClientStreamListenerImpl(observer)); if (compressor != Codec.Identity.NONE) { stream.setMessageCompression(true); } // Delay any sources of cancellation after start(), because most of the transports are broken if // they receive cancel before start. Issue #1343 has more details // Propagate later Context cancellation to the remote side. context.addListener(this, directExecutor()); if (contextListenerShouldBeRemoved) { // Race detected! ClientStreamListener.closed may have been called before // deadlineCancellationFuture was set, thereby preventing the future from being cancelled. // Go ahead and cancel again, just to be sure it was cancelled. context.removeListener(this); } } /** Based on the deadline, calculate and set the timeout to the given headers. */ private static void updateTimeoutHeaders( @Nullable Deadline effectiveDeadline, @Nullable Deadline callDeadline, @Nullable Deadline outerCallDeadline, Metadata headers) { headers.removeAll(TIMEOUT_KEY); if (effectiveDeadline == null) { return; } long effectiveTimeout = max(0, effectiveDeadline.timeRemaining(TimeUnit.NANOSECONDS)); headers.put(TIMEOUT_KEY, effectiveTimeout); logIfContextNarrowedTimeout( effectiveTimeout, effectiveDeadline, outerCallDeadline, callDeadline); } private static void logIfContextNarrowedTimeout( long effectiveTimeout, Deadline effectiveDeadline, @Nullable Deadline outerCallDeadline, @Nullable Deadline callDeadline) { if (!log.isLoggable(Level.INFO) || outerCallDeadline != effectiveDeadline) { return; } StringBuilder builder = new StringBuilder(); builder.append( String.format("Call timeout set to '%d' ns, due to context deadline.", effectiveTimeout)); if (callDeadline == null) { builder.append(" Explicit call timeout was not set."); } else { long callTimeout = callDeadline.timeRemaining(TimeUnit.NANOSECONDS); builder.append(String.format(" Explicit call timeout was '%d' ns.", callTimeout)); } log.info(builder.toString()); } private static Deadline min(@Nullable Deadline deadline0, @Nullable Deadline deadline1) { if (deadline0 == null) { return deadline1; } if (deadline1 == null) { return deadline0; } return deadline0.minimum(deadline1); } @Override public void request(int numMessages) { Preconditions.checkState(stream != null, "Not started"); checkArgument(numMessages >= 0, "Number requested must be non-negative"); stream.request(numMessages); } @Override public void cancel(@Nullable String message, @Nullable Throwable cause) { if (cancelCalled) { return; } cancelCalled = true; try { // Cancel is called in exception handling cases, so it may be the case that the // stream was never successfully created. if (stream != null) { Status status = Status.CANCELLED; if (message != null) { status = status.withDescription(message); } if (cause != null) { status = status.withCause(cause); } if (message == null && cause == null) { // TODO(zhangkun83): log a warning with this exception once cancel() has been deleted from // ClientCall. status = status.withCause( new CancellationException("Client called cancel() without any detail")); } stream.cancel(status); } } finally { if (context != null) { context.removeListener(ClientCallImpl.this); } } } @Override public void halfClose() { Preconditions.checkState(stream != null, "Not started"); Preconditions.checkState(!cancelCalled, "call was cancelled"); Preconditions.checkState(!halfCloseCalled, "call already half-closed"); halfCloseCalled = true; stream.halfClose(); } @Override public void sendMessage(ReqT message) { Preconditions.checkState(stream != null, "Not started"); Preconditions.checkState(!cancelCalled, "call was cancelled"); Preconditions.checkState(!halfCloseCalled, "call was half-closed"); try { // TODO(notcarl): Find out if messageIs needs to be closed. InputStream messageIs = method.streamRequest(message); stream.writeMessage(messageIs); } catch (Throwable e) { stream.cancel(Status.CANCELLED.withCause(e).withDescription("Failed to stream message")); return; } // For unary requests, we don't flush since we know that halfClose should be coming soon. This // allows us to piggy-back the END_STREAM=true on the last message frame without opening the // possibility of broken applications forgetting to call halfClose without noticing. if (!unaryRequest) { stream.flush(); } } @Override public void setMessageCompression(boolean enabled) { checkState(stream != null, "Not started"); stream.setMessageCompression(enabled); } @Override public boolean isReady() { return stream.isReady(); } private class ClientStreamListenerImpl implements ClientStreamListener { private final Listener<RespT> observer; private boolean closed; public ClientStreamListenerImpl(Listener<RespT> observer) { this.observer = Preconditions.checkNotNull(observer, "observer"); } @Override public void headersRead(final Metadata headers) { Decompressor decompressor = Codec.Identity.NONE; if (headers.containsKey(MESSAGE_ENCODING_KEY)) { String encoding = headers.get(MESSAGE_ENCODING_KEY); decompressor = decompressorRegistry.lookupDecompressor(encoding); if (decompressor == null) { stream.cancel( Status.INTERNAL.withDescription( String.format("Can't find decompressor for %s", encoding))); return; } } stream.setDecompressor(decompressor); callExecutor.execute( new ContextRunnable(context) { @Override public final void runInContext() { try { if (closed) { return; } observer.onHeaders(headers); } catch (Throwable t) { stream.cancel( Status.CANCELLED.withCause(t).withDescription("Failed to read headers")); return; } } }); } @Override public void messageRead(final InputStream message) { callExecutor.execute( new ContextRunnable(context) { @Override public final void runInContext() { try { if (closed) { return; } try { observer.onMessage(method.parseResponse(message)); } finally { message.close(); } } catch (Throwable t) { stream.cancel( Status.CANCELLED.withCause(t).withDescription("Failed to read message.")); return; } } }); } @Override public void closed(Status status, Metadata trailers) { Deadline deadline = context.getDeadline(); if (status.getCode() == Status.Code.CANCELLED && deadline != null) { // When the server's deadline expires, it can only reset the stream with CANCEL and no // description. Since our timer may be delayed in firing, we double-check the deadline and // turn the failure into the likely more helpful DEADLINE_EXCEEDED status. if (deadline.isExpired()) { status = DEADLINE_EXCEEDED; // Replace trailers to prevent mixing sources of status and trailers. trailers = new Metadata(); } } final Status savedStatus = status; final Metadata savedTrailers = trailers; callExecutor.execute( new ContextRunnable(context) { @Override public final void runInContext() { try { closed = true; contextListenerShouldBeRemoved = true; observer.onClose(savedStatus, savedTrailers); } finally { context.removeListener(ClientCallImpl.this); } } }); } @Override public void onReady() { callExecutor.execute( new ContextRunnable(context) { @Override public final void runInContext() { observer.onReady(); } }); } } }