@VisibleForTesting
  static void prepareHeaders(
      Metadata headers,
      CallOptions callOptions,
      String userAgent,
      DecompressorRegistry decompressorRegistry,
      Compressor compressor) {
    // Fill out the User-Agent header.
    headers.removeAll(USER_AGENT_KEY);
    if (userAgent != null) {
      headers.put(USER_AGENT_KEY, userAgent);
    }

    headers.removeAll(MESSAGE_ENCODING_KEY);
    if (compressor != Codec.Identity.NONE) {
      headers.put(MESSAGE_ENCODING_KEY, compressor.getMessageEncoding());
    }

    headers.removeAll(MESSAGE_ACCEPT_ENCODING_KEY);
    if (!decompressorRegistry.getAdvertisedMessageEncodings().isEmpty()) {
      String acceptEncoding =
          ACCEPT_ENCODING_JOINER.join(decompressorRegistry.getAdvertisedMessageEncodings());
      headers.put(MESSAGE_ACCEPT_ENCODING_KEY, acceptEncoding);
    }
  }
  @Test
  public void prepareHeaders_acceptedEncodingsAdded() {
    Metadata m = new Metadata();
    DecompressorRegistry customRegistry = DecompressorRegistry.newEmptyInstance();
    customRegistry.register(
        new Decompressor() {
          @Override
          public String getMessageEncoding() {
            return "a";
          }

          @Override
          public InputStream decompress(InputStream is) throws IOException {
            return null;
          }
        },
        true);
    customRegistry.register(
        new Decompressor() {
          @Override
          public String getMessageEncoding() {
            return "b";
          }

          @Override
          public InputStream decompress(InputStream is) throws IOException {
            return null;
          }
        },
        true);
    customRegistry.register(
        new Decompressor() {
          @Override
          public String getMessageEncoding() {
            return "c";
          }

          @Override
          public InputStream decompress(InputStream is) throws IOException {
            return null;
          }
        },
        false); // not advertised

    ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, "user agent", customRegistry);

    Iterable<String> acceptedEncodings =
        Splitter.on(',').split(m.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY));

    // Order may be different, since decoder priorities have not yet been implemented.
    assertEquals(ImmutableSet.of("b", "a"), ImmutableSet.copyOf(acceptedEncodings));
  }
  @Test
  public void prepareHeaders_userAgentAdded() {
    Metadata m = new Metadata();
    ClientCallImpl.prepareHeaders(
        m, CallOptions.DEFAULT, "user agent", DecompressorRegistry.getDefaultInstance());

    assertEquals(m.get(GrpcUtil.USER_AGENT_KEY), "user agent");
  }
  @Test
  public void prepareHeaders_ignoreIdentityEncoding() {
    Metadata m = new Metadata();
    CallOptions callOptions = CallOptions.DEFAULT.withCompressor(Codec.Identity.NONE);
    ClientCallImpl.prepareHeaders(
        m, callOptions, "user agent", DecompressorRegistry.getDefaultInstance());

    assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY));
  }
  @Test
  public void prepareHeaders_messageEncodingAdded() {
    Metadata m = new Metadata();
    CallOptions callOptions = CallOptions.DEFAULT.withCompressor(new Codec.Gzip());
    ClientCallImpl.prepareHeaders(
        m, callOptions, "user agent", DecompressorRegistry.getDefaultInstance());

    assertEquals(m.get(GrpcUtil.MESSAGE_ENCODING_KEY), new Codec.Gzip().getMessageEncoding());
  }
  @Test
  public void prepareHeaders_authorityAdded() {
    Metadata m = new Metadata();
    CallOptions callOptions = CallOptions.DEFAULT.withAuthority("auth");
    ClientCallImpl.prepareHeaders(
        m, callOptions, "user agent", DecompressorRegistry.getDefaultInstance());

    assertEquals(m.get(GrpcUtil.AUTHORITY_KEY), "auth");
  }
  @Test
  public void prepareHeaders_removeReservedHeaders() {
    Metadata m = new Metadata();
    m.put(GrpcUtil.MESSAGE_ENCODING_KEY, "gzip");
    m.put(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY, "gzip");

    ClientCallImpl.prepareHeaders(m, DecompressorRegistry.emptyInstance(), Codec.Identity.NONE);

    assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY));
    assertNull(m.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY));
  }
  @Test
  public void prepareHeaders_removeReservedHeaders() {
    Metadata m = new Metadata();
    m.put(GrpcUtil.AUTHORITY_KEY, "auth");
    m.put(GrpcUtil.USER_AGENT_KEY, "user agent");
    m.put(GrpcUtil.MESSAGE_ENCODING_KEY, "gzip");
    m.put(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY, "gzip");

    ClientCallImpl.prepareHeaders(
        m, CallOptions.DEFAULT, null, DecompressorRegistry.newEmptyInstance());

    assertNull(m.get(GrpcUtil.AUTHORITY_KEY));
    assertNull(m.get(GrpcUtil.USER_AGENT_KEY));
    assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY));
    assertNull(m.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY));
  }
  @Test
  public void advertisedEncodingsAreSent() {
    ClientCallImpl<Void, Void> call =
        new ClientCallImpl<Void, Void>(
                method,
                MoreExecutors.directExecutor(),
                CallOptions.DEFAULT,
                provider,
                deadlineCancellationExecutor)
            .setDecompressorRegistry(decompressorRegistry);

    call.start(callListener, new Metadata());

    ArgumentCaptor<Metadata> metadataCaptor = ArgumentCaptor.forClass(Metadata.class);
    verify(transport).newStream(eq(method), metadataCaptor.capture(), same(CallOptions.DEFAULT));
    Metadata actual = metadataCaptor.getValue();

    Set<String> acceptedEncodings =
        ImmutableSet.copyOf(actual.getAll(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY));
    assertEquals(decompressorRegistry.getAdvertisedMessageEncodings(), acceptedEncodings);
  }
  @Test
  public void advertisedEncodingsAreSent() {
    MethodDescriptor<Void, Void> descriptor =
        MethodDescriptor.create(
            MethodType.UNARY,
            "service/method",
            new TestMarshaller<Void>(),
            new TestMarshaller<Void>());
    final ClientTransport transport = mock(ClientTransport.class);
    final ClientStream stream = mock(ClientStream.class);
    ClientTransportProvider provider =
        new ClientTransportProvider() {
          @Override
          public ListenableFuture<ClientTransport> get(CallOptions callOptions) {
            return Futures.immediateFuture(transport);
          }
        };
    when(transport.newStream(
            any(MethodDescriptor.class), any(Metadata.class), any(ClientStreamListener.class)))
        .thenReturn(stream);
    ClientCallImpl<Void, Void> call =
        new ClientCallImpl<Void, Void>(
                descriptor, executor, CallOptions.DEFAULT, provider, deadlineCancellationExecutor)
            .setDecompressorRegistry(decompressorRegistry);

    call.start(new TestClientCallListener<Void>(), new Metadata());

    ArgumentCaptor<Metadata> metadataCaptor = ArgumentCaptor.forClass(Metadata.class);
    verify(transport)
        .newStream(eq(descriptor), metadataCaptor.capture(), isA(ClientStreamListener.class));
    Metadata actual = metadataCaptor.getValue();

    Set<String> acceptedEncodings =
        ImmutableSet.copyOf(actual.getAll(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY));
    assertEquals(decompressorRegistry.getAdvertisedMessageEncodings(), acceptedEncodings);
  }
 /**
  * Constructs a new {@link GrpcService} that can be bound to {@link
  * com.linecorp.armeria.server.ServerBuilder}. As GRPC services themselves are mounted at a path
  * that corresponds to their protobuf package, you will almost always want to bind to a prefix,
  * e.g. by using {@link com.linecorp.armeria.server.ServerBuilder#serviceUnder(String, Service)}.
  */
 public GrpcService build() {
   return new GrpcService(
       registryBuilder.build(),
       firstNonNull(decompressorRegistry, DecompressorRegistry.getDefaultInstance()),
       firstNonNull(compressorRegistry, CompressorRegistry.getDefaultInstance()));
 }
 @Before
 public void setUp() {
   decompressorRegistry.register(new Codec.Gzip(), true);
 }
