@Test public void callerContextPropagatedToListener() throws Exception { // Attach the context which is recorded when the call is created final Context.Key<String> testKey = Context.key("testing"); Context.current().withValue(testKey, "testValue").attach(); ClientCallImpl<Void, Void> call = new ClientCallImpl<Void, Void>( DESCRIPTOR, new SerializingExecutor(Executors.newSingleThreadExecutor()), CallOptions.DEFAULT, provider, deadlineCancellationExecutor) .setDecompressorRegistry(decompressorRegistry); Context.ROOT.attach(); // Override the value after creating the call, this should not be seen by callbacks Context.current().withValue(testKey, "badValue").attach(); final AtomicBoolean onHeadersCalled = new AtomicBoolean(); final AtomicBoolean onMessageCalled = new AtomicBoolean(); final AtomicBoolean onReadyCalled = new AtomicBoolean(); final AtomicBoolean observedIncorrectContext = new AtomicBoolean(); final CountDownLatch latch = new CountDownLatch(1); call.start( new ClientCall.Listener<Void>() { @Override public void onHeaders(Metadata headers) { onHeadersCalled.set(true); checkContext(); } @Override public void onMessage(Void message) { onMessageCalled.set(true); checkContext(); } @Override public void onClose(Status status, Metadata trailers) { checkContext(); latch.countDown(); } @Override public void onReady() { onReadyCalled.set(true); checkContext(); } private void checkContext() { if (!"testValue".equals(testKey.get())) { observedIncorrectContext.set(true); } } }, new Metadata()); verify(stream).start(listenerArgumentCaptor.capture()); ClientStreamListener listener = listenerArgumentCaptor.getValue(); listener.onReady(); listener.headersRead(new Metadata()); listener.messageRead(new ByteArrayInputStream(new byte[0])); listener.messageRead(new ByteArrayInputStream(new byte[0])); listener.closed(Status.OK, new Metadata()); assertTrue(latch.await(5, TimeUnit.SECONDS)); assertTrue(onHeadersCalled.get()); assertTrue(onMessageCalled.get()); assertTrue(onReadyCalled.get()); assertFalse(observedIncorrectContext.get()); }
/** Tests for {@link Contexts}. */ @RunWith(JUnit4.class) public class ContextsTest { private static Context.Key<Object> contextKey = Context.key("key"); /** For use in comparing context by reference. */ private Context uniqueContext = Context.ROOT.withValue(contextKey, new Object()); @SuppressWarnings("unchecked") private MethodDescriptor<Object, Object> method = mock(MethodDescriptor.class); @SuppressWarnings("unchecked") private ServerCall<Object, Object> call = mock(ServerCall.class); private Metadata headers = new Metadata(); @Test public void interceptCall_basic() { Context origContext = Context.current(); final Object message = new Object(); final List<Integer> methodCalls = new ArrayList<Integer>(); final ServerCall.Listener<Object> listener = new ServerCall.Listener<Object>() { @Override public void onMessage(Object messageIn) { assertSame(message, messageIn); assertSame(uniqueContext, Context.current()); methodCalls.add(1); } @Override public void onHalfClose() { assertSame(uniqueContext, Context.current()); methodCalls.add(2); } @Override public void onCancel() { assertSame(uniqueContext, Context.current()); methodCalls.add(3); } @Override public void onComplete() { assertSame(uniqueContext, Context.current()); methodCalls.add(4); } @Override public void onReady() { assertSame(uniqueContext, Context.current()); methodCalls.add(5); } }; ServerCall.Listener<Object> wrapped = interceptCall( uniqueContext, call, headers, new ServerCallHandler<Object, Object>() { @Override public ServerCall.Listener<Object> startCall( ServerCall<Object, Object> call, Metadata headers) { assertSame(ContextsTest.this.method, method); assertSame(ContextsTest.this.call, call); assertSame(ContextsTest.this.headers, headers); assertSame(uniqueContext, Context.current()); return listener; } }); assertSame(origContext, Context.current()); wrapped.onMessage(message); wrapped.onHalfClose(); wrapped.onCancel(); wrapped.onComplete(); wrapped.onReady(); assertEquals(Arrays.asList(1, 2, 3, 4, 5), methodCalls); assertSame(origContext, Context.current()); } @Test public void interceptCall_restoresIfNextThrows() { Context origContext = Context.current(); try { interceptCall( uniqueContext, call, headers, new ServerCallHandler<Object, Object>() { @Override public ServerCall.Listener<Object> startCall( ServerCall<Object, Object> call, Metadata headers) { throw new RuntimeException(); } }); fail("Expected exception"); } catch (RuntimeException expected) { } assertSame(origContext, Context.current()); } @Test public void interceptCall_restoresIfListenerThrows() { Context origContext = Context.current(); final ServerCall.Listener<Object> listener = new ServerCall.Listener<Object>() { @Override public void onMessage(Object messageIn) { throw new RuntimeException(); } @Override public void onHalfClose() { throw new RuntimeException(); } @Override public void onCancel() { throw new RuntimeException(); } @Override public void onComplete() { throw new RuntimeException(); } @Override public void onReady() { throw new RuntimeException(); } }; ServerCall.Listener<Object> wrapped = interceptCall( uniqueContext, call, headers, new ServerCallHandler<Object, Object>() { @Override public ServerCall.Listener<Object> startCall( ServerCall<Object, Object> call, Metadata headers) { return listener; } }); try { wrapped.onMessage(new Object()); fail("Exception expected"); } catch (RuntimeException expected) { } try { wrapped.onHalfClose(); fail("Exception expected"); } catch (RuntimeException expected) { } try { wrapped.onCancel(); fail("Exception expected"); } catch (RuntimeException expected) { } try { wrapped.onComplete(); fail("Exception expected"); } catch (RuntimeException expected) { } try { wrapped.onReady(); fail("Exception expected"); } catch (RuntimeException expected) { } assertSame(origContext, Context.current()); } @Test public void statusFromCancelled_returnNullIfCtxNotCancelled() { Context context = Context.current(); assertFalse(context.isCancelled()); assertNull(statusFromCancelled(context)); } @Test public void statusFromCancelled_returnStatusAsSetOnCtx() { Context.CancellableContext cancellableContext = Context.current().withCancellation(); cancellableContext.cancel(Status.DEADLINE_EXCEEDED.withDescription("foo bar").asException()); Status status = statusFromCancelled(cancellableContext); assertNotNull(status); assertEquals(Status.Code.DEADLINE_EXCEEDED, status.getCode()); assertEquals("foo bar", status.getDescription()); } @Test public void statusFromCancelled_shouldReturnStatusWithCauseAttached() { Context.CancellableContext cancellableContext = Context.current().withCancellation(); Throwable t = new Throwable(); cancellableContext.cancel(t); Status status = statusFromCancelled(cancellableContext); assertNotNull(status); assertEquals(Status.Code.CANCELLED, status.getCode()); assertSame(t, status.getCause()); } @Test public void statusFromCancelled_TimeoutExceptionShouldMapToDeadlineExceeded() { FakeClock fakeClock = new FakeClock(); Context.CancellableContext cancellableContext = Context.current() .withDeadlineAfter(100, TimeUnit.MILLISECONDS, fakeClock.scheduledExecutorService); fakeClock.forwardTime(System.nanoTime(), TimeUnit.NANOSECONDS); fakeClock.forwardMillis(100); assertTrue(cancellableContext.isCancelled()); assertThat(cancellableContext.cancellationCause(), instanceOf(TimeoutException.class)); Status status = statusFromCancelled(cancellableContext); assertNotNull(status); assertEquals(Status.Code.DEADLINE_EXCEEDED, status.getCode()); assertEquals("context timed out", status.getDescription()); } @Test public void statusFromCancelled_returnCancelledIfCauseIsNull() { Context.CancellableContext cancellableContext = Context.current().withCancellation(); cancellableContext.cancel(null); assertTrue(cancellableContext.isCancelled()); Status status = statusFromCancelled(cancellableContext); assertNotNull(status); assertEquals(Status.Code.CANCELLED, status.getCode()); } /** This is a whitebox test, to verify a special case of the implementation. */ @Test public void statusFromCancelled_StatusUnknownShouldWork() { Context.CancellableContext cancellableContext = Context.current().withCancellation(); Exception e = Status.UNKNOWN.asException(); cancellableContext.cancel(e); assertTrue(cancellableContext.isCancelled()); Status status = statusFromCancelled(cancellableContext); assertNotNull(status); assertEquals(Status.Code.UNKNOWN, status.getCode()); assertSame(e, status.getCause()); } @Test public void statusFromCancelled_shouldThrowIfCtxIsNull() { try { statusFromCancelled(null); fail("NPE expected"); } catch (NullPointerException npe) { assertEquals("context must not be null", npe.getMessage()); } } }