/** Test for {@link ClientCallImpl}. */
@RunWith(JUnit4.class)
public class ClientCallImplTest {

  private final SerializingExecutor executor =
      new SerializingExecutor(MoreExecutors.directExecutor());
  private final ScheduledExecutorService deadlineCancellationExecutor =
      Executors.newScheduledThreadPool(0);
  private final DecompressorRegistry decompressorRegistry =
      DecompressorRegistry.getDefaultInstance();

  @Before
  public void setUp() {
    decompressorRegistry.register(new Codec.Gzip(), true);
  }

  @Test
  public void advertisedEncodingsAreSent() {
    MethodDescriptor<Void, Void> descriptor =
        MethodDescriptor.create(
            MethodType.UNARY,
            "service/method",
            new TestMarshaller<Void>(),
            new TestMarshaller<Void>());
    final ClientTransport transport = mock(ClientTransport.class);
    final ClientStream stream = mock(ClientStream.class);
    ClientTransportProvider provider =
        new ClientTransportProvider() {
          @Override
          public ListenableFuture<ClientTransport> get(CallOptions callOptions) {
            return Futures.immediateFuture(transport);
          }
        };
    when(transport.newStream(
            any(MethodDescriptor.class), any(Metadata.class), any(ClientStreamListener.class)))
        .thenReturn(stream);
    ClientCallImpl<Void, Void> call =
        new ClientCallImpl<Void, Void>(
                descriptor, executor, CallOptions.DEFAULT, provider, deadlineCancellationExecutor)
            .setDecompressorRegistry(decompressorRegistry);

    call.start(new TestClientCallListener<Void>(), new Metadata());

    ArgumentCaptor<Metadata> metadataCaptor = ArgumentCaptor.forClass(Metadata.class);
    verify(transport)
        .newStream(eq(descriptor), metadataCaptor.capture(), isA(ClientStreamListener.class));
    Metadata actual = metadataCaptor.getValue();

    Set<String> acceptedEncodings =
        ImmutableSet.copyOf(actual.getAll(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY));
    assertEquals(decompressorRegistry.getAdvertisedMessageEncodings(), acceptedEncodings);
  }

  @Test
  public void prepareHeaders_authorityAdded() {
    Metadata m = new Metadata();
    CallOptions callOptions = CallOptions.DEFAULT.withAuthority("auth");
    ClientCallImpl.prepareHeaders(
        m, callOptions, "user agent", DecompressorRegistry.getDefaultInstance());

    assertEquals(m.get(GrpcUtil.AUTHORITY_KEY), "auth");
  }

  @Test
  public void prepareHeaders_userAgentAdded() {
    Metadata m = new Metadata();
    ClientCallImpl.prepareHeaders(
        m, CallOptions.DEFAULT, "user agent", DecompressorRegistry.getDefaultInstance());

    assertEquals(m.get(GrpcUtil.USER_AGENT_KEY), "user agent");
  }

  @Test
  public void prepareHeaders_messageEncodingAdded() {
    Metadata m = new Metadata();
    CallOptions callOptions = CallOptions.DEFAULT.withCompressor(new Codec.Gzip());
    ClientCallImpl.prepareHeaders(
        m, callOptions, "user agent", DecompressorRegistry.getDefaultInstance());

    assertEquals(m.get(GrpcUtil.MESSAGE_ENCODING_KEY), new Codec.Gzip().getMessageEncoding());
  }

  @Test
  public void prepareHeaders_ignoreIdentityEncoding() {
    Metadata m = new Metadata();
    CallOptions callOptions = CallOptions.DEFAULT.withCompressor(Codec.Identity.NONE);
    ClientCallImpl.prepareHeaders(
        m, callOptions, "user agent", DecompressorRegistry.getDefaultInstance());

    assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY));
  }

  @Test
  public void prepareHeaders_acceptedEncodingsAdded() {
    Metadata m = new Metadata();
    DecompressorRegistry customRegistry = DecompressorRegistry.newEmptyInstance();
    customRegistry.register(
        new Decompressor() {
          @Override
          public String getMessageEncoding() {
            return "a";
          }

          @Override
          public InputStream decompress(InputStream is) throws IOException {
            return null;
          }
        },
        true);
    customRegistry.register(
        new Decompressor() {
          @Override
          public String getMessageEncoding() {
            return "b";
          }

          @Override
          public InputStream decompress(InputStream is) throws IOException {
            return null;
          }
        },
        true);
    customRegistry.register(
        new Decompressor() {
          @Override
          public String getMessageEncoding() {
            return "c";
          }

          @Override
          public InputStream decompress(InputStream is) throws IOException {
            return null;
          }
        },
        false); // not advertised

    ClientCallImpl.prepareHeaders(m, CallOptions.DEFAULT, "user agent", customRegistry);

    Iterable<String> acceptedEncodings =
        Splitter.on(',').split(m.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY));

    // Order may be different, since decoder priorities have not yet been implemented.
    assertEquals(ImmutableSet.of("b", "a"), ImmutableSet.copyOf(acceptedEncodings));
  }

  @Test
  public void prepareHeaders_removeReservedHeaders() {
    Metadata m = new Metadata();
    m.put(GrpcUtil.AUTHORITY_KEY, "auth");
    m.put(GrpcUtil.USER_AGENT_KEY, "user agent");
    m.put(GrpcUtil.MESSAGE_ENCODING_KEY, "gzip");
    m.put(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY, "gzip");

    ClientCallImpl.prepareHeaders(
        m, CallOptions.DEFAULT, null, DecompressorRegistry.newEmptyInstance());

    assertNull(m.get(GrpcUtil.AUTHORITY_KEY));
    assertNull(m.get(GrpcUtil.USER_AGENT_KEY));
    assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY));
    assertNull(m.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY));
  }

  private static class TestMarshaller<T> implements Marshaller<T> {
    @Override
    public InputStream stream(T value) {
      return null;
    }

    @Override
    public T parse(InputStream stream) {
      return null;
    }
  }

  private static class TestClientCallListener<T> extends ClientCall.Listener<T> {}
}
/** Test for {@link ClientCallImpl}. */
@RunWith(JUnit4.class)
public class ClientCallImplTest {

  private static final MethodDescriptor<Void, Void> DESCRIPTOR =
      MethodDescriptor.create(
          MethodType.UNARY,
          "service/method",
          new TestMarshaller<Void>(),
          new TestMarshaller<Void>());

  private final FakeClock fakeClock = new FakeClock();
  private final ScheduledExecutorService deadlineCancellationExecutor =
      fakeClock.scheduledExecutorService;
  private final DecompressorRegistry decompressorRegistry =
      DecompressorRegistry.getDefaultInstance().with(new Codec.Gzip(), true);
  private final MethodDescriptor<Void, Void> method =
      MethodDescriptor.create(
          MethodType.UNARY,
          "service/method",
          new TestMarshaller<Void>(),
          new TestMarshaller<Void>());

  @Mock private ClientStreamListener streamListener;
  @Mock private ClientTransport clientTransport;
  @Captor private ArgumentCaptor<Status> statusCaptor;

  @Mock private ClientTransport transport;

  @Mock private ClientTransportProvider provider;

  @Mock private ClientStream stream;

  @Mock private ClientCall.Listener<Void> callListener;

  @Captor private ArgumentCaptor<ClientStreamListener> listenerArgumentCaptor;

  @Captor private ArgumentCaptor<Status> statusArgumentCaptor;

  @Before
  public void setUp() {
    MockitoAnnotations.initMocks(this);
    when(provider.get(any(CallOptions.class))).thenReturn(transport);
    when(transport.newStream(
            any(MethodDescriptor.class), any(Metadata.class), any(CallOptions.class)))
        .thenReturn(stream);
  }

  @After
  public void tearDown() {
    Context.ROOT.attach();
  }

  @Test
  public void advertisedEncodingsAreSent() {
    ClientCallImpl<Void, Void> call =
        new ClientCallImpl<Void, Void>(
                method,
                MoreExecutors.directExecutor(),
                CallOptions.DEFAULT,
                provider,
                deadlineCancellationExecutor)
            .setDecompressorRegistry(decompressorRegistry);

    call.start(callListener, new Metadata());

    ArgumentCaptor<Metadata> metadataCaptor = ArgumentCaptor.forClass(Metadata.class);
    verify(transport).newStream(eq(method), metadataCaptor.capture(), same(CallOptions.DEFAULT));
    Metadata actual = metadataCaptor.getValue();

    Set<String> acceptedEncodings =
        ImmutableSet.copyOf(actual.getAll(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY));
    assertEquals(decompressorRegistry.getAdvertisedMessageEncodings(), acceptedEncodings);
  }

  @Test
  public void authorityPropagatedToStream() {
    ClientCallImpl<Void, Void> call =
        new ClientCallImpl<Void, Void>(
                method,
                MoreExecutors.directExecutor(),
                CallOptions.DEFAULT.withAuthority("overridden-authority"),
                provider,
                deadlineCancellationExecutor)
            .setDecompressorRegistry(decompressorRegistry);

    call.start(callListener, new Metadata());
    verify(stream).setAuthority("overridden-authority");
  }

  @Test
  public void callOptionsPropagatedToTransport() {
    final CallOptions callOptions = CallOptions.DEFAULT.withAuthority("dummy_value");
    final ClientCallImpl<Void, Void> call =
        new ClientCallImpl<Void, Void>(
                method,
                MoreExecutors.directExecutor(),
                callOptions,
                provider,
                deadlineCancellationExecutor)
            .setDecompressorRegistry(decompressorRegistry);
    final Metadata metadata = new Metadata();

    call.start(callListener, metadata);

    verify(transport).newStream(same(method), same(metadata), same(callOptions));
  }

  @Test
  public void authorityNotPropagatedToStream() {
    ClientCallImpl<Void, Void> call =
        new ClientCallImpl<Void, Void>(
                method,
                MoreExecutors.directExecutor(),
                // Don't provide an authority
                CallOptions.DEFAULT,
                provider,
                deadlineCancellationExecutor)
            .setDecompressorRegistry(decompressorRegistry);

    call.start(callListener, new Metadata());
    verify(stream, never()).setAuthority(any(String.class));
  }

  @Test
  public void prepareHeaders_userAgentIgnored() {
    Metadata m = new Metadata();
    m.put(GrpcUtil.USER_AGENT_KEY, "batmobile");
    ClientCallImpl.prepareHeaders(m, decompressorRegistry, Codec.Identity.NONE);

    // User Agent is removed and set by the transport
    assertThat(m.get(GrpcUtil.USER_AGENT_KEY)).isNotNull();
  }

  @Test
  public void prepareHeaders_ignoreIdentityEncoding() {
    Metadata m = new Metadata();
    ClientCallImpl.prepareHeaders(m, decompressorRegistry, Codec.Identity.NONE);

    assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY));
  }

  @Test
  public void prepareHeaders_acceptedEncodingsAdded() {
    Metadata m = new Metadata();
    DecompressorRegistry customRegistry =
        DecompressorRegistry.emptyInstance()
            .with(
                new Decompressor() {
                  @Override
                  public String getMessageEncoding() {
                    return "a";
                  }

                  @Override
                  public InputStream decompress(InputStream is) throws IOException {
                    return null;
                  }
                },
                true)
            .with(
                new Decompressor() {
                  @Override
                  public String getMessageEncoding() {
                    return "b";
                  }

                  @Override
                  public InputStream decompress(InputStream is) throws IOException {
                    return null;
                  }
                },
                true)
            .with(
                new Decompressor() {
                  @Override
                  public String getMessageEncoding() {
                    return "c";
                  }

                  @Override
                  public InputStream decompress(InputStream is) throws IOException {
                    return null;
                  }
                },
                false); // not advertised

    ClientCallImpl.prepareHeaders(m, customRegistry, Codec.Identity.NONE);

    Iterable<String> acceptedEncodings =
        ACCEPT_ENCODING_SPLITER.split(m.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY));

    // Order may be different, since decoder priorities have not yet been implemented.
    assertEquals(ImmutableSet.of("b", "a"), ImmutableSet.copyOf(acceptedEncodings));
  }

  @Test
  public void prepareHeaders_removeReservedHeaders() {
    Metadata m = new Metadata();
    m.put(GrpcUtil.MESSAGE_ENCODING_KEY, "gzip");
    m.put(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY, "gzip");

    ClientCallImpl.prepareHeaders(m, DecompressorRegistry.emptyInstance(), Codec.Identity.NONE);

    assertNull(m.get(GrpcUtil.MESSAGE_ENCODING_KEY));
    assertNull(m.get(GrpcUtil.MESSAGE_ACCEPT_ENCODING_KEY));
  }

  @Test
  public void callerContextPropagatedToListener() throws Exception {
    // Attach the context which is recorded when the call is created
    final Context.Key<String> testKey = Context.key("testing");
    Context.current().withValue(testKey, "testValue").attach();

    ClientCallImpl<Void, Void> call =
        new ClientCallImpl<Void, Void>(
                DESCRIPTOR,
                new SerializingExecutor(Executors.newSingleThreadExecutor()),
                CallOptions.DEFAULT,
                provider,
                deadlineCancellationExecutor)
            .setDecompressorRegistry(decompressorRegistry);

    Context.ROOT.attach();

    // Override the value after creating the call, this should not be seen by callbacks
    Context.current().withValue(testKey, "badValue").attach();

    final AtomicBoolean onHeadersCalled = new AtomicBoolean();
    final AtomicBoolean onMessageCalled = new AtomicBoolean();
    final AtomicBoolean onReadyCalled = new AtomicBoolean();
    final AtomicBoolean observedIncorrectContext = new AtomicBoolean();
    final CountDownLatch latch = new CountDownLatch(1);

    call.start(
        new ClientCall.Listener<Void>() {
          @Override
          public void onHeaders(Metadata headers) {
            onHeadersCalled.set(true);
            checkContext();
          }

          @Override
          public void onMessage(Void message) {
            onMessageCalled.set(true);
            checkContext();
          }

          @Override
          public void onClose(Status status, Metadata trailers) {
            checkContext();
            latch.countDown();
          }

          @Override
          public void onReady() {
            onReadyCalled.set(true);
            checkContext();
          }

          private void checkContext() {
            if (!"testValue".equals(testKey.get())) {
              observedIncorrectContext.set(true);
            }
          }
        },
        new Metadata());

    verify(stream).start(listenerArgumentCaptor.capture());
    ClientStreamListener listener = listenerArgumentCaptor.getValue();
    listener.onReady();
    listener.headersRead(new Metadata());
    listener.messageRead(new ByteArrayInputStream(new byte[0]));
    listener.messageRead(new ByteArrayInputStream(new byte[0]));
    listener.closed(Status.OK, new Metadata());

    assertTrue(latch.await(5, TimeUnit.SECONDS));

    assertTrue(onHeadersCalled.get());
    assertTrue(onMessageCalled.get());
    assertTrue(onReadyCalled.get());
    assertFalse(observedIncorrectContext.get());
  }

  @Test
  public void contextCancellationCancelsStream() throws Exception {
    // Attach the context which is recorded when the call is created
    Context.CancellableContext cancellableContext = Context.current().withCancellation();
    Context previous = cancellableContext.attach();

    ClientCallImpl<Void, Void> call =
        new ClientCallImpl<Void, Void>(
                DESCRIPTOR,
                new SerializingExecutor(Executors.newSingleThreadExecutor()),
                CallOptions.DEFAULT,
                provider,
                deadlineCancellationExecutor)
            .setDecompressorRegistry(decompressorRegistry);

    previous.attach();

    call.start(callListener, new Metadata());

    Throwable t = new Throwable();
    cancellableContext.cancel(t);

    verify(stream, times(1)).cancel(statusArgumentCaptor.capture());

    verify(stream, times(1)).cancel(statusCaptor.capture());
    assertEquals(Status.Code.CANCELLED, statusCaptor.getValue().getCode());
  }

  @Test
  public void contextAlreadyCancelledNotifiesImmediately() throws Exception {
    // Attach the context which is recorded when the call is created
    Context.CancellableContext cancellableContext = Context.current().withCancellation();
    Throwable cause = new Throwable();
    cancellableContext.cancel(cause);
    Context previous = cancellableContext.attach();

    ClientCallImpl<Void, Void> call =
        new ClientCallImpl<Void, Void>(
                DESCRIPTOR,
                new SerializingExecutor(Executors.newSingleThreadExecutor()),
                CallOptions.DEFAULT,
                provider,
                deadlineCancellationExecutor)
            .setDecompressorRegistry(decompressorRegistry);

    previous.attach();

    final SettableFuture<Status> statusFuture = SettableFuture.create();
    call.start(
        new ClientCall.Listener<Void>() {
          @Override
          public void onClose(Status status, Metadata trailers) {
            statusFuture.set(status);
          }
        },
        new Metadata());

    // Caller should receive onClose callback.
    Status status = statusFuture.get(5, TimeUnit.SECONDS);
    assertEquals(Status.Code.CANCELLED, status.getCode());
    assertSame(cause, status.getCause());

    // Following operations should be no-op.
    call.request(1);
    call.sendMessage(null);
    call.halfClose();

    // Stream should never be created.
    verifyZeroInteractions(transport);

    try {
      call.sendMessage(null);
      fail("Call has been cancelled");
    } catch (IllegalStateException ise) {
      // expected
    }
  }

  @Test
  public void deadlineExceededBeforeCallStarted() {
    CallOptions callOptions = CallOptions.DEFAULT.withDeadlineAfter(0, TimeUnit.SECONDS);
    fakeClock.forwardTime(System.nanoTime(), TimeUnit.NANOSECONDS);
    ClientCallImpl<Void, Void> call =
        new ClientCallImpl<Void, Void>(
                DESCRIPTOR,
                new SerializingExecutor(Executors.newSingleThreadExecutor()),
                callOptions,
                provider,
                deadlineCancellationExecutor)
            .setDecompressorRegistry(decompressorRegistry);
    call.start(callListener, new Metadata());
    verify(transport, times(0)).newStream(any(MethodDescriptor.class), any(Metadata.class));
    verify(callListener, timeout(1000)).onClose(statusCaptor.capture(), any(Metadata.class));
    assertEquals(Status.Code.DEADLINE_EXCEEDED, statusCaptor.getValue().getCode());
    verifyZeroInteractions(provider);
  }

  @Test
  public void contextDeadlineShouldBePropagatedInMetadata() {
    long deadlineNanos = TimeUnit.SECONDS.toNanos(1);
    Context context =
        Context.current()
            .withDeadlineAfter(deadlineNanos, TimeUnit.NANOSECONDS, deadlineCancellationExecutor);
    context.attach();

    ClientCallImpl<Void, Void> call =
        new ClientCallImpl<Void, Void>(
            DESCRIPTOR,
            MoreExecutors.directExecutor(),
            CallOptions.DEFAULT,
            provider,
            deadlineCancellationExecutor);

    Metadata headers = new Metadata();

    call.start(callListener, headers);

    assertTrue(headers.containsKey(GrpcUtil.TIMEOUT_KEY));
    Long timeout = headers.get(GrpcUtil.TIMEOUT_KEY);
    assertNotNull(timeout);

    long deltaNanos = TimeUnit.MILLISECONDS.toNanos(400);
    assertTimeoutBetween(timeout, deadlineNanos - deltaNanos, deadlineNanos);
  }

  @Test
  public void contextDeadlineShouldOverrideLargerMetadataTimeout() {
    long deadlineNanos = TimeUnit.SECONDS.toNanos(1);
    Context context =
        Context.current()
            .withDeadlineAfter(deadlineNanos, TimeUnit.NANOSECONDS, deadlineCancellationExecutor);
    context.attach();

    CallOptions callOpts = CallOptions.DEFAULT.withDeadlineAfter(2, TimeUnit.SECONDS);
    ClientCallImpl<Void, Void> call =
        new ClientCallImpl<Void, Void>(
            DESCRIPTOR,
            MoreExecutors.directExecutor(),
            callOpts,
            provider,
            deadlineCancellationExecutor);

    Metadata headers = new Metadata();

    call.start(callListener, headers);

    assertTrue(headers.containsKey(GrpcUtil.TIMEOUT_KEY));
    Long timeout = headers.get(GrpcUtil.TIMEOUT_KEY);
    assertNotNull(timeout);

    long deltaNanos = TimeUnit.MILLISECONDS.toNanos(400);
    assertTimeoutBetween(timeout, deadlineNanos - deltaNanos, deadlineNanos);
  }

  @Test
  public void contextDeadlineShouldNotOverrideSmallerMetadataTimeout() {
    long deadlineNanos = TimeUnit.SECONDS.toNanos(2);
    Context context =
        Context.current()
            .withDeadlineAfter(deadlineNanos, TimeUnit.NANOSECONDS, deadlineCancellationExecutor);
    context.attach();

    CallOptions callOpts = CallOptions.DEFAULT.withDeadlineAfter(1, TimeUnit.SECONDS);
    ClientCallImpl<Void, Void> call =
        new ClientCallImpl<Void, Void>(
            DESCRIPTOR,
            MoreExecutors.directExecutor(),
            callOpts,
            provider,
            deadlineCancellationExecutor);

    Metadata headers = new Metadata();

    call.start(callListener, headers);

    assertTrue(headers.containsKey(GrpcUtil.TIMEOUT_KEY));
    Long timeout = headers.get(GrpcUtil.TIMEOUT_KEY);
    assertNotNull(timeout);

    long callOptsNanos = TimeUnit.SECONDS.toNanos(1);
    long deltaNanos = TimeUnit.MILLISECONDS.toNanos(400);
    assertTimeoutBetween(timeout, callOptsNanos - deltaNanos, callOptsNanos);
  }

  @Test
  public void expiredDeadlineCancelsStream_CallOptions() {
    fakeClock.forwardTime(System.nanoTime(), TimeUnit.NANOSECONDS);
    ClientCallImpl<Void, Void> call =
        new ClientCallImpl<Void, Void>(
            DESCRIPTOR,
            MoreExecutors.directExecutor(),
            CallOptions.DEFAULT.withDeadline(Deadline.after(1000, TimeUnit.MILLISECONDS)),
            provider,
            deadlineCancellationExecutor);

    call.start(callListener, new Metadata());

    fakeClock.forwardMillis(1001);

    verify(stream, times(1)).cancel(statusCaptor.capture());
    assertEquals(Status.Code.DEADLINE_EXCEEDED, statusCaptor.getValue().getCode());
  }

  @Test
  public void expiredDeadlineCancelsStream_Context() {
    fakeClock.forwardTime(System.nanoTime(), TimeUnit.NANOSECONDS);

    Context.current()
        .withDeadlineAfter(1000, TimeUnit.MILLISECONDS, deadlineCancellationExecutor)
        .attach();

    ClientCallImpl<Void, Void> call =
        new ClientCallImpl<Void, Void>(
            DESCRIPTOR,
            MoreExecutors.directExecutor(),
            CallOptions.DEFAULT,
            provider,
            deadlineCancellationExecutor);

    call.start(callListener, new Metadata());

    fakeClock.forwardMillis(TimeUnit.SECONDS.toMillis(1001));

    verify(stream, times(1)).cancel(statusCaptor.capture());
    assertEquals(Status.Code.DEADLINE_EXCEEDED, statusCaptor.getValue().getCode());
  }

  @Test
  public void streamCancelAbortsDeadlineTimer() {
    fakeClock.forwardTime(System.nanoTime(), TimeUnit.NANOSECONDS);

    ClientCallImpl<Void, Void> call =
        new ClientCallImpl<Void, Void>(
            DESCRIPTOR,
            MoreExecutors.directExecutor(),
            CallOptions.DEFAULT.withDeadline(Deadline.after(1000, TimeUnit.MILLISECONDS)),
            provider,
            deadlineCancellationExecutor);
    call.start(callListener, new Metadata());
    call.cancel("canceled", null);

    // Run the deadline timer, which should have been cancelled by the previous call to cancel()
    fakeClock.forwardMillis(1001);

    verify(stream, times(1)).cancel(statusCaptor.capture());

    assertEquals(Status.CANCELLED.getCode(), statusCaptor.getValue().getCode());
  }

  /** Without a context or call options deadline, a timeout should not be set in metadata. */
  @Test
  public void timeoutShouldNotBeSet() {
    ClientCallImpl<Void, Void> call =
        new ClientCallImpl<Void, Void>(
            DESCRIPTOR,
            MoreExecutors.directExecutor(),
            CallOptions.DEFAULT,
            provider,
            deadlineCancellationExecutor);

    Metadata headers = new Metadata();

    call.start(callListener, headers);

    assertFalse(headers.containsKey(GrpcUtil.TIMEOUT_KEY));
  }

  @Test
  public void cancelInOnMessageShouldInvokeStreamCancel() throws Exception {
    final ClientCallImpl<Void, Void> call =
        new ClientCallImpl<Void, Void>(
            DESCRIPTOR,
            MoreExecutors.directExecutor(),
            CallOptions.DEFAULT,
            provider,
            deadlineCancellationExecutor);
    final Exception cause = new Exception();
    ClientCall.Listener<Void> callListener =
        new ClientCall.Listener<Void>() {
          @Override
          public void onMessage(Void message) {
            call.cancel("foo", cause);
          }
        };

    call.start(callListener, new Metadata());
    call.halfClose();
    call.request(1);

    verify(stream).start(listenerArgumentCaptor.capture());
    ClientStreamListener streamListener = listenerArgumentCaptor.getValue();
    streamListener.onReady();
    streamListener.headersRead(new Metadata());
    streamListener.messageRead(new ByteArrayInputStream(new byte[0]));
    verify(stream).cancel(statusCaptor.capture());
    Status status = statusCaptor.getValue();
    assertEquals(Status.CANCELLED.getCode(), status.getCode());
    assertEquals("foo", status.getDescription());
    assertSame(cause, status.getCause());
  }

  private static class TestMarshaller<T> implements Marshaller<T> {
    @Override
    public InputStream stream(T value) {
      return null;
    }

    @Override
    public T parse(InputStream stream) {
      return null;
    }
  }

  private static void assertTimeoutBetween(long timeout, long from, long to) {
    assertTrue("timeout: " + timeout + " ns", timeout <= to);
    assertTrue("timeout: " + timeout + " ns", timeout >= from);
  }
}
Beispiel #15
0
/** Implementation of {@link ClientCall}. */
final class ClientCallImpl<ReqT, RespT> extends ClientCall<ReqT, RespT>
    implements Context.CancellationListener {

  private static final Logger log = Logger.getLogger(ClientCallImpl.class.getName());

  private final MethodDescriptor<ReqT, RespT> method;
  private final Executor callExecutor;
  private final Context parentContext;
  private volatile Context context;
  private final boolean unaryRequest;
  private final CallOptions callOptions;
  private ClientStream stream;
  private volatile boolean contextListenerShouldBeRemoved;
  private boolean cancelCalled;
  private boolean halfCloseCalled;
  private final ClientTransportProvider clientTransportProvider;
  private String userAgent;
  private ScheduledExecutorService deadlineCancellationExecutor;
  private DecompressorRegistry decompressorRegistry = DecompressorRegistry.getDefaultInstance();
  private CompressorRegistry compressorRegistry = CompressorRegistry.getDefaultInstance();

  ClientCallImpl(
      MethodDescriptor<ReqT, RespT> method,
      Executor executor,
      CallOptions callOptions,
      ClientTransportProvider clientTransportProvider,
      ScheduledExecutorService deadlineCancellationExecutor) {
    this.method = method;
    // If we know that the executor is a direct executor, we don't need to wrap it with a
    // SerializingExecutor. This is purely for performance reasons.
    // See https://github.com/grpc/grpc-java/issues/368
    this.callExecutor =
        executor == directExecutor()
            ? new SerializeReentrantCallsDirectExecutor()
            : new SerializingExecutor(executor);
    // Propagate the context from the thread which initiated the call to all callbacks.
    this.parentContext = Context.current();
    this.unaryRequest =
        method.getType() == MethodType.UNARY || method.getType() == MethodType.SERVER_STREAMING;
    this.callOptions = callOptions;
    this.clientTransportProvider = clientTransportProvider;
    this.deadlineCancellationExecutor = deadlineCancellationExecutor;
  }

  @Override
  public void cancelled(Context context) {
    stream.cancel(statusFromCancelled(context));
  }

  /** Provider of {@link ClientTransport}s. */
  interface ClientTransportProvider {
    /** Returns a transport for a new call. */
    ClientTransport get(CallOptions callOptions);
  }

  ClientCallImpl<ReqT, RespT> setUserAgent(String userAgent) {
    this.userAgent = userAgent;
    return this;
  }

  ClientCallImpl<ReqT, RespT> setDecompressorRegistry(DecompressorRegistry decompressorRegistry) {
    this.decompressorRegistry = decompressorRegistry;
    return this;
  }

  ClientCallImpl<ReqT, RespT> setCompressorRegistry(CompressorRegistry compressorRegistry) {
    this.compressorRegistry = compressorRegistry;
    return this;
  }

  @VisibleForTesting
  static void prepareHeaders(
      Metadata headers,
      CallOptions callOptions,
      String userAgent,
      DecompressorRegistry decompressorRegistry,
      Compressor compressor) {
    // Fill out the User-Agent header.
    headers.removeAll(USER_AGENT_KEY);
    if (userAgent != null) {
      headers.put(USER_AGENT_KEY, userAgent);
    }

    headers.removeAll(MESSAGE_ENCODING_KEY);
    if (compressor != Codec.Identity.NONE) {
      headers.put(MESSAGE_ENCODING_KEY, compressor.getMessageEncoding());
    }

    headers.removeAll(MESSAGE_ACCEPT_ENCODING_KEY);
    if (!decompressorRegistry.getAdvertisedMessageEncodings().isEmpty()) {
      String acceptEncoding =
          ACCEPT_ENCODING_JOINER.join(decompressorRegistry.getAdvertisedMessageEncodings());
      headers.put(MESSAGE_ACCEPT_ENCODING_KEY, acceptEncoding);
    }
  }

  @Override
  public void start(final Listener<RespT> observer, Metadata headers) {
    checkState(stream == null, "Already started");
    checkNotNull(observer, "observer");
    checkNotNull(headers, "headers");

    // Create the context
    final Deadline effectiveDeadline = min(callOptions.getDeadline(), parentContext.getDeadline());
    if (effectiveDeadline != parentContext.getDeadline()) {
      context = parentContext.withDeadline(effectiveDeadline, deadlineCancellationExecutor);
    } else {
      context = parentContext.withCancellation();
    }

    if (context.isCancelled()) {
      // Context is already cancelled so no need to create a real stream, just notify the observer
      // of cancellation via callback on the executor
      stream = NoopClientStream.INSTANCE;
      callExecutor.execute(
          new ContextRunnable(context) {
            @Override
            public void runInContext() {
              observer.onClose(statusFromCancelled(context), new Metadata());
            }
          });
      return;
    }
    final String compressorName = callOptions.getCompressor();
    Compressor compressor = null;
    if (compressorName != null) {
      compressor = compressorRegistry.lookupCompressor(compressorName);
      if (compressor == null) {
        stream = NoopClientStream.INSTANCE;
        callExecutor.execute(
            new ContextRunnable(context) {
              @Override
              public void runInContext() {
                observer.onClose(
                    Status.INTERNAL.withDescription(
                        String.format("Unable to find compressor by name %s", compressorName)),
                    new Metadata());
              }
            });
        return;
      }
    } else {
      compressor = Codec.Identity.NONE;
    }

    prepareHeaders(headers, callOptions, userAgent, decompressorRegistry, compressor);

    final boolean deadlineExceeded = effectiveDeadline != null && effectiveDeadline.isExpired();
    if (!deadlineExceeded) {
      updateTimeoutHeaders(
          effectiveDeadline, callOptions.getDeadline(), parentContext.getDeadline(), headers);
      ClientTransport transport = clientTransportProvider.get(callOptions);
      stream = transport.newStream(method, headers);
    } else {
      stream = new FailingClientStream(DEADLINE_EXCEEDED);
    }

    if (callOptions.getAuthority() != null) {
      stream.setAuthority(callOptions.getAuthority());
    }
    stream.setCompressor(compressor);

    stream.start(new ClientStreamListenerImpl(observer));
    if (compressor != Codec.Identity.NONE) {
      stream.setMessageCompression(true);
    }

    // Delay any sources of cancellation after start(), because most of the transports are broken if
    // they receive cancel before start. Issue #1343 has more details

    // Propagate later Context cancellation to the remote side.
    context.addListener(this, directExecutor());
    if (contextListenerShouldBeRemoved) {
      // Race detected! ClientStreamListener.closed may have been called before
      // deadlineCancellationFuture was set, thereby preventing the future from being cancelled.
      // Go ahead and cancel again, just to be sure it was cancelled.
      context.removeListener(this);
    }
  }

  /** Based on the deadline, calculate and set the timeout to the given headers. */
  private static void updateTimeoutHeaders(
      @Nullable Deadline effectiveDeadline,
      @Nullable Deadline callDeadline,
      @Nullable Deadline outerCallDeadline,
      Metadata headers) {
    headers.removeAll(TIMEOUT_KEY);

    if (effectiveDeadline == null) {
      return;
    }

    long effectiveTimeout = max(0, effectiveDeadline.timeRemaining(TimeUnit.NANOSECONDS));
    headers.put(TIMEOUT_KEY, effectiveTimeout);

    logIfContextNarrowedTimeout(
        effectiveTimeout, effectiveDeadline, outerCallDeadline, callDeadline);
  }

  private static void logIfContextNarrowedTimeout(
      long effectiveTimeout,
      Deadline effectiveDeadline,
      @Nullable Deadline outerCallDeadline,
      @Nullable Deadline callDeadline) {
    if (!log.isLoggable(Level.INFO) || outerCallDeadline != effectiveDeadline) {
      return;
    }

    StringBuilder builder = new StringBuilder();
    builder.append(
        String.format("Call timeout set to '%d' ns, due to context deadline.", effectiveTimeout));
    if (callDeadline == null) {
      builder.append(" Explicit call timeout was not set.");
    } else {
      long callTimeout = callDeadline.timeRemaining(TimeUnit.NANOSECONDS);
      builder.append(String.format(" Explicit call timeout was '%d' ns.", callTimeout));
    }

    log.info(builder.toString());
  }

  private static Deadline min(@Nullable Deadline deadline0, @Nullable Deadline deadline1) {
    if (deadline0 == null) {
      return deadline1;
    }
    if (deadline1 == null) {
      return deadline0;
    }
    return deadline0.minimum(deadline1);
  }

  @Override
  public void request(int numMessages) {
    Preconditions.checkState(stream != null, "Not started");
    checkArgument(numMessages >= 0, "Number requested must be non-negative");
    stream.request(numMessages);
  }

  @Override
  public void cancel(@Nullable String message, @Nullable Throwable cause) {
    if (cancelCalled) {
      return;
    }
    cancelCalled = true;
    try {
      // Cancel is called in exception handling cases, so it may be the case that the
      // stream was never successfully created.
      if (stream != null) {
        Status status = Status.CANCELLED;
        if (message != null) {
          status = status.withDescription(message);
        }
        if (cause != null) {
          status = status.withCause(cause);
        }
        if (message == null && cause == null) {
          // TODO(zhangkun83): log a warning with this exception once cancel() has been deleted from
          // ClientCall.
          status =
              status.withCause(
                  new CancellationException("Client called cancel() without any detail"));
        }
        stream.cancel(status);
      }
    } finally {
      if (context != null) {
        context.removeListener(ClientCallImpl.this);
      }
    }
  }

  @Override
  public void halfClose() {
    Preconditions.checkState(stream != null, "Not started");
    Preconditions.checkState(!cancelCalled, "call was cancelled");
    Preconditions.checkState(!halfCloseCalled, "call already half-closed");
    halfCloseCalled = true;
    stream.halfClose();
  }

  @Override
  public void sendMessage(ReqT message) {
    Preconditions.checkState(stream != null, "Not started");
    Preconditions.checkState(!cancelCalled, "call was cancelled");
    Preconditions.checkState(!halfCloseCalled, "call was half-closed");
    try {
      // TODO(notcarl): Find out if messageIs needs to be closed.
      InputStream messageIs = method.streamRequest(message);
      stream.writeMessage(messageIs);
    } catch (Throwable e) {
      stream.cancel(Status.CANCELLED.withCause(e).withDescription("Failed to stream message"));
      return;
    }
    // For unary requests, we don't flush since we know that halfClose should be coming soon. This
    // allows us to piggy-back the END_STREAM=true on the last message frame without opening the
    // possibility of broken applications forgetting to call halfClose without noticing.
    if (!unaryRequest) {
      stream.flush();
    }
  }

  @Override
  public void setMessageCompression(boolean enabled) {
    checkState(stream != null, "Not started");
    stream.setMessageCompression(enabled);
  }

  @Override
  public boolean isReady() {
    return stream.isReady();
  }

  private class ClientStreamListenerImpl implements ClientStreamListener {
    private final Listener<RespT> observer;
    private boolean closed;

    public ClientStreamListenerImpl(Listener<RespT> observer) {
      this.observer = Preconditions.checkNotNull(observer, "observer");
    }

    @Override
    public void headersRead(final Metadata headers) {
      Decompressor decompressor = Codec.Identity.NONE;
      if (headers.containsKey(MESSAGE_ENCODING_KEY)) {
        String encoding = headers.get(MESSAGE_ENCODING_KEY);
        decompressor = decompressorRegistry.lookupDecompressor(encoding);
        if (decompressor == null) {
          stream.cancel(
              Status.INTERNAL.withDescription(
                  String.format("Can't find decompressor for %s", encoding)));
          return;
        }
      }
      stream.setDecompressor(decompressor);

      callExecutor.execute(
          new ContextRunnable(context) {
            @Override
            public final void runInContext() {
              try {
                if (closed) {
                  return;
                }

                observer.onHeaders(headers);
              } catch (Throwable t) {
                stream.cancel(
                    Status.CANCELLED.withCause(t).withDescription("Failed to read headers"));
                return;
              }
            }
          });
    }

    @Override
    public void messageRead(final InputStream message) {
      callExecutor.execute(
          new ContextRunnable(context) {
            @Override
            public final void runInContext() {
              try {
                if (closed) {
                  return;
                }

                try {
                  observer.onMessage(method.parseResponse(message));
                } finally {
                  message.close();
                }
              } catch (Throwable t) {
                stream.cancel(
                    Status.CANCELLED.withCause(t).withDescription("Failed to read message."));
                return;
              }
            }
          });
    }

    @Override
    public void closed(Status status, Metadata trailers) {
      Deadline deadline = context.getDeadline();
      if (status.getCode() == Status.Code.CANCELLED && deadline != null) {
        // When the server's deadline expires, it can only reset the stream with CANCEL and no
        // description. Since our timer may be delayed in firing, we double-check the deadline and
        // turn the failure into the likely more helpful DEADLINE_EXCEEDED status.
        if (deadline.isExpired()) {
          status = DEADLINE_EXCEEDED;
          // Replace trailers to prevent mixing sources of status and trailers.
          trailers = new Metadata();
        }
      }
      final Status savedStatus = status;
      final Metadata savedTrailers = trailers;
      callExecutor.execute(
          new ContextRunnable(context) {
            @Override
            public final void runInContext() {
              try {
                closed = true;
                contextListenerShouldBeRemoved = true;
                observer.onClose(savedStatus, savedTrailers);
              } finally {
                context.removeListener(ClientCallImpl.this);
              }
            }
          });
    }

    @Override
    public void onReady() {
      callExecutor.execute(
          new ContextRunnable(context) {
            @Override
            public final void runInContext() {
              observer.onReady();
            }
          });
    }
  }
}