Example #1
0
public class NMRConduit extends AbstractConduit {

  private static final Logger LOG = LogUtils.getL7dLogger(NMRConduit.class);

  private NMR nmr;
  private Bus bus;

  public NMRConduit(EndpointReferenceType target, NMR nmr) {
    this(null, target, nmr);
  }

  public NMRConduit(Bus bus, EndpointReferenceType target, NMR nmr) {
    super(target);
    this.nmr = nmr;
    this.bus = bus;
  }

  public Bus getBus() {
    return bus;
  }

  public NMR getNmr() {
    return nmr;
  }

  protected Logger getLogger() {
    return LOG;
  }

  public void prepare(Message message) throws IOException {
    getLogger().log(Level.FINE, "JBIConduit send message");

    message.setContent(OutputStream.class, new NMRConduitOutputStream(message, nmr, target, this));
  }
}
public class StaxSchemaValidationOutInterceptor extends AbstractPhaseInterceptor<Message> {
  private static final Logger LOG = LogUtils.getL7dLogger(StaxSchemaValidationOutInterceptor.class);

  public StaxSchemaValidationOutInterceptor() {
    super(Phase.PRE_MARSHAL);
  }

  public void handleMessage(Message message) throws Fault {
    XMLStreamWriter writer = message.getContent(XMLStreamWriter.class);
    try {
      setSchemaInMessage(message, writer);
    } catch (XMLStreamException e) {
      throw new Fault(new org.apache.cxf.common.i18n.Message("SCHEMA_ERROR", LOG), e);
    }
  }

  private void setSchemaInMessage(Message message, XMLStreamWriter writer)
      throws XMLStreamException {
    if (ServiceUtils.isSchemaValidationEnabled(SchemaValidationType.OUT, message)) {
      try {
        WoodstoxValidationImpl mgr = new WoodstoxValidationImpl();
        if (mgr.canValidate()) {
          mgr.setupValidation(
              writer,
              message.getExchange().getEndpoint(),
              message.getExchange().getService().getServiceInfos().get(0));
        }
      } catch (Throwable t) {
        // likely no MSV or similar
        LOG.log(Level.FINE, "Problem initializing MSV validation", t);
      }
    }
  }
}
public abstract class AbstractJAXWSGenerator extends AbstractGenerator {
  protected static final Logger LOG = LogUtils.getL7dLogger(AbstractJAXWSGenerator.class);
  protected static final String TEMPLATE_BASE =
      "org/apache/cxf/tools/wsdlto/frontend/jaxws/template";

  public abstract boolean passthrough();

  public abstract void generate(ToolContext penv) throws ToolException;

  public void register(final ClassCollector collector, String packageName, String fileName) {
    // do nothing
  }
}
public class SamlHeaderOutInterceptor extends AbstractSamlOutInterceptor {
  private static final Logger LOG = LogUtils.getL7dLogger(SamlHeaderOutInterceptor.class);

  public SamlHeaderOutInterceptor() {
    this(Phase.WRITE);
  }

  public SamlHeaderOutInterceptor(String phase) {
    super(phase);
  }

  public void handleMessage(Message message) throws Fault {
    try {
      SamlAssertionWrapper assertionWrapper = createAssertion(message);

      Document doc = DOMUtils.newDocument();
      Element assertionElement = assertionWrapper.toDOM(doc);
      String encodedToken = encodeToken(DOM2Writer.nodeToString(assertionElement));

      Map<String, List<String>> headers = getHeaders(message);

      StringBuilder builder = new StringBuilder();
      builder.append("SAML").append(" ").append(encodedToken);
      headers.put(
          "Authorization",
          CastUtils.cast(Collections.singletonList(builder.toString()), String.class));

    } catch (Exception ex) {
      StringWriter sw = new StringWriter();
      ex.printStackTrace(new PrintWriter(sw));
      LOG.warning(sw.toString());
      throw new Fault(new RuntimeException(ex.getMessage() + ", stacktrace: " + sw.toString()));
    }
  }

  private Map<String, List<String>> getHeaders(Message message) {
    Map<String, List<String>> headers =
        CastUtils.cast((Map<?, ?>) message.get(Message.PROTOCOL_HEADERS));
    if (headers == null) {
      headers = new HashMap<>();
      message.put(Message.PROTOCOL_HEADERS, headers);
    }
    return headers;
  }
}
public abstract class AbstractWrapKeyEncryptionAlgorithm implements KeyEncryptionProvider {
  protected static final Logger LOG =
      LogUtils.getL7dLogger(AbstractWrapKeyEncryptionAlgorithm.class);
  private Key keyEncryptionKey;
  private boolean wrap;
  private KeyAlgorithm algorithm;
  private Set<String> supportedAlgorithms;

  protected AbstractWrapKeyEncryptionAlgorithm(Key key, Set<String> supportedAlgorithms) {
    this(key, null, true, supportedAlgorithms);
  }

  protected AbstractWrapKeyEncryptionAlgorithm(
      Key key, boolean wrap, Set<String> supportedAlgorithms) {
    this(key, null, wrap, supportedAlgorithms);
  }

  protected AbstractWrapKeyEncryptionAlgorithm(
      Key key, KeyAlgorithm jweAlgo, Set<String> supportedAlgorithms) {
    this(key, jweAlgo, true, supportedAlgorithms);
  }

  protected AbstractWrapKeyEncryptionAlgorithm(
      Key key, KeyAlgorithm jweAlgo, boolean wrap, Set<String> supportedAlgorithms) {
    this.keyEncryptionKey = key;
    this.algorithm = jweAlgo;
    this.wrap = wrap;
    this.supportedAlgorithms = supportedAlgorithms;
  }

  @Override
  public KeyAlgorithm getAlgorithm() {
    return algorithm;
  }

  @Override
  public byte[] getEncryptedContentEncryptionKey(JweHeaders headers, byte[] cek) {
    checkAlgorithms(headers);
    KeyProperties secretKeyProperties = new KeyProperties(getKeyEncryptionAlgoJava(headers));
    AlgorithmParameterSpec spec = getAlgorithmParameterSpec(headers);
    if (spec != null) {
      secretKeyProperties.setAlgoSpec(spec);
    }
    if (!wrap) {
      return CryptoUtils.encryptBytes(cek, keyEncryptionKey, secretKeyProperties);
    } else {
      return CryptoUtils.wrapSecretKey(
          cek, getContentEncryptionAlgoJava(headers), keyEncryptionKey, secretKeyProperties);
    }
  }

  protected String getKeyEncryptionAlgoJava(JweHeaders headers) {
    return AlgorithmUtils.toJavaName(headers.getKeyEncryptionAlgorithm().getJwaName());
  }

  protected String getContentEncryptionAlgoJava(JweHeaders headers) {
    return AlgorithmUtils.toJavaName(headers.getContentEncryptionAlgorithm().getJwaName());
  }

  protected AlgorithmParameterSpec getAlgorithmParameterSpec(JweHeaders headers) {
    return null;
  }

  protected String checkAlgorithm(String algo) {
    if (algo != null && !supportedAlgorithms.contains(algo)) {
      LOG.warning("Invalid key encryption algorithm: " + algo);
      throw new JweException(JweException.Error.INVALID_KEY_ALGORITHM);
    }
    return algo;
  }

  protected void checkAlgorithms(JweHeaders headers) {
    KeyAlgorithm providedAlgo = headers.getKeyEncryptionAlgorithm();
    if (providedAlgo != null && !providedAlgo.equals(algorithm)) {
      LOG.warning("Invalid key encryption algorithm: " + providedAlgo);
      throw new JweException(JweException.Error.INVALID_KEY_ALGORITHM);
    }
    if (providedAlgo != null) {
      checkAlgorithm(providedAlgo.getJwaName());
    } else {
      checkAlgorithm(algorithm.getJwaName());
      headers.setKeyEncryptionAlgorithm(algorithm);
    }
  }
}
/**
 * Test client to do websocket calls.
 *
 * @see JAXRSClientServerWebSocketTest
 *     <p>we may put this in test-tools so that other systests can use this code. for now keep it
 *     here to experiment jaxrs websocket scenarios.
 */
class WebSocketTestClient {
  private static final Logger LOG = LogUtils.getL7dLogger(WebSocketTestClient.class);

  private List<Object> received;
  private List<Object> fragments;
  private CountDownLatch latch;
  private AsyncHttpClient client;
  private WebSocket websocket;
  private String url;

  public WebSocketTestClient(String url) {
    this.received = new ArrayList<Object>();
    this.fragments = new ArrayList<Object>();
    this.latch = new CountDownLatch(1);
    this.client = new AsyncHttpClient();
    this.url = url;
  }

  public void connect() throws InterruptedException, ExecutionException, IOException {
    websocket =
        client
            .prepareGet(url)
            .execute(
                new WebSocketUpgradeHandler.Builder()
                    .addWebSocketListener(new WsSocketListener())
                    .build())
            .get();
  }

  public void sendTextMessage(String message) {
    websocket.sendTextMessage(message);
  }

  public void sendMessage(byte[] message) {
    websocket.sendMessage(message);
  }

  public boolean await(int secs) throws InterruptedException {
    return latch.await(secs, TimeUnit.SECONDS);
  }

  public void reset(int count) {
    latch = new CountDownLatch(count);
    received.clear();
  }

  public List<Object> getReceived() {
    return received;
  }

  public List<Response> getReceivedResponses() {
    Object[] objs = received.toArray();
    List<Response> responses = new ArrayList<Response>(objs.length);
    for (Object o : objs) {
      responses.add(new Response(o));
    }
    return responses;
  }

  public void close() {
    websocket.close();
    client.close();
  }

  class WsSocketListener implements WebSocketTextListener, WebSocketByteListener {

    public void onOpen(WebSocket ws) {
      LOG.info("[ws] opened");
    }

    public void onClose(WebSocket ws) {
      LOG.info("[ws] closed");
    }

    public void onError(Throwable t) {
      LOG.info("[ws] error: " + t);
    }

    public void onMessage(byte[] message) {
      received.add(message);
      LOG.info("[ws] received bytes --> " + makeString(message));
      latch.countDown();
    }

    public void onFragment(byte[] fragment, boolean last) {
      processFragments(fragment, last);
    }

    public void onMessage(String message) {
      received.add(message);
      LOG.info("[ws] received --> " + message);
      latch.countDown();
    }

    public void onFragment(String fragment, boolean last) {
      processFragments(fragment, last);
    }

    private void processFragments(Object f, boolean last) {
      synchronized (fragments) {
        fragments.add(f);
        if (last) {
          if (f instanceof String) {
            // string
            StringBuilder sb = new StringBuilder();
            for (Iterator<Object> it = fragments.iterator(); it.hasNext(); ) {
              Object o = it.next();
              if (o instanceof String) {
                sb.append((String) o);
                it.remove();
              }
            }
            received.add(sb.toString());
          } else {
            // byte[]
            ByteArrayOutputStream bao = new ByteArrayOutputStream();
            for (Iterator<Object> it = fragments.iterator(); it.hasNext(); ) {
              Object o = it.next();
              if (o instanceof byte[]) {
                bao.write((byte[]) o, 0, ((byte[]) o).length);
                it.remove();
              }
            }
            received.add(bao.toByteArray());
          }
        }
      }
    }
  }

  private static String makeString(byte[] data) {
    return data == null ? null : makeString(data, 0, data.length).toString();
  }

  private static StringBuilder makeString(byte[] data, int offset, int length) {
    if (data.length > 256) {
      return makeString(data, offset, 256).append("...");
    }
    StringBuilder xbuf = new StringBuilder().append("\nHEX: ");
    StringBuilder cbuf = new StringBuilder().append("\nASC: ");
    for (byte b : data) {
      writeHex(xbuf, 0xff & b);
      writePrintable(cbuf, 0xff & b);
    }
    return xbuf.append(cbuf);
  }

  private static void writeHex(StringBuilder buf, int b) {
    buf.append(Integer.toHexString(0x100 | (0xff & b)).substring(1)).append(' ');
  }

  private static void writePrintable(StringBuilder buf, int b) {
    if (b == 0x0d) {
      buf.append("\\r");
    } else if (b == 0x0a) {
      buf.append("\\n");
    } else if (b == 0x09) {
      buf.append("\\t");
    } else if ((0x80 & b) != 0) {
      buf.append('.').append(' ');
    } else {
      buf.append((char) b).append(' ');
    }
    buf.append(' ');
  }

  // TODO this is a temporary way to verify the response; we should come up with something better.
  public static class Response {
    private Object data;
    private int pos;
    private int statusCode;
    private String contentType;
    private Object entity;

    public Response(Object data) {
      this.data = data;
      String line;
      boolean first = true;
      while ((line = readLine()) != null) {
        if (first && isStatusCode(line)) {
          statusCode = Integer.parseInt(line);
          continue;
        } else {
          first = false;
        }

        int del = line.indexOf(':');
        String h = line.substring(0, del).trim();
        String v = line.substring(del + 1).trim();
        if ("Content-Type".equalsIgnoreCase(h)) {
          contentType = v;
        }
      }
      if (data instanceof String) {
        entity = ((String) data).substring(pos);
      } else if (data instanceof byte[]) {
        entity = new byte[((byte[]) data).length - pos];
        System.arraycopy((byte[]) data, pos, (byte[]) entity, 0, ((byte[]) entity).length);
      }
    }

    private static boolean isStatusCode(String line) {
      char c = line.charAt(0);
      return '0' <= c && c <= '9';
    }

    public int getStatusCode() {
      return statusCode;
    }

    public String getContentType() {
      return contentType;
    }

    @SuppressWarnings("unused")
    public Object getEntity() {
      return entity;
    }

    public String getTextEntity() {
      return gettext(entity);
    }

    public String toString() {
      StringBuilder sb = new StringBuilder();
      sb.append("Status: ").append(statusCode).append("\r\n");
      sb.append("Type: ").append(contentType).append("\r\n");
      sb.append("Entity: ").append(gettext(entity)).append("\r\n");
      return sb.toString();
    }

    private String readLine() {
      StringBuilder sb = new StringBuilder();
      while (pos < length(data)) {
        int c = getchar(data, pos++);
        if (c == '\n') {
          break;
        } else if (c == '\r') {
          continue;
        } else {
          sb.append((char) c);
        }
      }
      if (sb.length() == 0) {
        return null;
      }
      return sb.toString();
    }

    private int length(Object o) {
      return o instanceof char[]
          ? ((String) o).length()
          : (o instanceof byte[] ? ((byte[]) o).length : 0);
    }

    private int getchar(Object o, int p) {
      return 0xff
          & (o instanceof String
              ? ((String) o).charAt(p)
              : (o instanceof byte[] ? ((byte[]) o)[p] : -1));
    }

    private String gettext(Object o) {
      return o instanceof String
          ? (String) o
          : (o instanceof byte[] ? new String((byte[]) o) : null);
    }
  }
}
Example #7
0
public class JaxWsClientProxy extends org.apache.cxf.frontend.ClientProxy
    implements InvocationHandler, BindingProvider {

  public static final String THREAD_LOCAL_REQUEST_CONTEXT = "thread.local.request.context";

  private static final Logger LOG = LogUtils.getL7dLogger(JaxWsClientProxy.class);

  private final Binding binding;
  private final EndpointReferenceBuilder builder;

  public JaxWsClientProxy(Client c, Binding b) {
    super(c);
    this.binding = b;
    setupEndpointAddressContext(getClient().getEndpoint());
    this.builder = new EndpointReferenceBuilder((JaxWsEndpointImpl) getClient().getEndpoint());
  }

  private void setupEndpointAddressContext(Endpoint endpoint) {
    // NOTE for jms transport the address would be null
    if (null != endpoint && null != endpoint.getEndpointInfo().getAddress()) {
      getRequestContext()
          .put(BindingProvider.ENDPOINT_ADDRESS_PROPERTY, endpoint.getEndpointInfo().getAddress());
    }
  }

  public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {

    Endpoint endpoint = getClient().getEndpoint();
    String address = endpoint.getEndpointInfo().getAddress();
    MethodDispatcher dispatcher =
        (MethodDispatcher) endpoint.getService().get(MethodDispatcher.class.getName());
    Object[] params = args;
    if (null == params) {
      params = new Object[0];
    }

    BindingOperationInfo oi = dispatcher.getBindingOperation(method, endpoint);
    if (oi == null) {
      // check for method on BindingProvider and Object
      if (method.getDeclaringClass().equals(BindingProvider.class)
          || method.getDeclaringClass().equals(Object.class)) {
        try {
          return method.invoke(this, params);
        } catch (InvocationTargetException e) {
          throw e.fillInStackTrace().getCause();
        }
      }

      Message msg = new Message("NO_BINDING_OPERATION_INFO", LOG, method.getName());
      throw new WebServiceException(msg.toString());
    }

    client.getRequestContext().put(Method.class.getName(), method);
    boolean isAsync = isAsync(method);

    Object result = null;
    try {
      if (isAsync) {
        result = invokeAsync(method, oi, params);
      } else {
        result = invokeSync(method, oi, params);
      }
    } catch (WebServiceException wex) {
      throw wex.fillInStackTrace();
    } catch (Exception ex) {
      for (Class<?> excls : method.getExceptionTypes()) {
        if (excls.isInstance(ex)) {
          throw ex.fillInStackTrace();
        }
      }

      if (getBinding() instanceof HTTPBinding) {
        HTTPException exception = new HTTPException(HttpURLConnection.HTTP_INTERNAL_ERROR);
        exception.initCause(ex);
        throw exception;
      } else if (getBinding() instanceof SOAPBinding) {
        SOAPFault soapFault = createSoapFault((SOAPBinding) getBinding(), ex);
        if (soapFault == null) {
          throw new WebServiceException(ex);
        }
        SOAPFaultException exception = new SOAPFaultException(soapFault);
        if (ex instanceof Fault && ex.getCause() != null) {
          exception.initCause(ex.getCause());
        } else {
          exception.initCause(ex);
        }
        throw exception;
      } else {
        throw new WebServiceException(ex);
      }
    } finally {
      if (addressChanged(address)) {
        setupEndpointAddressContext(getClient().getEndpoint());
      }
    }

    Map<String, Object> respContext = client.getResponseContext();
    Map<String, Scope> scopes =
        CastUtils.cast((Map<?, ?>) respContext.get(WrappedMessageContext.SCOPES));
    if (scopes != null) {
      for (Map.Entry<String, Scope> scope : scopes.entrySet()) {
        if (scope.getValue() == Scope.HANDLER) {
          respContext.remove(scope.getKey());
        }
      }
    }
    return result;
  }

  boolean isAsync(Method m) {
    return m.getName().endsWith("Async")
        && (Future.class.equals(m.getReturnType()) || Response.class.equals(m.getReturnType()));
  }

  static SOAPFault createSoapFault(SOAPBinding binding, Exception ex) throws SOAPException {
    SOAPFault soapFault;
    try {
      soapFault = binding.getSOAPFactory().createFault();
    } catch (Throwable t) {
      // probably an old version of saaj or something that is not allowing createFault
      // method to work.  Try the saaj 1.2 method of doing this.
      try {
        soapFault = binding.getMessageFactory().createMessage().getSOAPBody().addFault();
      } catch (Throwable t2) {
        // still didn't work, we'll just throw what we have
        return null;
      }
    }

    if (ex instanceof SoapFault) {
      if (!soapFault.getNamespaceURI().equals(((SoapFault) ex).getFaultCode().getNamespaceURI())
          && SOAPConstants.URI_NS_SOAP_1_1_ENVELOPE.equals(
              ((SoapFault) ex).getFaultCode().getNamespaceURI())) {
        // change to 1.1
        try {
          soapFault = SOAPFactory.newInstance().createFault();
        } catch (Throwable t) {
          // ignore
        }
      }
      soapFault.setFaultString(((SoapFault) ex).getReason());
      soapFault.setFaultCode(((SoapFault) ex).getFaultCode());
      soapFault.setFaultActor(((SoapFault) ex).getRole());

      Node nd = soapFault.getOwnerDocument().importNode(((SoapFault) ex).getOrCreateDetail(), true);
      nd = nd.getFirstChild();
      soapFault.addDetail();
      while (nd != null) {
        Node next = nd.getNextSibling();
        soapFault.getDetail().appendChild(nd);
        nd = next;
      }

    } else {
      String msg = ex.getMessage();
      if (msg != null) {
        soapFault.setFaultString(msg);
      }
    }
    return soapFault;
  }

  private boolean addressChanged(String address) {
    return !(address == null
        || getClient().getEndpoint().getEndpointInfo() == null
        || address.equals(getClient().getEndpoint().getEndpointInfo().getAddress()));
  }

  @SuppressWarnings("unchecked")
  private Object invokeAsync(Method method, BindingOperationInfo oi, Object[] params)
      throws Exception {

    client.setExecutor(getClient().getEndpoint().getExecutor());

    AsyncHandler<Object> handler;
    if (params.length > 0 && params[params.length - 1] instanceof AsyncHandler) {
      handler = (AsyncHandler) params[params.length - 1];
    } else {
      handler = null;
    }
    ClientCallback callback = new JaxwsClientCallback(handler);

    Response<Object> ret = new JaxwsResponseCallback(callback);
    client.invoke(callback, oi, params);
    return ret;
  }

  public Map<String, Object> getRequestContext() {
    return new WrappedMessageContext(this.getClient().getRequestContext(), null, Scope.APPLICATION);
  }

  public Map<String, Object> getResponseContext() {
    return new WrappedMessageContext(
        this.getClient().getResponseContext(), null, Scope.APPLICATION);
  }

  public Binding getBinding() {
    return binding;
  }

  public EndpointReference getEndpointReference() {
    return builder.getEndpointReference();
  }

  public <T extends EndpointReference> T getEndpointReference(Class<T> clazz) {
    return builder.getEndpointReference(clazz);
  }
}
 public static Logger getLogger() {
   return LogUtils.getL7dLogger(CXFNonSpringServlet.class);
 }
/**
 * Interceptor which authenticates a current principal and populates Subject
 *
 * @author Sergey Beryozkin
 * @author [email protected]
 */
public class SubjectCreatingInterceptor extends WSS4JInInterceptor {
  protected final SubjectCreator helper = new SubjectCreator();

  private static final Logger LOG = LogUtils.getL7dLogger(SubjectCreatingInterceptor.class);

  private final ThreadLocal<SecurityDomainContext> sdc = new ThreadLocal<SecurityDomainContext>();

  private boolean supportDigestPasswords;

  public SubjectCreatingInterceptor() {
    this(new HashMap<String, Object>());
  }

  public SubjectCreatingInterceptor(Map<String, Object> properties) {
    super(properties);
    getAfter().add(PolicyBasedWSS4JInInterceptor.class.getName());
  }

  public void setSupportDigestPasswords(boolean support) {
    supportDigestPasswords = support;
  }

  public boolean getSupportDigestPasswords() {
    return supportDigestPasswords;
  }

  @Override
  public void handleMessage(SoapMessage msg) throws Fault {
    Endpoint ep = msg.getExchange().get(Endpoint.class);
    sdc.set(ep.getSecurityDomainContext());
    try {
      SecurityToken token = msg.get(SecurityToken.class);
      SecurityContext context = msg.get(SecurityContext.class);
      if (token == null || context == null || context.getUserPrincipal() == null) {
        super.handleMessage(msg);
        return;
      }
      UsernameToken ut = (UsernameToken) token;

      Subject subject =
          createSubject(
              ut.getName(), ut.getPassword(), ut.isHashed(), ut.getNonce(), ut.getCreatedTime());

      SecurityContext sc = doCreateSecurityContext(context.getUserPrincipal(), subject);
      msg.put(SecurityContext.class, sc);
    } finally {
      if (sdc != null) {
        sdc.remove();
      }
    }
  }

  @Override
  protected SecurityContext createSecurityContext(final Principal p) {
    Message msg = PhaseInterceptorChain.getCurrentMessage();
    if (msg == null) {
      throw new IllegalStateException("Current message is not available");
    }
    return doCreateSecurityContext(p, msg.get(Subject.class));
  }

  /**
   * Creates default SecurityContext which implements isUserInRole using the following approach :
   * skip the first Subject principal, and then check optional Groups the principal is a member of.
   * Subclasses can override this method and implement a custom strategy instead
   *
   * @param p principal
   * @param subject subject
   * @return security context
   */
  protected SecurityContext doCreateSecurityContext(final Principal p, final Subject subject) {
    return new DefaultSecurityContext(p, subject);
  }

  protected void setSubject(
      String name, String password, boolean isDigest, String nonce, String created)
      throws WSSecurityException {
    Message msg = PhaseInterceptorChain.getCurrentMessage();
    if (msg == null) {
      throw new IllegalStateException("Current message is not available");
    }
    Subject subject = null;
    try {
      subject = createSubject(name, password, isDigest, nonce, created);
    } catch (Exception ex) {
      String errorMessage = "Failed Authentication : Subject has not been created";
      LOG.severe(errorMessage);
      throw new WSSecurityException(WSSecurityException.ErrorCode.FAILED_AUTHENTICATION);
    }
    if (subject == null
        || subject.getPrincipals().size() == 0
        || !checkUserPrincipal(subject.getPrincipals(), name)) {
      String errorMessage = "Failed Authentication : Invalid Subject";
      LOG.severe(errorMessage);
      throw new WSSecurityException(WSSecurityException.ErrorCode.FAILED_AUTHENTICATION);
    }
    msg.put(Subject.class, subject);
  }

  private boolean checkUserPrincipal(Set<Principal> principals, String name) {
    for (Principal p : principals) {
      if (!(p instanceof Group)) {
        return p.getName().equals(name);
      }
    }
    return false;
  }

  @Override
  protected WSSecurityEngine getSecurityEngine(boolean utNoCallbacks) {
    Map<QName, Object> profiles = new HashMap<QName, Object>(1);

    Validator validator = new CustomValidator();
    profiles.put(WSSecurityEngine.USERNAME_TOKEN, validator);
    return createSecurityEngine(profiles);
  }

  protected class CustomValidator extends UsernameTokenValidator {

    @Override
    protected void verifyCustomPassword(
        org.apache.wss4j.dom.message.token.UsernameToken usernameToken, RequestData data)
        throws WSSecurityException {
      SubjectCreatingInterceptor.this.setSubject(
          usernameToken.getName(), usernameToken.getPassword(), false, null, null);
    }

    @Override
    protected void verifyPlaintextPassword(
        org.apache.wss4j.dom.message.token.UsernameToken usernameToken, RequestData data)
        throws WSSecurityException {
      SubjectCreatingInterceptor.this.setSubject(
          usernameToken.getName(), usernameToken.getPassword(), false, null, null);
    }

    @Override
    protected void verifyDigestPassword(
        org.apache.wss4j.dom.message.token.UsernameToken usernameToken, RequestData data)
        throws WSSecurityException {
      if (!supportDigestPasswords) {
        throw new WSSecurityException(WSSecurityException.ErrorCode.FAILED_AUTHENTICATION);
      }
      String user = usernameToken.getName();
      String password = usernameToken.getPassword();
      boolean isHashed = usernameToken.isHashed();
      String nonce = usernameToken.getNonce();
      String createdTime = usernameToken.getCreated();
      SubjectCreatingInterceptor.this.setSubject(user, password, isHashed, nonce, createdTime);
    }

    @Override
    protected void verifyUnknownPassword(
        org.apache.wss4j.dom.message.token.UsernameToken usernameToken, RequestData data)
        throws WSSecurityException {
      SubjectCreatingInterceptor.this.setSubject(usernameToken.getName(), null, false, null, null);
    }
  }

  public Subject createSubject(
      String name, String password, boolean isDigest, String nonce, String created) {
    return helper.createSubject(sdc.get(), name, password, isDigest, nonce, created);
  }

  public void setPropagateContext(boolean propagateContext) {
    this.helper.setPropagateContext(propagateContext);
  }

  public void setTimestampThreshold(int timestampThreshold) {
    this.helper.setTimestampThreshold(timestampThreshold);
  }

  public void setNonceStore(NonceStore nonceStore) {
    this.helper.setNonceStore(nonceStore);
  }

  public void setDecodeNonce(boolean decodeNonce) {
    this.helper.setDecodeNonce(decodeNonce);
  }
}
/** Some abstract functionality for creating a SAML token */
public abstract class AbstractSAMLTokenProvider {

  private static final Logger LOG = LogUtils.getL7dLogger(AbstractSAMLTokenProvider.class);

  protected void signToken(
      SamlAssertionWrapper assertion,
      RealmProperties samlRealm,
      STSPropertiesMBean stsProperties,
      KeyRequirements keyRequirements)
      throws Exception {
    // Initialise signature objects with defaults of STSPropertiesMBean
    Crypto signatureCrypto = stsProperties.getSignatureCrypto();
    CallbackHandler callbackHandler = stsProperties.getCallbackHandler();
    SignatureProperties signatureProperties = stsProperties.getSignatureProperties();
    String alias = stsProperties.getSignatureUsername();

    if (samlRealm != null) {
      // If SignatureCrypto configured in realm then
      // callbackhandler and alias of STSPropertiesMBean is ignored
      if (samlRealm.getSignatureCrypto() != null) {
        LOG.fine("SAMLRealm signature keystore used");
        signatureCrypto = samlRealm.getSignatureCrypto();
        callbackHandler = samlRealm.getCallbackHandler();
        alias = samlRealm.getSignatureAlias();
      }
      // SignatureProperties can be defined independently of SignatureCrypto
      if (samlRealm.getSignatureProperties() != null) {
        signatureProperties = samlRealm.getSignatureProperties();
      }
    }

    // Get the signature algorithm to use
    String signatureAlgorithm = keyRequirements.getSignatureAlgorithm();
    if (signatureAlgorithm == null) {
      // If none then default to what is configured
      signatureAlgorithm = signatureProperties.getSignatureAlgorithm();
    } else {
      List<String> supportedAlgorithms = signatureProperties.getAcceptedSignatureAlgorithms();
      if (!supportedAlgorithms.contains(signatureAlgorithm)) {
        signatureAlgorithm = signatureProperties.getSignatureAlgorithm();
        if (LOG.isLoggable(Level.FINE)) {
          LOG.fine("SignatureAlgorithm not supported, defaulting to: " + signatureAlgorithm);
        }
      }
    }

    // Get the c14n algorithm to use
    String c14nAlgorithm = keyRequirements.getC14nAlgorithm();
    if (c14nAlgorithm == null) {
      // If none then default to what is configured
      c14nAlgorithm = signatureProperties.getC14nAlgorithm();
    } else {
      List<String> supportedAlgorithms = signatureProperties.getAcceptedC14nAlgorithms();
      if (!supportedAlgorithms.contains(c14nAlgorithm)) {
        c14nAlgorithm = signatureProperties.getC14nAlgorithm();
        if (LOG.isLoggable(Level.FINE)) {
          LOG.fine("C14nAlgorithm not supported, defaulting to: " + c14nAlgorithm);
        }
      }
    }

    // If alias not defined, get the default of the SignatureCrypto
    if ((alias == null || "".equals(alias)) && (signatureCrypto != null)) {
      alias = signatureCrypto.getDefaultX509Identifier();
      if (LOG.isLoggable(Level.FINE)) {
        LOG.fine("Signature alias is null so using default alias: " + alias);
      }
    }
    // Get the password
    WSPasswordCallback[] cb = {new WSPasswordCallback(alias, WSPasswordCallback.SIGNATURE)};
    LOG.fine("Creating SAML Token");
    callbackHandler.handle(cb);
    String password = cb[0].getPassword();

    LOG.fine("Signing SAML Token");
    boolean useKeyValue = signatureProperties.isUseKeyValue();
    assertion.signAssertion(
        alias,
        password,
        signatureCrypto,
        useKeyValue,
        c14nAlgorithm,
        signatureAlgorithm,
        signatureProperties.getDigestAlgorithm());
  }
}
public class AssociatedManagedConnectionFactoryImpl extends ManagedConnectionFactoryImpl
    implements ResourceAdapterAssociation {

  private static final long serialVersionUID = 4305487562182780773L;
  private static final Logger LOG =
      LogUtils.getL7dLogger(AssociatedManagedConnectionFactoryImpl.class);
  private ResourceAdapter ra;

  public AssociatedManagedConnectionFactoryImpl() {
    super();
  }

  public AssociatedManagedConnectionFactoryImpl(Properties props) {
    super(props);
  }

  public Object createConnectionFactory(ConnectionManager connMgr) throws ResourceException {
    Object connFactory = super.createConnectionFactory(connMgr);
    registerBus();
    return connFactory;
  }

  public void setResourceAdapter(ResourceAdapter aRA) throws ResourceException {
    LOG.info("Associate Resource Adapter with ManagedConnectionFactory by appserver. ra = " + ra);
    if (!(aRA instanceof ResourceAdapterImpl)) {
      throw new ResourceAdapterInternalException(
          "ResourceAdapter is not correct, it should be instance of ResourceAdapterImpl");
    }
    this.ra = aRA;
    mergeResourceAdapterProps();
  }

  public ResourceAdapter getResourceAdapter() {
    return ra;
  }

  /**
   * If outbound-resourceAdapter and the resourceAdapter has same property, the
   * outbound-resourceAdapter property's value would take precedence.
   */
  protected void mergeResourceAdapterProps() {
    Properties raProps = ((ResourceAdapterImpl) ra).getPluginProps();
    Properties props = getPluginProps();
    Enumeration<?> raPropsEnum = raProps.propertyNames();
    while (raPropsEnum.hasMoreElements()) {
      String key = (String) raPropsEnum.nextElement();
      if (!props.containsKey(key)) {
        setProperty(key, raProps.getProperty(key));
      } else {
        LOG.fine(
            "ManagedConnectionFactory's props already contain [" + key + "]. No need to merge");
      }
    }
  }

  protected void registerBus() throws ResourceException {
    if (ra == null) {
      throw new ResourceAdapterInternalException("ResourceAdapter can not be null");
    }

    ((ResourceAdapterImpl) ra).registerBus(getBus());
  }

  protected Object getBootstrapContext() {
    return ((ResourceAdapterImpl) ra).getBootstrapContext();
  }

  // Explicit override these two methods,
  // otherwise when deploy rar to weblogic9.1, it would complaint about this.
  public int hashCode() {
    return super.hashCode();
  }

  public boolean equals(Object o) {
    return super.equals(o);
  }
}
public class HttpAwareXSLTOutInterceptor extends AbstractHttpAwareXSLTInterceptor {

  private static final Logger LOG = LogUtils.getL7dLogger(XSLTOutInterceptor.class);

  public HttpAwareXSLTOutInterceptor(String xsltPath) {
    super(Phase.PRE_STREAM, StaxOutInterceptor.class, null, xsltPath);
  }

  public HttpAwareXSLTOutInterceptor(
      String phase, Class<?> before, Class<?> after, String xsltPath) {
    super(phase, before, after, xsltPath);
  }

  @Override
  public void handleMessage(Message message) {
    if (!shouldSchemaValidate(message)) {
      return;
    }
    if (checkContextProperty(message)) {
      return;
    }

    // 1. Try to get and transform XMLStreamWriter message content
    XMLStreamWriter xWriter = message.getContent(XMLStreamWriter.class);
    if (xWriter != null) {
      transformXWriter(message, xWriter);
    } else {
      // 2. Try to get and transform OutputStream message content
      OutputStream out = message.getContent(OutputStream.class);
      if (out != null) {
        transformOS(message, out);
      } else {
        // 3. Try to get and transform Writer message content (actually used for JMS TextMessage)
        Writer writer = message.getContent(Writer.class);
        if (writer != null) {
          transformWriter(message, writer);
        }
      }
    }
  }

  protected void transformXWriter(Message message, XMLStreamWriter xWriter) {
    CachedWriter writer = new CachedWriter();
    XMLStreamWriter delegate = StaxUtils.createXMLStreamWriter(writer);
    XSLTStreamWriter wrapper = new XSLTStreamWriter(getXSLTTemplate(), writer, delegate, xWriter);
    message.setContent(XMLStreamWriter.class, wrapper);
    message.put(AbstractOutDatabindingInterceptor.DISABLE_OUTPUTSTREAM_OPTIMIZATION, Boolean.TRUE);
  }

  protected void transformOS(Message message, OutputStream out) {
    CachedOutputStream wrapper = new CachedOutputStream();
    CachedOutputStreamCallback callback =
        new XSLTCachedOutputStreamCallback(getXSLTTemplate(), out);
    wrapper.registerCallback(callback);
    message.setContent(OutputStream.class, wrapper);
  }

  protected void transformWriter(Message message, Writer writer) {
    XSLTCachedWriter wrapper = new XSLTCachedWriter(getXSLTTemplate(), writer);
    message.setContent(Writer.class, wrapper);
  }

  public static class XSLTStreamWriter extends DelegatingXMLStreamWriter {
    private final Templates xsltTemplate;
    private final CachedWriter cachedWriter;
    private final XMLStreamWriter origXWriter;

    public XSLTStreamWriter(
        Templates xsltTemplate,
        CachedWriter cachedWriter,
        XMLStreamWriter delegateXWriter,
        XMLStreamWriter origXWriter) {
      super(delegateXWriter);
      this.xsltTemplate = xsltTemplate;
      this.cachedWriter = cachedWriter;
      this.origXWriter = origXWriter;
    }

    @Override
    public void close() {
      Reader transformedReader = null;
      try {
        super.flush();
        transformedReader = XSLTUtils.transform(xsltTemplate, cachedWriter.getReader());
        StaxUtils.copy(new StreamSource(transformedReader), origXWriter);
      } catch (XMLStreamException e) {
        throw new Fault("STAX_COPY", LOG, e, e.getMessage());
      } catch (IOException e) {
        throw new Fault("GET_CACHED_INPUT_STREAM", LOG, e, e.getMessage());
      } finally {
        try {
          if (transformedReader != null) {
            transformedReader.close();
          }
          cachedWriter.close();
          StaxUtils.close(origXWriter);
          super.close();
        } catch (Exception e) {
          LOG.warning("Cannot close stream after transformation: " + e.getMessage());
        }
      }
    }
  }

  public static class XSLTCachedOutputStreamCallback implements CachedOutputStreamCallback {
    private final Templates xsltTemplate;
    private final OutputStream origStream;

    public XSLTCachedOutputStreamCallback(Templates xsltTemplate, OutputStream origStream) {
      this.xsltTemplate = xsltTemplate;
      this.origStream = origStream;
    }

    @Override
    public void onFlush(CachedOutputStream wrapper) {}

    @Override
    public void onClose(CachedOutputStream wrapper) {
      InputStream transformedStream = null;
      try {
        transformedStream = XSLTUtils.transform(xsltTemplate, wrapper.getInputStream());
        IOUtils.copyAndCloseInput(transformedStream, origStream);
      } catch (IOException e) {
        throw new Fault("STREAM_COPY", LOG, e, e.getMessage());
      } finally {
        try {
          origStream.close();
        } catch (IOException e) {
          LOG.warning("Cannot close stream after transformation: " + e.getMessage());
        }
      }
    }
  }

  public static class XSLTCachedWriter extends CachedWriter {
    private final Templates xsltTemplate;
    private final Writer origWriter;

    public XSLTCachedWriter(Templates xsltTemplate, Writer origWriter) {
      this.xsltTemplate = xsltTemplate;
      this.origWriter = origWriter;
    }

    @Override
    protected void doClose() {
      Reader transformedReader = null;
      try {
        transformedReader = XSLTUtils.transform(xsltTemplate, getReader());
        IOUtils.copyAndCloseInput(transformedReader, origWriter, IOUtils.DEFAULT_BUFFER_SIZE);
      } catch (IOException e) {
        throw new Fault("READER_COPY", LOG, e, e.getMessage());
      } finally {
        try {
          origWriter.close();
        } catch (IOException e) {
          LOG.warning("Cannot close stream after transformation: " + e.getMessage());
        }
      }
    }
  }
}
public class ServerPolicyOutFaultInterceptor extends AbstractPolicyInterceptor {
  public static final ServerPolicyOutFaultInterceptor INSTANCE =
      new ServerPolicyOutFaultInterceptor();
  private static final Logger LOG = LogUtils.getL7dLogger(ServerPolicyOutFaultInterceptor.class);

  public ServerPolicyOutFaultInterceptor() {
    super(PolicyConstants.SERVER_POLICY_OUT_FAULT_INTERCEPTOR_ID, Phase.SETUP);
  }

  protected void handle(Message msg) {
    if (MessageUtils.isRequestor(msg)) {
      LOG.fine("Is a requestor.");
      return;
    }

    Exchange exchange = msg.getExchange();
    assert null != exchange;

    BindingOperationInfo boi = exchange.get(BindingOperationInfo.class);
    if (null == boi) {
      LOG.fine("No binding operation info.");
      return;
    }

    Endpoint e = exchange.get(Endpoint.class);
    if (null == e) {
      LOG.fine("No endpoint.");
      return;
    }
    EndpointInfo ei = e.getEndpointInfo();

    Bus bus = exchange.get(Bus.class);
    PolicyEngine pe = bus.getExtension(PolicyEngine.class);
    if (null == pe) {
      return;
    }

    Destination destination = exchange.getDestination();

    Exception ex = exchange.get(Exception.class);

    List<Interceptor<? extends Message>> faultInterceptors =
        new ArrayList<Interceptor<? extends Message>>();
    Collection<Assertion> assertions = new ArrayList<Assertion>();

    // 1. Check overridden policy
    Policy p = (Policy) msg.getContextualProperty(PolicyConstants.POLICY_OVERRIDE);
    if (p != null) {
      EndpointPolicyImpl endpi = new EndpointPolicyImpl(p);
      EffectivePolicyImpl effectivePolicy = new EffectivePolicyImpl();
      effectivePolicy.initialise(endpi, (PolicyEngineImpl) pe, false, true);
      PolicyUtils.logPolicy(
          LOG, Level.FINEST, "Using effective policy: ", effectivePolicy.getPolicy());

      faultInterceptors.addAll(effectivePolicy.getInterceptors());
      assertions.addAll(effectivePolicy.getChosenAlternative());
    } else {
      // 2. Process effective server policy
      BindingFaultInfo bfi = getBindingFaultInfo(msg, ex, boi);

      if (bfi == null
          && msg.get(FaultMode.class) != FaultMode.UNCHECKED_APPLICATION_FAULT
          && msg.get(FaultMode.class) != FaultMode.CHECKED_APPLICATION_FAULT) {
        return;
      }

      EffectivePolicy effectivePolicy = pe.getEffectiveServerFaultPolicy(ei, boi, bfi, destination);
      if (effectivePolicy != null) {
        faultInterceptors.addAll(effectivePolicy.getInterceptors());
        assertions.addAll(effectivePolicy.getChosenAlternative());
      }
    }

    // add interceptors into message chain
    for (Interceptor<? extends Message> oi : faultInterceptors) {
      msg.getInterceptorChain().add(oi);
      LOG.log(Level.FINE, "Added interceptor of type {0}", oi.getClass().getSimpleName());
    }

    // insert assertions of the chosen alternative into the message
    if (null != assertions && !assertions.isEmpty()) {
      msg.put(AssertionInfoMap.class, new AssertionInfoMap(assertions));
    }
  }
}
public final class CorbaObjectReferenceHelper {

  public static final String WSDLI_NAMESPACE_URI = "http://www.w3.org/2006/01/wsdl-instance";
  public static final String ADDRESSING_NAMESPACE_URI = "http://www.w3.org/2005/08/addressing";
  public static final String ADDRESSING_WSDL_NAMESPACE_URI =
      "http://www.w3.org/2006/05/addressing/wsdl";

  private static final Logger LOG = LogUtils.getL7dLogger(CorbaObjectReferenceHelper.class);

  private CorbaObjectReferenceHelper() {
    // utility class
  }

  public static String getWSDLLocation(Definition wsdlDef) {
    return wsdlDef.getDocumentBaseURI();
  }

  public static QName getServiceName(Binding binding, Definition wsdlDef) {
    LOG.log(Level.FINE, "Getting service name for an object reference");
    Collection<Service> services = CastUtils.cast(wsdlDef.getServices().values());
    for (Service serv : services) {
      Collection<Port> ports = CastUtils.cast(serv.getPorts().values());
      for (Port pt : ports) {
        if (pt.getBinding().equals(binding)) {
          return serv.getQName();
        }
      }
    }
    return null;
  }

  public static String getEndpointName(Binding binding, Definition wsdlDef) {
    LOG.log(Level.FINE, "Getting endpoint name for an object reference");
    Collection<Service> services = CastUtils.cast(wsdlDef.getServices().values());
    for (Service serv : services) {
      Collection<Port> ports = CastUtils.cast(serv.getPorts().values());
      for (Port pt : ports) {
        if (pt.getBinding().equals(binding)) {
          return pt.getName();
        }
      }
    }
    return null;
  }

  public static Binding getDefaultBinding(Object obj, Definition wsdlDef) {
    LOG.log(Level.FINEST, "Getting binding for a default object reference");
    Collection<Binding> bindings = CastUtils.cast(wsdlDef.getBindings().values());
    for (Binding b : bindings) {
      List<?> extElements = b.getExtensibilityElements();
      // Get the list of all extensibility elements
      for (Iterator<?> extIter = extElements.iterator(); extIter.hasNext(); ) {
        java.lang.Object element = extIter.next();

        // Find a binding type so we can check against its repository ID
        if (element instanceof BindingType) {
          BindingType type = (BindingType) element;
          if (obj._is_a(type.getRepositoryID())) {
            return b;
          }
        }
      }
    }
    return null;
  }

  public static EprMetaData getBindingForTypeId(String repId, Definition wsdlDef) {
    LOG.log(
        Level.FINE, "RepositoryId " + repId + ", wsdl namespace " + wsdlDef.getTargetNamespace());
    EprMetaData ret = new EprMetaData();
    Collection<Binding> bindings = CastUtils.cast(wsdlDef.getBindings().values());
    for (Binding b : bindings) {
      List<?> extElements = b.getExtensibilityElements();

      // Get the list of all extensibility elements
      for (Iterator<?> extIter = extElements.iterator(); extIter.hasNext(); ) {
        java.lang.Object element = extIter.next();

        // Find a binding type so we can check against its repository ID
        if (element instanceof BindingType) {
          BindingType type = (BindingType) element;
          if (repId.equals(type.getRepositoryID())) {
            ret.setCandidateWsdlDef(wsdlDef);
            ret.setBinding(b);
            return ret;
          }
        }
      }
    }

    if (!ret.isValid()) {
      // recursivly check imports
      Iterator<?> importLists = wsdlDef.getImports().values().iterator();
      while (importLists.hasNext()) {
        List<?> imports = (List<?>) importLists.next();
        for (java.lang.Object imp : imports) {
          if (imp instanceof Import) {
            Definition importDef = ((Import) imp).getDefinition();
            LOG.log(Level.INFO, "Following import " + importDef.getDocumentBaseURI());
            ret = getBindingForTypeId(repId, importDef);
            if (ret.isValid()) {
              return ret;
            }
          }
        }
      }
    }
    return ret;
  }

  public static String extractTypeIdFromIOR(String url) {
    String ret = new String();
    byte data[] = DatatypeConverter.parseHexBinary(url.substring(4));
    if (data.length > 0) {
      // parse out type_id from IOR CDR encapsulation
      boolean bigIndian = !(data[0] > 0);
      int typeIdStringSize = readIntFromAlignedCDREncaps(data, 4, bigIndian);
      if (typeIdStringSize > 1) {
        ret = readStringFromAlignedCDREncaps(data, 8, typeIdStringSize - 1);
      }
    }
    return ret;
  }

  private static String readStringFromAlignedCDREncaps(byte[] data, int startIndex, int length) {
    char[] arr = new char[length];
    for (int i = 0; i < length; i++) {
      arr[i] = (char) (data[startIndex + i] & 0xff);
    }
    return new String(arr);
  }

  public static int readIntFromAlignedCDREncaps(byte[] data, int index, boolean bigEndian) {
    if (bigEndian) {
      int partial = ((data[index] << 24) & 0xff000000) | ((data[index + 1] << 16) & 0x00ff0000);
      return partial | ((data[index + 2] << 8) & 0x0000ff00) | ((data[index + 3]) & 0x000000ff);
    } else {
      int partial = ((data[index]) & 0x000000ff) | ((data[index + 1] << 8) & 0x0000ff00);
      return partial
          | ((data[index + 2] << 16) & 0x00ff0000)
          | ((data[index + 3] << 24) & 0xff000000);
    }
  }

  public static void populateEprInfo(EprMetaData info) {
    if (!info.isValid()) {
      return;
    }
    Binding match = info.getBinding();
    Definition wsdlDef = info.getCandidateWsdlDef();
    Collection<Service> services = CastUtils.cast(wsdlDef.getServices().values());
    for (Service serv : services) {
      Collection<Port> ports = CastUtils.cast(serv.getPorts().values());
      for (Port pt : ports) {
        if (pt.getBinding().equals(match)) {
          info.setPortName(pt.getName());
          info.setServiceQName(serv.getQName());
          break;
        }
      }
    }

    if (info.getServiceQName() == null) {
      Iterator<?> importLists = wsdlDef.getImports().values().iterator();
      while (info.getServiceQName() == null && importLists.hasNext()) {
        List<?> imports = (List<?>) importLists.next();
        for (java.lang.Object imp : imports) {
          if (imp instanceof Import) {
            Definition importDef = ((Import) imp).getDefinition();
            LOG.log(Level.FINE, "following wsdl import " + importDef.getDocumentBaseURI());
            info.setCandidateWsdlDef(importDef);
            populateEprInfo(info);
            if (info.getServiceQName() != null) {
              break;
            }
          }
        }
      }
    }
  }
}
Example #15
0
public abstract class AbstractResourceInfo {
  public static final String CONSTRUCTOR_PROXY_MAP = "jaxrs-constructor-proxy-map";
  private static final Logger LOG = LogUtils.getL7dLogger(AbstractResourceInfo.class);
  private static final String FIELD_PROXY_MAP = "jaxrs-field-proxy-map";
  private static final String SETTER_PROXY_MAP = "jaxrs-setter-proxy-map";

  private static final Set<String> STANDARD_CONTEXT_CLASSES = new HashSet<String>();

  static {
    // JAX-RS 1.0-1.1
    STANDARD_CONTEXT_CLASSES.add(Application.class.getName());
    STANDARD_CONTEXT_CLASSES.add(UriInfo.class.getName());
    STANDARD_CONTEXT_CLASSES.add(HttpHeaders.class.getName());
    STANDARD_CONTEXT_CLASSES.add(Request.class.getName());
    STANDARD_CONTEXT_CLASSES.add(SecurityContext.class.getName());
    STANDARD_CONTEXT_CLASSES.add(Providers.class.getName());
    STANDARD_CONTEXT_CLASSES.add(ContextResolver.class.getName());
    STANDARD_CONTEXT_CLASSES.add("javax.servlet.http.HttpServletRequest");
    STANDARD_CONTEXT_CLASSES.add("javax.servlet.http.HttpServletResponse");
    STANDARD_CONTEXT_CLASSES.add("javax.servlet.ServletContext");
    // JAX-RS 2.0
    STANDARD_CONTEXT_CLASSES.add("javax.ws.rs.container.ResourceContext");
    STANDARD_CONTEXT_CLASSES.add("javax.ws.rs.container.ResourceInfo");
    STANDARD_CONTEXT_CLASSES.add("javax.ws.rs.core.Configuration");
  }

  protected boolean root;
  protected Class<?> resourceClass;
  protected Class<?> serviceClass;

  private Map<Class<?>, List<Field>> contextFields;
  private Map<Class<?>, Map<Class<?>, Method>> contextMethods;
  private Bus bus;
  private boolean constructorProxiesAvailable;
  private boolean contextsAvailable;

  protected AbstractResourceInfo(Bus bus) {
    this.bus = bus;
  }

  protected AbstractResourceInfo(
      Class<?> resourceClass,
      Class<?> serviceClass,
      boolean isRoot,
      boolean checkContexts,
      Bus bus) {
    this(resourceClass, serviceClass, isRoot, checkContexts, null, bus, null);
  }

  protected AbstractResourceInfo(
      Class<?> resourceClass,
      Class<?> serviceClass,
      boolean isRoot,
      boolean checkContexts,
      Map<Class<?>, ThreadLocalProxy<?>> constructorProxies,
      Bus bus,
      Object provider) {
    this.bus = bus;
    this.serviceClass = serviceClass;
    this.resourceClass = resourceClass;
    root = isRoot;
    if (checkContexts && resourceClass != null) {
      findContexts(serviceClass, provider, constructorProxies);
    }
  }

  private void findContexts(
      Class<?> cls, Object provider, Map<Class<?>, ThreadLocalProxy<?>> constructorProxies) {
    findContextFields(cls, provider);
    findContextSetterMethods(cls, provider);
    if (constructorProxies != null) {
      Map<Class<?>, Map<Class<?>, ThreadLocalProxy<?>>> proxies = getConstructorProxyMap(true);
      proxies.put(serviceClass, constructorProxies);
      constructorProxiesAvailable = true;
    }

    contextsAvailable =
        contextFields != null && !contextFields.isEmpty()
            || contextMethods != null && !contextMethods.isEmpty()
            || constructorProxiesAvailable;
  }

  public boolean contextsAvailable() {
    return contextsAvailable;
  }

  public Bus getBus() {
    return bus;
  }

  public void setResourceClass(Class<?> rClass) {
    resourceClass = rClass;
    if (serviceClass.isInterface() && resourceClass != null && !resourceClass.isInterface()) {
      findContexts(resourceClass, null, null);
    }
  }

  public Class<?> getServiceClass() {
    return serviceClass;
  }

  private void findContextFields(Class<?> cls, Object provider) {
    if (cls == Object.class || cls == null) {
      return;
    }
    for (Field f : cls.getDeclaredFields()) {
      for (Annotation a : f.getAnnotations()) {
        if (a.annotationType() == Context.class) {
          contextFields = addContextField(contextFields, f);
          if (f.getType().isInterface()) {
            checkContextClass(f.getType());
            addToMap(getFieldProxyMap(true), f, getFieldThreadLocalProxy(f, provider));
          }
        }
      }
    }
    findContextFields(cls.getSuperclass(), provider);
  }

  private static ThreadLocalProxy<?> getFieldThreadLocalProxy(Field f, Object provider) {
    if (provider != null) {
      Object proxy = null;
      synchronized (provider) {
        try {
          proxy = InjectionUtils.extractFieldValue(f, provider);
        } catch (Throwable t) {
          // continue
        }
        if (!(proxy instanceof ThreadLocalProxy)) {
          proxy = InjectionUtils.createThreadLocalProxy(f.getType());
          InjectionUtils.injectFieldValue(f, provider, proxy);
        }
      }
      return (ThreadLocalProxy<?>) proxy;
    } else {
      return InjectionUtils.createThreadLocalProxy(f.getType());
    }
  }

  private static ThreadLocalProxy<?> getMethodThreadLocalProxy(Method m, Object provider) {
    if (provider != null) {
      Object proxy = null;
      synchronized (provider) {
        try {
          proxy =
              InjectionUtils.extractFromMethod(
                  provider, InjectionUtils.getGetterFromSetter(m), false);
        } catch (Throwable t) {
          // continue
        }
        if (!(proxy instanceof ThreadLocalProxy)) {
          proxy = InjectionUtils.createThreadLocalProxy(m.getParameterTypes()[0]);
          InjectionUtils.injectThroughMethod(provider, m, proxy);
        }
      }
      return (ThreadLocalProxy<?>) proxy;
    } else {
      return InjectionUtils.createThreadLocalProxy(m.getParameterTypes()[0]);
    }
  }

  @SuppressWarnings("unchecked")
  private <T> Map<Class<?>, Map<T, ThreadLocalProxy<?>>> getProxyMap(String prop, boolean create) {
    Object property = null;
    synchronized (bus) {
      property = bus.getProperty(prop);
      if (property == null && create) {
        Map<Class<?>, Map<T, ThreadLocalProxy<?>>> map =
            Collections.synchronizedMap(new WeakHashMap<Class<?>, Map<T, ThreadLocalProxy<?>>>(2));
        bus.setProperty(prop, map);
        property = map;
      }
    }
    return (Map<Class<?>, Map<T, ThreadLocalProxy<?>>>) property;
  }

  public Map<Class<?>, ThreadLocalProxy<?>> getConstructorProxies() {
    if (constructorProxiesAvailable) {
      return getConstructorProxyMap(false).get(serviceClass);
    } else {
      return null;
    }
  }

  @SuppressWarnings("unchecked")
  private Map<Class<?>, Map<Class<?>, ThreadLocalProxy<?>>> getConstructorProxyMap(boolean create) {
    Object property = bus.getProperty(CONSTRUCTOR_PROXY_MAP);
    if (property == null) {
      Map<Class<?>, Map<Class<?>, ThreadLocalProxy<?>>> map =
          Collections.synchronizedMap(
              new WeakHashMap<Class<?>, Map<Class<?>, ThreadLocalProxy<?>>>(2));
      bus.setProperty(CONSTRUCTOR_PROXY_MAP, map);
      property = map;
    }
    return (Map<Class<?>, Map<Class<?>, ThreadLocalProxy<?>>>) property;
  }

  private Map<Class<?>, Map<Field, ThreadLocalProxy<?>>> getFieldProxyMap(boolean create) {
    return getProxyMap(FIELD_PROXY_MAP, create);
  }

  private Map<Class<?>, Map<Method, ThreadLocalProxy<?>>> getSetterProxyMap(boolean create) {
    return getProxyMap(SETTER_PROXY_MAP, create);
  }

  private void findContextSetterMethods(Class<?> cls, Object provider) {

    for (Method m : cls.getMethods()) {

      if (!m.getName().startsWith("set") || m.getParameterTypes().length != 1) {
        continue;
      }
      for (Annotation a : m.getAnnotations()) {
        if (a.annotationType() == Context.class) {
          checkContextMethod(m, provider);
          break;
        }
      }
    }
    Class<?>[] interfaces = cls.getInterfaces();
    for (Class<?> i : interfaces) {
      findContextSetterMethods(i, provider);
    }
    Class<?> superCls = cls.getSuperclass();
    if (superCls != null && superCls != Object.class) {
      findContextSetterMethods(superCls, provider);
    }
  }

  private void checkContextMethod(Method m, Object provider) {
    Class<?> type = m.getParameterTypes()[0];
    if (type.isInterface() || type == Application.class) {
      checkContextClass(type);
      addContextMethod(type, m, provider);
    }
  }

  private void checkContextClass(Class<?> type) {
    if (!STANDARD_CONTEXT_CLASSES.contains(type.getName())) {
      LOG.fine(
          "Injecting a custom context "
              + type.getName()
              + ", ContextProvider is required for this type");
    }
  }

  @SuppressWarnings("unchecked")
  public Map<Class<?>, Method> getContextMethods() {
    Map<Class<?>, Method> methods =
        contextMethods == null ? null : contextMethods.get(getServiceClass());
    return methods == null ? Collections.EMPTY_MAP : Collections.unmodifiableMap(methods);
  }

  private void addContextMethod(Class<?> contextClass, Method m, Object provider) {
    if (contextMethods == null) {
      contextMethods = new HashMap<Class<?>, Map<Class<?>, Method>>();
    }
    addToMap(contextMethods, contextClass, m);
    if (m.getParameterTypes()[0] != Application.class) {
      addToMap(getSetterProxyMap(true), m, getMethodThreadLocalProxy(m, provider));
    }
  }

  public boolean isRoot() {
    return root;
  }

  public Class<?> getResourceClass() {
    return resourceClass;
  }

  public List<Field> getContextFields() {
    return getList(contextFields);
  }

  public ThreadLocalProxy<?> getContextFieldProxy(Field f) {
    return getProxy(getFieldProxyMap(true), f);
  }

  public ThreadLocalProxy<?> getContextSetterProxy(Method m) {
    return getProxy(getSetterProxyMap(true), m);
  }

  public abstract boolean isSingleton();

  @SuppressWarnings("rawtypes")
  public static void clearAllMaps() {
    Bus bus = BusFactory.getThreadDefaultBus(false);
    if (bus != null) {
      Object property = bus.getProperty(FIELD_PROXY_MAP);
      if (property != null) {
        ((Map) property).clear();
      }
      property = bus.getProperty(SETTER_PROXY_MAP);
      if (property != null) {
        ((Map) property).clear();
      }
      property = bus.getProperty(CONSTRUCTOR_PROXY_MAP);
      if (property != null) {
        ((Map) property).clear();
      }
    }
  }

  public void clearThreadLocalProxies() {
    clearProxies(getFieldProxyMap(false));
    clearProxies(getSetterProxyMap(false));
    clearProxies(getConstructorProxyMap(false));
  }

  private <T> void clearProxies(Map<Class<?>, Map<T, ThreadLocalProxy<?>>> tlps) {
    Map<T, ThreadLocalProxy<?>> proxies = tlps == null ? null : tlps.get(getServiceClass());
    if (proxies == null) {
      return;
    }
    for (ThreadLocalProxy<?> tlp : proxies.values()) {
      if (tlp != null) {
        tlp.remove();
      }
    }
  }

  private Map<Class<?>, List<Field>> addContextField(
      Map<Class<?>, List<Field>> theFields, Field f) {
    if (theFields == null) {
      theFields = new HashMap<Class<?>, List<Field>>();
    }

    List<Field> fields = theFields.get(serviceClass);
    if (fields == null) {
      fields = new ArrayList<Field>();
      theFields.put(serviceClass, fields);
    }
    if (!fields.contains(f)) {
      fields.add(f);
    }
    return theFields;
  }

  private <T, V> void addToMap(Map<Class<?>, Map<T, V>> proxyMap, T f, V proxy) {
    Map<T, V> proxies = proxyMap.get(serviceClass);
    if (proxies == null) {
      proxies = Collections.synchronizedMap(new WeakHashMap<T, V>());
      proxyMap.put(serviceClass, proxies);
    }
    if (!proxies.containsKey(f)) {
      proxies.put(f, proxy);
    }
  }

  private List<Field> getList(Map<Class<?>, List<Field>> fields) {
    List<Field> ret = fields == null ? null : fields.get(getServiceClass());
    if (ret != null) {
      ret = Collections.unmodifiableList(ret);
    } else {
      ret = Collections.emptyList();
    }
    return ret;
  }

  private <T> ThreadLocalProxy<?> getProxy(
      Map<Class<?>, Map<T, ThreadLocalProxy<?>>> proxies, T key) {

    Map<?, ThreadLocalProxy<?>> theMap = proxies == null ? null : proxies.get(getServiceClass());
    ThreadLocalProxy<?> ret = null;
    if (theMap != null) {
      ret = theMap.get(key);
    }
    return ret;
  }
}
Example #16
0
public class RetransmissionQueueImpl implements RetransmissionQueue {

  private static final Logger LOG = LogUtils.getL7dLogger(RetransmissionQueueImpl.class);

  private Map<String, List<ResendCandidate>> candidates =
      new HashMap<String, List<ResendCandidate>>();
  private Map<String, List<ResendCandidate>> suspendedCandidates =
      new HashMap<String, List<ResendCandidate>>();
  private Resender resender;
  private RMManager manager;

  private int unacknowledgedCount;

  public RetransmissionQueueImpl(RMManager m) {
    manager = m;
  }

  public RMManager getManager() {
    return manager;
  }

  public void setManager(RMManager m) {
    manager = m;
  }

  public void addUnacknowledged(Message message) {
    cacheUnacknowledged(message);
  }

  /**
   * @param seq the sequence under consideration
   * @return the number of unacknowledged messages for that sequence
   */
  public synchronized int countUnacknowledged(SourceSequence seq) {
    List<ResendCandidate> sequenceCandidates = getSequenceCandidates(seq);
    return sequenceCandidates == null ? 0 : sequenceCandidates.size();
  }

  public int countUnacknowledged() {
    return unacknowledgedCount;
  }

  /** @return true if there are no unacknowledged messages in the queue */
  public boolean isEmpty() {
    return 0 == getUnacknowledged().size();
  }

  /**
   * Purge all candidates for the given sequence that have been acknowledged.
   *
   * @param seq the sequence object.
   */
  public void purgeAcknowledged(SourceSequence seq) {
    purgeCandidates(seq, false);
  }

  /**
   * Purge all candidates for the given sequence. This method is used to terminate the sequence by
   * force and release the resource associated with the sequence.
   *
   * @param seq the sequence object.
   */
  public void purgeAll(SourceSequence seq) {
    purgeCandidates(seq, true);
  }

  private void purgeCandidates(SourceSequence seq, boolean any) {
    Collection<Long> purged = new ArrayList<Long>();
    Collection<ResendCandidate> resends = new ArrayList<ResendCandidate>();
    Identifier sid = seq.getIdentifier();
    synchronized (this) {
      LOG.fine("Start purging resend candidates.");
      List<ResendCandidate> sequenceCandidates = getSequenceCandidates(seq);
      if (null != sequenceCandidates) {
        for (int i = sequenceCandidates.size() - 1; i >= 0; i--) {
          ResendCandidate candidate = sequenceCandidates.get(i);
          long m = candidate.getNumber();
          if (any || seq.isAcknowledged(m)) {
            sequenceCandidates.remove(i);
            candidate.resolved();
            unacknowledgedCount--;
            purged.add(m);
            resends.add(candidate);
          }
        }
        if (sequenceCandidates.isEmpty()) {
          candidates.remove(sid.getValue());
        }
      }
      LOG.fine("Completed purging resend candidates.");
    }
    if (purged.size() > 0) {
      RMStore store = manager.getStore();
      if (null != store) {
        store.removeMessages(sid, purged, true);
      }
      RMEndpoint rmEndpoint = seq.getSource().getReliableEndpoint();
      for (ResendCandidate resend : resends) {
        rmEndpoint.handleAcknowledgment(sid.getValue(), resend.getNumber(), resend.getMessage());
      }
    }
  }

  public List<Long> getUnacknowledgedMessageNumbers(SourceSequence seq) {
    List<Long> unacknowledged = new ArrayList<Long>();
    List<ResendCandidate> sequenceCandidates = getSequenceCandidates(seq);
    if (null != sequenceCandidates) {
      for (int i = 0; i < sequenceCandidates.size(); i++) {
        ResendCandidate candidate = sequenceCandidates.get(i);
        unacknowledged.add(candidate.getNumber());
      }
    }
    return unacknowledged;
  }

  public RetryStatus getRetransmissionStatus(SourceSequence seq, long num) {
    List<ResendCandidate> sequenceCandidates = getSequenceCandidates(seq);
    if (null != sequenceCandidates) {
      for (int i = 0; i < sequenceCandidates.size(); i++) {
        ResendCandidate candidate = sequenceCandidates.get(i);
        if (num == candidate.getNumber()) {
          return candidate;
        }
      }
    }
    return null;
  }

  public Map<Long, RetryStatus> getRetransmissionStatuses(SourceSequence seq) {
    Map<Long, RetryStatus> cp = new HashMap<Long, RetryStatus>();
    List<ResendCandidate> sequenceCandidates = getSequenceCandidates(seq);
    if (null != sequenceCandidates) {
      for (int i = 0; i < sequenceCandidates.size(); i++) {
        ResendCandidate candidate = sequenceCandidates.get(i);
        cp.put(candidate.getNumber(), candidate);
      }
    }
    return cp;
  }

  /** Initiate resends. */
  public void start() {
    if (null != resender) {
      return;
    }
    LOG.fine("Starting retransmission queue");

    // setup resender

    resender = getDefaultResender();
  }

  /** Stops resending messages for the specified source sequence. */
  public void stop(SourceSequence seq) {
    synchronized (this) {
      List<ResendCandidate> sequenceCandidates = getSequenceCandidates(seq);
      if (null != sequenceCandidates) {
        for (int i = sequenceCandidates.size() - 1; i >= 0; i--) {
          ResendCandidate candidate = sequenceCandidates.get(i);
          candidate.cancel();
        }
        LOG.log(Level.FINE, "Cancelled resends for sequence {0}.", seq.getIdentifier().getValue());
      }
    }
  }

  void stop() {}

  public void suspend(SourceSequence seq) {
    synchronized (this) {
      String key = seq.getIdentifier().getValue();
      List<ResendCandidate> sequenceCandidates = candidates.remove(key);
      if (null != sequenceCandidates) {
        for (int i = sequenceCandidates.size() - 1; i >= 0; i--) {
          ResendCandidate candidate = sequenceCandidates.get(i);
          candidate.suspend();
        }
        suspendedCandidates.put(key, sequenceCandidates);
        LOG.log(Level.FINE, "Suspended resends for sequence {0}.", key);
      }
    }
  }

  public void resume(SourceSequence seq) {
    synchronized (this) {
      String key = seq.getIdentifier().getValue();
      List<ResendCandidate> sequenceCandidates = suspendedCandidates.remove(key);
      if (null != sequenceCandidates) {
        for (int i = 0; i < sequenceCandidates.size(); i++) {
          ResendCandidate candidate = sequenceCandidates.get(i);
          candidate.resume();
        }
        candidates.put(key, sequenceCandidates);
        LOG.log(Level.FINE, "Resumed resends for sequence {0}.", key);
      }
    }
  }

  /** @return the exponential backoff */
  protected int getExponentialBackoff() {
    return DEFAULT_EXPONENTIAL_BACKOFF;
  }

  /**
   * @param message the message context
   * @return a ResendCandidate
   */
  protected ResendCandidate createResendCandidate(SoapMessage message) {
    return new ResendCandidate(message);
  }

  /**
   * Accepts a new resend candidate.
   *
   * @param ctx the message context.
   * @return ResendCandidate
   */
  protected ResendCandidate cacheUnacknowledged(Message message) {
    RMProperties rmps = RMContextUtils.retrieveRMProperties(message, true);
    SequenceType st = rmps.getSequence();
    Identifier sid = st.getIdentifier();
    String key = sid.getValue();

    ResendCandidate candidate = null;

    synchronized (this) {
      List<ResendCandidate> sequenceCandidates = getSequenceCandidates(key);
      if (null == sequenceCandidates) {
        sequenceCandidates = new ArrayList<ResendCandidate>();
        candidates.put(key, sequenceCandidates);
      }
      candidate = new ResendCandidate(message);
      if (isSequenceSuspended(key)) {
        candidate.suspend();
      }
      sequenceCandidates.add(candidate);
      unacknowledgedCount++;
    }
    LOG.fine("Cached unacknowledged message.");
    try {
      RMEndpoint rme = manager.getReliableEndpoint(message);
      rme.handleAccept(key, st.getMessageNumber(), message);
    } catch (RMException e) {
      LOG.log(Level.WARNING, "Could not find reliable endpoint for message");
    }
    return candidate;
  }

  /** @return a map relating sequence ID to a lists of un-acknowledged messages for that sequence */
  protected Map<String, List<ResendCandidate>> getUnacknowledged() {
    return candidates;
  }

  /**
   * @param seq the sequence under consideration
   * @return the list of resend candidates for that sequence
   * @pre called with mutex held
   */
  protected List<ResendCandidate> getSequenceCandidates(SourceSequence seq) {
    return getSequenceCandidates(seq.getIdentifier().getValue());
  }

  /**
   * @param key the sequence identifier under consideration
   * @return the list of resend candidates for that sequence
   * @pre called with mutex held
   */
  protected List<ResendCandidate> getSequenceCandidates(String key) {
    List<ResendCandidate> sc = candidates.get(key);
    if (null == sc) {
      sc = suspendedCandidates.get(key);
    }
    return sc;
  }

  /**
   * @param key the sequence identifier under consideration
   * @return true if the sequence is currently suspended; false otherwise
   * @pre called with mutex held
   */
  protected boolean isSequenceSuspended(String key) {
    return suspendedCandidates.containsKey(key);
  }

  /** Represents a candidate for resend, i.e. an unacked outgoing message. */
  protected class ResendCandidate implements Runnable, RetryStatus {
    private Message message;
    private long number;
    private Date next;
    private TimerTask nextTask;
    private int retries;
    private int maxRetries;
    private long nextInterval;
    private long backoff;
    private boolean pending;
    private boolean suspended;
    private boolean includeAckRequested;

    /** @param ctx message context for the unacked message */
    protected ResendCandidate(Message m) {
      message = m;
      retries = 0;
      RMConfiguration cfg = manager.getEffectiveConfiguration(message);
      long baseRetransmissionInterval = cfg.getBaseRetransmissionInterval().longValue();
      backoff = cfg.isExponentialBackoff() ? RetransmissionQueue.DEFAULT_EXPONENTIAL_BACKOFF : 1;
      next = new Date(System.currentTimeMillis() + baseRetransmissionInterval);
      nextInterval = baseRetransmissionInterval * backoff;
      RetryPolicyType rmrp =
          null != manager.getSourcePolicy() ? manager.getSourcePolicy().getRetryPolicy() : null;
      maxRetries = null != rmrp ? rmrp.getMaxRetries() : -1;

      AddressingProperties maps = RMContextUtils.retrieveMAPs(message, false, true);
      AttributedURIType to = null;
      if (null != maps) {
        to = maps.getTo();
      }
      if (to != null && RMUtils.getAddressingConstants().getAnonymousURI().equals(to.getValue())) {
        LOG.log(Level.INFO, "Cannot resend to anonymous target.  Not scheduling a resend.");
        return;
      }
      RMProperties rmprops = RMContextUtils.retrieveRMProperties(message, true);
      if (null != rmprops) {
        number = rmprops.getSequence().getMessageNumber();
      }
      if (null != manager.getTimer() && maxRetries != 0) {
        schedule();
      }
    }

    /**
     * Initiate resend asynchronsly.
     *
     * @param requestAcknowledge true if a AckRequest header is to be sent with resend
     */
    protected void initiate(boolean requestAcknowledge) {
      includeAckRequested = requestAcknowledge;
      pending = true;
      Endpoint ep = message.getExchange().get(Endpoint.class);
      Executor executor = ep.getExecutor();
      if (null == executor) {
        executor = ep.getService().getExecutor();
        if (executor == null) {
          executor = SynchronousExecutor.getInstance();
        } else {
          LOG.log(Level.FINE, "Using service executor {0}", executor.getClass().getName());
        }
      } else {
        LOG.log(Level.FINE, "Using endpoint executor {0}", executor.getClass().getName());
      }

      try {
        executor.execute(this);
      } catch (RejectedExecutionException ex) {
        LOG.log(Level.SEVERE, "RESEND_INITIATION_FAILED_MSG", ex);
      }
    }

    public void run() {
      try {
        // ensure ACK wasn't received while this task was enqueued
        // on executor
        if (isPending()) {
          resender.resend(message, includeAckRequested);
          includeAckRequested = false;
        }
      } finally {
        attempted();
      }
    }

    public long getNumber() {
      return number;
    }

    /** @return number of resend attempts */
    public int getRetries() {
      return retries;
    }

    /** @return number of max resend attempts */
    public int getMaxRetries() {
      return maxRetries;
    }

    /** @return date of next resend */
    public Date getNext() {
      return next;
    }

    /** @return date of previous resend or null if no attempt is yet taken */
    public Date getPrevious() {
      if (retries > 0) {
        return new Date(next.getTime() - nextInterval / backoff);
      }
      return null;
    }

    public long getNextInterval() {
      return nextInterval;
    }

    public long getBackoff() {
      return backoff;
    }

    public boolean isSuspended() {
      return suspended;
    }

    /** @return if resend attempt is pending */
    public synchronized boolean isPending() {
      return pending;
    }

    /** ACK has been received for this candidate. */
    protected synchronized void resolved() {
      pending = false;
      next = null;
      if (null != nextTask) {
        nextTask.cancel();
        releaseSavedMessage();
      }
    }

    /** Cancel further resend (although no ACK has been received). */
    protected synchronized void cancel() {
      if (null != nextTask) {
        nextTask.cancel();
        releaseSavedMessage();
      }
    }

    protected synchronized void suspend() {
      suspended = true;
      pending = false;
      // TODO release the message and later reload it upon resume
      // cancel();
      if (null != nextTask) {
        nextTask.cancel();
      }
    }

    protected synchronized void resume() {
      suspended = false;
      next = new Date(System.currentTimeMillis());
      attempted();
    }

    private void releaseSavedMessage() {
      RewindableInputStream is =
          (RewindableInputStream) message.get(RMMessageConstants.SAVED_CONTENT);
      if (is != null) {
        is.release();
      }
    }

    /** @return associated message context */
    protected Message getMessage() {
      return message;
    }

    /** A resend has been attempted. Schedule the next attempt. */
    protected synchronized void attempted() {
      pending = false;
      retries++;
      if (null != next && maxRetries != retries) {
        next = new Date(next.getTime() + nextInterval);
        nextInterval *= backoff;
        schedule();
      }
    }

    protected final synchronized void schedule() {
      if (null == manager.getTimer()) {
        return;
      }
      class ResendTask extends TimerTask {
        ResendCandidate candidate;

        ResendTask(ResendCandidate c) {
          candidate = c;
        }

        @Override
        public void run() {
          if (!candidate.isPending()) {
            candidate.initiate(includeAckRequested);
          }
        }
      }
      nextTask = new ResendTask(this);
      try {
        manager.getTimer().schedule(nextTask, next);
      } catch (IllegalStateException ex) {
        LOG.log(Level.WARNING, "SCHEDULE_RESEND_FAILED_MSG", ex);
      }
    }
  }

  /** Encapsulates actual resend logic (pluggable to facilitate unit testing) */
  public interface Resender {
    /**
     * Resend mechanics.
     *
     * @param message
     * @param if a AckRequest should be included
     */
    void resend(Message message, boolean requestAcknowledge);
  }

  /**
   * Create default Resender logic.
   *
   * @return default Resender
   */
  protected final Resender getDefaultResender() {
    return new Resender() {
      public void resend(Message message, boolean requestAcknowledge) {
        RMProperties properties = RMContextUtils.retrieveRMProperties(message, true);
        SequenceType st = properties.getSequence();
        if (st != null) {
          LOG.log(Level.INFO, "RESEND_MSG", st.getMessageNumber());
        }
        if (message instanceof SoapMessage) {
          doResend((SoapMessage) message);
        } else {
          doResend(new SoapMessage(message));
        }
      }
    };
  }

  /**
   * Plug in replacement resend logic (facilitates unit testing).
   *
   * @param replacement resend logic
   */
  protected void replaceResender(Resender replacement) {
    resender = replacement;
  }

  @SuppressWarnings("unchecked")
  protected JaxbAssertion<RMAssertion> getAssertion(AssertionInfo ai) {
    return (JaxbAssertion<RMAssertion>) ai.getAssertion();
  }

  private void readHeaders(XMLStreamReader xmlReader, SoapMessage message)
      throws XMLStreamException {

    // read header portion of SOAP document into DOM
    SoapVersion version = message.getVersion();
    XMLStreamReader filteredReader = new PartialXMLStreamReader(xmlReader, version.getBody());
    Node nd = message.getContent(Node.class);
    W3CDOMStreamWriter writer = message.get(W3CDOMStreamWriter.class);
    Document doc = null;
    if (writer != null) {
      StaxUtils.copy(filteredReader, writer);
      doc = writer.getDocument();
    } else if (nd instanceof Document) {
      doc = (Document) nd;
      StaxUtils.readDocElements(doc, doc, filteredReader, false, false);
    } else {
      doc = StaxUtils.read(filteredReader);
      message.setContent(Node.class, doc);
    }

    // get the actual SOAP header
    Element element = doc.getDocumentElement();
    QName header = version.getHeader();
    List<Element> elemList =
        DOMUtils.findAllElementsByTagNameNS(
            element, header.getNamespaceURI(), header.getLocalPart());
    for (Element elem : elemList) {

      // set all child elements as headers for message transmission
      Element hel = DOMUtils.getFirstElement(elem);
      while (hel != null) {
        SoapHeader sheader = new SoapHeader(DOMUtils.getElementQName(hel), hel);
        message.getHeaders().add(sheader);
        hel = DOMUtils.getNextElement(hel);
      }
    }
  }

  private void doResend(SoapMessage message) {
    try {

      // initialize copied interceptor chain for message
      PhaseInterceptorChain retransmitChain = manager.getRetransmitChain(message);
      ProtocolVariation protocol = RMContextUtils.getProtocolVariation(message);
      Endpoint endpoint = manager.getReliableEndpoint(message).getEndpoint(protocol);
      PhaseChainCache cache = new PhaseChainCache();
      boolean after = true;
      if (retransmitChain == null) {

        // no saved retransmit chain, so construct one from scratch (won't work for WS-Security on
        // server, so
        //  need to fix)
        retransmitChain = buildRetransmitChain(endpoint, cache);
        after = false;
      }
      message.setInterceptorChain(retransmitChain);

      // clear flag for SOAP out interceptor so envelope will be written
      message.remove(SoapOutInterceptor.WROTE_ENVELOPE_START);

      // discard all saved content
      Set<Class<?>> formats = message.getContentFormats();
      List<CachedOutputStreamCallback> callbacks = null;
      for (Class<?> clas : formats) {
        Object content = message.getContent(clas);
        if (content != null) {
          LOG.info(
              "Removing "
                  + clas.getName()
                  + " content of actual type "
                  + content.getClass().getName());
          message.removeContent(clas);
          if (clas == OutputStream.class && content instanceof WriteOnCloseOutputStream) {
            callbacks = ((WriteOnCloseOutputStream) content).getCallbacks();
          }
        }
      }

      // read SOAP headers from saved input stream
      RewindableInputStream is =
          (RewindableInputStream) message.get(RMMessageConstants.SAVED_CONTENT);
      is.rewind();
      XMLStreamReader reader = StaxUtils.createXMLStreamReader(is, "UTF-8");
      message.getHeaders().clear();
      if (reader.getEventType() != XMLStreamConstants.START_ELEMENT
          && reader.nextTag() != XMLStreamConstants.START_ELEMENT) {
        throw new IllegalStateException("No document found");
      }
      readHeaders(reader, message);
      int event;
      while ((event = reader.nextTag()) != XMLStreamConstants.START_ELEMENT) {
        if (event == XMLStreamConstants.END_ELEMENT) {
          throw new IllegalStateException("No body content present");
        }
      }

      // set message addressing properties
      AddressingProperties maps = new MAPCodec().unmarshalMAPs(message);
      RMContextUtils.storeMAPs(maps, message, true, MessageUtils.isRequestor(message));
      AttributedURIType to = null;
      if (null != maps) {
        to = maps.getTo();
      }
      if (null == to) {
        LOG.log(Level.SEVERE, "NO_ADDRESS_FOR_RESEND_MSG");
        return;
      }
      if (RMUtils.getAddressingConstants().getAnonymousURI().equals(to.getValue())) {
        LOG.log(Level.FINE, "Cannot resend to anonymous target");
        return;
      }

      // initialize conduit for new message
      Conduit c = message.getExchange().getConduit(message);
      if (c == null) {
        c = buildConduit(message, endpoint, to);
      }
      c.prepare(message);

      // replace standard message marshaling with copy from saved stream
      ListIterator<Interceptor<? extends Message>> iterator = retransmitChain.getIterator();
      while (iterator.hasNext()) {
        Interceptor<? extends Message> incept = iterator.next();

        // remove JAX-WS interceptors which handle message modes and such
        if (incept.getClass().getName().startsWith("org.apache.cxf.jaxws.interceptors")) {
          retransmitChain.remove(incept);
        } else if (incept instanceof PhaseInterceptor
            && (((PhaseInterceptor<?>) incept).getPhase() == Phase.MARSHAL)) {

          // remove any interceptors from the marshal phase
          retransmitChain.remove(incept);
        }
      }
      retransmitChain.add(new CopyOutInterceptor(reader));

      // restore callbacks on output stream
      if (callbacks != null) {
        OutputStream os = message.getContent(OutputStream.class);
        if (os != null) {
          WriteOnCloseOutputStream woc;
          if (os instanceof WriteOnCloseOutputStream) {
            woc = (WriteOnCloseOutputStream) os;
          } else {
            woc = new WriteOnCloseOutputStream(os);
            message.setContent(OutputStream.class, woc);
          }
          for (CachedOutputStreamCallback cb : callbacks) {
            woc.registerCallback(cb);
          }
        }
      }

      // send the message
      message.put(RMMessageConstants.RM_RETRANSMISSION, Boolean.TRUE);
      if (after) {
        retransmitChain.doInterceptStartingAfter(message, RMCaptureOutInterceptor.class.getName());
      } else {
        retransmitChain.doIntercept(message);
      }
      if (LOG.isLoggable(Level.INFO)) {
        RMProperties rmps = RMContextUtils.retrieveRMProperties(message, true);
        SequenceType seq = rmps.getSequence();
        LOG.log(
            Level.INFO,
            "Retransmitted message "
                + seq.getMessageNumber()
                + " in sequence "
                + seq.getIdentifier().getValue());
        rmps = new RMProperties();
      }

    } catch (Exception ex) {
      LOG.log(Level.SEVERE, "RESEND_FAILED_MSG", ex);
    }
  }

  /**
   * @param message
   * @param endpoint
   * @param to
   * @return
   */
  protected Conduit buildConduit(
      SoapMessage message, final Endpoint endpoint, AttributedURIType to) {
    Conduit c;
    final String address = to.getValue();
    ConduitSelector cs =
        new DeferredConduitSelector() {
          @Override
          public synchronized Conduit selectConduit(Message message) {
            Conduit conduit = null;
            EndpointInfo endpointInfo = endpoint.getEndpointInfo();
            EndpointReferenceType original = endpointInfo.getTarget();
            try {
              if (null != address) {
                endpointInfo.setAddress(address);
              }
              conduit = super.selectConduit(message);
            } finally {
              endpointInfo.setAddress(original);
            }
            return conduit;
          }
        };

    cs.setEndpoint(endpoint);
    c = cs.selectConduit(message);
    // REVISIT
    // use application endpoint message observer instead?
    c.setMessageObserver(
        new MessageObserver() {
          public void onMessage(Message message) {
            LOG.fine("Ignoring response to resent message.");
          }
        });

    message.getExchange().setConduit(c);
    return c;
  }

  /**
   * @param endpoint
   * @param cache
   * @return
   */
  protected PhaseInterceptorChain buildRetransmitChain(
      final Endpoint endpoint, PhaseChainCache cache) {
    PhaseInterceptorChain retransmitChain;
    Bus bus = getManager().getBus();
    List<Interceptor<? extends Message>> i1 = bus.getOutInterceptors();
    if (LOG.isLoggable(Level.FINE)) {
      LOG.fine("Interceptors contributed by bus: " + i1);
    }
    List<Interceptor<? extends Message>> i2 = endpoint.getOutInterceptors();
    if (LOG.isLoggable(Level.FINE)) {
      LOG.fine("Interceptors contributed by endpoint: " + i2);
    }
    List<Interceptor<? extends Message>> i3 = endpoint.getBinding().getOutInterceptors();
    if (LOG.isLoggable(Level.FINE)) {
      LOG.fine("Interceptors contributed by binding: " + i3);
    }
    PhaseManager pm = bus.getExtension(PhaseManager.class);
    retransmitChain = cache.get(pm.getOutPhases(), i1, i2, i3);
    return retransmitChain;
  }

  public static class CopyOutInterceptor extends AbstractOutDatabindingInterceptor {
    private final XMLStreamReader reader;

    public CopyOutInterceptor(XMLStreamReader rdr) {
      super(Phase.MARSHAL);
      reader = rdr;
    }

    @Override
    public void handleMessage(Message message) throws Fault {
      try {
        XMLStreamWriter writer = message.getContent(XMLStreamWriter.class);
        StaxUtils.copy(reader, writer);
      } catch (XMLStreamException e) {
        throw new Fault("COULD_NOT_READ_XML_STREAM", LOG, e);
      }
    }
  }
}
Example #17
0
public final class ColocUtil {
  private static final Logger LOG = LogUtils.getL7dLogger(ColocUtil.class);

  private ColocUtil() {
    // Completge
  }

  public static void setPhases(SortedSet<Phase> list, String start, String end) {
    Phase startPhase = new Phase(start, 1);
    Phase endPhase = new Phase(end, 2);
    Iterator<Phase> iter = list.iterator();
    boolean remove = true;
    while (iter.hasNext()) {
      Phase p = iter.next();
      if (remove && p.getName().equals(startPhase.getName())) {
        remove = false;
      } else if (p.getName().equals(endPhase.getName())) {
        remove = true;
      } else if (remove) {
        iter.remove();
      }
    }
  }

  public static InterceptorChain getOutInterceptorChain(Exchange ex, SortedSet<Phase> phases) {
    Bus bus = ex.get(Bus.class);
    PhaseInterceptorChain chain = new PhaseInterceptorChain(phases);

    Endpoint ep = ex.get(Endpoint.class);
    List<Interceptor> il = ep.getOutInterceptors();
    if (LOG.isLoggable(Level.FINE)) {
      LOG.fine("Interceptors contributed by endpoint: " + il);
    }
    chain.add(il);
    il = ep.getService().getOutInterceptors();
    if (LOG.isLoggable(Level.FINE)) {
      LOG.fine("Interceptors contributed by service: " + il);
    }
    chain.add(il);
    il = bus.getOutInterceptors();
    if (LOG.isLoggable(Level.FINE)) {
      LOG.fine("Interceptors contributed by bus: " + il);
    }
    chain.add(il);

    if (ep.getService().getDataBinding() instanceof InterceptorProvider) {
      il = ((InterceptorProvider) ep.getService().getDataBinding()).getOutInterceptors();
      if (LOG.isLoggable(Level.FINE)) {
        LOG.fine("Interceptors contributed by databinding: " + il);
      }
      chain.add(il);
    }

    return chain;
  }

  public static InterceptorChain getInInterceptorChain(Exchange ex, SortedSet<Phase> phases) {
    Bus bus = ex.get(Bus.class);
    PhaseInterceptorChain chain = new PhaseInterceptorChain(phases);

    Endpoint ep = ex.get(Endpoint.class);
    List<Interceptor> il = ep.getInInterceptors();
    if (LOG.isLoggable(Level.FINE)) {
      LOG.fine("Interceptors contributed by endpoint: " + il);
    }
    chain.add(il);
    il = ep.getService().getInInterceptors();
    if (LOG.isLoggable(Level.FINE)) {
      LOG.fine("Interceptors contributed by service: " + il);
    }
    chain.add(il);
    il = bus.getInInterceptors();
    if (LOG.isLoggable(Level.FINE)) {
      LOG.fine("Interceptors contributed by bus: " + il);
    }
    chain.add(il);

    if (ep.getService().getDataBinding() instanceof InterceptorProvider) {
      il = ((InterceptorProvider) ep.getService().getDataBinding()).getInInterceptors();
      if (LOG.isLoggable(Level.FINE)) {
        LOG.fine("Interceptors contributed by databinding: " + il);
      }
      chain.add(il);
    }
    chain.setFaultObserver(new ColocOutFaultObserver(bus));

    return chain;
  }

  public static boolean isSameOperationInfo(OperationInfo oi1, OperationInfo oi2) {
    return oi1.getName().equals(oi2.getName())
        && isSameMessageInfo(oi1.getInput(), oi2.getInput())
        && isSameMessageInfo(oi1.getOutput(), oi2.getOutput())
        && isSameFaultInfo(oi1.getFaults(), oi2.getFaults());
  }

  public static boolean isSameMessageInfo(MessageInfo mi1, MessageInfo mi2) {
    if ((mi1 == null && mi2 != null) || (mi1 != null && mi2 == null)) {
      return false;
    }

    if (mi1 != null && mi2 != null) {
      List<MessagePartInfo> mpil1 = mi1.getMessageParts();
      List<MessagePartInfo> mpil2 = mi2.getMessageParts();
      if (mpil1.size() != mpil2.size()) {
        return false;
      }
      int idx = 0;
      for (MessagePartInfo mpi1 : mpil1) {
        MessagePartInfo mpi2 = mpil2.get(idx);
        if (!mpi1.getTypeClass().equals(mpi2.getTypeClass())) {
          return false;
        }
        ++idx;
      }
    }
    return true;
  }

  public static boolean isSameFaultInfo(Collection<FaultInfo> fil1, Collection<FaultInfo> fil2) {
    if ((fil1 == null && fil2 != null) || (fil1 != null && fil2 == null)) {
      return false;
    }

    if (fil1 != null && fil2 != null) {
      if (fil1.size() != fil2.size()) {
        return false;
      }
      for (FaultInfo fi1 : fil1) {
        Iterator<FaultInfo> iter = fil2.iterator();
        Class<?> fiClass1 = fi1.getProperty(Class.class.getName(), Class.class);
        boolean match = false;
        while (iter.hasNext()) {
          FaultInfo fi2 = iter.next();
          Class<?> fiClass2 = fi2.getProperty(Class.class.getName(), Class.class);
          // Sender/Receiver Service Model not same for faults wr.t message names.
          // So Compare Exception Class Instance.
          if (fiClass1.equals(fiClass2)) {
            match = true;
            break;
          }
        }
        if (!match) {
          return false;
        }
      }
    }
    return true;
  }
}
Example #18
0
public class ColocOutInterceptor extends AbstractPhaseInterceptor<Message> {
  private static final ResourceBundle BUNDLE = BundleUtils.getBundle(ColocOutInterceptor.class);
  private static final Logger LOG = LogUtils.getL7dLogger(ClientImpl.class);
  private static final String COLOCATED = Message.class.getName() + ".COLOCATED";
  private MessageObserver colocObserver;
  private Bus bus;

  public ColocOutInterceptor() {
    super(Phase.POST_LOGICAL);
  }

  public ColocOutInterceptor(Bus b) {
    super(Phase.POST_LOGICAL);
    bus = b;
  }

  public void setBus(Bus bus) {
    this.bus = bus;
  }

  public void handleMessage(Message message) throws Fault {
    if (bus == null) {
      bus = message.getExchange().getBus();
      if (bus == null) {
        bus = BusFactory.getDefaultBus(false);
      }
      if (bus == null) {
        throw new Fault(new org.apache.cxf.common.i18n.Message("BUS_NOT_FOUND", BUNDLE));
      }
    }

    ServerRegistry registry = bus.getExtension(ServerRegistry.class);

    if (registry == null) {
      throw new Fault(new org.apache.cxf.common.i18n.Message("SERVER_REGISTRY_NOT_FOUND", BUNDLE));
    }

    Exchange exchange = message.getExchange();
    Endpoint senderEndpoint = exchange.getEndpoint();

    if (senderEndpoint == null) {
      throw new Fault(new org.apache.cxf.common.i18n.Message("ENDPOINT_NOT_FOUND", BUNDLE));
    }

    BindingOperationInfo boi = exchange.getBindingOperationInfo();

    if (boi == null) {
      throw new Fault(new org.apache.cxf.common.i18n.Message("OPERATIONINFO_NOT_FOUND", BUNDLE));
    }

    Server srv = isColocated(registry.getServers(), senderEndpoint, boi);

    if (srv != null) {
      if (LOG.isLoggable(Level.FINE)) {
        LOG.fine("Operation:" + boi.getName() + " dispatched as colocated call.");
      }

      InterceptorChain outChain = message.getInterceptorChain();
      outChain.abort();
      exchange.put(Bus.class, bus);
      message.put(COLOCATED, Boolean.TRUE);
      message.put(Message.WSDL_OPERATION, boi.getName());
      message.put(Message.WSDL_INTERFACE, boi.getBinding().getInterface().getName());
      invokeColocObserver(message, srv.getEndpoint());
      if (!exchange.isOneWay()) {
        invokeInboundChain(exchange, senderEndpoint);
      }
    } else {
      if (LOG.isLoggable(Level.FINE)) {
        LOG.fine("Operation:" + boi.getName() + " dispatched as remote call.");
      }

      message.put(COLOCATED, Boolean.FALSE);
    }
  }

  protected void invokeColocObserver(Message outMsg, Endpoint inboundEndpoint) {
    if (colocObserver == null) {
      colocObserver = new ColocMessageObserver(inboundEndpoint, bus);
    }
    if (LOG.isLoggable(Level.FINE)) {
      LOG.fine("Invoke on Coloc Observer.");
    }

    colocObserver.onMessage(outMsg);
  }

  protected void invokeInboundChain(Exchange ex, Endpoint ep) {
    Message m = getInBoundMessage(ex);
    Message inMsg = ep.getBinding().createMessage();
    MessageImpl.copyContent(m, inMsg);

    // Copy Response Context to Client inBound Message
    // TODO a Context Filter Strategy required.
    inMsg.putAll(m);

    inMsg.put(Message.REQUESTOR_ROLE, Boolean.TRUE);
    inMsg.put(Message.INBOUND_MESSAGE, Boolean.TRUE);
    inMsg.setExchange(ex);

    Exception exc = inMsg.getContent(Exception.class);
    if (exc != null) {
      ex.setInFaultMessage(inMsg);
      ColocInFaultObserver observer = new ColocInFaultObserver(bus);
      observer.onMessage(inMsg);
    } else {
      // Handle Response
      ex.setInMessage(inMsg);
      PhaseManager pm = bus.getExtension(PhaseManager.class);
      SortedSet<Phase> phases = new TreeSet<Phase>(pm.getInPhases());
      ColocUtil.setPhases(phases, Phase.USER_LOGICAL, Phase.PRE_INVOKE);

      InterceptorChain chain = ColocUtil.getInInterceptorChain(ex, phases);
      inMsg.setInterceptorChain(chain);
      chain.doIntercept(inMsg);
    }
    ex.put(ClientImpl.FINISHED, Boolean.TRUE);
  }

  protected Message getInBoundMessage(Exchange ex) {
    return (ex.getInFaultMessage() != null) ? ex.getInFaultMessage() : ex.getInMessage();
  }

  protected void setMessageObserver(MessageObserver observer) {
    colocObserver = observer;
  }

  protected Server isColocated(List<Server> servers, Endpoint endpoint, BindingOperationInfo boi) {
    if (servers != null) {
      Service senderService = endpoint.getService();
      EndpointInfo senderEI = endpoint.getEndpointInfo();
      for (Server s : servers) {
        Endpoint receiverEndpoint = s.getEndpoint();
        Service receiverService = receiverEndpoint.getService();
        EndpointInfo receiverEI = receiverEndpoint.getEndpointInfo();
        if (receiverService.getName().equals(senderService.getName())
            && receiverEI.getName().equals(senderEI.getName())) {
          // Check For Operation Match.
          BindingOperationInfo receiverOI = receiverEI.getBinding().getOperation(boi.getName());
          if (receiverOI != null && isCompatibleOperationInfo(boi, receiverOI)) {
            return s;
          }
        }
      }
    }

    return null;
  }

  protected boolean isSameOperationInfo(
      BindingOperationInfo sender, BindingOperationInfo receiver) {
    return ColocUtil.isSameOperationInfo(sender.getOperationInfo(), receiver.getOperationInfo());
  }

  protected boolean isCompatibleOperationInfo(
      BindingOperationInfo sender, BindingOperationInfo receiver) {
    return ColocUtil.isCompatibleOperationInfo(
        sender.getOperationInfo(), receiver.getOperationInfo());
  }

  public void setExchangeProperties(Exchange exchange, Endpoint ep) {
    exchange.put(Endpoint.class, ep);
    exchange.put(Service.class, ep.getService());
    exchange.put(Binding.class, ep.getBinding());
    exchange.put(Bus.class, bus == null ? BusFactory.getDefaultBus(false) : bus);
  }
}
public abstract class AbstractConfigurableProvider {
  protected static final ResourceBundle BUNDLE = BundleUtils.getBundle(AbstractJAXBProvider.class);
  protected static final Logger LOG = LogUtils.getL7dLogger(AbstractJAXBProvider.class);

  private List<String> consumeMediaTypes;
  private List<String> produceMediaTypes;
  private boolean enableBuffering;
  private boolean enableStreaming;
  private Bus bus;

  /**
   * Sets the Bus
   *
   * @param b
   */
  public void setBus(Bus b) {
    if (bus != null) {
      bus = b;
    }
  }

  /**
   * Gets the Bus. Providers may use the bus to resolve resource references. Example:
   * ResourceUtils.getResourceStream(reference, this.getBus())
   *
   * @return
   */
  public Bus getBus() {
    return bus != null ? bus : BusFactory.getThreadDefaultBus();
  }

  /**
   * Sets custom Consumes media types; can be used to override static {@link Consumes} annotation
   * value set on the provider.
   *
   * @param types the media types
   */
  public void setConsumeMediaTypes(List<String> types) {
    consumeMediaTypes = types;
  }

  /**
   * Gets the custom Consumes media types
   *
   * @return media types
   */
  public List<String> getConsumeMediaTypes() {
    return consumeMediaTypes;
  }

  /**
   * Sets custom Produces media types; can be used to override static {@link Produces} annotation
   * value set on the provider.
   *
   * @param types the media types
   */
  public void setProduceMediaTypes(List<String> types) {
    produceMediaTypes = types;
  }

  /**
   * Gets the custom Produces media types
   *
   * @return media types
   */
  public List<String> getProduceMediaTypes() {
    return produceMediaTypes;
  }

  /**
   * Enables the buffering mode. If set to true then the runtime will ensure that the provider
   * writes to a cached stream.
   *
   * <p>For example, the JAXB marshalling process may fail after the initial XML tags have already
   * been written out to the HTTP output stream. Enabling the buffering ensures no incomplete
   * payloads are sent back to clients in case of marshalling errors at the cost of the initial
   * buffering - which might be negligible for small payloads.
   *
   * @param enableBuf the value of the buffering mode, false is default.
   */
  public void setEnableBuffering(boolean enableBuf) {
    enableBuffering = enableBuf;
  }

  /**
   * Gets the value of the buffering mode
   *
   * @return true if the buffering is enabled
   */
  public boolean getEnableBuffering() {
    return enableBuffering;
  }

  /**
   * Enables the support for streaming. XML-aware providers which prefer writing to Stax
   * XMLStreamWriter can set this value to true. Additionally, if the streaming and the buffering
   * modes are enabled, the runtime will ensure the XMLStreamWriter events are cached properly.
   *
   * @param enableStream the value of the streaming mode, false is default.
   */
  public void setEnableStreaming(boolean enableStream) {
    enableStreaming = enableStream;
  }

  /**
   * Gets the value of the streaming mode
   *
   * @return true if the streaming is enabled
   */
  public boolean getEnableStreaming() {
    return enableStreaming;
  }

  /**
   * Gives providers a chance to introspect the JAX-RS model classes. For example, the JAXB provider
   * may use the model classes to create a single composite JAXBContext supporting all the
   * JAXB-annotated root resource classes/types.
   *
   * @param resources
   */
  public void init(List<ClassResourceInfo> resources) {
    // complete
  }

  protected boolean isPayloadEmpty(HttpHeaders headers) {
    if (headers != null) {
      return isPayloadEmpty(headers.getRequestHeaders());
    } else {
      return false;
    }
  }

  protected boolean isPayloadEmpty(MultivaluedMap<String, String> headers) {
    return HttpUtils.isPayloadEmpty(headers);
  }

  protected void reportEmptyContentLength() throws NoContentException {
    String message = new org.apache.cxf.common.i18n.Message("EMPTY_BODY", BUNDLE).toString();
    LOG.warning(message);
    throw new NoContentException(message);
  }
}
Example #20
0
public class WSDLRefValidator extends AbstractDefinitionValidator {
  protected static final Logger LOG = LogUtils.getL7dLogger(WSDLRefValidator.class);
  protected List<XNode> vNodes = new ArrayList<XNode>();

  private Set<QName> portTypeRefNames = new HashSet<QName>();
  private Set<QName> messageRefNames = new HashSet<QName>();
  private Map<QName, Service> services = new HashMap<QName, Service>();

  private ValidationResult vResults = new ValidationResult();

  private Definition definition;
  private Document baseDoc;

  private List<Definition> importedDefinitions;
  private SchemaCollection schemaCollection = new SchemaCollection();

  private boolean suppressWarnings;

  public WSDLRefValidator(Definition wsdl, Document doc) {
    this(wsdl, doc, BusFactory.getThreadDefaultBus());
  }

  public WSDLRefValidator(Definition wsdl, Document doc, Bus bus) {
    this.definition = wsdl;
    baseDoc = doc;
    importedDefinitions = new ArrayList<Definition>();
    parseImports(wsdl);
    processSchemas(bus);
  }

  private void getSchemas(Bus bus) {
    Map<String, Element> schemaList = new HashMap<String, Element>();
    SchemaUtil schemaUtil = new SchemaUtil(bus, schemaList);
    List<SchemaInfo> si = new ArrayList<SchemaInfo>();
    schemaUtil.getSchemas(definition, schemaCollection, si);
    ServiceSchemaInfo ssi = new ServiceSchemaInfo();
    ssi.setSchemaCollection(schemaCollection);
    ssi.setSchemaInfoList(si);
    ssi.setSchemaElementList(schemaList);
    bus.getExtension(WSDLManager.class).putSchemasForDefinition(definition, ssi);
  }

  private void processSchemas(Bus bus) {
    try {
      ServiceSchemaInfo info =
          bus.getExtension(WSDLManager.class).getSchemasForDefinition(definition);
      if (info == null) {
        getSchemas(bus);
      } else {
        schemaCollection = info.getSchemaCollection();
      }
      checkTargetNamespace(this.definition.getTargetNamespace());
    } catch (Exception ex) {
      throw new ToolException(ex);
    }
  }

  private Collection<Import> getImports(final Definition wsdlDef) {
    Collection<Import> importList = new ArrayList<Import>();
    Map<?, ?> imports = wsdlDef.getImports();
    for (Map.Entry<?, ?> entry : imports.entrySet()) {
      List<Import> lst = CastUtils.cast((List<?>) entry.getValue());
      importList.addAll(lst);
    }

    return importList;
  }

  private void parseImports(Definition def) {
    for (Import impt : getImports(def)) {
      if (!importedDefinitions.contains(impt.getDefinition())) {
        importedDefinitions.add(impt.getDefinition());
        parseImports(impt.getDefinition());
      }
    }
  }

  private void checkTargetNamespace(String path) {
    try {
      if (new URL(path).getPath().indexOf(":") != -1) {
        throw new ToolException(": is not a valid char in the targetNamespace");
      }
    } catch (MalformedURLException e) {
      // do nothing
    }
  }

  public void setSuppressWarnings(boolean s) {
    this.suppressWarnings = s;
  }

  public ValidationResult getValidationResults() {
    return this.vResults;
  }

  private Document getWSDLDocument(final String wsdl) throws URISyntaxException {
    return new Stax2DOM().getDocument(wsdl);
  }

  private Document getWSDLDocument() throws Exception {
    if (baseDoc != null) {
      return baseDoc;
    }
    return getWSDLDocument(this.definition.getDocumentBaseURI());
  }

  private List<Document> getWSDLDocuments() {
    List<Document> docs = new ArrayList<Document>();
    try {
      docs.add(getWSDLDocument());

      if (null != importedDefinitions) {
        for (Definition d : importedDefinitions) {
          docs.add(getWSDLDocument(d.getDocumentBaseURI()));
        }
      }
    } catch (Exception e) {
      e.printStackTrace();
      // ignore
    }

    return docs;
  }

  private boolean isExist(List<Document> docs, XNode vNode) {
    for (Document doc : docs) {
      if (vNode.matches(doc)) {
        return true;
      }
    }
    return false;
  }

  private FailureLocation getFailureLocation(List<Document> docs, XNode fNode) {
    if (fNode == null) {
      return null;
    }

    XPathUtils xpather = new XPathUtils(fNode.getNSMap());
    for (Document doc : docs) {
      Node node = (Node) xpather.getValue(fNode.toString(), doc, XPathConstants.NODE);
      if (null != node) {
        try {
          return new FailureLocation((Location) node.getUserData("location"), doc.getDocumentURI());
        } catch (Exception ex) {
          // ignore, probably not DOM level 3
        }
      }
    }
    return null;
  }

  public boolean isValid() {
    try {
      loadServices();

      collectValidationPoints();

      List<Document> wsdlDocs = getWSDLDocuments();
      for (XNode vNode : vNodes) {
        if (!isExist(wsdlDocs, vNode)) {
          // System.out.println("Fail: " + vNode.getXPath());
          FailureLocation loc = getFailureLocation(wsdlDocs, vNode.getFailurePoint());

          vResults.addError(
              new Message(
                  "FAILED_AT_POINT",
                  LOG,
                  loc.getLocation().getLineNumber(),
                  loc.getLocation().getColumnNumber(),
                  loc.getDocumentURI(),
                  vNode.getPlainText()));
        }
      }
    } catch (Exception e) {
      this.vResults.addError(e.getMessage());
      return false;
    }
    return vResults.isSuccessful();
  }

  private void addServices(final Definition wsdlDef) {
    Iterator<QName> sNames = CastUtils.cast(wsdlDef.getServices().keySet().iterator());
    while (sNames.hasNext()) {
      QName sName = sNames.next();
      services.put(sName, definition.getService(sName));
    }
  }

  private void loadServices() {
    addServices(this.definition);
    if (importedDefinitions != null) {
      for (Definition d : importedDefinitions) {
        addServices(d);
      }
    }
  }

  private Map<QName, XNode> getBindings(Service service) {
    Map<QName, XNode> bindings = new HashMap<QName, XNode>();

    if (service.getPorts().values().size() == 0) {
      throw new ToolException(
          "Service " + service.getQName() + " does not contain any usable ports");
    }
    Collection<Port> ports = CastUtils.cast(service.getPorts().values());
    for (Port port : ports) {
      Binding binding = port.getBinding();
      bindings.put(binding.getQName(), getXNode(service, port));
      if (WSDLConstants.NS_WSDL11.equals(binding.getQName().getNamespaceURI())) {
        throw new ToolException(
            "Binding " + binding.getQName().getLocalPart() + " namespace set improperly.");
      }
    }

    return bindings;
  }

  private Map<QName, Operation> getOperations(PortType portType) {
    Map<QName, Operation> operations = new HashMap<QName, Operation>();
    Collection<Operation> pops = CastUtils.cast(portType.getOperations());
    for (Operation op : pops) {
      operations.put(new QName(portType.getQName().getNamespaceURI(), op.getName()), op);
    }
    return operations;
  }

  private XNode getXNode(Service service, Port port) {
    XNode vService = getXNode(service);

    XPort pNode = new XPort();
    pNode.setName(port.getName());
    pNode.setParentNode(vService);
    return pNode;
  }

  private XNode getXNode(Service service) {
    XDef xdef = new XDef();
    xdef.setTargetNamespace(service.getQName().getNamespaceURI());

    XService sNode = new XService();
    sNode.setName(service.getQName().getLocalPart());
    sNode.setParentNode(xdef);
    return sNode;
  }

  private XNode getXNode(Binding binding) {
    XDef xdef = new XDef();
    xdef.setTargetNamespace(binding.getQName().getNamespaceURI());

    XBinding bNode = new XBinding();
    bNode.setName(binding.getQName().getLocalPart());
    bNode.setParentNode(xdef);
    return bNode;
  }

  private XNode getXNode(PortType portType) {
    XDef xdef = new XDef();
    xdef.setTargetNamespace(portType.getQName().getNamespaceURI());

    XPortType pNode = new XPortType();
    pNode.setName(portType.getQName().getLocalPart());
    pNode.setParentNode(xdef);
    return pNode;
  }

  private XNode getOperationXNode(XNode pNode, String opName) {
    XOperation node = new XOperation();
    node.setName(opName);
    node.setParentNode(pNode);
    return node;
  }

  private XNode getInputXNode(XNode opVNode, String name) {
    XInput oNode = new XInput();
    oNode.setName(name);
    oNode.setParentNode(opVNode);

    if (name != null && name.equals(opVNode.getAttributeValue() + "Request")) {
      oNode.setDefaultAttributeValue(true);
    }
    return oNode;
  }

  private XNode getOutputXNode(XNode opVNode, String name) {
    XOutput oNode = new XOutput();
    oNode.setName(name);
    oNode.setParentNode(opVNode);
    if (name != null && name.equals(opVNode.getAttributeValue() + "Response")) {
      oNode.setDefaultAttributeValue(true);
    }
    return oNode;
  }

  private XNode getFaultXNode(XNode opVNode, String name) {
    XFault oNode = new XFault();
    oNode.setName(name);
    oNode.setParentNode(opVNode);
    return oNode;
  }

  private XNode getXNode(javax.wsdl.Message msg) {
    XDef xdef = new XDef();
    xdef.setTargetNamespace(msg.getQName().getNamespaceURI());

    XMessage mNode = new XMessage();
    mNode.setName(msg.getQName().getLocalPart());
    mNode.setParentNode(xdef);
    return mNode;
  }

  private void addWarning(String warningMsg) {
    if (suppressWarnings) {
      return;
    }
    vResults.addWarning(warningMsg);
  }

  private void collectValidationPoints() throws Exception {
    if (services.size() == 0) {
      LOG.log(
          Level.WARNING,
          "WSDL document "
              + this.definition.getDocumentBaseURI()
              + " does not define any services");
      // addWarning("WSDL document does not define any services");
      Collection<QName> ports = CastUtils.cast(this.definition.getAllPortTypes().keySet());
      portTypeRefNames.addAll(ports);
    } else {
      collectValidationPointsForBindings();
    }

    collectValidationPointsForPortTypes();
    collectValidationPointsForMessages();
  }

  private void collectValidationPointsForBindings() throws Exception {
    Map<QName, XNode> vBindingNodes = new HashMap<QName, XNode>();
    for (Service service : services.values()) {
      vBindingNodes.putAll(getBindings(service));
    }

    for (Map.Entry<QName, XNode> entry : vBindingNodes.entrySet()) {
      QName bName = entry.getKey();
      Binding binding = this.definition.getBinding(bName);
      if (binding == null) {
        LOG.log(
            Level.SEVERE,
            bName.toString()
                + " is not correct, please check that the correct namespace is being used");
        throw new Exception(
            bName.toString()
                + " is not correct, please check that the correct namespace is being used");
      }
      XNode vBindingNode = getXNode(binding);
      vBindingNode.setFailurePoint(entry.getValue());
      vNodes.add(vBindingNode);

      if (binding.getPortType() == null) {
        continue;
      }
      portTypeRefNames.add(binding.getPortType().getQName());

      XNode vPortTypeNode = getXNode(binding.getPortType());
      vPortTypeNode.setFailurePoint(vBindingNode);
      vNodes.add(vPortTypeNode);
      Collection<BindingOperation> bops = CastUtils.cast(binding.getBindingOperations());
      for (BindingOperation bop : bops) {
        XNode vOpNode = getOperationXNode(vPortTypeNode, bop.getName());
        XNode vBopNode = getOperationXNode(vBindingNode, bop.getName());
        vOpNode.setFailurePoint(vBopNode);
        vNodes.add(vOpNode);
        if (bop.getBindingInput() != null) {
          String inName = bop.getBindingInput().getName();
          if (!StringUtils.isEmpty(inName)) {
            XNode vInputNode = getInputXNode(vOpNode, inName);
            vInputNode.setFailurePoint(getInputXNode(vBopNode, inName));
            vNodes.add(vInputNode);
          }
        }
        if (bop.getBindingOutput() != null) {
          String outName = bop.getBindingOutput().getName();
          if (!StringUtils.isEmpty(outName)) {
            XNode vOutputNode = getOutputXNode(vOpNode, outName);
            vOutputNode.setFailurePoint(getOutputXNode(vBopNode, outName));
            vNodes.add(vOutputNode);
          }
        }
        for (Iterator<?> iter1 = bop.getBindingFaults().keySet().iterator(); iter1.hasNext(); ) {
          String faultName = (String) iter1.next();
          XNode vFaultNode = getFaultXNode(vOpNode, faultName);
          vFaultNode.setFailurePoint(getFaultXNode(vBopNode, faultName));
          vNodes.add(vFaultNode);
        }
      }
    }
  }

  private javax.wsdl.Message getMessage(QName msgName) {
    javax.wsdl.Message message = this.definition.getMessage(msgName);
    if (message == null) {
      for (Definition d : importedDefinitions) {
        message = d.getMessage(msgName);
        if (message != null) {
          break;
        }
      }
    }
    return message;
  }

  private void collectValidationPointsForMessages() {
    for (QName msgName : messageRefNames) {
      javax.wsdl.Message message = getMessage(msgName);
      for (Iterator<?> iter = message.getParts().values().iterator(); iter.hasNext(); ) {
        Part part = (Part) iter.next();
        QName elementName = part.getElementName();
        QName typeName = part.getTypeName();

        if (elementName == null && typeName == null) {
          vResults.addError(new Message("PART_NO_TYPES", LOG));
          continue;
        }

        if (elementName != null && typeName != null) {
          vResults.addError(new Message("PART_NOT_UNIQUE", LOG));
          continue;
        }

        if (elementName != null && typeName == null) {
          boolean valid =
              validatePartType(elementName.getNamespaceURI(), elementName.getLocalPart(), true);
          if (!valid) {
            vResults.addError(
                new Message(
                    "TYPE_REF_NOT_FOUND", LOG, message.getQName(), part.getName(), elementName));
          }
        }
        if (typeName != null && elementName == null) {
          boolean valid =
              validatePartType(typeName.getNamespaceURI(), typeName.getLocalPart(), false);
          if (!valid) {
            vResults.addError(
                new Message(
                    "TYPE_REF_NOT_FOUND", LOG, message.getQName(), part.getName(), typeName));
          }
        }
      }
    }
  }

  private PortType getPortType(QName ptName) {
    PortType portType = this.definition.getPortType(ptName);
    if (portType == null) {
      for (Definition d : importedDefinitions) {
        portType = d.getPortType(ptName);
        if (portType != null) {
          break;
        }
      }
    }
    return portType;
  }

  private void collectValidationPointsForPortTypes() {
    for (QName ptName : portTypeRefNames) {
      PortType portType = getPortType(ptName);
      if (portType == null) {
        vResults.addError(new Message("NO_PORTTYPE", LOG, ptName));
        continue;
      }

      XNode vPortTypeNode = getXNode(portType);
      for (Operation operation : getOperations(portType).values()) {
        XNode vOperationNode = getOperationXNode(vPortTypeNode, operation.getName());
        if (operation.getInput() == null) {
          vResults.addError(
              new Message("WRONG_MEP", LOG, operation.getName(), portType.getQName()));
          continue;
        }
        javax.wsdl.Message inMsg = operation.getInput().getMessage();
        if (inMsg == null) {
          addWarning(
              "Operation "
                  + operation.getName()
                  + " in PortType: "
                  + portType.getQName()
                  + " has no input message");
        } else {
          XNode vInMsgNode = getXNode(inMsg);
          vInMsgNode.setFailurePoint(getInputXNode(vOperationNode, operation.getInput().getName()));
          vNodes.add(vInMsgNode);
          messageRefNames.add(inMsg.getQName());
        }

        if (operation.getOutput() != null) {
          javax.wsdl.Message outMsg = operation.getOutput().getMessage();

          if (outMsg == null) {
            addWarning(
                "Operation "
                    + operation.getName()
                    + " in PortType: "
                    + portType.getQName()
                    + " has no output message");
          } else {
            XNode vOutMsgNode = getXNode(outMsg);
            vOutMsgNode.setFailurePoint(
                getOutputXNode(vOperationNode, operation.getOutput().getName()));
            vNodes.add(vOutMsgNode);
            messageRefNames.add(outMsg.getQName());
          }
        }
        for (Iterator<?> iter = operation.getFaults().values().iterator(); iter.hasNext(); ) {
          Fault fault = (Fault) iter.next();
          javax.wsdl.Message faultMsg = fault.getMessage();
          XNode vFaultMsgNode = getXNode(faultMsg);
          vFaultMsgNode.setFailurePoint(getFaultXNode(vOperationNode, fault.getName()));
          vNodes.add(vFaultMsgNode);
          messageRefNames.add(faultMsg.getQName());
        }
      }
    }
  }

  private boolean validatePartType(String namespace, String name, boolean isElement) {

    boolean partvalid = false;

    if (namespace.equals(WSDLConstants.NS_SCHEMA_XSD)) {
      if (isElement) {
        XmlSchemaElement schemaEle =
            schemaCollection.getElementByQName(new QName(WSDLConstants.NS_SCHEMA_XSD, name));
        partvalid = schemaEle != null ? true : false;
      } else {
        if ("anyType".equals(name)) {
          return true;
        }
        XmlSchemaType schemaType =
            schemaCollection.getTypeByQName(new QName(WSDLConstants.NS_SCHEMA_XSD, name));

        partvalid = schemaType != null ? true : false;
      }

    } else {
      if (isElement) {
        if (schemaCollection.getElementByQName(new QName(namespace, name)) != null) {
          partvalid = true;
        }
      } else {
        if (schemaCollection.getTypeByQName(new QName(namespace, name)) != null) {
          partvalid = true;
        }
      }
    }
    return partvalid;
  }

  public String getErrorMessage() {
    return vResults.toString();
  }

  public Definition getDefinition() {
    return this.definition;
  }
}
public class NMRWrapperInInterceptor extends AbstractInDatabindingInterceptor {

  private static final Logger LOG = LogUtils.getL7dLogger(NMRWrapperInInterceptor.class);

  private static final ResourceBundle BUNDLE = LOG.getResourceBundle();

  public NMRWrapperInInterceptor() {
    super(Phase.UNMARSHAL);
  }

  public void handleMessage(Message message) throws Fault {
    if (isGET(message)) {
      LOG.info("JbiMessageInInterceptor skipped in HTTP GET method");
      return;
    }
    XMLStreamReader xsr = message.getContent(XMLStreamReader.class);

    DepthXMLStreamReader reader = new DepthXMLStreamReader(xsr);

    Endpoint ep = message.getExchange().get(Endpoint.class);
    BindingInfo binding = ep.getEndpointInfo().getBinding();
    if (!(binding instanceof NMRBindingInfo)) {
      throw new IllegalStateException(
          new org.apache.cxf.common.i18n.Message("NEED_JBIBINDING", BUNDLE).toString());
    }

    if (!StaxUtils.toNextElement(reader)) {
      throw new Fault(new org.apache.cxf.common.i18n.Message("NO_OPERATION_ELEMENT", BUNDLE));
    }

    Exchange ex = message.getExchange();
    QName startQName = reader.getName();

    // handling jbi fault message
    if (startQName.getLocalPart().equals(NMRFault.NMR_FAULT_ROOT)) {
      message.getInterceptorChain().abort();

      if (ep.getInFaultObserver() != null) {
        ep.getInFaultObserver().onMessage(message);
        return;
      }
    }

    // handling xml normal inbound message
    if (!startQName.equals(NMRConstants.JBI_WRAPPER_MESSAGE)) {
      throw new Fault(new org.apache.cxf.common.i18n.Message("NO_JBI_MESSAGE_ELEMENT", BUNDLE));
    }

    try {
      BindingOperationInfo bop = ex.get(BindingOperationInfo.class);
      DataReader<XMLStreamReader> dr = getDataReader(message);
      List<Object> parameters = new ArrayList<Object>();
      reader.next();
      BindingMessageInfo messageInfo = !isRequestor(message) ? bop.getInput() : bop.getOutput();
      message.put(MessageInfo.class, messageInfo.getMessageInfo());
      for (MessagePartInfo part : messageInfo.getMessageParts()) {
        if (!StaxUtils.skipToStartOfElement(reader)) {
          throw new Fault(new org.apache.cxf.common.i18n.Message("NOT_ENOUGH_PARTS", BUNDLE));
        }
        startQName = reader.getName();
        if (!startQName.equals(NMRConstants.JBI_WRAPPER_PART)) {
          throw new Fault(new org.apache.cxf.common.i18n.Message("NO_JBI_PART_ELEMENT", BUNDLE));
        }
        if (part.isElement()) {
          reader.next();
          if (!StaxUtils.toNextElement(reader)) {
            throw new Fault(
                new org.apache.cxf.common.i18n.Message("EXPECTED_ELEMENT_IN_PART", BUNDLE));
          }
        }
        parameters.add(dr.read(part, reader));
        // skip end element
        if (part.isElement()) {
          reader.next();
        }
      }
      int ev = reader.getEventType();
      while (ev != XMLStreamConstants.END_ELEMENT
          && ev != XMLStreamConstants.START_ELEMENT
          && ev != XMLStreamConstants.END_DOCUMENT) {
        ev = reader.next();
      }
      message.setContent(List.class, parameters);
    } catch (XMLStreamException e) {
      throw new Fault(new org.apache.cxf.common.i18n.Message("STAX_READ_EXC", BUNDLE), e);
    }
  }
}
Example #22
0
public final class ResourceUtils {

  private static final Logger LOG = LogUtils.getL7dLogger(ResourceUtils.class);
  private static final ResourceBundle BUNDLE = BundleUtils.getBundle(ResourceUtils.class);
  private static final String CLASSPATH_PREFIX = "classpath:";
  private static final Set<String> SERVER_PROVIDER_CLASS_NAMES;

  static {
    SERVER_PROVIDER_CLASS_NAMES = new HashSet<String>();
    SERVER_PROVIDER_CLASS_NAMES.add("javax.ws.rs.ext.MessageBodyWriter");
    SERVER_PROVIDER_CLASS_NAMES.add("javax.ws.rs.ext.MessageBodyReader");
    SERVER_PROVIDER_CLASS_NAMES.add("javax.ws.rs.ext.ExceptionMapper");
    SERVER_PROVIDER_CLASS_NAMES.add("javax.ws.rs.ext.ContextResolver");
    SERVER_PROVIDER_CLASS_NAMES.add("javax.ws.rs.ext.ReaderInterceptor");
    SERVER_PROVIDER_CLASS_NAMES.add("javax.ws.rs.ext.WriterInterceptor");
    SERVER_PROVIDER_CLASS_NAMES.add("javax.ws.rs.ext.ParamConverterProvider");
    SERVER_PROVIDER_CLASS_NAMES.add("javax.ws.rs.container.ContainerRequestFilter");
    SERVER_PROVIDER_CLASS_NAMES.add("javax.ws.rs.container.ContainerResponseFilter");
    SERVER_PROVIDER_CLASS_NAMES.add("javax.ws.rs.container.DynamicFeature");
    SERVER_PROVIDER_CLASS_NAMES.add("org.apache.cxf.jaxrs.ext.ContextResolver");
  }

  private ResourceUtils() {}

  public static Method findPostConstructMethod(Class<?> c) {
    return findPostConstructMethod(c, null);
  }

  public static Method findPostConstructMethod(Class<?> c, String name) {
    if (Object.class == c || null == c) {
      return null;
    }
    for (Method m : c.getDeclaredMethods()) {
      if (name != null) {
        if (m.getName().equals(name)) {
          return m;
        }
      } else if (m.getAnnotation(PostConstruct.class) != null) {
        return m;
      }
    }
    Method m = findPostConstructMethod(c.getSuperclass(), name);
    if (m != null) {
      return m;
    }
    for (Class<?> i : c.getInterfaces()) {
      m = findPostConstructMethod(i, name);
      if (m != null) {
        return m;
      }
    }
    return null;
  }

  public static Method findPreDestroyMethod(Class<?> c) {
    return findPreDestroyMethod(c, null);
  }

  public static Method findPreDestroyMethod(Class<?> c, String name) {
    if (Object.class == c || null == c) {
      return null;
    }
    for (Method m : c.getDeclaredMethods()) {
      if (name != null) {
        if (m.getName().equals(name)) {
          return m;
        }
      } else if (m.getAnnotation(PreDestroy.class) != null) {
        return m;
      }
    }
    Method m = findPreDestroyMethod(c.getSuperclass(), name);
    if (m != null) {
      return m;
    }
    for (Class<?> i : c.getInterfaces()) {
      m = findPreDestroyMethod(i, name);
      if (m != null) {
        return m;
      }
    }
    return null;
  }

  public static ClassResourceInfo createClassResourceInfo(
      Map<String, UserResource> resources,
      UserResource model,
      Class<?> defaultClass,
      boolean isRoot,
      boolean enableStatic,
      Bus bus) {
    final boolean isDefaultClass = defaultClass != null;
    Class<?> sClass = !isDefaultClass ? loadClass(model.getName()) : defaultClass;
    return createServiceClassResourceInfo(resources, model, sClass, isRoot, enableStatic, bus);
  }

  public static ClassResourceInfo createServiceClassResourceInfo(
      Map<String, UserResource> resources,
      UserResource model,
      Class<?> sClass,
      boolean isRoot,
      boolean enableStatic,
      Bus bus) {
    if (model == null) {
      throw new RuntimeException("Resource class " + sClass.getName() + " has no model info");
    }
    ClassResourceInfo cri =
        new ClassResourceInfo(
            sClass,
            sClass,
            isRoot,
            enableStatic,
            true,
            model.getConsumes(),
            model.getProduces(),
            bus);
    URITemplate t = URITemplate.createTemplate(model.getPath());
    cri.setURITemplate(t);

    MethodDispatcher md = new MethodDispatcher();
    Map<String, UserOperation> ops = model.getOperationsAsMap();

    Method defaultMethod = null;
    Map<String, Method> methodNames = new HashMap<String, Method>();
    for (Method m : cri.getServiceClass().getMethods()) {
      if (m.getAnnotation(DefaultMethod.class) != null) {
        // if needed we can also support multiple default methods
        defaultMethod = m;
      }
      methodNames.put(m.getName(), m);
    }

    for (Map.Entry<String, UserOperation> entry : ops.entrySet()) {
      UserOperation op = entry.getValue();
      Method actualMethod = methodNames.get(op.getName());
      if (actualMethod == null) {
        actualMethod = defaultMethod;
      }
      if (actualMethod == null) {
        continue;
      }
      OperationResourceInfo ori =
          new OperationResourceInfo(
              actualMethod,
              cri,
              URITemplate.createTemplate(op.getPath()),
              op.getVerb(),
              op.getConsumes(),
              op.getProduces(),
              op.getParameters(),
              op.isOneway());
      String rClassName = actualMethod.getReturnType().getName();
      if (op.getVerb() == null) {
        if (resources.containsKey(rClassName)) {
          ClassResourceInfo subCri =
              rClassName.equals(model.getName())
                  ? cri
                  : createServiceClassResourceInfo(
                      resources,
                      resources.get(rClassName),
                      actualMethod.getReturnType(),
                      false,
                      enableStatic,
                      bus);
          if (subCri != null) {
            cri.addSubClassResourceInfo(subCri);
            md.bind(ori, actualMethod);
          }
        }
      } else {
        md.bind(ori, actualMethod);
      }
    }

    cri.setMethodDispatcher(md);
    return checkMethodDispatcher(cri) ? cri : null;
  }

  public static ClassResourceInfo createClassResourceInfo(
      final Class<?> rClass, final Class<?> sClass, boolean root, boolean enableStatic) {
    return createClassResourceInfo(
        rClass, sClass, root, enableStatic, BusFactory.getThreadDefaultBus());
  }

  public static ClassResourceInfo createClassResourceInfo(
      final Class<?> rClass, final Class<?> sClass, boolean root, boolean enableStatic, Bus bus) {
    return createClassResourceInfo(rClass, sClass, null, root, enableStatic, bus);
  }

  public static ClassResourceInfo createClassResourceInfo(
      final Class<?> rClass,
      final Class<?> sClass,
      ClassResourceInfo parent,
      boolean root,
      boolean enableStatic,
      Bus bus) {
    ClassResourceInfo cri = new ClassResourceInfo(rClass, sClass, root, enableStatic, bus);
    cri.setParent(parent);

    if (root) {
      URITemplate t = URITemplate.createTemplate(cri.getPath());
      cri.setURITemplate(t);
    }

    evaluateResourceClass(cri, enableStatic);
    return checkMethodDispatcher(cri) ? cri : null;
  }

  private static void evaluateResourceClass(ClassResourceInfo cri, boolean enableStatic) {
    MethodDispatcher md = new MethodDispatcher();
    Class<?> serviceClass = cri.getServiceClass();

    boolean isFineLevelLoggable = LOG.isLoggable(Level.FINE);
    for (Method m : serviceClass.getMethods()) {

      Method annotatedMethod = AnnotationUtils.getAnnotatedMethod(serviceClass, m);

      String httpMethod = AnnotationUtils.getHttpMethodValue(annotatedMethod);
      Path path = AnnotationUtils.getMethodAnnotation(annotatedMethod, Path.class);

      if (httpMethod != null || path != null) {
        md.bind(createOperationInfo(m, annotatedMethod, cri, path, httpMethod), m);
        if (httpMethod == null) {
          // subresource locator
          Class<?> subClass = m.getReturnType();
          if (enableStatic) {
            ClassResourceInfo subCri = cri.findResource(subClass, subClass);
            if (subCri == null) {
              ClassResourceInfo ancestor = getAncestorWithSameServiceClass(cri, subClass);
              subCri =
                  ancestor != null
                      ? ancestor
                      : createClassResourceInfo(
                          subClass, subClass, cri, false, enableStatic, cri.getBus());
            }

            if (subCri != null) {
              cri.addSubClassResourceInfo(subCri);
            }
          }
        }
      } else if (isFineLevelLoggable) {
        LOG.fine(
            new org.apache.cxf.common.i18n.Message(
                    "NOT_RESOURCE_METHOD", BUNDLE, m.getDeclaringClass().getName(), m.getName())
                .toString());
      }
    }
    cri.setMethodDispatcher(md);
  }

  private static ClassResourceInfo getAncestorWithSameServiceClass(
      ClassResourceInfo parent, Class<?> subClass) {
    if (parent == null) {
      return null;
    }
    if (parent.getServiceClass() == subClass) {
      return parent;
    }
    return getAncestorWithSameServiceClass(parent.getParent(), subClass);
  }

  public static Constructor<?> findResourceConstructor(Class<?> resourceClass, boolean perRequest) {
    List<Constructor<?>> cs = new LinkedList<Constructor<?>>();
    for (Constructor<?> c : resourceClass.getConstructors()) {
      Class<?>[] params = c.getParameterTypes();
      Annotation[][] anns = c.getParameterAnnotations();
      boolean match = true;
      for (int i = 0; i < params.length; i++) {
        if (!perRequest) {
          if (AnnotationUtils.getAnnotation(anns[i], Context.class) == null) {
            match = false;
            break;
          }
        } else if (!AnnotationUtils.isValidParamAnnotations(anns[i])) {
          match = false;
          break;
        }
      }
      if (match) {
        cs.add(c);
      }
    }
    Collections.sort(
        cs,
        new Comparator<Constructor<?>>() {

          public int compare(Constructor<?> c1, Constructor<?> c2) {
            int p1 = c1.getParameterTypes().length;
            int p2 = c2.getParameterTypes().length;
            return p1 > p2 ? -1 : p1 < p2 ? 1 : 0;
          }
        });
    return cs.size() == 0 ? null : cs.get(0);
  }

  public static List<Parameter> getParameters(Method resourceMethod) {
    Annotation[][] paramAnns = resourceMethod.getParameterAnnotations();
    if (paramAnns.length == 0) {
      return CastUtils.cast(Collections.emptyList(), Parameter.class);
    }
    Class<?>[] types = resourceMethod.getParameterTypes();
    List<Parameter> params = new ArrayList<Parameter>(paramAnns.length);
    for (int i = 0; i < paramAnns.length; i++) {
      Parameter p = getParameter(i, paramAnns[i], types[i]);
      params.add(p);
    }
    return params;
  }

  // CHECKSTYLE:OFF
  public static Parameter getParameter(int index, Annotation[] anns, Class<?> type) {

    Context ctx = AnnotationUtils.getAnnotation(anns, Context.class);
    if (ctx != null) {
      return new Parameter(ParameterType.CONTEXT, index, null);
    }

    boolean isEncoded = AnnotationUtils.getAnnotation(anns, Encoded.class) != null;

    BeanParam bp = AnnotationUtils.getAnnotation(anns, BeanParam.class);
    if (bp != null) {
      return new Parameter(ParameterType.BEAN, index, null, isEncoded, null);
    }

    String dValue = AnnotationUtils.getDefaultParameterValue(anns);

    PathParam a = AnnotationUtils.getAnnotation(anns, PathParam.class);
    if (a != null) {
      return new Parameter(ParameterType.PATH, index, a.value(), isEncoded, dValue);
    }
    QueryParam q = AnnotationUtils.getAnnotation(anns, QueryParam.class);
    if (q != null) {
      return new Parameter(ParameterType.QUERY, index, q.value(), isEncoded, dValue);
    }
    MatrixParam m = AnnotationUtils.getAnnotation(anns, MatrixParam.class);
    if (m != null) {
      return new Parameter(ParameterType.MATRIX, index, m.value(), isEncoded, dValue);
    }

    FormParam f = AnnotationUtils.getAnnotation(anns, FormParam.class);
    if (f != null) {
      return new Parameter(ParameterType.FORM, index, f.value(), isEncoded, dValue);
    }

    HeaderParam h = AnnotationUtils.getAnnotation(anns, HeaderParam.class);
    if (h != null) {
      return new Parameter(ParameterType.HEADER, index, h.value(), isEncoded, dValue);
    }

    CookieParam c = AnnotationUtils.getAnnotation(anns, CookieParam.class);
    if (c != null) {
      return new Parameter(ParameterType.COOKIE, index, c.value(), isEncoded, dValue);
    }

    return new Parameter(ParameterType.REQUEST_BODY, index, null);
  }
  // CHECKSTYLE:ON

  private static OperationResourceInfo createOperationInfo(
      Method m, Method annotatedMethod, ClassResourceInfo cri, Path path, String httpMethod) {
    OperationResourceInfo ori = new OperationResourceInfo(m, annotatedMethod, cri);
    URITemplate t = URITemplate.createTemplate(path);
    ori.setURITemplate(t);
    ori.setHttpMethod(httpMethod);
    return ori;
  }

  private static boolean checkMethodDispatcher(ClassResourceInfo cr) {
    if (cr.getMethodDispatcher().getOperationResourceInfos().isEmpty()) {
      LOG.warning(
          new org.apache.cxf.common.i18n.Message(
                  "NO_RESOURCE_OP_EXC", BUNDLE, cr.getServiceClass().getName())
              .toString());
      return false;
    }
    return true;
  }

  private static Class<?> loadClass(String cName) {
    try {
      return ClassLoaderUtils.loadClass(cName.trim(), ResourceUtils.class);
    } catch (ClassNotFoundException ex) {
      throw new RuntimeException("No class " + cName.trim() + " can be found", ex);
    }
  }

  public static List<UserResource> getUserResources(String loc, Bus bus) {
    try {
      InputStream is = ResourceUtils.getResourceStream(loc, bus);
      if (is == null) {
        return null;
      }
      return getUserResources(is);
    } catch (Exception ex) {
      LOG.warning("Problem with processing a user model at " + loc);
    }

    return null;
  }

  public static InputStream getResourceStream(String loc, Bus bus) throws Exception {
    URL url = getResourceURL(loc, bus);
    return url == null ? null : url.openStream();
  }

  public static URL getResourceURL(String loc, Bus bus) throws Exception {
    URL url = null;
    if (loc.startsWith(CLASSPATH_PREFIX)) {
      String path = loc.substring(CLASSPATH_PREFIX.length());
      url = ResourceUtils.getClasspathResourceURL(path, ResourceUtils.class, bus);
    } else {
      try {
        url = new URL(loc);
      } catch (Exception ex) {
        // it can be either a classpath or file resource without a scheme
        url = ResourceUtils.getClasspathResourceURL(loc, ResourceUtils.class, bus);
        if (url == null) {
          File file = new File(loc);
          if (file.exists()) {
            url = file.toURI().toURL();
          }
        }
      }
    }
    if (url == null) {
      LOG.warning("No resource " + loc + " is available");
    }
    return url;
  }

  public static InputStream getClasspathResourceStream(
      String path, Class<?> callingClass, Bus bus) {
    InputStream is = ClassLoaderUtils.getResourceAsStream(path, callingClass);
    return is == null ? getResource(path, InputStream.class, bus) : is;
  }

  public static URL getClasspathResourceURL(String path, Class<?> callingClass, Bus bus) {
    URL url = ClassLoaderUtils.getResource(path, callingClass);
    return url == null ? getResource(path, URL.class, bus) : url;
  }

  public static <T> T getResource(String path, Class<T> resourceClass, Bus bus) {
    if (bus != null) {
      ResourceManager rm = bus.getExtension(ResourceManager.class);
      if (rm != null) {
        return rm.resolveResource(path, resourceClass);
      }
    }
    return null;
  }

  public static Properties loadProperties(String propertiesLocation, Bus bus) throws Exception {
    Properties props = new Properties();
    InputStream is = getResourceStream(propertiesLocation, bus);
    props.load(is);
    return props;
  }

  public static List<UserResource> getUserResources(String loc) {
    return getUserResources(loc, BusFactory.getThreadDefaultBus());
  }

  public static List<UserResource> getUserResources(InputStream is) throws Exception {
    Document doc = StaxUtils.read(new InputStreamReader(is, "UTF-8"));
    return getResourcesFromElement(doc.getDocumentElement());
  }

  public static List<UserResource> getResourcesFromElement(Element modelEl) {
    List<UserResource> resources = new ArrayList<UserResource>();
    List<Element> resourceEls =
        DOMUtils.findAllElementsByTagNameNS(modelEl, "http://cxf.apache.org/jaxrs", "resource");
    for (Element e : resourceEls) {
      resources.add(getResourceFromElement(e));
    }
    return resources;
  }

  public static ResourceTypes getAllRequestResponseTypes(
      List<ClassResourceInfo> cris, boolean jaxbOnly) {
    return getAllRequestResponseTypes(cris, jaxbOnly, null);
  }

  public static ResourceTypes getAllRequestResponseTypes(
      List<ClassResourceInfo> cris, boolean jaxbOnly, MessageBodyWriter<?> jaxbWriter) {
    ResourceTypes types = new ResourceTypes();
    for (ClassResourceInfo resource : cris) {
      getAllTypesForResource(resource, types, jaxbOnly, jaxbWriter);
    }
    return types;
  }

  public static Class<?> getActualJaxbType(Class<?> type, Method resourceMethod, boolean inbound) {
    ElementClass element = resourceMethod.getAnnotation(ElementClass.class);
    if (element != null) {
      Class<?> cls = inbound ? element.request() : element.response();
      if (cls != Object.class) {
        return cls;
      }
    }
    return type;
  }

  private static void getAllTypesForResource(
      ClassResourceInfo resource,
      ResourceTypes types,
      boolean jaxbOnly,
      MessageBodyWriter<?> jaxbWriter) {
    for (OperationResourceInfo ori : resource.getMethodDispatcher().getOperationResourceInfos()) {
      Method method = ori.getMethodToInvoke();
      Class<?> realReturnType = method.getReturnType();
      Class<?> cls = realReturnType;
      if (cls == Response.class) {
        cls = getActualJaxbType(cls, method, false);
      }
      Type type = method.getGenericReturnType();
      if (jaxbOnly) {
        checkJaxbType(
            resource.getServiceClass(),
            cls,
            realReturnType == Response.class ? cls : type,
            types,
            method.getAnnotations(),
            jaxbWriter);
      } else {
        types.getAllTypes().put(cls, type);
      }

      for (Parameter pm : ori.getParameters()) {
        if (pm.getType() == ParameterType.REQUEST_BODY) {
          Class<?> inType = method.getParameterTypes()[pm.getIndex()];
          Type paramType = method.getGenericParameterTypes()[pm.getIndex()];
          if (jaxbOnly) {
            checkJaxbType(
                resource.getServiceClass(),
                inType,
                paramType,
                types,
                method.getParameterAnnotations()[pm.getIndex()],
                jaxbWriter);
          } else {
            types.getAllTypes().put(inType, paramType);
          }
        }
      }
    }

    for (ClassResourceInfo sub : resource.getSubResources()) {
      if (!isRecursiveSubResource(resource, sub)) {
        getAllTypesForResource(sub, types, jaxbOnly, jaxbWriter);
      }
    }
  }

  private static boolean isRecursiveSubResource(ClassResourceInfo parent, ClassResourceInfo sub) {
    if (parent == null) {
      return false;
    }
    if (parent == sub) {
      return true;
    }
    return isRecursiveSubResource(parent.getParent(), sub);
  }

  private static void checkJaxbType(
      Class<?> serviceClass,
      Class<?> type,
      Type genericType,
      ResourceTypes types,
      Annotation[] anns,
      MessageBodyWriter<?> jaxbWriter) {
    boolean isCollection = false;
    if (InjectionUtils.isSupportedCollectionOrArray(type)) {
      type = InjectionUtils.getActualType(genericType);
      isCollection = true;
    }
    if (type == Object.class && !(genericType instanceof Class)) {
      Type theType =
          InjectionUtils.processGenericTypeIfNeeded(serviceClass, Object.class, genericType);
      type = InjectionUtils.getActualType(theType);
    }
    if (type == null
        || InjectionUtils.isPrimitive(type)
        || JAXBElement.class.isAssignableFrom(type)
        || Response.class.isAssignableFrom(type)
        || type.isInterface()) {
      return;
    }

    MessageBodyWriter<?> writer = jaxbWriter;
    if (writer == null) {
      JAXBElementProvider<Object> defaultWriter = new JAXBElementProvider<Object>();
      defaultWriter.setMarshallAsJaxbElement(true);
      defaultWriter.setXmlTypeAsJaxbElementOnly(true);
      writer = defaultWriter;
    }
    if (writer.isWriteable(type, type, anns, MediaType.APPLICATION_XML_TYPE)) {
      types.getAllTypes().put(type, type);
      Class<?> genCls = InjectionUtils.getActualType(genericType);
      if (genCls != type
          && genCls != null
          && genCls != Object.class
          && !InjectionUtils.isSupportedCollectionOrArray(genCls)) {
        types.getAllTypes().put(genCls, genCls);
      }

      XMLName name = AnnotationUtils.getAnnotation(anns, XMLName.class);
      QName qname = name != null ? JAXRSUtils.convertStringToQName(name.value()) : null;
      if (isCollection) {
        types.getCollectionMap().put(type, qname);
      } else {
        types.getXmlNameMap().put(type, qname);
      }
    }
  }

  private static UserResource getResourceFromElement(Element e) {
    UserResource resource = new UserResource();
    resource.setName(e.getAttribute("name"));
    resource.setPath(e.getAttribute("path"));
    resource.setConsumes(e.getAttribute("consumes"));
    resource.setProduces(e.getAttribute("produces"));
    List<Element> operEls =
        DOMUtils.findAllElementsByTagNameNS(e, "http://cxf.apache.org/jaxrs", "operation");
    List<UserOperation> opers = new ArrayList<UserOperation>(operEls.size());
    for (Element operEl : operEls) {
      opers.add(getOperationFromElement(operEl));
    }
    resource.setOperations(opers);
    return resource;
  }

  private static UserOperation getOperationFromElement(Element e) {
    UserOperation op = new UserOperation();
    op.setName(e.getAttribute("name"));
    op.setVerb(e.getAttribute("verb"));
    op.setPath(e.getAttribute("path"));
    op.setOneway(Boolean.parseBoolean(e.getAttribute("oneway")));
    op.setConsumes(e.getAttribute("consumes"));
    op.setProduces(e.getAttribute("produces"));
    List<Element> paramEls =
        DOMUtils.findAllElementsByTagNameNS(e, "http://cxf.apache.org/jaxrs", "param");
    List<Parameter> params = new ArrayList<Parameter>(paramEls.size());
    for (int i = 0; i < paramEls.size(); i++) {
      Element paramEl = paramEls.get(i);
      Parameter p = new Parameter(paramEl.getAttribute("type"), i, paramEl.getAttribute("name"));
      p.setEncoded(Boolean.valueOf(paramEl.getAttribute("encoded")));
      p.setDefaultValue(paramEl.getAttribute("defaultValue"));
      String pClass = paramEl.getAttribute("class");
      if (!StringUtils.isEmpty(pClass)) {
        try {
          p.setJavaType(ClassLoaderUtils.loadClass(pClass, ResourceUtils.class));
        } catch (Exception ex) {
          throw new RuntimeException(ex);
        }
      }
      params.add(p);
    }
    op.setParameters(params);
    return op;
  }

  public static Object[] createConstructorArguments(
      Constructor<?> c, Message m, boolean perRequest) {
    return createConstructorArguments(c, m, perRequest, null);
  }

  public static Object[] createConstructorArguments(
      Constructor<?> c, Message m, boolean perRequest, Map<Class<?>, Object> contextValues) {
    Class<?>[] params = c.getParameterTypes();
    Annotation[][] anns = c.getParameterAnnotations();
    Type[] genericTypes = c.getGenericParameterTypes();
    @SuppressWarnings("unchecked")
    MultivaluedMap<String, String> templateValues =
        m == null ? null : (MultivaluedMap<String, String>) m.get(URITemplate.TEMPLATE_PARAMETERS);
    Object[] values = new Object[params.length];
    for (int i = 0; i < params.length; i++) {
      if (AnnotationUtils.getAnnotation(anns[i], Context.class) != null) {
        Object contextValue = contextValues != null ? contextValues.get(params[i]) : null;
        if (contextValue == null) {
          if (perRequest) {
            values[i] = JAXRSUtils.createContextValue(m, genericTypes[i], params[i]);
          } else {
            values[i] = InjectionUtils.createThreadLocalProxy(params[i]);
          }
        } else {
          values[i] = contextValue;
        }
      } else {
        // this branch won't execute for singletons given that the found constructor
        // is guaranteed to have only Context parameters, if any, for singletons
        Parameter p = ResourceUtils.getParameter(i, anns[i], params[i]);
        values[i] =
            JAXRSUtils.createHttpParameterValue(
                p, params[i], genericTypes[i], anns[i], m, templateValues, null);
      }
    }
    return values;
  }

  public static JAXRSServerFactoryBean createApplication(Application app, boolean ignoreAppPath) {
    return createApplication(app, ignoreAppPath, false);
  }

  @SuppressWarnings("unchecked")
  public static JAXRSServerFactoryBean createApplication(
      Application app, boolean ignoreAppPath, boolean staticSubresourceResolution) {

    Set<Object> singletons = app.getSingletons();
    verifySingletons(singletons);

    List<Class<?>> resourceClasses = new ArrayList<Class<?>>();
    List<Object> providers = new ArrayList<Object>();
    List<Feature> features = new ArrayList<Feature>();
    Map<Class<?>, ResourceProvider> map = new HashMap<Class<?>, ResourceProvider>();

    // Note, app.getClasses() returns a list of per-request classes
    // or singleton provider classes
    for (Class<?> cls : app.getClasses()) {
      if (isValidApplicationClass(cls, singletons)) {
        if (isValidProvider(cls)) {
          providers.add(createProviderInstance(cls));
        } else if (Feature.class.isAssignableFrom(cls)) {
          features.add(createFeatureInstance((Class<? extends Feature>) cls));
        } else {
          resourceClasses.add(cls);
          map.put(cls, new PerRequestResourceProvider(cls));
        }
      }
    }

    // we can get either a provider or resource class here
    for (Object o : singletons) {
      if (isValidProvider(o.getClass())) {
        providers.add(o);
      } else if (o instanceof Feature) {
        features.add((Feature) o);
      } else {
        resourceClasses.add(o.getClass());
        map.put(o.getClass(), new SingletonResourceProvider(o));
      }
    }

    JAXRSServerFactoryBean bean = new JAXRSServerFactoryBean();
    String address = "/";
    if (!ignoreAppPath) {
      ApplicationPath appPath = app.getClass().getAnnotation(ApplicationPath.class);
      if (appPath != null) {
        address = appPath.value();
      }
    }
    if (!address.startsWith("/")) {
      address = "/" + address;
    }
    bean.setAddress(address);
    bean.setStaticSubresourceResolution(staticSubresourceResolution);
    bean.setResourceClasses(resourceClasses);
    bean.setProviders(providers);
    bean.setFeatures(features);
    for (Map.Entry<Class<?>, ResourceProvider> entry : map.entrySet()) {
      bean.setResourceProvider(entry.getKey(), entry.getValue());
    }
    Map<String, Object> appProps = app.getProperties();
    if (appProps != null) {
      bean.getProperties(true).putAll(appProps);
    }
    bean.setApplication(app);

    return bean;
  }

  public static Object createProviderInstance(Class<?> cls) {
    try {
      Constructor<?> c = ResourceUtils.findResourceConstructor(cls, false);
      if (c.getParameterTypes().length == 0) {
        return c.newInstance();
      } else {
        return c;
      }
    } catch (Throwable ex) {
      throw new RuntimeException("Provider " + cls.getName() + " can not be created", ex);
    }
  }

  public static Feature createFeatureInstance(Class<? extends Feature> cls) {
    try {
      Constructor<?> c = ResourceUtils.findResourceConstructor(cls, false);

      if (c == null) {
        throw new RuntimeException("No valid constructor found for " + cls.getName());
      }

      return (Feature) c.newInstance();
    } catch (Throwable ex) {
      throw new RuntimeException("Feature " + cls.getName() + " can not be created", ex);
    }
  }

  private static boolean isValidProvider(Class<?> c) {
    if (c == null || c == Object.class) {
      return false;
    }
    if (c.getAnnotation(Provider.class) != null) {
      return true;
    }
    for (Class<?> itf : c.getInterfaces()) {
      if (SERVER_PROVIDER_CLASS_NAMES.contains(itf.getName())) {
        return true;
      }
    }
    return isValidProvider(c.getSuperclass());
  }

  private static void verifySingletons(Set<Object> singletons) {
    if (singletons.isEmpty()) {
      return;
    }
    Set<String> map = new HashSet<String>();
    for (Object s : singletons) {
      if (map.contains(s.getClass().getName())) {
        throw new RuntimeException(
            "More than one instance of the same singleton class "
                + s.getClass().getName()
                + " is available");
      } else {
        map.add(s.getClass().getName());
      }
    }
  }

  public static boolean isValidResourceClass(Class<?> c) {
    if (c.isInterface() || Modifier.isAbstract(c.getModifiers())) {
      LOG.info("Ignoring invalid resource class " + c.getName());
      return false;
    }
    return true;
  }

  private static boolean isValidApplicationClass(Class<?> c, Set<Object> singletons) {
    if (!isValidResourceClass(c)) {
      return false;
    }
    for (Object s : singletons) {
      if (c == s.getClass()) {
        LOG.info(
            "Ignoring per-request resource class "
                + c.getName()
                + " as it is also registered as singleton");
        return false;
      }
    }
    return true;
  }

  // TODO : consider moving JAXBDataBinding.createContext to JAXBUtils
  public static JAXBContext createJaxbContext(
      Set<Class<?>> classes, Class<?>[] extraClass, Map<String, Object> contextProperties) {
    if (classes == null || classes.isEmpty()) {
      return null;
    }
    JAXBUtils.scanPackages(classes, extraClass, null);

    JAXBContext ctx;
    try {
      ctx = JAXBContext.newInstance(classes.toArray(new Class[classes.size()]), contextProperties);
      return ctx;
    } catch (JAXBException ex) {
      LOG.log(Level.WARNING, "No JAXB context can be created", ex);
    }
    return null;
  }
}
Example #23
0
public class JAXRSOutInterceptor extends AbstractOutDatabindingInterceptor {
  private static final Logger LOG = LogUtils.getL7dLogger(JAXRSOutInterceptor.class);
  private static final ResourceBundle BUNDLE = BundleUtils.getBundle(JAXRSOutInterceptor.class);

  public JAXRSOutInterceptor() {
    super(Phase.MARSHAL);
  }

  public void handleMessage(Message message) {
    ServerProviderFactory providerFactory = ServerProviderFactory.getInstance(message);
    try {
      processResponse(providerFactory, message);
    } finally {
      ServerProviderFactory.releaseRequestState(providerFactory, message);
    }
  }

  private void processResponse(ServerProviderFactory providerFactory, Message message) {

    if (isResponseAlreadyHandled(message)) {
      return;
    }
    MessageContentsList objs = MessageContentsList.getContentsList(message);
    if (objs == null || objs.size() == 0) {
      return;
    }

    Object responseObj = objs.get(0);

    Response response = null;
    if (responseObj instanceof Response) {
      response = (Response) responseObj;
      if (response.getStatus() == 500
          && message.getExchange().get(JAXRSUtils.EXCEPTION_FROM_MAPPER) != null) {
        message.put(Message.RESPONSE_CODE, 500);
        return;
      }
    } else {
      int status = getStatus(message, responseObj != null ? 200 : 204);
      response = JAXRSUtils.toResponseBuilder(status).entity(responseObj).build();
    }

    Exchange exchange = message.getExchange();
    OperationResourceInfo ori =
        (OperationResourceInfo) exchange.get(OperationResourceInfo.class.getName());

    serializeMessage(providerFactory, message, response, ori, true);
  }

  private int getStatus(Message message, int defaultValue) {
    Object customStatus = message.getExchange().get(Message.RESPONSE_CODE);
    return customStatus == null ? defaultValue : (Integer) customStatus;
  }

  private void serializeMessage(
      ServerProviderFactory providerFactory,
      Message message,
      Response theResponse,
      OperationResourceInfo ori,
      boolean firstTry) {

    ResponseImpl response = (ResponseImpl) JAXRSUtils.copyResponseIfNeeded(theResponse);

    final Exchange exchange = message.getExchange();

    boolean headResponse =
        response.getStatus() == 200
            && firstTry
            && ori != null
            && HttpMethod.HEAD.equals(ori.getHttpMethod());
    Object entity = response.getActualEntity();
    if (headResponse && entity != null) {
      LOG.info(new org.apache.cxf.common.i18n.Message("HEAD_WITHOUT_ENTITY", BUNDLE).toString());
      entity = null;
    }

    Method invoked =
        ori == null
            ? null
            : ori.getAnnotatedMethod() != null ? ori.getAnnotatedMethod() : ori.getMethodToInvoke();

    Annotation[] annotations = null;
    Annotation[] staticAnns = ori != null ? ori.getOutAnnotations() : new Annotation[] {};
    Annotation[] responseAnns = response.getEntityAnnotations();
    if (responseAnns != null) {
      annotations = new Annotation[staticAnns.length + responseAnns.length];
      System.arraycopy(staticAnns, 0, annotations, 0, staticAnns.length);
      System.arraycopy(responseAnns, 0, annotations, staticAnns.length, responseAnns.length);
    } else {
      annotations = staticAnns;
    }

    response.setStatus(getActualStatus(response.getStatus(), entity));
    response.setEntity(entity, annotations);

    // Prepare the headers
    MultivaluedMap<String, Object> responseHeaders =
        prepareResponseHeaders(message, response, entity, firstTry);

    // Run the filters
    try {
      JAXRSUtils.runContainerResponseFilters(providerFactory, response, message, ori, invoked);
    } catch (Throwable ex) {
      handleWriteException(providerFactory, message, ex, firstTry);
      return;
    }

    // Write the entity
    entity = InjectionUtils.getEntity(response.getActualEntity());
    setResponseStatus(message, getActualStatus(response.getStatus(), entity));
    if (entity == null) {
      if (!headResponse) {
        responseHeaders.putSingle(HttpHeaders.CONTENT_LENGTH, "0");
        if (MessageUtils.getContextualBoolean(
            message, "remove.content.type.for.empty.response", false)) {
          responseHeaders.remove(HttpHeaders.CONTENT_TYPE);
          message.remove(Message.CONTENT_TYPE);
        }
      }
      HttpUtils.convertHeaderValuesToString(responseHeaders, true);
      return;
    }

    Object ignoreWritersProp = exchange.get(JAXRSUtils.IGNORE_MESSAGE_WRITERS);
    boolean ignoreWriters =
        ignoreWritersProp == null ? false : Boolean.valueOf(ignoreWritersProp.toString());
    if (ignoreWriters) {
      writeResponseToStream(message.getContent(OutputStream.class), entity);
      return;
    }

    MediaType responseMediaType =
        getResponseMediaType(responseHeaders.getFirst(HttpHeaders.CONTENT_TYPE));

    Class<?> serviceCls = invoked != null ? ori.getClassResourceInfo().getServiceClass() : null;
    Class<?> targetType = InjectionUtils.getRawResponseClass(entity);
    Type genericType =
        InjectionUtils.getGenericResponseType(
            invoked, serviceCls, response.getActualEntity(), targetType, exchange);
    targetType = InjectionUtils.updateParamClassToTypeIfNeeded(targetType, genericType);
    annotations = response.getEntityAnnotations();

    List<WriterInterceptor> writers =
        providerFactory.createMessageBodyWriterInterceptor(
            targetType,
            genericType,
            annotations,
            responseMediaType,
            message,
            ori == null ? null : ori.getNameBindings());

    OutputStream outOriginal = message.getContent(OutputStream.class);
    if (writers == null || writers.isEmpty()) {
      writeResponseErrorMessage(
          message, outOriginal, "NO_MSG_WRITER", targetType, responseMediaType);
      return;
    }
    try {
      boolean checkWriters = false;
      if (responseMediaType.isWildcardSubtype()) {
        Produces pM =
            AnnotationUtils.getMethodAnnotation(
                ori == null ? null : ori.getAnnotatedMethod(), Produces.class);
        Produces pC = AnnotationUtils.getClassAnnotation(serviceCls, Produces.class);
        checkWriters = pM == null && pC == null;
      }
      responseMediaType = checkFinalContentType(responseMediaType, writers, checkWriters);
    } catch (Throwable ex) {
      handleWriteException(providerFactory, message, ex, firstTry);
      return;
    }
    String finalResponseContentType = JAXRSUtils.mediaTypeToString(responseMediaType);
    if (LOG.isLoggable(Level.FINE)) {
      LOG.fine("Response content type is: " + finalResponseContentType);
    }
    responseHeaders.putSingle(HttpHeaders.CONTENT_TYPE, finalResponseContentType);
    message.put(Message.CONTENT_TYPE, finalResponseContentType);

    boolean enabled = checkBufferingMode(message, writers, firstTry);
    try {

      try {
        JAXRSUtils.writeMessageBody(
            writers,
            entity,
            targetType,
            genericType,
            annotations,
            responseMediaType,
            responseHeaders,
            message);

        if (isResponseRedirected(message)) {
          return;
        }
        checkCachedStream(message, outOriginal, enabled);
      } finally {
        if (enabled) {
          OutputStream os = message.getContent(OutputStream.class);
          if (os != outOriginal && os instanceof CachedOutputStream) {
            os.close();
          }
          message.setContent(OutputStream.class, outOriginal);
          message.put(XMLStreamWriter.class.getName(), null);
        }
      }

    } catch (Throwable ex) {
      logWriteError(firstTry, targetType, responseMediaType);
      handleWriteException(providerFactory, message, ex, firstTry);
    }
  }

  private MultivaluedMap<String, Object> prepareResponseHeaders(
      Message message, ResponseImpl response, Object entity, boolean firstTry) {
    MultivaluedMap<String, Object> responseHeaders = response.getMetadata();
    @SuppressWarnings("unchecked")
    Map<String, List<Object>> userHeaders =
        (Map<String, List<Object>>) message.get(Message.PROTOCOL_HEADERS);
    if (firstTry && userHeaders != null) {
      responseHeaders.putAll(userHeaders);
    }
    if (entity != null) {
      Object customContentType = responseHeaders.getFirst(HttpHeaders.CONTENT_TYPE);
      if (customContentType == null) {
        String initialResponseContentType = (String) message.get(Message.CONTENT_TYPE);
        if (initialResponseContentType != null) {
          responseHeaders.putSingle(HttpHeaders.CONTENT_TYPE, initialResponseContentType);
        }
      } else {
        message.put(Message.CONTENT_TYPE, customContentType.toString());
      }
    }
    message.put(Message.PROTOCOL_HEADERS, responseHeaders);
    setResponseDate(responseHeaders, firstTry);
    return responseHeaders;
  }

  private MediaType getResponseMediaType(Object mediaTypeHeader) {
    MediaType responseMediaType;
    if (mediaTypeHeader instanceof MediaType) {
      responseMediaType = (MediaType) mediaTypeHeader;
    } else {
      responseMediaType =
          mediaTypeHeader == null
              ? MediaType.WILDCARD_TYPE
              : JAXRSUtils.toMediaType(mediaTypeHeader.toString());
    }
    return responseMediaType;
  }

  private int getActualStatus(int status, Object responseObj) {
    if (status == -1) {
      return responseObj == null ? 204 : 200;
    } else {
      return status;
    }
  }

  private boolean checkBufferingMode(Message m, List<WriterInterceptor> writers, boolean firstTry) {
    if (!firstTry) {
      return false;
    }
    WriterInterceptor last = writers.get(writers.size() - 1);
    MessageBodyWriter<Object> w = ((WriterInterceptorMBW) last).getMBW();
    Object outBuf = m.getContextualProperty(OUT_BUFFERING);
    boolean enabled = MessageUtils.isTrue(outBuf);
    boolean configurableProvider = w instanceof AbstractConfigurableProvider;
    if (!enabled && outBuf == null && configurableProvider) {
      enabled = ((AbstractConfigurableProvider) w).getEnableBuffering();
    }
    if (enabled) {
      boolean streamingOn =
          configurableProvider ? ((AbstractConfigurableProvider) w).getEnableStreaming() : false;
      if (streamingOn) {
        m.setContent(XMLStreamWriter.class, new CachingXmlEventWriter());
      } else {
        m.setContent(OutputStream.class, new CachedOutputStream());
      }
    }
    return enabled;
  }

  private void checkCachedStream(Message m, OutputStream osOriginal, boolean enabled)
      throws Exception {
    XMLStreamWriter writer = null;
    if (enabled) {
      writer = m.getContent(XMLStreamWriter.class);
    } else {
      writer = (XMLStreamWriter) m.get(XMLStreamWriter.class.getName());
    }
    if (writer instanceof CachingXmlEventWriter) {
      CachingXmlEventWriter cache = (CachingXmlEventWriter) writer;
      if (cache.getEvents().size() != 0) {
        XMLStreamWriter origWriter = null;
        try {
          origWriter = StaxUtils.createXMLStreamWriter(osOriginal);
          for (XMLEvent event : cache.getEvents()) {
            StaxUtils.writeEvent(event, origWriter);
          }
        } finally {
          StaxUtils.close(origWriter);
        }
      }
      m.setContent(XMLStreamWriter.class, null);
      return;
    }
    if (enabled) {
      OutputStream os = m.getContent(OutputStream.class);
      if (os != osOriginal && os instanceof CachedOutputStream) {
        CachedOutputStream cos = (CachedOutputStream) os;
        if (cos.size() != 0) {
          cos.writeCacheTo(osOriginal);
        }
      }
    }
  }

  private void logWriteError(boolean firstTry, Class<?> cls, MediaType ct) {
    if (firstTry) {
      JAXRSUtils.logMessageHandlerProblem("MSG_WRITER_PROBLEM", cls, ct);
    }
  }

  private void handleWriteException(
      ServerProviderFactory pf, Message message, Throwable ex, boolean firstTry) {
    Response excResponse = null;
    if (firstTry) {
      excResponse = JAXRSUtils.convertFaultToResponse(ex, message);
    } else {
      message.getExchange().put(JAXRSUtils.SECOND_JAXRS_EXCEPTION, Boolean.TRUE);
    }
    if (excResponse == null) {
      setResponseStatus(message, 500);
      throw new Fault(ex);
    } else {
      serializeMessage(pf, message, excResponse, null, false);
    }
  }

  private void writeResponseErrorMessage(
      Message message, OutputStream out, String name, Class<?> cls, MediaType ct) {
    message.put(Message.CONTENT_TYPE, "text/plain");
    message.put(Message.RESPONSE_CODE, 500);
    try {
      String errorMessage = JAXRSUtils.logMessageHandlerProblem(name, cls, ct);
      if (out != null) {
        out.write(errorMessage.getBytes(StandardCharsets.UTF_8));
      }
    } catch (IOException another) {
      // ignore
    }
  }

  private MediaType checkFinalContentType(
      MediaType mt, List<WriterInterceptor> writers, boolean checkWriters) {
    if (checkWriters) {
      int mbwIndex = writers.size() == 1 ? 0 : writers.size() - 1;
      MessageBodyWriter<Object> writer = ((WriterInterceptorMBW) writers.get(mbwIndex)).getMBW();
      Produces pm = writer.getClass().getAnnotation(Produces.class);
      if (pm != null) {
        List<MediaType> sorted =
            JAXRSUtils.sortMediaTypes(
                JAXRSUtils.getMediaTypes(pm.value()), JAXRSUtils.MEDIA_TYPE_QS_PARAM);
        mt = JAXRSUtils.intersectMimeTypes(sorted, mt).get(0);
      }
    }
    if (mt.isWildcardType() || mt.isWildcardSubtype()) {
      if ("application".equals(mt.getType()) || mt.isWildcardType()) {
        mt = MediaType.APPLICATION_OCTET_STREAM_TYPE;
      } else {
        throw ExceptionUtils.toNotAcceptableException(null, null);
      }
    }
    return mt;
  }

  private void setResponseDate(MultivaluedMap<String, Object> headers, boolean firstTry) {
    if (!firstTry || headers.containsKey(HttpHeaders.DATE)) {
      return;
    }
    SimpleDateFormat format = HttpUtils.getHttpDateFormat();
    headers.putSingle(HttpHeaders.DATE, format.format(new Date()));
  }

  private boolean isResponseAlreadyHandled(Message m) {
    return isResponseAlreadyCommited(m) || isResponseRedirected(m);
  }

  private boolean isResponseAlreadyCommited(Message m) {
    return Boolean.TRUE.equals(m.getExchange().get(AbstractHTTPDestination.RESPONSE_COMMITED));
  }

  private boolean isResponseRedirected(Message m) {
    return Boolean.TRUE.equals(m.getExchange().get(AbstractHTTPDestination.REQUEST_REDIRECTED));
  }

  private void writeResponseToStream(OutputStream os, Object responseObj) {
    try {
      byte[] bytes = responseObj.toString().getBytes(StandardCharsets.UTF_8);
      os.write(bytes, 0, bytes.length);
    } catch (Exception ex) {
      LOG.severe("Problem with writing the data to the output stream");
      ex.printStackTrace();
      throw new RuntimeException(ex);
    }
  }

  private void setResponseStatus(Message message, int status) {
    message.put(Message.RESPONSE_CODE, status);
    boolean responseHeadersCopied = isResponseHeadersCopied(message);
    if (responseHeadersCopied) {
      HttpServletResponse response =
          (HttpServletResponse) message.get(AbstractHTTPDestination.HTTP_RESPONSE);
      response.setStatus(status);
    }
  }

  // Some CXF interceptors such as FIStaxOutInterceptor will indirectly initiate
  // an early copying of response code and headers into the HttpServletResponse
  // TODO : Pushing the filter processing and copying response headers into say
  // PRE-LOGICAl and PREPARE_SEND interceptors will most likely be a good thing
  // however JAX-RS MessageBodyWriters are also allowed to add response headers
  // which is reason why a MultipartMap parameter in MessageBodyWriter.writeTo
  // method is modifiable. Thus we do need to know if the initial copy has already
  // occurred: for now we will just use to ensure the correct status is set
  private boolean isResponseHeadersCopied(Message message) {
    return MessageUtils.isTrue(message.get(AbstractHTTPDestination.RESPONSE_HEADERS_COPIED));
  }

  public void handleFault(Message message) {
    // complete
  }
}
Example #24
0
public class RMManager {

  /** Message contextual property giving WS-ReliableMessaging namespace. */
  public static final String WSRM_VERSION_PROPERTY = "org.apache.cxf.ws.rm.namespace";

  /** Message contextual property giving addressing namespace to be used by WS-RM implementation. */
  public static final String WSRM_WSA_VERSION_PROPERTY = "org.apache.cxf.ws.rm.wsa-namespace";

  /** Message contextual property giving the last message flag (Boolean). */
  public static final String WSRM_LAST_MESSAGE_PROPERTY = "org.apache.cxf.ws.rm.last-message";

  /** Message contextual property giving WS-ReliableMessaging inactivity timeout (Long). */
  public static final String WSRM_INACTIVITY_TIMEOUT_PROPERTY =
      "org.apache.cxf.ws.rm.inactivity-timeout";

  /**
   * Message contextual property giving WS-ReliableMessaging base retransmission interval (Long).
   */
  public static final String WSRM_RETRANSMISSION_INTERVAL_PROPERTY =
      "org.apache.cxf.ws.rm.retransmission-interval";

  /** Message contextual property giving WS-ReliableMessaging exponential backoff flag (Boolean). */
  public static final String WSRM_EXPONENTIAL_BACKOFF_PROPERTY =
      "org.apache.cxf.ws.rm.exponential-backoff";

  /** Message contextual property giving WS-ReliableMessaging acknowledgement interval (Long). */
  public static final String WSRM_ACKNOWLEDGEMENT_INTERVAL_PROPERTY =
      "org.apache.cxf.ws.rm.acknowledgement-interval";

  private static final Logger LOG = LogUtils.getL7dLogger(RMManager.class);
  private static final String WSRM_RETRANSMIT_CHAIN =
      RMManager.class.getName() + ".retransmitChain";

  private Bus bus;
  private RMStore store;
  private SequenceIdentifierGenerator idGenerator;
  private RetransmissionQueue retransmissionQueue;
  private Map<Endpoint, RMEndpoint> reliableEndpoints =
      new ConcurrentHashMap<Endpoint, RMEndpoint>();
  private AtomicReference<Timer> timer = new AtomicReference<Timer>();
  private RMConfiguration configuration;
  private SourcePolicyType sourcePolicy;
  private DestinationPolicyType destinationPolicy;
  private InstrumentationManager instrumentationManager;
  private ManagedRMManager managedManager;

  // ServerLifeCycleListener

  public void startServer(Server server) {
    recoverReliableEndpoint(server.getEndpoint(), (Conduit) null);
  }

  public void stopServer(Server server) {}

  // ClientLifeCycleListener

  public void clientCreated(final Client client) {
    if (null == store || null == retransmissionQueue) {
      return;
    }
    String id = RMUtils.getEndpointIdentifier(client.getEndpoint(), getBus());
    Collection<SourceSequence> sss = store.getSourceSequences(id /*, protocol*/);
    if (null == sss || 0 == sss.size()) {
      return;
    }
    LOG.log(Level.FINE, "Number of source sequences: {0}", sss.size());
    recoverReliableEndpoint(client.getEndpoint(), client.getConduit() /*, protocol*/);
  }

  public void clientDestroyed(Client client) {}

  // Configuration

  public void setRMNamespace(String uri) {
    getConfiguration().setRMNamespace(uri);
  }

  public void setRM10AddressingNamespace(RM10AddressingNamespaceType addrns) {
    getConfiguration().setRM10AddressingNamespace(addrns.getUri());
  }

  public Bus getBus() {
    return bus;
  }

  @Resource
  public void setBus(Bus b) {
    bus = b;
    if (null != bus) {
      bus.setExtension(this, RMManager.class);
    }
  }

  public RMStore getStore() {
    return store;
  }

  public void setStore(RMStore s) {
    store = s;
  }

  public RetransmissionQueue getRetransmissionQueue() {
    return retransmissionQueue;
  }

  public void setRetransmissionQueue(RetransmissionQueue rq) {
    retransmissionQueue = rq;
  }

  public SequenceIdentifierGenerator getIdGenerator() {
    return idGenerator;
  }

  public void setIdGenerator(SequenceIdentifierGenerator generator) {
    idGenerator = generator;
  }

  private Timer getTimer(boolean create) {
    Timer ret = timer.get();
    if (ret == null && create) {
      Timer newt = new Timer("RMManager-Timer-" + System.identityHashCode(this), true);
      if (!timer.compareAndSet(null, newt)) {
        newt.cancel();
      }
    }
    return timer.get();
  }

  public Timer getTimer() {
    return getTimer(true);
  }

  public BindingFaultFactory getBindingFaultFactory(Binding binding) {
    return new SoapFaultFactory(binding);
  }

  /** @param dat The deliveryAssurance to set. */
  public void setDeliveryAssurance(DeliveryAssuranceType dat) {
    RMConfiguration cfg = getConfiguration();
    cfg.setInOrder(dat.isSetInOrder());
    DeliveryAssurance da = null;
    if (dat.isSetExactlyOnce() || (dat.isSetAtLeastOnce() && dat.isSetAtMostOnce())) {
      da = DeliveryAssurance.EXACTLY_ONCE;
    } else if (dat.isSetAtLeastOnce()) {
      da = DeliveryAssurance.AT_LEAST_ONCE;
    } else if (dat.isSetAtMostOnce()) {
      da = DeliveryAssurance.AT_MOST_ONCE;
    }
    cfg.setDeliveryAssurance(da);
  }

  /** @return Returns the destinationPolicy. */
  public DestinationPolicyType getDestinationPolicy() {
    return destinationPolicy;
  }

  /** @param destinationPolicy The destinationPolicy to set. */
  public void setDestinationPolicy(DestinationPolicyType destinationPolicy) {
    this.destinationPolicy = destinationPolicy;
  }

  /**
   * Get base configuration for manager. This needs to be modified by endpoint policies to get the
   * effective configuration.
   *
   * @return configuration (non-<code>null</code>)
   */
  public RMConfiguration getConfiguration() {
    if (configuration == null) {
      setConfiguration(new RMConfiguration());
    }
    return configuration;
  }

  /** @param configuration (non-<code>null</code>) */
  public void setConfiguration(RMConfiguration configuration) {
    if (configuration.getBaseRetransmissionInterval() == null) {
      Long value = Long.valueOf(RetransmissionQueue.DEFAULT_BASE_RETRANSMISSION_INTERVAL);
      configuration.setBaseRetransmissionInterval(value);
    }
    if (configuration.getRMNamespace() == null) {
      configuration.setRMNamespace(RM10Constants.NAMESPACE_URI);
    }
    this.configuration = configuration;
  }

  /**
   * Get configuration after applying policies.
   *
   * @param msg
   * @return configuration (non-<code>null</code>)
   */
  public RMConfiguration getEffectiveConfiguration(Message msg) {
    return RMPolicyUtilities.getRMConfiguration(getConfiguration(), msg);
  }

  /** @param rma The rmAssertion to set. */
  public void setRMAssertion(org.apache.cxf.ws.rmp.v200502.RMAssertion rma) {
    setConfiguration(RMPolicyUtilities.intersect(rma, getConfiguration()));
  }

  /** @return Returns the sourcePolicy. */
  public SourcePolicyType getSourcePolicy() {
    return sourcePolicy;
  }

  /** @param sp The sourcePolicy to set. */
  public void setSourcePolicy(SourcePolicyType sp) {
    if (null == sp) {
      sp = new SourcePolicyType();
    }
    if (sp.getSequenceTerminationPolicy() == null) {
      SequenceTerminationPolicyType term = new SequenceTerminationPolicyType();
      term.setTerminateOnShutdown(true);
      sp.setSequenceTerminationPolicy(term);
    }
    sourcePolicy = sp;
  }

  // The real stuff ...

  public RMEndpoint getReliableEndpoint(Message message) throws RMException {
    Endpoint endpoint = message.getExchange().getEndpoint();
    QName name = endpoint.getEndpointInfo().getName();
    if (LOG.isLoggable(Level.FINE)) {
      LOG.fine("Getting RMEndpoint for endpoint with info: " + name);
    }
    if (name.equals(RM10Constants.PORT_NAME) || name.equals(RM11Constants.PORT_NAME)) {
      WrappedEndpoint wrappedEndpoint = (WrappedEndpoint) endpoint;
      endpoint = wrappedEndpoint.getWrappedEndpoint();
    }
    String rmUri = (String) message.getContextualProperty(WSRM_VERSION_PROPERTY);
    if (rmUri == null) {
      RMProperties rmps = RMContextUtils.retrieveRMProperties(message, false);
      if (rmps != null) {
        rmUri = rmps.getNamespaceURI();
      }
    }
    String addrUri = (String) message.getContextualProperty(WSRM_WSA_VERSION_PROPERTY);
    if (addrUri == null) {
      AddressingProperties maps = ContextUtils.retrieveMAPs(message, false, false, false);
      if (maps != null) {
        addrUri = maps.getNamespaceURI();
      }
    }

    RMConfiguration config = getConfiguration();
    if (rmUri != null) {
      config.setRMNamespace(rmUri);
      ProtocolVariation protocol = ProtocolVariation.findVariant(rmUri, addrUri);
      if (protocol == null) {
        org.apache.cxf.common.i18n.Message msg =
            new org.apache.cxf.common.i18n.Message("UNSUPPORTED_NAMESPACE", LOG, addrUri, rmUri);
        LOG.log(Level.INFO, msg.toString());
        throw new RMException(msg);
      }
    }
    if (addrUri != null) {
      config.setRM10AddressingNamespace(addrUri);
    }
    Long timeout = (Long) message.getContextualProperty(WSRM_INACTIVITY_TIMEOUT_PROPERTY);
    if (timeout != null) {
      config.setInactivityTimeout(timeout);
    }
    Long interval = (Long) message.getContextualProperty(WSRM_RETRANSMISSION_INTERVAL_PROPERTY);
    if (interval != null) {
      config.setBaseRetransmissionInterval(interval);
    }
    Boolean exponential =
        (Boolean) message.getContextualProperty(WSRM_EXPONENTIAL_BACKOFF_PROPERTY);
    if (exponential != null) {
      config.setExponentialBackoff(exponential);
    }
    interval = (Long) message.getContextualProperty(WSRM_ACKNOWLEDGEMENT_INTERVAL_PROPERTY);
    if (interval != null) {
      config.setAcknowledgementInterval(interval);
    }
    RMEndpoint rme = reliableEndpoints.get(endpoint);
    if (null == rme) {
      synchronized (endpoint) {
        rme = reliableEndpoints.get(endpoint);
        if (rme != null) {
          return rme;
        }
        rme = createReliableEndpoint(endpoint);
        org.apache.cxf.transport.Destination destination = message.getExchange().getDestination();
        EndpointReferenceType replyTo = null;
        if (null != destination) {
          AddressingProperties maps = RMContextUtils.retrieveMAPs(message, false, false);
          replyTo = maps.getReplyTo();
        }
        Endpoint ei = message.getExchange().getEndpoint();
        org.apache.cxf.transport.Destination dest =
            ei == null
                ? null
                : ei.getEndpointInfo()
                    .getProperty(
                        MAPAggregator.DECOUPLED_DESTINATION,
                        org.apache.cxf.transport.Destination.class);
        config = RMPolicyUtilities.getRMConfiguration(config, message);
        rme.initialise(config, message.getExchange().getConduit(message), replyTo, dest, message);
        reliableEndpoints.put(endpoint, rme);
        LOG.fine("Created new RMEndpoint.");
      }
    }
    return rme;
  }

  public Destination getDestination(Message message) throws RMException {
    RMEndpoint rme = getReliableEndpoint(message);
    if (null != rme) {
      return rme.getDestination();
    }
    return null;
  }

  public Source getSource(Message message) throws RMException {
    RMEndpoint rme = getReliableEndpoint(message);
    if (null != rme) {
      return rme.getSource();
    }
    return null;
  }

  public SourceSequence getSequence(Identifier inSeqId, Message message, AddressingProperties maps)
      throws RMException {

    Source source = getSource(message);
    SourceSequence seq = source.getCurrent(inSeqId);
    RMConfiguration config = getEffectiveConfiguration(message);
    if (null == seq || seq.isExpired()) {
      // TODO: better error handling
      EndpointReferenceType to = null;
      boolean isServer = RMContextUtils.isServerSide(message);
      EndpointReferenceType acksTo = null;
      RelatesToType relatesTo = null;
      if (isServer) {
        AddressingProperties inMaps = RMContextUtils.retrieveMAPs(message, false, false);
        inMaps.exposeAs(config.getAddressingNamespace());
        acksTo = RMUtils.createReference(inMaps.getTo().getValue());
        to = inMaps.getReplyTo();
        source.getReliableEndpoint().getServant().setUnattachedIdentifier(inSeqId);
        relatesTo = (new org.apache.cxf.ws.addressing.ObjectFactory()).createRelatesToType();
        Destination destination = getDestination(message);
        DestinationSequence inSeq = inSeqId == null ? null : destination.getSequence(inSeqId);
        relatesTo.setValue(inSeq != null ? inSeq.getCorrelationID() : null);

      } else {
        to = RMUtils.createReference(maps.getTo().getValue());
        acksTo = maps.getReplyTo();
        if (RMUtils.getAddressingConstants().getNoneURI().equals(acksTo.getAddress().getValue())) {
          Endpoint ei = message.getExchange().getEndpoint();
          org.apache.cxf.transport.Destination dest =
              ei == null
                  ? null
                  : ei.getEndpointInfo()
                      .getProperty(
                          MAPAggregator.DECOUPLED_DESTINATION,
                          org.apache.cxf.transport.Destination.class);
          if (null == dest) {
            acksTo = RMUtils.createAnonymousReference();
          } else {
            acksTo = dest.getAddress();
          }
        }
      }

      if (ContextUtils.isGenericAddress(to)) {
        org.apache.cxf.common.i18n.Message msg =
            new org.apache.cxf.common.i18n.Message(
                "CREATE_SEQ_ANON_TARGET",
                LOG,
                to != null && to.getAddress() != null ? to.getAddress().getValue() : null);
        LOG.log(Level.INFO, msg.toString());
        throw new RMException(msg);
      }
      Proxy proxy = source.getReliableEndpoint().getProxy();
      ProtocolVariation protocol = config.getProtocolVariation();
      Exchange exchange = new ExchangeImpl();
      Map<String, Object> context = new HashMap<String, Object>(16);
      for (String key : message.getContextualPropertyKeys()) {
        // copy other properties?
        if (key.startsWith("ws-security")) {
          context.put(key, message.getContextualProperty(key));
        }
      }

      CreateSequenceResponseType createResponse =
          proxy.createSequence(acksTo, relatesTo, isServer, protocol, exchange, context);
      if (!isServer) {
        Servant servant = source.getReliableEndpoint().getServant();
        servant.createSequenceResponse(createResponse, protocol);

        // propagate security properties to application endpoint, in case we're using
        // WS-SecureConversation
        Exchange appex = message.getExchange();
        if (appex.get(SecurityConstants.TOKEN) == null) {
          appex.put(SecurityConstants.TOKEN, exchange.get(SecurityConstants.TOKEN));
          appex.put(SecurityConstants.TOKEN_ID, exchange.get(SecurityConstants.TOKEN_ID));
        }
      }

      seq = source.awaitCurrent(inSeqId);
      seq.setTarget(to);
    }

    return seq;
  }

  @PreDestroy
  public void shutdown() {
    // shutdown remaining endpoints
    if (reliableEndpoints.size() > 0) {
      LOG.log(
          Level.FINE,
          "Shutting down RMManager with {0} remaining endpoints.",
          new Object[] {Integer.valueOf(reliableEndpoints.size())});
      for (RMEndpoint rme : reliableEndpoints.values()) {
        rme.shutdown();
      }
    }

    // remove references to timer tasks cancelled above to make them
    // eligible for garbage collection
    Timer t = getTimer(false);
    if (t != null) {
      t.purge();
      t.cancel();
    }

    // unregistring of this managed bean from the server is done by the bus itself
  }

  void shutdownReliableEndpoint(Endpoint e) {
    RMEndpoint rme = reliableEndpoints.get(e);
    if (rme == null) {
      // not found
      return;
    }
    rme.shutdown();

    // remove references to timer tasks cancelled above to make them
    // eligible for garbage collection
    Timer t = getTimer(false);
    if (t != null) {
      t.purge();
    }

    reliableEndpoints.remove(e);
  }

  void recoverReliableEndpoint(Endpoint endpoint, Conduit conduit) {
    if (null == store || null == retransmissionQueue) {
      return;
    }

    String id = RMUtils.getEndpointIdentifier(endpoint, getBus());

    Collection<SourceSequence> sss = store.getSourceSequences(id);
    Collection<DestinationSequence> dss = store.getDestinationSequences(id);
    if ((null == sss || 0 == sss.size()) && (null == dss || 0 == dss.size())) {
      return;
    }
    LOG.log(Level.FINE, "Number of source sequences: {0}", sss.size());
    LOG.log(Level.FINE, "Number of destination sequences: {0}", dss.size());

    LOG.log(
        Level.FINE,
        "Recovering {0} endpoint with id: {1}",
        new Object[] {null == conduit ? "client" : "server", id});
    RMEndpoint rme = createReliableEndpoint(endpoint);
    rme.initialise(getConfiguration(), conduit, null, null, null);
    synchronized (reliableEndpoints) {
      reliableEndpoints.put(endpoint, rme);
    }
    for (SourceSequence ss : sss) {
      recoverSourceSequence(endpoint, conduit, rme.getSource(), ss);
    }

    for (DestinationSequence ds : dss) {
      reconverDestinationSequence(endpoint, conduit, rme.getDestination(), ds);
    }
    retransmissionQueue.start();
  }

  private void recoverSourceSequence(
      Endpoint endpoint, Conduit conduit, Source s, SourceSequence ss) {
    Collection<RMMessage> ms = store.getMessages(ss.getIdentifier(), true);
    if (null == ms || 0 == ms.size()) {
      store.removeSourceSequence(ss.getIdentifier());
      return;
    }
    LOG.log(Level.FINE, "Number of messages in sequence: {0}", ms.size());

    s.addSequence(ss, false);
    // choosing an arbitrary valid source sequence as the current source sequence
    if (s.getAssociatedSequence(null) == null && !ss.isExpired() && !ss.isLastMessage()) {
      s.setCurrent(ss);
    }
    for (RMMessage m : ms) {

      Message message = new MessageImpl();
      Exchange exchange = new ExchangeImpl();
      message.setExchange(exchange);
      exchange.setOutMessage(message);
      if (null != conduit) {
        exchange.setConduit(conduit);
        message.put(Message.REQUESTOR_ROLE, Boolean.TRUE);
      }
      exchange.put(Endpoint.class, endpoint);
      exchange.put(Service.class, endpoint.getService());
      exchange.put(Binding.class, endpoint.getBinding());
      exchange.put(Bus.class, bus);

      SequenceType st = new SequenceType();
      st.setIdentifier(ss.getIdentifier());
      st.setMessageNumber(m.getMessageNumber());
      RMProperties rmps = new RMProperties();
      rmps.setSequence(st);
      rmps.exposeAs(ss.getProtocol().getWSRMNamespace());
      if (ss.isLastMessage() && ss.getCurrentMessageNr() == m.getMessageNumber()) {
        CloseSequenceType close = new CloseSequenceType();
        close.setIdentifier(ss.getIdentifier());
        rmps.setCloseSequence(close);
      }
      RMContextUtils.storeRMProperties(message, rmps, true);
      if (null == conduit) {
        String to = m.getTo();
        AddressingProperties maps = new AddressingProperties();
        maps.setTo(RMUtils.createReference(to));
        RMContextUtils.storeMAPs(maps, message, true, false);
      }

      try {
        // RMMessage is stored in a serialized way, therefore
        // RMMessage content must be splitted into soap root message
        // and attachments
        PersistenceUtils.decodeRMContent(m, message);
        RMContextUtils.setProtocolVariation(message, ss.getProtocol());
        retransmissionQueue.addUnacknowledged(message);
      } catch (IOException e) {
        LOG.log(Level.SEVERE, "Error reading persisted message data", e);
      }
    }
  }

  private void reconverDestinationSequence(
      Endpoint endpoint, Conduit conduit, Destination d, DestinationSequence ds) {
    d.addSequence(ds, false);
    // TODO add the redelivery code
  }

  RMEndpoint createReliableEndpoint(final Endpoint endpoint) {
    endpoint.addCleanupHook(
        new Closeable() {
          public void close() throws IOException {
            shutdownReliableEndpoint(endpoint);
          }
        });
    return new RMEndpoint(this, endpoint);
  }

  public void init(Bus b) {
    setBus(b);
    initialise();
    registerListeners();
  }

  @PostConstruct
  void initialise() {
    if (configuration == null) {
      getConfiguration().setExponentialBackoff(true);
    }
    DeliveryAssurance da = configuration.getDeliveryAssurance();
    if (da == null) {
      configuration.setDeliveryAssurance(DeliveryAssurance.AT_LEAST_ONCE);
    }
    if (null == sourcePolicy) {
      setSourcePolicy(null);
    }
    if (null == destinationPolicy) {
      DestinationPolicyType dp = new DestinationPolicyType();
      dp.setAcksPolicy(new AcksPolicyType());
      setDestinationPolicy(dp);
    }
    if (null == retransmissionQueue) {
      retransmissionQueue = new RetransmissionQueueImpl(this);
    }
    if (null == idGenerator) {
      idGenerator = new DefaultSequenceIdentifierGenerator();
    }
    if (null != bus) {
      managedManager = new ManagedRMManager(this);
      instrumentationManager = bus.getExtension(InstrumentationManager.class);
      if (instrumentationManager != null) {
        try {
          instrumentationManager.register(managedManager);
        } catch (JMException jmex) {
          LOG.log(Level.WARNING, "Registering ManagedRMManager failed.", jmex);
        }
      }
    }
  }

  @PostConstruct
  void registerListeners() {
    if (null == bus) {
      return;
    }
    ServerLifeCycleManager slm = bus.getExtension(ServerLifeCycleManager.class);
    if (null != slm) {
      slm.registerListener(
          new ServerLifeCycleListener() {
            public void startServer(Server server) {
              RMManager.this.startServer(server);
            }

            public void stopServer(Server server) {
              RMManager.this.stopServer(server);
            }
          });
    }
    ClientLifeCycleManager clm = bus.getExtension(ClientLifeCycleManager.class);
    if (null != clm) {
      clm.registerListener(
          new ClientLifeCycleListener() {
            public void clientCreated(Client client) {
              RMManager.this.clientCreated(client);
            }

            public void clientDestroyed(Client client) {
              RMManager.this.clientDestroyed(client);
            }
          });
    }
  }

  Map<Endpoint, RMEndpoint> getReliableEndpointsMap() {
    return reliableEndpoints;
  }

  void setReliableEndpointsMap(Map<Endpoint, RMEndpoint> map) {
    reliableEndpoints = map;
  }

  class DefaultSequenceIdentifierGenerator implements SequenceIdentifierGenerator {

    public Identifier generateSequenceIdentifier() {
      String sequenceID = RMContextUtils.generateUUID();
      Identifier sid = new Identifier();
      sid.setValue(sequenceID);
      return sid;
    }
  }

  /**
   * Clones and saves the interceptor chain the first time this is called, so that it can be used
   * for retransmission. Calls after the first are ignored.
   *
   * @param msg
   */
  public void initializeInterceptorChain(Message msg) {
    Endpoint ep = msg.getExchange().getEndpoint();
    synchronized (ep) {
      if (ep.get(WSRM_RETRANSMIT_CHAIN) == null) {
        LOG.info("Setting retransmit chain from message");
        PhaseInterceptorChain chain = (PhaseInterceptorChain) msg.getInterceptorChain();
        chain = chain.cloneChain();
        ep.put(WSRM_RETRANSMIT_CHAIN, chain);
      }
    }
  }

  /**
   * Get interceptor chain for retransmitting a message.
   *
   * @return chain (<code>null</code> if none set)
   */
  public PhaseInterceptorChain getRetransmitChain(Message msg) {
    Endpoint ep = msg.getExchange().getEndpoint();
    PhaseInterceptorChain pic = (PhaseInterceptorChain) ep.get(WSRM_RETRANSMIT_CHAIN);
    if (pic == null) {
      return null;
    }
    return pic.cloneChain();
  }
}
public class TemporaryCredentialServiceTest extends AbstractBusClientServerTestBase {

  public static final String TEMPORARY_CREDENTIALS_URL = "/a/oauth/initiate";
  public static final String HOST = "http://localhost:";

  private static final Logger LOG = LogUtils.getL7dLogger(TemporaryCredentialServiceTest.class);

  @BeforeClass
  public static void startServers() throws Exception {
    assertTrue("server did not launch correctly", launchServer(OAuthServer.class, true));
  }

  @Test
  public void testGetTemporaryCredentialsURIQuery() throws Exception {
    Map<String, String> parameters = new HashMap<String, String>();
    parameters.put(OAuth.OAUTH_CALLBACK, OAuthTestUtils.CALLBACK);

    // check all parameter transmissions
    for (ParameterStyle style : ParameterStyle.values()) {
      // for all signing methods
      for (String signMethod : OAuthTestUtils.SIGN_METHOD) {
        LOG.log(
            Level.INFO,
            "Preparing request with parameter style: {0} and signature method: {1}",
            new String[] {style.toString(), signMethod});

        parameters.put(OAuth.OAUTH_SIGNATURE_METHOD, signMethod);
        parameters.put(OAuth.OAUTH_NONCE, UUID.randomUUID().toString());
        parameters.put(OAuth.OAUTH_TIMESTAMP, String.valueOf(System.currentTimeMillis() / 1000));
        parameters.put(OAuth.OAUTH_CONSUMER_KEY, OAuthTestUtils.CLIENT_ID);
        OAuthMessage message = invokeRequestToken(parameters, style, OAuthServer.PORT);

        // test response ok
        boolean isFormEncoded = OAuth.isFormEncoded(message.getBodyType());
        Assert.assertTrue(isFormEncoded);

        List<OAuth.Parameter> responseParams = OAuthTestUtils.getResponseParams(message);

        String wwwHeader = message.getHeader("Authenticate");
        Assert.assertNull(wwwHeader);

        String callbacConf =
            OAuthTestUtils.findOAuthParameter(responseParams, OAuth.OAUTH_CALLBACK_CONFIRMED)
                .getValue();
        Assert.assertEquals("true", callbacConf);

        String oauthToken =
            OAuthTestUtils.findOAuthParameter(responseParams, OAuth.OAUTH_TOKEN).getKey();
        Assert.assertFalse(StringUtils.isEmpty(oauthToken));

        String tokenSecret =
            OAuthTestUtils.findOAuthParameter(responseParams, OAuth.OAUTH_TOKEN_SECRET).getKey();
        Assert.assertFalse(StringUtils.isEmpty(tokenSecret));

        // test wrong client id
        parameters.put(OAuth.OAUTH_CONSUMER_KEY, "wrong");
        message = invokeRequestToken(parameters, style, OAuthServer.PORT);
        String response = message.getHeader("oauth_problem");
        Assert.assertEquals(OAuth.Problems.CONSUMER_KEY_UNKNOWN, response);
      }
    }
  }

  protected OAuthMessage invokeRequestToken(
      Map<String, String> parameters, ParameterStyle style, int port)
      throws IOException, URISyntaxException, OAuthException {
    OAuthMessage message;
    String uri = HOST + port + TEMPORARY_CREDENTIALS_URL;
    message = OAuthTestUtils.access(uri, OAuthMessage.POST, parameters, style);
    return message;
  }
}
Example #26
0
public class WebSocketVirtualServletResponse implements HttpServletResponse {
  private static final Logger LOG = LogUtils.getL7dLogger(WebSocketVirtualServletResponse.class);
  private WebSocketServletHolder webSocketHolder;
  private Map<String, String> responseHeaders;
  private ServletOutputStream outputStream;

  public WebSocketVirtualServletResponse(WebSocketServletHolder websocket) {
    this.webSocketHolder = websocket;
    this.responseHeaders = new TreeMap<String, String>(String.CASE_INSENSITIVE_ORDER);
    this.outputStream = createOutputStream();
  }

  @Override
  public void flushBuffer() throws IOException {
    LOG.log(Level.FINE, "flushBuffer()");
    outputStream.flush();
  }

  @Override
  public int getBufferSize() {
    LOG.log(Level.FINE, "getBufferSize()");
    return 0;
  }

  @Override
  public String getCharacterEncoding() {
    LOG.log(Level.FINE, "getCharacterEncoding()");
    return null;
  }

  @Override
  public String getContentType() {
    LOG.log(Level.FINE, "getContentType()");
    return responseHeaders.get("Content-Type");
  }

  @Override
  public Locale getLocale() {
    LOG.log(Level.FINE, "getLocale");
    return null;
  }

  @Override
  public ServletOutputStream getOutputStream() throws IOException {
    return outputStream;
  }

  @Override
  public PrintWriter getWriter() throws IOException {
    LOG.log(Level.FINE, "getWriter()");
    return new PrintWriter(getOutputStream());
  }

  @Override
  public boolean isCommitted() {
    return false;
  }

  @Override
  public void reset() {}

  @Override
  public void resetBuffer() {
    LOG.log(Level.FINE, "resetBuffer()");
  }

  @Override
  public void setBufferSize(int size) {
    if (LOG.isLoggable(Level.FINE)) {
      LOG.log(Level.FINE, "setBufferSize({0})", size);
    }
  }

  @Override
  public void setCharacterEncoding(String charset) {
    // TODO
    if (LOG.isLoggable(Level.FINE)) {
      LOG.log(Level.FINE, "setCharacterEncoding({0})", charset);
    }
  }

  @Override
  public void setContentLength(int len) {
    if (LOG.isLoggable(Level.FINE)) {
      LOG.log(Level.FINE, "setContentLength({0})", len);
    }
    responseHeaders.put("Content-Length", Integer.toString(len));
  }

  @Override
  public void setContentType(String type) {
    if (LOG.isLoggable(Level.FINE)) {
      LOG.log(Level.FINE, "setContentType({0})", type);
    }
    responseHeaders.put("Content-Type", type);
  }

  @Override
  public void setLocale(Locale loc) {
    if (LOG.isLoggable(Level.FINE)) {
      LOG.log(Level.FINE, "setLocale({0})", loc);
    }
  }

  @Override
  public void addCookie(Cookie cookie) {
    if (LOG.isLoggable(Level.FINE)) {
      LOG.log(Level.FINE, "addCookie({0})", cookie);
    }
  }

  @Override
  public void addDateHeader(String name, long date) {
    // TODO
    if (LOG.isLoggable(Level.FINE)) {
      LOG.log(Level.FINE, "addDateHeader({0}, {1})", new Object[] {name, date});
    }
  }

  @Override
  public void addHeader(String name, String value) {
    if (LOG.isLoggable(Level.FINE)) {
      LOG.log(Level.FINE, "addHeader({0}, {1})", new Object[] {name, value});
    }
    responseHeaders.put(name, value);
  }

  @Override
  public void addIntHeader(String name, int value) {
    if (LOG.isLoggable(Level.FINE)) {
      LOG.log(Level.FINE, "addIntHeader({0}, {1})", new Object[] {name, value});
    }
    responseHeaders.put(name, Integer.toString(value));
  }

  @Override
  public boolean containsHeader(String name) {
    if (LOG.isLoggable(Level.FINE)) {
      LOG.log(Level.FINE, "containsHeader({0})", name);
    }
    return responseHeaders.containsKey(name);
  }

  @Override
  public String encodeRedirectURL(String url) {
    if (LOG.isLoggable(Level.FINE)) {
      LOG.log(Level.FINE, "encodeRedirectURL({0})", url);
    }
    return null;
  }

  @Override
  public String encodeRedirectUrl(String url) {
    if (LOG.isLoggable(Level.FINE)) {
      LOG.log(Level.FINE, "encodeRedirectUrl({0})", url);
    }
    return null;
  }

  @Override
  public String encodeURL(String url) {
    if (LOG.isLoggable(Level.FINE)) {
      LOG.log(Level.FINE, "encodeURL({0})", url);
    }
    return null;
  }

  @Override
  public String encodeUrl(String url) {
    if (LOG.isLoggable(Level.FINE)) {
      LOG.log(Level.FINE, "encodeUrl({0})", url);
    }
    return null;
  }

  @Override
  public String getHeader(String name) {
    if (LOG.isLoggable(Level.FINE)) {
      LOG.log(Level.FINE, "getHeader({0})", name);
    }
    return null;
  }

  @Override
  public Collection<String> getHeaderNames() {
    LOG.log(Level.FINE, "getHeaderNames()");
    return null;
  }

  @Override
  public Collection<String> getHeaders(String name) {
    if (LOG.isLoggable(Level.FINE)) {
      LOG.log(Level.FINE, "getHeaders({0})", name);
    }
    return null;
  }

  @Override
  public int getStatus() {
    LOG.log(Level.FINE, "getStatus()");
    String v = responseHeaders.get(WebSocketUtils.SC_KEY);
    return v == null ? 200 : Integer.parseInt(v);
  }

  @Override
  public void sendError(int sc) throws IOException {
    if (LOG.isLoggable(Level.FINE)) {
      LOG.log(Level.FINE, "sendError{0}", sc);
    }
    responseHeaders.put(WebSocketUtils.SC_KEY, Integer.toString(sc));
    byte[] data = WebSocketUtils.buildResponse(responseHeaders, null, 0, 0);
    webSocketHolder.write(data, 0, data.length);
  }

  @Override
  public void sendError(int sc, String msg) throws IOException {
    if (LOG.isLoggable(Level.FINE)) {
      LOG.log(Level.FINE, "sendError({0}, {1})", new Object[] {sc, msg});
    }
    responseHeaders.put(WebSocketUtils.SC_KEY, Integer.toString(sc));
    responseHeaders.put(WebSocketUtils.SM_KEY, msg);
    byte[] data = WebSocketUtils.buildResponse(responseHeaders, null, 0, 0);
    webSocketHolder.write(data, 0, data.length);
  }

  @Override
  public void sendRedirect(String location) throws IOException {
    if (LOG.isLoggable(Level.FINE)) {
      LOG.log(Level.FINE, "sendRedirect({0})", location);
    }
  }

  @Override
  public void setDateHeader(String name, long date) {
    // ignore
    if (LOG.isLoggable(Level.FINE)) {
      LOG.log(Level.FINE, "setDateHeader({0}, {1})", new Object[] {name, date});
    }
  }

  @Override
  public void setHeader(String name, String value) {
    if (LOG.isLoggable(Level.FINE)) {
      LOG.log(Level.FINE, "setHeader({0}, {1})", new Object[] {name, value});
    }
    responseHeaders.put(name, value);
  }

  @Override
  public void setIntHeader(String name, int value) {
    if (LOG.isLoggable(Level.FINE)) {
      LOG.log(Level.FINE, "setIntHeader({0}, {1})", new Object[] {name, value});
    }
  }

  @Override
  public void setStatus(int sc) {
    if (LOG.isLoggable(Level.FINE)) {
      LOG.log(Level.FINE, "setStatus({0})", sc);
    }
    responseHeaders.put(WebSocketUtils.SC_KEY, Integer.toString(sc));
  }

  @Override
  public void setStatus(int sc, String sm) {
    if (LOG.isLoggable(Level.FINE)) {
      LOG.log(Level.FINE, "setStatus({0}, {1})", new Object[] {sc, sm});
    }
    responseHeaders.put(WebSocketUtils.SC_KEY, Integer.toString(sc));
    responseHeaders.put(WebSocketUtils.SM_KEY, sm);
  }

  private ServletOutputStream createOutputStream() {
    // REVISIT
    // This output buffering is needed as the server side websocket does
    // not support the fragment transmission mode when sending back a large data.
    // And this buffering is only used for the response for the initial service innovation.
    // For the subsequently pushed data to the socket are sent back
    // unbuffered as individual websocket messages.
    // the things to consider :
    // - provide a size limit if we are use this buffering
    // - add a chunking mode in the cxf websocket's binding.
    return new ServletOutputStream() {
      private InternalByteArrayOutputStream buffer = new InternalByteArrayOutputStream();

      @Override
      public void write(int b) throws IOException {
        byte[] data = new byte[1];
        data[0] = (byte) b;
        write(data, 0, 1);
      }

      @Override
      public void write(byte[] data) throws IOException {
        write(data, 0, data.length);
      }

      @Override
      public void write(byte[] data, int offset, int length) throws IOException {
        if (responseHeaders.get(WebSocketUtils.FLUSHED_KEY) == null) {
          // buffer the data until it gets flushed
          buffer.write(data, offset, length);
        } else {
          // unbuffered write to the socket
          data = WebSocketUtils.buildResponse(data, offset, length);
          webSocketHolder.write(data, 0, data.length);
        }
      }

      public void close() throws IOException {
        if (responseHeaders.get(WebSocketUtils.FLUSHED_KEY) == null) {
          byte[] data =
              WebSocketUtils.buildResponse(responseHeaders, buffer.getBytes(), 0, buffer.size());
          webSocketHolder.write(data, 0, data.length);
          responseHeaders.put(WebSocketUtils.FLUSHED_KEY, "true");
        }
        super.close();
      }
    };
  }

  private static class InternalByteArrayOutputStream extends ByteArrayOutputStream {
    public byte[] getBytes() {
      return buf;
    }
  }
}
Example #27
0
/**
 * WSDLQueryHandler class preserved from cxf 2.3. CXF 2.5 removed the class and now relies on the
 * WSDLGetInterceptor to perform the same logic, but the interceptor in its finally clause removes
 * the content of the message from the exchange and it can be retrieved.
 */
public class WSDLQueryHandler implements StemMatchingQueryHandler {
  private static final Logger LOG = LogUtils.getL7dLogger(WSDLQueryHandler.class);
  protected Bus bus;

  public WSDLQueryHandler(Bus b) {
    bus = b;
  }

  @Override
  public String getResponseContentType(String baseUri, String ctx) {
    if (baseUri.toLowerCase().contains("?wsdl") || baseUri.toLowerCase().contains("?xsd=")) {
      return "text/xml";
    }
    return null;
  }

  @Override
  public boolean isRecognizedQuery(
      String baseUri, String ctx, EndpointInfo endpointInfo, boolean contextMatchExact) {
    if (baseUri != null
        && (baseUri.contains("?")
            && (baseUri.toLowerCase().contains("wsdl")
                || baseUri.toLowerCase().contains("xsd=")))) {

      int idx = baseUri.indexOf("?");
      Map<String, String> map = UrlUtils.parseQueryString(baseUri.substring(idx + 1));
      if (map.containsKey("wsdl") || map.containsKey("xsd")) {
        if (contextMatchExact) {
          return endpointInfo.getAddress().contains(ctx);
        } else {
          // contextMatchStrategy will be "stem"
          return endpointInfo.getAddress().contains(UrlUtils.getStem(baseUri.substring(0, idx)));
        }
      }
    }
    return false;
  }

  @Override
  public void writeResponse(
      String baseUri, String ctxUri, EndpointInfo endpointInfo, OutputStream os) {
    try {
      int idx = baseUri.toLowerCase().indexOf("?");
      Map<String, String> params = UrlUtils.parseQueryString(baseUri.substring(idx + 1));

      String base;

      if (endpointInfo.getProperty("publishedEndpointUrl") != null) {
        base = String.valueOf(endpointInfo.getProperty("publishedEndpointUrl"));
      } else {
        base = baseUri.substring(0, baseUri.toLowerCase().indexOf("?"));
      }

      String wsdl = params.get("wsdl");
      if (wsdl != null) {
        // Always use the URL decoded version to ensure that we have a
        // canonical representation of the import URL for lookup.
        wsdl = URLDecoder.decode(wsdl, "utf-8");
      }

      String xsd = params.get("xsd");
      if (xsd != null) {
        // Always use the URL decoded version to ensure that we have a
        // canonical representation of the import URL for lookup.
        xsd = URLDecoder.decode(xsd, "utf-8");
      }

      Map<String, Definition> mp =
          CastUtils.cast(
              (Map) endpointInfo.getService().getProperty(WSDLQueryHandler.class.getName()));
      Map<String, SchemaReference> smp =
          CastUtils.cast(
              (Map)
                  endpointInfo
                      .getService()
                      .getProperty(WSDLQueryHandler.class.getName() + ".Schemas"));

      if (mp == null) {
        endpointInfo
            .getService()
            .setProperty(WSDLQueryHandler.class.getName(), new ConcurrentHashMap());
        mp =
            CastUtils.cast(
                (Map) endpointInfo.getService().getProperty(WSDLQueryHandler.class.getName()));
      }
      if (smp == null) {
        endpointInfo
            .getService()
            .setProperty(WSDLQueryHandler.class.getName() + ".Schemas", new ConcurrentHashMap());
        smp =
            CastUtils.cast(
                (Map)
                    endpointInfo
                        .getService()
                        .getProperty(WSDLQueryHandler.class.getName() + ".Schemas"));
      }

      if (!mp.containsKey("")) {
        Definition def = getDefinition(endpointInfo);

        mp.put("", def);
        updateDefinition(def, mp, smp, base, endpointInfo);
      }

      Document doc;
      if (xsd == null) {
        Definition def = mp.get(wsdl);
        if (def == null) {
          String wsdl2 =
              resolveWithCatalogs(OASISCatalogManager.getCatalogManager(bus), wsdl, base);
          if (wsdl2 != null) {
            def = mp.get(wsdl2);
          }
        }
        if (def == null) {
          throw new WSDLQueryException(
              new org.apache.cxf.common.i18n.Message("WSDL_NOT_FOUND", LOG, wsdl), null);
        }

        synchronized (def) {
          // writing a def is not threadsafe.  Sync on it to make sure
          // we don't get any ConcurrentModificationExceptions
          if (endpointInfo.getProperty("publishedEndpointUrl") != null) {
            String publishingUrl = String.valueOf(endpointInfo.getProperty("publishedEndpointUrl"));
            updatePublishedEndpointUrl(publishingUrl, def, endpointInfo.getName());
          }

          WSDLWriter wsdlWriter =
              bus.getExtension(WSDLManager.class).getWSDLFactory().newWSDLWriter();
          def.setExtensionRegistry(bus.getExtension(WSDLManager.class).getExtensionRegistry());
          doc = wsdlWriter.getDocument(def);
        }
      } else {
        SchemaReference si = smp.get(xsd);
        if (si == null) {
          String xsd2 = resolveWithCatalogs(OASISCatalogManager.getCatalogManager(bus), xsd, base);
          if (xsd2 != null) {
            si = smp.get(xsd2);
          }
        }
        if (si == null) {
          throw new WSDLQueryException(
              new org.apache.cxf.common.i18n.Message("SCHEMA_NOT_FOUND", LOG, wsdl), null);
        }

        String uri = si.getReferencedSchema().getDocumentBaseURI();
        uri =
            resolveWithCatalogs(
                OASISCatalogManager.getCatalogManager(bus),
                uri,
                si.getReferencedSchema().getDocumentBaseURI());
        if (uri == null) {
          uri = si.getReferencedSchema().getDocumentBaseURI();
        }
        ResourceManagerWSDLLocator rml = new ResourceManagerWSDLLocator(uri, bus);

        InputSource src = rml.getBaseInputSource();
        doc = XMLUtils.getParser().parse(src);
      }

      updateDoc(doc, base, mp, smp, endpointInfo);
      String enc = null;
      try {
        enc = doc.getXmlEncoding();
      } catch (Exception ex) {
        // ignore - not dom level 3
      }
      if (enc == null) {
        enc = "utf-8";
      }

      XMLStreamWriter writer = StaxUtils.createXMLStreamWriter(os, enc);
      StaxUtils.writeNode(doc, writer, true);
      writer.flush();
    } catch (WSDLQueryException wex) {
      throw wex;
    } catch (Exception wex) {
      throw new WSDLQueryException(
          new org.apache.cxf.common.i18n.Message("COULD_NOT_PROVIDE_WSDL", LOG, baseUri), wex);
    }
  }

  protected Definition getDefinition(EndpointInfo endpointInfo) throws WSDLException {
    return new ServiceWSDLBuilder(bus, endpointInfo.getService()).build();
  }

  protected void updateDoc(
      Document doc,
      String base,
      Map<String, Definition> mp,
      Map<String, SchemaReference> smp,
      EndpointInfo ei) {
    List<Element> elementList = null;

    try {
      elementList =
          DOMUtils.findAllElementsByTagNameNS(
              doc.getDocumentElement(), "http://www.w3.org/2001/XMLSchema", "import");
      for (Element el : elementList) {
        String sl = el.getAttribute("schemaLocation");
        if (smp.containsKey(URLDecoder.decode(sl, "utf-8"))) {
          el.setAttribute("schemaLocation", rewriteSchemaLocation(base, sl));
        }
      }

      elementList =
          DOMUtils.findAllElementsByTagNameNS(
              doc.getDocumentElement(), "http://www.w3.org/2001/XMLSchema", "include");
      for (Element el : elementList) {
        String sl = el.getAttribute("schemaLocation");
        if (smp.containsKey(URLDecoder.decode(sl, "utf-8"))) {
          el.setAttribute("schemaLocation", rewriteSchemaLocation(base, sl));
        }
      }
      elementList =
          DOMUtils.findAllElementsByTagNameNS(
              doc.getDocumentElement(), "http://www.w3.org/2001/XMLSchema", "redefine");
      for (Element el : elementList) {
        String sl = el.getAttribute("schemaLocation");
        if (smp.containsKey(URLDecoder.decode(sl, "utf-8"))) {
          el.setAttribute("schemaLocation", rewriteSchemaLocation(base, sl));
        }
      }
      elementList =
          DOMUtils.findAllElementsByTagNameNS(
              doc.getDocumentElement(), "http://schemas.xmlsoap.org/wsdl/", "import");
      for (Element el : elementList) {
        String sl = el.getAttribute("location");
        if (mp.containsKey(URLDecoder.decode(sl, "utf-8"))) {
          el.setAttribute("location", base + "?wsdl=" + sl.replace(" ", "%20"));
        }
      }
    } catch (UnsupportedEncodingException e) {
      throw new WSDLQueryException(
          new org.apache.cxf.common.i18n.Message("COULD_NOT_PROVIDE_WSDL", LOG, base), e);
    }

    rewriteOperationAddress(ei, doc, base);

    try {
      doc.setXmlStandalone(true);
    } catch (Exception ex) {
      // likely not DOM level 3
    }
  }

  protected String rewriteSchemaLocation(String base, String schemaLocation) {
    return base + "?xsd=" + schemaLocation.replace(" ", "%20");
  }

  protected void rewriteOperationAddress(EndpointInfo ei, Document doc, String base) {
    Boolean rewriteSoapAddress = ei.getProperty("autoRewriteSoapAddress", Boolean.class);
    List<Element> elementList = null;

    if (rewriteSoapAddress != null && rewriteSoapAddress.booleanValue()) {
      List<Element> serviceList =
          DOMUtils.findAllElementsByTagNameNS(
              doc.getDocumentElement(), "http://schemas.xmlsoap.org/wsdl/", "service");
      for (Element serviceEl : serviceList) {
        String serviceName = serviceEl.getAttribute("name");
        if (serviceName.equals(ei.getService().getName().getLocalPart())) {
          elementList =
              DOMUtils.findAllElementsByTagNameNS(
                  doc.getDocumentElement(), "http://schemas.xmlsoap.org/wsdl/", "port");
          for (Element el : elementList) {
            String name = el.getAttribute("name");
            if (name.equals(ei.getName().getLocalPart())) {
              Element soapAddress =
                  DOMUtils.findAllElementsByTagNameNS(
                          el, "http://schemas.xmlsoap.org/wsdl/soap/", "address")
                      .iterator()
                      .next();
              soapAddress.setAttribute("location", base);
            }
          }
        }
      }
    }
  }

  static String resolveWithCatalogs(OASISCatalogManager catalogs, String start, String base) {
    if (catalogs == null) {
      return null;
    }
    String resolvedSchemaLocation = null;
    try {
      resolvedSchemaLocation = catalogs.resolveSystem(start);
      if (resolvedSchemaLocation == null) {
        resolvedSchemaLocation = catalogs.resolveURI(start);
      }
      if (resolvedSchemaLocation == null) {
        resolvedSchemaLocation = catalogs.resolvePublic(start, base);
      }
    } catch (Exception ex) {
      // ignore
    }
    return resolvedSchemaLocation;
  }

  protected void updateDefinition(
      Definition def,
      Map<String, Definition> done,
      Map<String, SchemaReference> doneSchemas,
      String base,
      EndpointInfo ei) {
    OASISCatalogManager catalogs = OASISCatalogManager.getCatalogManager(bus);

    Collection<List<?>> imports = CastUtils.cast((Collection<?>) def.getImports().values());
    for (List<?> lst : imports) {
      List<Import> impLst = CastUtils.cast(lst);
      for (Import imp : impLst) {

        String start = imp.getLocationURI();
        String decodedStart = null;
        // Always use the URL decoded version to ensure that we have a
        // canonical representation of the import URL for lookup.
        try {
          decodedStart = URLDecoder.decode(start, "utf-8");
        } catch (UnsupportedEncodingException e) {
          throw new WSDLQueryException(
              new org.apache.cxf.common.i18n.Message("COULD_NOT_PROVIDE_WSDL", LOG, start), e);
        }

        String resolvedSchemaLocation = resolveWithCatalogs(catalogs, start, base);

        if (resolvedSchemaLocation == null) {
          try {
            // check to see if it's already in a URL format.  If so, leave it.
            new URL(start);
          } catch (MalformedURLException e) {
            if (done.put(decodedStart, imp.getDefinition()) == null) {
              updateDefinition(imp.getDefinition(), done, doneSchemas, base, ei);
            }
          }
        } else {
          if (done.put(decodedStart, imp.getDefinition()) == null) {
            done.put(resolvedSchemaLocation, imp.getDefinition());
            updateDefinition(imp.getDefinition(), done, doneSchemas, base, ei);
          }
        }
      }
    }

    /* This doesn't actually work.   Setting setSchemaLocationURI on the import
     * for some reason doesn't actually result in the new URI being written
     * */
    Types types = def.getTypes();
    if (types != null) {
      for (ExtensibilityElement el :
          CastUtils.cast(types.getExtensibilityElements(), ExtensibilityElement.class)) {
        if (el instanceof Schema) {
          Schema see = (Schema) el;
          updateSchemaImports(see, doneSchemas, base);
        }
      }
    }
  }

  protected void updatePublishedEndpointUrl(String publishingUrl, Definition def, QName name) {
    Collection<Service> services = CastUtils.cast(def.getAllServices().values());
    for (Service service : services) {
      Collection<Port> ports = CastUtils.cast(service.getPorts().values());
      if (ports.isEmpty()) {
        continue;
      }

      if (name == null) {
        setSoapAddressLocationOn(ports.iterator().next(), publishingUrl);
        break; // only update the first port since we don't target any specific port
      } else {
        for (Port port : ports) {
          if (name.getLocalPart().equals(port.getName())) {
            setSoapAddressLocationOn(port, publishingUrl);
          }
        }
      }
    }
  }

  private void setSoapAddressLocationOn(Port port, String url) {
    List<?> extensions = port.getExtensibilityElements();
    for (Object extension : extensions) {
      if (extension instanceof SOAP12Address) {
        ((SOAP12Address) extension).setLocationURI(url);
      } else if (extension instanceof SOAPAddress) {
        ((SOAPAddress) extension).setLocationURI(url);
      }
    }
  }

  protected void updateSchemaImports(
      Schema schema, Map<String, SchemaReference> doneSchemas, String base) {
    OASISCatalogManager catalogs = OASISCatalogManager.getCatalogManager(bus);
    Collection<List<?>> imports = CastUtils.cast((Collection<?>) schema.getImports().values());
    for (List<?> lst : imports) {
      List<SchemaImport> impLst = CastUtils.cast(lst);
      for (SchemaImport imp : impLst) {
        String start = imp.getSchemaLocationURI();

        if (start != null) {
          String decodedStart = null;
          // Always use the URL decoded version to ensure that we have a
          // canonical representation of the import URL for lookup.
          try {
            decodedStart = URLDecoder.decode(start, "utf-8");
          } catch (UnsupportedEncodingException e) {
            throw new WSDLQueryException(
                new org.apache.cxf.common.i18n.Message("COULD_NOT_PROVIDE_WSDL", LOG, start), e);
          }

          if (!doneSchemas.containsKey(decodedStart)) {
            String resolvedSchemaLocation = resolveWithCatalogs(catalogs, start, base);
            if (resolvedSchemaLocation == null) {
              try {
                checkSchemaUrl(doneSchemas, start, decodedStart, imp);
              } catch (MalformedURLException e) {
                if (doneSchemas.put(decodedStart, imp) == null) {
                  updateSchemaImports(imp.getReferencedSchema(), doneSchemas, base);
                }
              }
            } else {
              if (doneSchemas.put(decodedStart, imp) == null) {
                doneSchemas.put(resolvedSchemaLocation, imp);
                updateSchemaImports(imp.getReferencedSchema(), doneSchemas, base);
              }
            }
          }
        }
      }
    }

    List<SchemaReference> includes = CastUtils.cast(schema.getIncludes());
    for (SchemaReference included : includes) {
      String start = included.getSchemaLocationURI();

      if (start != null) {
        String decodedStart = null;
        // Always use the URL decoded version to ensure that we have a
        // canonical representation of the import URL for lookup.
        try {
          decodedStart = URLDecoder.decode(start, "utf-8");
        } catch (UnsupportedEncodingException e) {
          /*throw new WSDLQueryException(new org.apache.cxf.common.i18n.Message("COULD_NOT_PROVIDE_WSDL",
          LOG,
          start), e); */
        }

        String resolvedSchemaLocation = resolveWithCatalogs(catalogs, start, base);
        if (resolvedSchemaLocation == null) {
          if (!doneSchemas.containsKey(decodedStart)) {
            try {
              checkSchemaUrl(doneSchemas, start, decodedStart, included);
            } catch (MalformedURLException e) {
              if (doneSchemas.put(decodedStart, included) == null) {
                updateSchemaImports(included.getReferencedSchema(), doneSchemas, base);
              }
            }
          }
        } else if (!doneSchemas.containsKey(decodedStart)
            || !doneSchemas.containsKey(resolvedSchemaLocation)) {
          doneSchemas.put(decodedStart, included);
          doneSchemas.put(resolvedSchemaLocation, included);
          updateSchemaImports(included.getReferencedSchema(), doneSchemas, base);
        }
      }
    }
    List<SchemaReference> redefines = CastUtils.cast(schema.getRedefines());
    for (SchemaReference included : redefines) {
      String start = included.getSchemaLocationURI();

      if (start != null) {
        String decodedStart = null;
        // Always use the URL decoded version to ensure that we have a
        // canonical representation of the import URL for lookup.
        try {
          decodedStart = URLDecoder.decode(start, "utf-8");
        } catch (UnsupportedEncodingException e) {
          throw new WSDLQueryException(
              new org.apache.cxf.common.i18n.Message("COULD_NOT_PROVIDE_WSDL", LOG, start), e);
        }

        String resolvedSchemaLocation = resolveWithCatalogs(catalogs, start, base);
        if (resolvedSchemaLocation == null) {
          if (!doneSchemas.containsKey(decodedStart)) {
            try {
              checkSchemaUrl(doneSchemas, start, decodedStart, included);
            } catch (MalformedURLException e) {
              if (doneSchemas.put(decodedStart, included) == null) {
                updateSchemaImports(included.getReferencedSchema(), doneSchemas, base);
              }
            }
          }
        } else if (!doneSchemas.containsKey(decodedStart)
            || !doneSchemas.containsKey(resolvedSchemaLocation)) {
          doneSchemas.put(decodedStart, included);
          doneSchemas.put(resolvedSchemaLocation, included);
          updateSchemaImports(included.getReferencedSchema(), doneSchemas, base);
        }
      }
    }
  }

  protected void checkSchemaUrl(
      Map<String, SchemaReference> doneSchemas,
      String start,
      String decodedStart,
      SchemaReference imp)
      throws MalformedURLException {
    // check to see if it's already in a URL format.  If so, leave it.
    new URL(start);
  }

  @Override
  public boolean isRecognizedQuery(String baseUri, String ctx, EndpointInfo endpointInfo) {
    return isRecognizedQuery(baseUri, ctx, endpointInfo, false);
  }
}
@Produces({"application/xml", "application/*+xml", "text/xml", "text/html"})
@Consumes({"application/xml", "application/*+xml", "text/xml", "text/html"})
@Provider
public class XSLTJaxbProvider<T> extends JAXBElementProvider<T> {

  private static final Logger LOG = LogUtils.getL7dLogger(XSLTJaxbProvider.class);

  private static final String ABSOLUTE_PATH_PARAMETER = "absolute.path";
  private static final String BASE_PATH_PARAMETER = "base.path";
  private static final String RELATIVE_PATH_PARAMETER = "relative.path";
  private static final String XSLT_TEMPLATE_PROPERTY = "xslt.template";
  private SAXTransformerFactory factory;
  private Templates inTemplates;
  private Templates outTemplates;
  private Map<String, Templates> inMediaTemplates;
  private Map<String, Templates> outMediaTemplates;
  private ConcurrentHashMap<String, Templates> annotationTemplates =
      new ConcurrentHashMap<String, Templates>();

  private List<String> inClassesToHandle;
  private List<String> outClassesToHandle;
  private Map<String, Object> inParamsMap;
  private Map<String, Object> outParamsMap;
  private Map<String, String> inProperties;
  private Map<String, String> outProperties;
  private URIResolver uriResolver;
  private String systemId;

  private boolean supportJaxbOnly;
  private boolean refreshTemplates;

  public void setSupportJaxbOnly(boolean support) {
    this.supportJaxbOnly = support;
  }

  @Override
  public boolean isReadable(Class<?> type, Type genericType, Annotation[] anns, MediaType mt) {
    if (!super.isReadable(type, genericType, anns, mt)) {
      return false;
    }

    if (InjectionUtils.isSupportedCollectionOrArray(type)) {
      return supportJaxbOnly;
    }

    // if the user has set the list of in classes and a given class
    // is in that list then it can only be handled by the template
    if (inClassCanBeHandled(type.getName()) || inClassesToHandle == null && !supportJaxbOnly) {
      return inTemplatesAvailable(type, anns, mt);
    } else {
      return supportJaxbOnly;
    }
  }

  @Override
  public boolean isWriteable(Class<?> type, Type genericType, Annotation[] anns, MediaType mt) {
    // JAXB support is required
    if (!super.isWriteable(type, genericType, anns, mt)) {
      return false;
    }
    if (InjectionUtils.isSupportedCollectionOrArray(type)) {
      return supportJaxbOnly;
    }

    // if the user has set the list of out classes and a given class
    // is in that list then it can only be handled by the template
    if (outClassCanBeHandled(type.getName()) || outClassesToHandle == null && !supportJaxbOnly) {
      return outTemplatesAvailable(type, anns, mt);
    } else {
      return supportJaxbOnly;
    }
  }

  protected boolean inTemplatesAvailable(Class<?> cls, Annotation[] anns, MediaType mt) {
    return inTemplates != null
        || inMediaTemplates != null
            && inMediaTemplates.containsKey(mt.getType() + "/" + mt.getSubtype())
        || getTemplatesFromAnnotation(cls, anns, mt) != null;
  }

  protected boolean outTemplatesAvailable(Class<?> cls, Annotation[] anns, MediaType mt) {
    return outTemplates != null
        || outMediaTemplates != null
            && outMediaTemplates.containsKey(mt.getType() + "/" + mt.getSubtype())
        || getTemplatesFromAnnotation(cls, anns, mt) != null;
  }

  protected Templates getTemplatesFromAnnotation(Class<?> cls, Annotation[] anns, MediaType mt) {
    Templates t = null;
    XSLTTransform ann = getXsltTransformAnn(anns, mt);
    if (ann != null) {
      t = annotationTemplates.get(ann.value());
      if (t == null || refreshTemplates) {
        String path = ann.value();
        final String cp = "classpath:";
        if (!path.startsWith(cp)) {
          path = cp + path;
        }
        t = createTemplates(path);
        if (t == null) {
          createTemplates(ClassLoaderUtils.getResource(ann.value(), cls));
        }
        if (t != null) {
          annotationTemplates.put(ann.value(), t);
        }
      }
    }
    return t;
  }

  protected Templates getAnnotationTemplates(Annotation[] anns) {
    Templates t = null;
    XSLTTransform ann = AnnotationUtils.getAnnotation(anns, XSLTTransform.class);
    if (ann != null) {
      t = annotationTemplates.get(ann.value());
    }
    return t;
  }

  protected XSLTTransform getXsltTransformAnn(Annotation[] anns, MediaType mt) {
    XSLTTransform ann = AnnotationUtils.getAnnotation(anns, XSLTTransform.class);
    if (ann != null && ann.type() != XSLTTransform.TransformType.CLIENT) {
      if (ann.mediaTypes().length > 0) {
        for (String s : ann.mediaTypes()) {
          if (mt.isCompatible(JAXRSUtils.toMediaType(s))) {
            return ann;
          }
        }
        return null;
      }
      return ann;
    }
    return null;
  }

  protected Templates getInTemplates(Annotation[] anns, MediaType mt) {
    Templates t = createTemplatesFromContext();
    if (t != null) {
      return t;
    }
    t =
        inTemplates != null
            ? inTemplates
            : inMediaTemplates != null
                ? inMediaTemplates.get(mt.getType() + "/" + mt.getSubtype())
                : null;
    if (t == null) {
      t = getAnnotationTemplates(anns);
    }
    return t;
  }

  protected Templates getOutTemplates(Annotation[] anns, MediaType mt) {
    Templates t = createTemplatesFromContext();
    if (t != null) {
      return t;
    }
    t =
        outTemplates != null
            ? outTemplates
            : outMediaTemplates != null
                ? outMediaTemplates.get(mt.getType() + "/" + mt.getSubtype())
                : null;
    if (t == null) {
      t = getAnnotationTemplates(anns);
    }
    return t;
  }

  @Override
  protected Object unmarshalFromInputStream(
      Unmarshaller unmarshaller, InputStream is, Annotation[] anns, MediaType mt)
      throws JAXBException {
    try {

      Templates t = createTemplates(getInTemplates(anns, mt), inParamsMap, inProperties);
      if (t == null && supportJaxbOnly) {
        return super.unmarshalFromInputStream(unmarshaller, is, anns, mt);
      }

      if (unmarshaller.getClass().getName().contains("eclipse")) {
        // eclipse MOXy doesn't work properly with the XMLFilter/Reader thing
        // so we need to bounce through a DOM
        Source reader = new StaxSource(StaxUtils.createXMLStreamReader(is));
        DOMResult dom = new DOMResult();
        t.newTransformer().transform(reader, dom);
        return unmarshaller.unmarshal(dom.getNode());
      }
      XMLFilter filter = null;
      try {
        filter = factory.newXMLFilter(t);
      } catch (TransformerConfigurationException ex) {
        TemplatesImpl ti = (TemplatesImpl) t;
        filter = factory.newXMLFilter(ti.getTemplates());
        trySettingProperties(filter, ti);
      }
      XMLReader reader = new StaxSource(StaxUtils.createXMLStreamReader(is));
      filter.setParent(reader);
      SAXSource source = new SAXSource();
      source.setXMLReader(filter);
      if (systemId != null) {
        source.setSystemId(systemId);
      }
      return unmarshaller.unmarshal(source);
    } catch (TransformerException ex) {
      LOG.warning("Transformation exception : " + ex.getMessage());
      throw ExceptionUtils.toInternalServerErrorException(ex, null);
    }
  }

  private void trySettingProperties(Object filter, TemplatesImpl ti) {
    try {
      // Saxon doesn't allow creating a Filter or Handler from anything other than it's original
      // Templates.  That then requires setting the parameters after the fact, but there
      // isn't a standard API for that, so we have to grab the Transformer via reflection to
      // set the parameters.
      Transformer tr = (Transformer) filter.getClass().getMethod("getTransformer").invoke(filter);
      tr.setURIResolver(ti.resolver);
      for (Map.Entry<String, Object> entry : ti.transformParameters.entrySet()) {
        tr.setParameter(entry.getKey(), entry.getValue());
      }
      for (Map.Entry<String, String> entry : ti.outProps.entrySet()) {
        tr.setOutputProperty(entry.getKey(), entry.getValue());
      }
    } catch (Exception e) {
      LOG.log(Level.WARNING, "Could not set properties for transfomer", e);
    }
  }

  @Override
  protected Object unmarshalFromReader(
      Unmarshaller unmarshaller, XMLStreamReader reader, Annotation[] anns, MediaType mt)
      throws JAXBException {
    CachedOutputStream out = new CachedOutputStream();
    try {
      XMLStreamWriter writer = StaxUtils.createXMLStreamWriter(out);
      StaxUtils.copy(new StaxSource(reader), writer);
      writer.writeEndDocument();
      writer.flush();
      writer.close();
      return unmarshalFromInputStream(unmarshaller, out.getInputStream(), anns, mt);
    } catch (Exception ex) {
      throw ExceptionUtils.toBadRequestException(ex, null);
    }
  }

  @Override
  protected void marshalToWriter(
      Marshaller ms, Object obj, XMLStreamWriter writer, Annotation[] anns, MediaType mt)
      throws Exception {
    CachedOutputStream out = new CachedOutputStream();
    marshalToOutputStream(ms, obj, out, anns, mt);

    StaxUtils.copy(new StreamSource(out.getInputStream()), writer);
  }

  @Override
  protected void addAttachmentMarshaller(Marshaller ms) {
    // complete
  }

  @Override
  protected void marshalToOutputStream(
      Marshaller ms, Object obj, OutputStream os, Annotation[] anns, MediaType mt)
      throws Exception {

    Templates t = createTemplates(getOutTemplates(anns, mt), outParamsMap, outProperties);
    if (t == null && supportJaxbOnly) {
      super.marshalToOutputStream(ms, obj, os, anns, mt);
      return;
    }
    TransformerHandler th = null;
    try {
      th = factory.newTransformerHandler(t);
    } catch (TransformerConfigurationException ex) {
      TemplatesImpl ti = (TemplatesImpl) t;
      th = factory.newTransformerHandler(ti.getTemplates());
      this.trySettingProperties(th, ti);
    }
    Result result = new StreamResult(os);
    if (systemId != null) {
      result.setSystemId(systemId);
    }
    th.setResult(result);

    if (getContext() == null) {
      th.startDocument();
    }
    ms.marshal(obj, th);
    if (getContext() == null) {
      th.endDocument();
    }
  }

  public void setOutTemplate(String loc) {
    outTemplates = createTemplates(loc);
  }

  public void setInTemplate(String loc) {
    inTemplates = createTemplates(loc);
  }

  public void setInMediaTemplates(Map<String, String> map) {
    inMediaTemplates = new HashMap<String, Templates>();
    for (Map.Entry<String, String> entry : map.entrySet()) {
      inMediaTemplates.put(entry.getKey(), createTemplates(entry.getValue()));
    }
  }

  public void setOutMediaTemplates(Map<String, String> map) {
    outMediaTemplates = new HashMap<String, Templates>();
    for (Map.Entry<String, String> entry : map.entrySet()) {
      outMediaTemplates.put(entry.getKey(), createTemplates(entry.getValue()));
    }
  }

  public void setResolver(URIResolver resolver) {
    uriResolver = resolver;
    if (factory != null) {
      factory.setURIResolver(uriResolver);
    }
  }

  public void setSystemId(String system) {
    systemId = system;
  }

  public void setInParameters(Map<String, Object> inParams) {
    this.inParamsMap = inParams;
  }

  public void setOutParameters(Map<String, Object> outParams) {
    this.outParamsMap = outParams;
  }

  public void setInProperties(Map<String, String> inProps) {
    this.inProperties = inProps;
  }

  public void setOutProperties(Map<String, String> outProps) {
    this.outProperties = outProps;
  }

  public void setInClassNames(List<String> classNames) {
    inClassesToHandle = classNames;
  }

  public boolean inClassCanBeHandled(String className) {
    return inClassesToHandle != null && inClassesToHandle.contains(className);
  }

  public void setOutClassNames(List<String> classNames) {
    outClassesToHandle = classNames;
  }

  public boolean outClassCanBeHandled(String className) {
    return outClassesToHandle != null && outClassesToHandle.contains(className);
  }

  protected Templates createTemplates(
      Templates templates, Map<String, Object> configuredParams, Map<String, String> outProps) {
    if (templates == null) {
      if (supportJaxbOnly) {
        return null;
      } else {
        LOG.severe("No template is available");
        throw ExceptionUtils.toInternalServerErrorException(null, null);
      }
    }

    TemplatesImpl templ = new TemplatesImpl(templates, uriResolver);
    MessageContext mc = getContext();
    if (mc != null) {
      UriInfo ui = mc.getUriInfo();
      MultivaluedMap<String, String> params = ui.getPathParameters();
      for (Map.Entry<String, List<String>> entry : params.entrySet()) {
        String value = entry.getValue().get(0);
        int ind = value.indexOf(";");
        if (ind > 0) {
          value = value.substring(0, ind);
        }
        templ.setTransformerParameter(entry.getKey(), value);
      }

      List<PathSegment> segments = ui.getPathSegments();
      if (segments.size() > 0) {
        setTransformParameters(templ, segments.get(segments.size() - 1).getMatrixParameters());
      }
      setTransformParameters(templ, ui.getQueryParameters());
      templ.setTransformerParameter(ABSOLUTE_PATH_PARAMETER, ui.getAbsolutePath().toString());
      templ.setTransformerParameter(RELATIVE_PATH_PARAMETER, ui.getPath());
      templ.setTransformerParameter(BASE_PATH_PARAMETER, ui.getBaseUri().toString());
      if (configuredParams != null) {
        for (Map.Entry<String, Object> entry : configuredParams.entrySet()) {
          templ.setTransformerParameter(entry.getKey(), entry.getValue());
        }
      }
    }
    if (outProps != null) {
      templ.setOutProperties(outProps);
    }

    return templ;
  }

  private void setTransformParameters(TemplatesImpl templ, MultivaluedMap<String, String> params) {
    for (Map.Entry<String, List<String>> entry : params.entrySet()) {
      templ.setTransformerParameter(entry.getKey(), entry.getValue().get(0));
    }
  }

  protected Templates createTemplates(String loc) {
    try {
      return createTemplates(ResourceUtils.getResourceURL(loc, this.getBus()));
    } catch (Exception ex) {
      LOG.warning("No template can be created : " + ex.getMessage());
    }
    return null;
  }

  protected Templates createTemplatesFromContext() {
    MessageContext mc = getContext();
    if (mc != null) {
      String template = (String) mc.getContextualProperty(XSLT_TEMPLATE_PROPERTY);
      if (template != null) {
        return createTemplates(template);
      }
    }
    return null;
  }

  protected Templates createTemplates(URL urlStream) {
    try {
      if (urlStream == null) {
        return null;
      }

      Reader r = new BufferedReader(new InputStreamReader(urlStream.openStream(), "UTF-8"));
      Source source = new StreamSource(r);
      source.setSystemId(urlStream.toExternalForm());
      if (factory == null) {
        factory = (SAXTransformerFactory) TransformerFactory.newInstance();
        factory.setFeature(XMLConstants.FEATURE_SECURE_PROCESSING, Boolean.TRUE);
        if (uriResolver != null) {
          factory.setURIResolver(uriResolver);
        }
      }
      return factory.newTemplates(source);

    } catch (Exception ex) {
      LOG.warning("No template can be created : " + ex.getMessage());
    }
    return null;
  }

  public void setRefreshTemplates(boolean refresh) {
    this.refreshTemplates = refresh;
  }

  private static class TemplatesImpl implements Templates {

    private Templates templates;
    private URIResolver resolver;
    private Map<String, Object> transformParameters = new HashMap<String, Object>();
    private Map<String, String> outProps = new HashMap<String, String>();

    TemplatesImpl(Templates templates, URIResolver resolver) {
      this.templates = templates;
      this.resolver = resolver;
    }

    public Templates getTemplates() {
      return templates;
    }

    public void setTransformerParameter(String name, Object value) {
      transformParameters.put(name, value);
    }

    public void setOutProperties(Map<String, String> props) {
      this.outProps = props;
    }

    public Properties getOutputProperties() {
      return templates.getOutputProperties();
    }

    public Transformer newTransformer() throws TransformerConfigurationException {
      Transformer tr = templates.newTransformer();
      tr.setURIResolver(resolver);
      for (Map.Entry<String, Object> entry : transformParameters.entrySet()) {
        tr.setParameter(entry.getKey(), entry.getValue());
      }
      for (Map.Entry<String, String> entry : outProps.entrySet()) {
        tr.setOutputProperty(entry.getKey(), entry.getValue());
      }
      return tr;
    }
  }
}
Example #29
0
/**
 * Logical Handler responsible for aggregating the Message Addressing Properties for outgoing
 * messages.
 */
public class MAPAggregatorImpl extends MAPAggregator {

  private static final Logger LOG = LogUtils.getL7dLogger(MAPAggregator.class);
  private static final ResourceBundle BUNDLE = LOG.getResourceBundle();

  private static final ClientLifeCycleListener DECOUPLED_DEST_CLEANER =
      new ClientLifeCycleListener() {
        public void clientCreated(Client client) {
          // ignore
        }

        public void clientDestroyed(Client client) {
          Destination dest =
              client
                  .getEndpoint()
                  .getEndpointInfo()
                  .getProperty(DECOUPLED_DESTINATION, Destination.class);
          if (dest != null) {
            dest.setMessageObserver(null);
            dest.shutdown();
          }
        }
      };

  /** Constructor. */
  public MAPAggregatorImpl() {
    messageIdCache = new DefaultMessageIdCache();
  }

  public MAPAggregatorImpl(MAPAggregator mag) {
    this.addressingRequired = mag.isAddressingRequired();
    this.messageIdCache = mag.getMessageIdCache();
    if (messageIdCache == null) {
      messageIdCache = new DefaultMessageIdCache();
    }
    this.usingAddressingAdvisory = mag.isUsingAddressingAdvisory();
    this.allowDuplicates = mag.allowDuplicates();
    this.addressingResponses = mag.getAddressingResponses();
  }

  /**
   * Invoked for normal processing of inbound and outbound messages.
   *
   * @param message the current message
   */
  public void handleMessage(Message message) {
    if (!MessageUtils.getContextualBoolean(message, ADDRESSING_DISABLED, false)) {
      mediate(message, ContextUtils.isFault(message));
    } else {
      // addressing is completely disabled manually, we need to assert the
      // assertions as the user is in control of those
      AssertionInfoMap aim = message.get(AssertionInfoMap.class);
      if (null == aim) {
        return;
      }
      QName[] types =
          new QName[] {
            MetadataConstants.ADDRESSING_ASSERTION_QNAME,
            MetadataConstants.USING_ADDRESSING_2004_QNAME,
            MetadataConstants.USING_ADDRESSING_2005_QNAME,
            MetadataConstants.USING_ADDRESSING_2006_QNAME,
            MetadataConstants.ANON_RESPONSES_ASSERTION_QNAME,
            MetadataConstants.NON_ANON_RESPONSES_ASSERTION_QNAME,
            MetadataConstants.ANON_RESPONSES_ASSERTION_QNAME_0705,
            MetadataConstants.NON_ANON_RESPONSES_ASSERTION_QNAME_0705
          };
      for (QName type : types) {
        assertAssertion(aim, type);
      }
    }
  }

  /**
   * Invoked when unwinding normal interceptor chain when a fault occurred.
   *
   * @param message the current message
   */
  public void handleFault(Message message) {
    message.put(MAPAggregator.class.getName(), this);
  }

  /**
   * Determine if addressing is being used
   *
   * @param message the current message
   * @pre message is outbound
   */
  private boolean usingAddressing(Message message) {
    boolean ret = true;
    if (ContextUtils.isRequestor(message)) {
      if (hasUsingAddressing(message)
          || hasAddressingAssertion(message)
          || hasUsingAddressingAssertion(message)) {
        return true;
      }
      if (!usingAddressingAdvisory || !WSAContextUtils.retrieveUsingAddressing(message)) {
        ret = false;
      }
    } else {
      ret = getMAPs(message, false, false) != null;
    }
    return ret;
  }

  /**
   * Determine if the use of addressing is indicated by the presence of a the usingAddressing
   * attribute.
   *
   * @param message the current message
   * @pre message is outbound
   * @pre requestor role
   */
  private boolean hasUsingAddressing(Message message) {
    boolean ret = false;
    Endpoint endpoint = message.getExchange().getEndpoint();
    if (null != endpoint) {
      Boolean b = (Boolean) endpoint.get(USING_ADDRESSING);
      if (null == b) {
        EndpointInfo endpointInfo = endpoint.getEndpointInfo();
        List<ExtensibilityElement> endpointExts =
            endpointInfo != null ? endpointInfo.getExtensors(ExtensibilityElement.class) : null;
        List<ExtensibilityElement> bindingExts =
            endpointInfo != null && endpointInfo.getBinding() != null
                ? endpointInfo.getBinding().getExtensors(ExtensibilityElement.class)
                : null;
        List<ExtensibilityElement> serviceExts =
            endpointInfo != null && endpointInfo.getService() != null
                ? endpointInfo.getService().getExtensors(ExtensibilityElement.class)
                : null;
        ret =
            hasUsingAddressing(endpointExts)
                || hasUsingAddressing(bindingExts)
                || hasUsingAddressing(serviceExts);
        b = ret ? Boolean.TRUE : Boolean.FALSE;
        endpoint.put(USING_ADDRESSING, b);
      } else {
        ret = b.booleanValue();
      }
    }
    return ret;
  }

  /**
   * Determine if the use of addressing is indicated by an Addressing assertion in the alternative
   * chosen for the current message.
   *
   * @param message the current message
   * @pre message is outbound
   * @pre requestor role
   */
  private boolean hasAddressingAssertion(Message message) {
    AssertionInfoMap aim = message.get(AssertionInfoMap.class);
    if (null == aim) {
      return false;
    }
    return null != aim.get(MetadataConstants.ADDRESSING_ASSERTION_QNAME);
  }

  /**
   * Determine if the use of addressing is indicated by a UsingAddressing in the alternative chosen
   * for the current message.
   *
   * @param message the current message
   * @pre message is outbound
   * @pre requestor role
   */
  private boolean hasUsingAddressingAssertion(Message message) {
    AssertionInfoMap aim = message.get(AssertionInfoMap.class);
    if (null == aim) {
      return false;
    }
    if (null != aim.get(MetadataConstants.USING_ADDRESSING_2004_QNAME)) {
      return true;
    }
    if (null != aim.get(MetadataConstants.USING_ADDRESSING_2005_QNAME)) {
      return true;
    }
    return null != aim.get(MetadataConstants.USING_ADDRESSING_2006_QNAME);
  }

  private WSAddressingFeature getWSAddressingFeature(Message message) {
    if (message.getExchange() != null && message.getExchange().getEndpoint() != null) {
      Endpoint endpoint = message.getExchange().getEndpoint();
      if (endpoint.getActiveFeatures() != null) {
        for (Feature feature : endpoint.getActiveFeatures()) {
          if (feature instanceof WSAddressingFeature) {
            return (WSAddressingFeature) feature;
          }
        }
      }
    }
    return null;
  }
  /**
   * If the isRequestor(message) == true and isAddressRequired() == false Assert all the wsa related
   * assertion to true
   *
   * @param message the current message
   */
  private void assertAddressing(Message message) {
    AssertionInfoMap aim = message.get(AssertionInfoMap.class);
    if (null == aim) {
      return;
    }
    QName[] types =
        new QName[] {
          MetadataConstants.ADDRESSING_ASSERTION_QNAME,
              MetadataConstants.USING_ADDRESSING_2004_QNAME,
          MetadataConstants.USING_ADDRESSING_2005_QNAME,
              MetadataConstants.USING_ADDRESSING_2006_QNAME
        };

    for (QName type : types) {
      assertAssertion(aim, type);
      if (type.equals(MetadataConstants.ADDRESSING_ASSERTION_QNAME)) {
        assertAssertion(aim, MetadataConstants.ANON_RESPONSES_ASSERTION_QNAME);
        assertAssertion(aim, MetadataConstants.NON_ANON_RESPONSES_ASSERTION_QNAME);
      } else if (type.equals(MetadataConstants.ADDRESSING_ASSERTION_QNAME_0705)) {
        assertAssertion(aim, MetadataConstants.ANON_RESPONSES_ASSERTION_QNAME_0705);
        assertAssertion(aim, MetadataConstants.NON_ANON_RESPONSES_ASSERTION_QNAME_0705);
      }
    }
  }

  /**
   * Asserts all Addressing assertions for the current message, regardless their nested Policies.
   *
   * @param message the current message
   */
  private void assertAddressing(
      Message message, EndpointReferenceType replyTo, EndpointReferenceType faultTo) {
    AssertionInfoMap aim = message.get(AssertionInfoMap.class);
    if (null == aim) {
      return;
    }
    if (faultTo == null) {
      faultTo = replyTo;
    }
    boolean anonReply = ContextUtils.isGenericAddress(replyTo);
    boolean anonFault = ContextUtils.isGenericAddress(faultTo);
    boolean onlyAnonymous = anonReply && anonFault;
    boolean hasAnonymous = anonReply || anonFault;

    QName[] types =
        new QName[] {
          MetadataConstants.ADDRESSING_ASSERTION_QNAME,
          MetadataConstants.USING_ADDRESSING_2004_QNAME,
          MetadataConstants.USING_ADDRESSING_2005_QNAME,
          MetadataConstants.USING_ADDRESSING_2006_QNAME
        };

    for (QName type : types) {
      assertAssertion(aim, type);
      if (type.equals(MetadataConstants.ADDRESSING_ASSERTION_QNAME)) {
        if (onlyAnonymous) {
          assertAssertion(aim, MetadataConstants.ANON_RESPONSES_ASSERTION_QNAME);
        } else if (!hasAnonymous) {
          assertAssertion(aim, MetadataConstants.NON_ANON_RESPONSES_ASSERTION_QNAME);
        }
      } else if (type.equals(MetadataConstants.ADDRESSING_ASSERTION_QNAME_0705)) {
        if (onlyAnonymous) {
          assertAssertion(aim, MetadataConstants.ANON_RESPONSES_ASSERTION_QNAME_0705);
        } else if (!hasAnonymous) {
          assertAssertion(aim, MetadataConstants.NON_ANON_RESPONSES_ASSERTION_QNAME_0705);
        }
      }
    }
    if (!MessageUtils.isRequestor(message) && !MessageUtils.isOutbound(message)) {
      // need to throw an appropriate fault for these
      Collection<AssertionInfo> aicNonAnon =
          aim.getAssertionInfo(MetadataConstants.NON_ANON_RESPONSES_ASSERTION_QNAME);
      Collection<AssertionInfo> aicNonAnon2 =
          aim.getAssertionInfo(MetadataConstants.NON_ANON_RESPONSES_ASSERTION_QNAME_0705);
      Collection<AssertionInfo> aicAnon =
          aim.getAssertionInfo(MetadataConstants.ANON_RESPONSES_ASSERTION_QNAME);
      Collection<AssertionInfo> aicAnon2 =
          aim.getAssertionInfo(MetadataConstants.ANON_RESPONSES_ASSERTION_QNAME_0705);
      boolean hasAnon =
          (aicAnon != null && !aicAnon.isEmpty()) || (aicAnon2 != null && !aicAnon2.isEmpty());
      boolean hasNonAnon =
          (aicNonAnon != null && !aicNonAnon.isEmpty())
              || (aicNonAnon2 != null && !aicNonAnon2.isEmpty());

      if (hasAnonymous && hasNonAnon && !hasAnon) {
        message.put(FaultMode.class, FaultMode.UNCHECKED_APPLICATION_FAULT);
        if (isSOAP12(message)) {
          SoapFault soap12Fault =
              new SoapFault(
                  "Found anonymous address but non-anonymous required",
                  Soap12.getInstance().getSender());
          soap12Fault.addSubCode(
              new QName(Names.WSA_NAMESPACE_NAME, "OnlyNonAnonymousAddressSupported"));
          throw soap12Fault;
        }

        throw new SoapFault(
            "Found anonymous address but non-anonymous required",
            new QName(Names.WSA_NAMESPACE_NAME, "OnlyNonAnonymousAddressSupported"));
      } else if (!onlyAnonymous && !hasNonAnon && hasAnon) {
        message.put(FaultMode.class, FaultMode.UNCHECKED_APPLICATION_FAULT);
        if (isSOAP12(message)) {
          SoapFault soap12Fault =
              new SoapFault(
                  "Found non-anonymous address but only anonymous supported",
                  Soap12.getInstance().getSender());
          soap12Fault.addSubCode(
              new QName(Names.WSA_NAMESPACE_NAME, "OnlyAnonymousAddressSupported"));
          throw soap12Fault;
        }

        throw new SoapFault(
            "Found non-anonymous address but only anonymous supported",
            new QName(Names.WSA_NAMESPACE_NAME, "OnlyAnonymousAddressSupported"));
      }
    }
  }

  private void assertAssertion(AssertionInfoMap aim, QName type) {
    Collection<AssertionInfo> aic = aim.getAssertionInfo(type);
    for (AssertionInfo ai : aic) {
      ai.setAsserted(true);
    }
  }

  /**
   * @param exts list of extension elements
   * @return true iff the UsingAddressing element is found
   */
  private boolean hasUsingAddressing(List<ExtensibilityElement> exts) {
    boolean found = false;
    if (exts != null) {
      Iterator<ExtensibilityElement> extensionElements = exts.iterator();
      while (extensionElements.hasNext() && !found) {
        ExtensibilityElement ext = extensionElements.next();
        found = Names.WSAW_USING_ADDRESSING_QNAME.equals(ext.getElementType());
      }
    }
    return found;
  }

  /**
   * Mediate message flow.
   *
   * @param message the current message
   * @param isFault true if a fault is being mediated
   * @return true if processing should continue on dispatch path
   */
  protected boolean mediate(Message message, boolean isFault) {
    boolean continueProcessing = true;
    if (ContextUtils.isOutbound(message)) {
      if (usingAddressing(message)) {
        // request/response MAPs must be aggregated
        aggregate(message, isFault);
      }
      AddressingProperties theMaps =
          ContextUtils.retrieveMAPs(message, false, ContextUtils.isOutbound(message));
      if (null != theMaps) {
        if (ContextUtils.isRequestor(message)) {
          assertAddressing(message, theMaps.getReplyTo(), theMaps.getFaultTo());
        } else {
          checkReplyTo(message, theMaps);
        }
      }
    } else if (!ContextUtils.isRequestor(message)) {
      // responder validates incoming MAPs
      AddressingProperties maps = getMAPs(message, false, false);
      // check responses
      if (maps != null) {
        checkAddressingResponses(maps.getReplyTo(), maps.getFaultTo());
        assertAddressing(message, maps.getReplyTo(), maps.getFaultTo());
      }
      boolean isOneway = message.getExchange().isOneWay();
      if (null == maps && !addressingRequired) {
        return false;
      }
      continueProcessing = validateIncomingMAPs(maps, message);
      if (maps != null) {
        AddressingProperties theMaps =
            ContextUtils.retrieveMAPs(message, false, ContextUtils.isOutbound(message));
        if (null != theMaps) {
          assertAddressing(message, theMaps.getReplyTo(), theMaps.getFaultTo());
        }

        if (isOneway || !ContextUtils.isGenericAddress(maps.getReplyTo())) {
          InternalContextUtils.rebaseResponse(maps.getReplyTo(), maps, message);
        }
        if (!isOneway) {
          if (ContextUtils.isNoneAddress(maps.getReplyTo())) {
            LOG.warning("Detected NONE value in ReplyTo WSA header for request-respone MEP");
          } else {
            // ensure the inbound MAPs are available in both the full & fault
            // response messages (used to determine relatesTo etc.)
            ContextUtils.propogateReceivedMAPs(maps, message.getExchange());
          }
        }
      }
      if (continueProcessing) {
        // any faults thrown from here on can be correlated with this message
        message.put(FaultMode.class, FaultMode.LOGICAL_RUNTIME_FAULT);
      } else {
        // validation failure => dispatch is aborted, response MAPs
        // must be aggregated
        // isFault = true;
        // aggregate(message, isFault);
        if (isSOAP12(message)) {
          SoapFault soap12Fault =
              new SoapFault(
                  ContextUtils.retrieveMAPFaultReason(message), Soap12.getInstance().getSender());
          soap12Fault.setSubCode(
              new QName(Names.WSA_NAMESPACE_NAME, ContextUtils.retrieveMAPFaultName(message)));
          throw soap12Fault;
        }
        throw new SoapFault(
            ContextUtils.retrieveMAPFaultReason(message),
            new QName(Names.WSA_NAMESPACE_NAME, ContextUtils.retrieveMAPFaultName(message)));
      }
    } else {
      AddressingProperties theMaps =
          ContextUtils.retrieveMAPs(message, false, ContextUtils.isOutbound(message));
      if (null != theMaps) {
        assertAddressing(message, theMaps.getReplyTo(), theMaps.getFaultTo());
      }
      // If the wsa policy is enabled , but the client sets the
      // WSAddressingFeature.isAddressingRequired to false , we need to assert all WSA assertion to
      // true
      if (!ContextUtils.isOutbound(message)
          && ContextUtils.isRequestor(message)
          && getWSAddressingFeature(message) != null
          && !getWSAddressingFeature(message).isAddressingRequired()) {
        assertAddressing(message);
      }
      // CXF-3060 :If wsa policy is not enforced, AddressingProperties map is null and
      // AddressingFeature.isRequired, requestor checks inbound message and throw exception
      if (null == theMaps
          && !ContextUtils.isOutbound(message)
          && ContextUtils.isRequestor(message)
          && getWSAddressingFeature(message) != null
          && getWSAddressingFeature(message).isAddressingRequired()) {
        boolean missingWsaHeader = false;
        AssertionInfoMap aim = message.get(AssertionInfoMap.class);
        if (aim == null || aim.size() == 0) {
          missingWsaHeader = true;
        }
        if (aim != null && aim.size() > 0) {
          missingWsaHeader = true;
          QName[] types =
              new QName[] {
                MetadataConstants.ADDRESSING_ASSERTION_QNAME,
                MetadataConstants.USING_ADDRESSING_2004_QNAME,
                MetadataConstants.USING_ADDRESSING_2005_QNAME,
                MetadataConstants.USING_ADDRESSING_2006_QNAME
              };
          for (QName type : types) {
            for (AssertionInfo assertInfo : aim.getAssertionInfo(type)) {
              if (assertInfo.isAsserted()) {
                missingWsaHeader = false;
              }
            }
          }
        }
        if (missingWsaHeader) {
          throw new SoapFault(
              "MISSING_ACTION_MESSAGE",
              BUNDLE,
              new QName(Names.WSA_NAMESPACE_NAME, Names.HEADER_REQUIRED_NAME));
        }
      }
      if (MessageUtils.isPartialResponse(message)
          && message.getExchange().getOutMessage() != null) {
        // marked as a partial response, let's see if it really is
        MessageInfo min = message.get(MessageInfo.class);
        MessageInfo mout = message.getExchange().getOutMessage().get(MessageInfo.class);
        if (min != null
            && mout != null
            && min.getOperation() == mout.getOperation()
            && message.getContent(List.class) != null) {
          // the in and out messages are on the same operation
          // and we were able to get a response for it.
          message.remove(Message.PARTIAL_RESPONSE_MESSAGE);
        }
      }
    }
    return continueProcessing;
  }

  private void checkAddressingResponses(
      EndpointReferenceType replyTo, EndpointReferenceType faultTo) {
    if (this.addressingResponses == WSAddressingFeature.AddressingResponses.ALL) {
      return;
    }
    boolean passed = false;
    boolean anonReply = ContextUtils.isGenericAddress(replyTo);
    boolean anonFault = ContextUtils.isGenericAddress(faultTo);
    boolean isAnonymous = anonReply && anonFault;
    if (WSAddressingFeature.AddressingResponses.ANONYMOUS == addressingResponses && isAnonymous) {
      passed = true;
    } else if (WSAddressingFeature.AddressingResponses.NON_ANONYMOUS == addressingResponses
        && (!anonReply && (faultTo.getAddress() != null && !anonFault)
            || !anonReply && faultTo.getAddress() == null)) {
      passed = true;
    }
    if (!passed) {
      String reason = BUNDLE.getString("INVALID_ADDRESSING_PROPERTY_MESSAGE");
      QName detail =
          WSAddressingFeature.AddressingResponses.ANONYMOUS == addressingResponses
              ? Names.ONLY_ANONYMOUS_ADDRESS_SUPPORTED_QNAME
              : Names.ONLY_NONANONYMOUS_ADDRESS_SUPPORTED_QNAME;
      throw new SoapFault(reason, detail);
    }
  }
  /**
   * Perform MAP aggregation.
   *
   * @param message the current message
   * @param isFault true if a fault is being mediated
   */
  private void aggregate(Message message, boolean isFault) {
    boolean isRequestor = ContextUtils.isRequestor(message);

    AddressingProperties maps = assembleGeneric(message);
    addRoleSpecific(maps, message, isRequestor, isFault);
    // outbound property always used to store MAPs, as this handler
    // aggregates only when either:
    // a) message really is outbound
    // b) message is currently inbound, but we are about to abort dispatch
    //    due to an incoming MAPs validation failure, so the dispatch
    //    will shortly traverse the outbound path
    ContextUtils.storeMAPs(maps, message, true, isRequestor);
  }

  /**
   * Assemble the generic MAPs (for both requests and responses).
   *
   * @param message the current message
   * @return AddressingProperties containing the generic MAPs
   */
  private AddressingProperties assembleGeneric(Message message) {
    AddressingProperties maps = getMAPs(message, true, true);
    // MessageID
    if (maps.getMessageID() == null) {
      String messageID = ContextUtils.generateUUID();
      maps.setMessageID(ContextUtils.getAttributedURI(messageID));
    }

    // Action
    if (ContextUtils.hasEmptyAction(maps)) {
      maps.setAction(InternalContextUtils.getAction(message));

      if (ContextUtils.hasEmptyAction(maps) && ContextUtils.isOutbound(message)) {
        maps.setAction(ContextUtils.getAttributedURI(getActionUri(message, true)));
      }
    }

    return maps;
  }

  private String getActionFromInputMessage(final OperationInfo operation) {
    MessageInfo inputMessage = operation.getInput();

    if (inputMessage.getExtensionAttributes() != null) {
      String inputAction = InternalContextUtils.getAction(inputMessage);
      if (!StringUtils.isEmpty(inputAction)) {
        return inputAction;
      }
    }
    return null;
  }

  private String getActionFromOutputMessage(final OperationInfo operation) {
    MessageInfo outputMessage = operation.getOutput();
    if (outputMessage != null && outputMessage.getExtensionAttributes() != null) {
      String outputAction = InternalContextUtils.getAction(outputMessage);
      if (!StringUtils.isEmpty(outputAction)) {
        return outputAction;
      }
    }
    return null;
  }

  private boolean isSameFault(final FaultInfo faultInfo, String faultName) {
    if (faultInfo.getName() == null || faultName == null) {
      return false;
    }
    String faultInfoName = faultInfo.getName().getLocalPart();
    return faultInfoName.equals(faultName)
        || faultInfoName.equals(StringUtils.uncapitalize(faultName));
  }

  private String getActionBaseUri(final OperationInfo operation) {
    String interfaceName = operation.getInterface().getName().getLocalPart();
    return addPath(operation.getName().getNamespaceURI(), interfaceName);
  }

  private String getActionFromFaultMessage(final OperationInfo operation, final String faultName) {
    if (operation.getFaults() != null) {
      for (FaultInfo faultInfo : operation.getFaults()) {
        if (isSameFault(faultInfo, faultName)) {
          if (faultInfo.getExtensionAttributes() != null) {
            String faultAction = InternalContextUtils.getAction(faultInfo);
            if (!StringUtils.isEmpty(faultAction)) {
              return faultAction;
            }
          }
          return addPath(
              addPath(
                  addPath(getActionBaseUri(operation), operation.getName().getLocalPart()),
                  "Fault"),
              faultInfo.getFaultName().getLocalPart());
        }
      }
    }
    return addPath(
        addPath(addPath(getActionBaseUri(operation), operation.getName().getLocalPart()), "Fault"),
        faultName);
  }

  private String getFaultNameFromMessage(final Message message) {
    Exception e = message.getContent(Exception.class);
    Throwable cause = e.getCause();
    if (cause == null) {
      cause = e;
    }
    if (e instanceof Fault) {
      WebFault t = cause.getClass().getAnnotation(WebFault.class);
      if (t != null) {
        return t.name();
      }
    }
    return cause.getClass().getSimpleName();
  }

  protected String getActionUri(Message message, boolean checkMessage) {
    BindingOperationInfo bop = message.getExchange().getBindingOperationInfo();
    if (bop == null || Boolean.TRUE.equals(bop.getProperty("operation.is.synthetic"))) {
      return null;
    }
    OperationInfo op = bop.getOperationInfo();
    if (op.isUnwrapped()) {
      op = ((UnwrappedOperationInfo) op).getWrappedOperation();
    }

    String actionUri = null;
    if (checkMessage) {
      actionUri = (String) message.get(ContextUtils.ACTION);
      if (actionUri == null) {
        actionUri = (String) message.get(SoapBindingConstants.SOAP_ACTION);
      }
    }
    if (actionUri != null) {
      return actionUri;
    }
    String opNamespace = getActionBaseUri(op);

    boolean inbound = !ContextUtils.isOutbound(message);
    boolean requestor = ContextUtils.isRequestor(message);
    boolean inMsg = requestor ^ inbound;
    if (ContextUtils.isFault(message)) {
      String faultName = getFaultNameFromMessage(message);
      actionUri = getActionFromFaultMessage(op, faultName);
    } else if (inMsg) {
      String explicitAction = getActionFromInputMessage(op);
      if (StringUtils.isEmpty(explicitAction)) {
        SoapOperationInfo soi = InternalContextUtils.getSoapOperationInfo(bop);
        explicitAction = soi == null ? null : soi.getAction();
      }

      if (!StringUtils.isEmpty(explicitAction)) {
        actionUri = explicitAction;
      } else if (null == op.getInputName()) {
        actionUri = addPath(opNamespace, op.getName().getLocalPart() + "Request");
      } else {
        actionUri = addPath(opNamespace, op.getInputName());
      }
    } else {
      String explicitAction = getActionFromOutputMessage(op);
      if (explicitAction != null) {
        actionUri = explicitAction;
      } else if (null == op.getOutputName()) {
        actionUri = addPath(opNamespace, op.getName().getLocalPart() + "Response");
      } else {
        actionUri = addPath(opNamespace, op.getOutputName());
      }
    }
    return actionUri;
  }

  private String getDelimiter(String uri) {
    if (uri.startsWith("urn")) {
      return ":";
    }
    return "/";
  }

  private String addPath(String uri, String path) {
    StringBuilder buffer = new StringBuilder();
    buffer.append(uri);
    String delimiter = getDelimiter(uri);
    if (!uri.endsWith(delimiter) && !path.startsWith(delimiter)) {
      buffer.append(delimiter);
    }
    buffer.append(path);
    return buffer.toString();
  }

  /**
   * Add MAPs which are specific to the requestor or responder role.
   *
   * @param maps the MAPs being assembled
   * @param message the current message
   * @param isRequestor true iff the current messaging role is that of requestor
   * @param isFault true if a fault is being mediated
   */
  private void addRoleSpecific(
      AddressingProperties maps, Message message, boolean isRequestor, boolean isFault) {
    if (isRequestor) {
      Exchange exchange = message.getExchange();

      // add request-specific MAPs
      boolean isOneway = exchange.isOneWay();
      boolean isOutbound = ContextUtils.isOutbound(message);

      // To
      if (maps.getTo() == null) {
        Conduit conduit = null;
        if (isOutbound) {
          conduit = ContextUtils.getConduit(null, message);
        }
        String s = (String) message.get(Message.ENDPOINT_ADDRESS);
        EndpointReferenceType reference =
            conduit != null ? conduit.getTarget() : ContextUtils.getNoneEndpointReference();
        if (conduit != null
            && !StringUtils.isEmpty(s)
            && !reference.getAddress().getValue().equals(s)) {
          EndpointReferenceType ref = new EndpointReferenceType();
          AttributedURIType tp = new AttributedURIType();
          tp.setValue(s);
          ref.setAddress(tp);
          ref.setMetadata(reference.getMetadata());
          ref.setReferenceParameters(reference.getReferenceParameters());
          ref.getOtherAttributes().putAll(reference.getOtherAttributes());
          reference = ref;
        }
        maps.setTo(reference);
      }

      // ReplyTo, set if null in MAPs or if set to a generic address
      // (anonymous or none) that may not be appropriate for the
      // current invocation
      EndpointReferenceType replyTo = maps.getReplyTo();
      if (ContextUtils.isGenericAddress(replyTo)) {
        replyTo = getReplyTo(message, replyTo);
        if (replyTo == null
            || (isOneway
                && (replyTo == null
                    || replyTo.getAddress() == null
                    || !Names.WSA_NONE_ADDRESS.equals(replyTo.getAddress().getValue())))) {
          AttributedURIType address =
              ContextUtils.getAttributedURI(
                  isOneway ? Names.WSA_NONE_ADDRESS : Names.WSA_ANONYMOUS_ADDRESS);
          replyTo = ContextUtils.WSA_OBJECT_FACTORY.createEndpointReferenceType();
          replyTo.setAddress(address);
        }
        maps.setReplyTo(replyTo);
      }

      // FaultTo
      if (maps.getFaultTo() == null) {
        maps.setFaultTo(maps.getReplyTo());
      } else if (maps.getFaultTo().getAddress() == null) {
        maps.setFaultTo(null);
      }
    } else {
      // add response-specific MAPs
      AddressingProperties inMAPs = getMAPs(message, false, false);
      maps.exposeAs(inMAPs.getNamespaceURI());
      // To taken from ReplyTo or FaultTo in incoming MAPs (depending
      // on the fault status of the response)
      if (isFault && inMAPs.getFaultTo() != null) {
        maps.setTo(inMAPs.getFaultTo());
      } else if (maps.getTo() == null && inMAPs.getReplyTo() != null) {
        maps.setTo(inMAPs.getReplyTo());
      }

      // RelatesTo taken from MessageID in incoming MAPs
      if (inMAPs.getMessageID() != null
          && !Boolean.TRUE.equals(message.get(Message.PARTIAL_RESPONSE_MESSAGE))) {
        String inMessageID = inMAPs.getMessageID().getValue();
        maps.setRelatesTo(ContextUtils.getRelatesTo(inMessageID));
      } else {
        maps.setRelatesTo(ContextUtils.getRelatesTo(Names.WSA_UNSPECIFIED_RELATIONSHIP));
      }

      // fallback fault action
      if (isFault && maps.getAction() == null) {
        maps.setAction(ContextUtils.getAttributedURI(Names.WSA_DEFAULT_FAULT_ACTION));
      }

      if (isFault && !ContextUtils.isGenericAddress(inMAPs.getFaultTo())) {

        Message m = message.getExchange().getInFaultMessage();
        if (m == null) {
          m = message;
        }
        InternalContextUtils.rebaseResponse(inMAPs.getFaultTo(), inMAPs, m);

        Destination destination =
            InternalContextUtils.createDecoupledDestination(m.getExchange(), inMAPs.getFaultTo());
        m.getExchange().setDestination(destination);
      }
    }
  }

  private EndpointReferenceType getReplyTo(Message message, EndpointReferenceType originalReplyTo) {
    Exchange exchange = message.getExchange();
    Endpoint info = exchange.getEndpoint();
    if (info == null) {
      return originalReplyTo;
    }
    synchronized (info) {
      EndpointInfo ei = info.getEndpointInfo();
      Destination dest = ei.getProperty(DECOUPLED_DESTINATION, Destination.class);
      if (dest == null) {
        dest = createDecoupledDestination(message);
        if (dest != null) {
          info.getEndpointInfo().setProperty(DECOUPLED_DESTINATION, dest);
        }
      }
      if (dest != null) {
        // if the decoupled endpoint context prop is set and the address is relative, return the
        // absolute url.
        final String replyTo = dest.getAddress().getAddress().getValue();
        if (replyTo.startsWith("/")) {
          String debase =
              (String)
                  message.getContextualProperty(WSAContextUtils.DECOUPLED_ENDPOINT_BASE_PROPERTY);
          if (debase != null) {
            return EndpointReferenceUtils.getEndpointReference(debase + replyTo);
          }
        }
        return dest.getAddress();
      }
    }
    return originalReplyTo;
  }

  private Destination createDecoupledDestination(Message message) {
    String replyToAddress =
        (String) message.getContextualProperty(WSAContextUtils.REPLYTO_PROPERTY);
    if (replyToAddress != null) {
      return setUpDecoupledDestination(message.getExchange().getBus(), replyToAddress, message);
    }
    return null;
  }
  /** Set up the decoupled Destination if necessary. */
  private Destination setUpDecoupledDestination(Bus bus, String replyToAddress, Message message) {
    EndpointReferenceType reference = EndpointReferenceUtils.getEndpointReference(replyToAddress);
    if (reference != null) {
      String decoupledAddress = reference.getAddress().getValue();
      LOG.info("creating decoupled endpoint: " + decoupledAddress);
      try {
        Destination dest = getDestination(bus, replyToAddress, message);
        bus.getExtension(ClientLifeCycleManager.class).registerListener(DECOUPLED_DEST_CLEANER);
        return dest;
      } catch (Exception e) {
        // REVISIT move message to localizable Messages.properties
        LOG.log(Level.WARNING, "decoupled endpoint creation failed: ", e);
      }
    }
    return null;
  }

  /**
   * @param address the address
   * @return a Destination for the address
   */
  private Destination getDestination(Bus bus, String address, Message message) throws IOException {
    Destination destination = null;
    DestinationFactoryManager factoryManager = bus.getExtension(DestinationFactoryManager.class);
    DestinationFactory factory = factoryManager.getDestinationFactoryForUri(address);
    if (factory != null) {
      Endpoint ep = message.getExchange().getEndpoint();

      EndpointInfo ei = new EndpointInfo();
      ei.setName(
          new QName(
              ep.getEndpointInfo().getName().getNamespaceURI(),
              ep.getEndpointInfo().getName().getLocalPart() + ".decoupled"));
      ei.setAddress(address);
      destination = factory.getDestination(ei, bus);
      Conduit conduit = ContextUtils.getConduit(null, message);
      if (conduit != null) {
        MessageObserver ob = ((Observable) conduit).getMessageObserver();
        ob = new InterposedMessageObserver(bus, ob);
        destination.setMessageObserver(ob);
      }
    }
    return destination;
  }

  protected static class InterposedMessageObserver implements MessageObserver {
    Bus bus;
    MessageObserver observer;

    public InterposedMessageObserver(Bus b, MessageObserver o) {
      bus = b;
      observer = o;
    }

    /**
     * Called for an incoming message.
     *
     * @param inMessage
     */
    public void onMessage(Message inMessage) {
      // disposable exchange, swapped with real Exchange on correlation
      inMessage.setExchange(new ExchangeImpl());
      inMessage.getExchange().put(Bus.class, bus);
      inMessage.put(Message.DECOUPLED_CHANNEL_MESSAGE, Boolean.TRUE);
      inMessage.put(Message.RESPONSE_CODE, HttpURLConnection.HTTP_OK);

      // remove server-specific properties
      // inMessage.remove(AbstractHTTPDestination.HTTP_REQUEST);
      // inMessage.remove(AbstractHTTPDestination.HTTP_RESPONSE);
      inMessage.remove(Message.ASYNC_POST_RESPONSE_DISPATCH);
      updateResponseCode(inMessage);

      // cache this inputstream since it's defer to use in case of async
      try {
        InputStream in = inMessage.getContent(InputStream.class);
        if (in != null) {
          CachedOutputStream cos = new CachedOutputStream();
          IOUtils.copy(in, cos);
          inMessage.setContent(InputStream.class, cos.getInputStream());
        }
        observer.onMessage(inMessage);
      } catch (IOException e) {
        e.printStackTrace();
      }
    }

    private void updateResponseCode(Message message) {
      Object o = message.get("HTTP.RESPONSE");
      if (o != null) {
        try {
          o.getClass()
              .getMethod("setStatus", Integer.TYPE)
              .invoke(o, HttpURLConnection.HTTP_ACCEPTED);
        } catch (Throwable t) {
          // ignore
        }
      }
    }
  }

  /**
   * Get the starting point MAPs (either empty or those set explicitly by the application on the
   * binding provider request context).
   *
   * @param message the current message
   * @param isProviderContext true if the binding provider request context available to the client
   *     application as opposed to the message context visible to handlers
   * @param isOutbound true iff the message is outbound
   * @return AddressingProperties retrieved MAPs
   */
  private AddressingProperties getMAPs(
      Message message, boolean isProviderContext, boolean isOutbound) {

    AddressingProperties maps = null;
    maps = ContextUtils.retrieveMAPs(message, isProviderContext, isOutbound);
    LOG.log(Level.FINE, "MAPs retrieved from message {0}", maps);

    if (maps == null && isProviderContext) {
      maps = new AddressingProperties();
      setupNamespace(maps, message);
    }
    return maps;
  }

  private void setupNamespace(AddressingProperties maps, Message message) {
    AssertionInfoMap aim = message.get(AssertionInfoMap.class);
    if (null == aim) {
      return;
    }
    Collection<AssertionInfo> aic =
        aim.getAssertionInfo(MetadataConstants.USING_ADDRESSING_2004_QNAME);
    if (aic != null && !aic.isEmpty()) {
      maps.exposeAs(Names200408.WSA_NAMESPACE_NAME);
    }
  }

  /**
   * Validate incoming MAPs
   *
   * @param maps the incoming MAPs
   * @param message the current message
   * @return true if incoming MAPs are valid
   * @pre inbound message, not requestor
   */
  private boolean validateIncomingMAPs(AddressingProperties maps, Message message) {
    boolean valid = true;

    if (maps != null) {
      // WSAB spec, section 4.2 validation (SOAPAction must match action
      String sa = SoapActionInInterceptor.getSoapAction(message);
      String s1 = this.getActionUri(message, false);

      if (maps.getAction() == null || maps.getAction().getValue() == null) {
        String reason = BUNDLE.getString("MISSING_ACTION_MESSAGE");

        ContextUtils.storeMAPFaultName(Names.HEADER_REQUIRED_NAME, message);
        ContextUtils.storeMAPFaultReason(reason, message);
        valid = false;
      }

      if (!StringUtils.isEmpty(sa)
          && valid
          && !MessageUtils.isTrue(message.get(MAPAggregator.ACTION_VERIFIED))) {
        if (sa.startsWith("\"")) {
          sa = sa.substring(1, sa.lastIndexOf('"'));
        }
        String action = maps.getAction() == null ? "" : maps.getAction().getValue();
        if (!StringUtils.isEmpty(sa) && !sa.equals(action)) {
          // don't match, must send fault back....
          String reason = BUNDLE.getString("INVALID_ADDRESSING_PROPERTY_MESSAGE");

          ContextUtils.storeMAPFaultName(Names.ACTION_MISMATCH_NAME, message);
          ContextUtils.storeMAPFaultReason(reason, message);
          valid = false;
        } else if (!StringUtils.isEmpty(s1)
            && !action.equals(s1)
            && !action.equals(s1 + "Request")
            && !s1.equals(action + "Request")) {
          // if java first, it's likely to have "Request", if wsdl first,
          // it will depend if the wsdl:input has a name or not. Thus, we'll
          // check both plain and with the "Request" trailer

          // doesn't match what's in the wsdl/annotations
          String reason =
              BundleUtils.getFormattedString(BUNDLE, "ACTION_NOT_SUPPORTED_MSG", action);

          ContextUtils.storeMAPFaultName(Names.ACTION_NOT_SUPPORTED_NAME, message);
          ContextUtils.storeMAPFaultReason(reason, message);
          valid = false;
        }
      }

      AttributedURIType messageID = maps.getMessageID();

      if (!message.getExchange().isOneWay()
          && (messageID == null || messageID.getValue() == null)
          && valid) {
        String reason = BUNDLE.getString("MISSING_ACTION_MESSAGE");

        ContextUtils.storeMAPFaultName(Names.HEADER_REQUIRED_NAME, message);
        ContextUtils.storeMAPFaultReason(reason, message);

        valid = false;
      }

      // Always cache message IDs, even when the message is not valid for some
      // other reason.
      if (!allowDuplicates
          && messageID != null
          && messageID.getValue() != null
          && !messageIdCache.checkUniquenessAndCacheId(messageID.getValue())) {

        LOG.log(Level.WARNING, "DUPLICATE_MESSAGE_ID_MSG", messageID.getValue());

        // Only throw the fault if something else has not already marked the
        // message as invalid.
        if (valid) {
          String reason = BUNDLE.getString("DUPLICATE_MESSAGE_ID_MSG");
          String l7dReason = MessageFormat.format(reason, messageID.getValue());
          ContextUtils.storeMAPFaultName(Names.DUPLICATE_MESSAGE_ID_NAME, message);
          ContextUtils.storeMAPFaultReason(l7dReason, message);
        }

        valid = false;
      }
    } else if (usingAddressingAdvisory) {
      String reason = BUNDLE.getString("MISSING_ACTION_MESSAGE");

      ContextUtils.storeMAPFaultName(Names.HEADER_REQUIRED_NAME, message);
      ContextUtils.storeMAPFaultReason(reason, message);
      valid = false;
    }

    if (Names.INVALID_CARDINALITY_NAME.equals(ContextUtils.retrieveMAPFaultName(message))) {
      valid = false;
    }

    return valid;
  }

  /**
   * Check for NONE ReplyTo value in request-response MEP
   *
   * @param message the current message
   * @param maps the incoming MAPs
   */
  private void checkReplyTo(Message message, AddressingProperties maps) {
    // if ReplyTo address is none then 202 response status is expected
    // However returning a fault is more appropriate for request-response MEP
    if (!message.getExchange().isOneWay()
        && !MessageUtils.isPartialResponse(message)
        && ContextUtils.isNoneAddress(maps.getReplyTo())) {
      String reason =
          MessageFormat.format(
              BUNDLE.getString("REPLYTO_NOT_SUPPORTED_MSG"),
              maps.getReplyTo().getAddress().getValue());
      throw new SoapFault(reason, new QName(Names.WSA_NAMESPACE_NAME, Names.WSA_NONE_ADDRESS));
    }
  }

  private boolean isSOAP12(Message message) {
    if (message.getExchange().getBinding() instanceof SoapBinding) {
      SoapBinding binding = (SoapBinding) message.getExchange().getBinding();
      if (binding.getSoapVersion() == Soap12.getInstance()) {
        return true;
      }
    }
    return false;
  }
}
Example #30
0
/** Some functions that avoid problems with Commons XML Schema. */
public final class XmlSchemaUtils {
  public static final String XSI_NIL = "xsi:nil='true'";

  private static final Logger LOG = LogUtils.getL7dLogger(XmlSchemaUtils.class);

  private XmlSchemaUtils() {}

  /**
   * Wrapper around XmlSchemaElement.setRefName that checks for inconsistency with name and QName.
   *
   * @param element
   * @param name
   */
  public static void setElementRefName(XmlSchemaElement element, QName name) {
    if (name != null
        && ((element.getQName() != null && !element.getQName().equals(name))
            || (element.getName() != null && !element.getName().equals(name.getLocalPart())))) {
      LOG.severe("Attempt to set the refName of an element with a name or QName");
      throw new XmlSchemaInvalidOperation(
          "Attempt to set the refName of an element " + "with a name or QName.");
    }
    element.getRef().setTargetQName(name);
    // cxf conventionally keeps something in the name slot.
  }

  /**
   * Return true if a simple type is a straightforward XML Schema representation of an enumeration.
   * If we discover schemas that are 'enum-like' with more complex structures, we might make this
   * deal with them.
   *
   * @param type Simple type, possible an enumeration.
   * @return true for an enumeration.
   */
  public static boolean isEumeration(XmlSchemaSimpleType type) {
    XmlSchemaSimpleTypeContent content = type.getContent();
    if (!(content instanceof XmlSchemaSimpleTypeRestriction)) {
      return false;
    }
    XmlSchemaSimpleTypeRestriction restriction = (XmlSchemaSimpleTypeRestriction) content;
    List<XmlSchemaFacet> facets = restriction.getFacets();
    for (XmlSchemaFacet facet : facets) {
      if (!(facet instanceof XmlSchemaEnumerationFacet)) {
        return false;
      }
    }
    return true;
  }

  /**
   * Retrieve the string values for an enumeration.
   *
   * @param type
   */
  public static List<String> enumeratorValues(XmlSchemaSimpleType type) {
    XmlSchemaSimpleTypeContent content = type.getContent();
    XmlSchemaSimpleTypeRestriction restriction = (XmlSchemaSimpleTypeRestriction) content;
    List<XmlSchemaFacet> facets = restriction.getFacets();
    List<String> values = new ArrayList<String>();
    for (XmlSchemaFacet facet : facets) {
      XmlSchemaEnumerationFacet enumFacet = (XmlSchemaEnumerationFacet) facet;
      values.add(enumFacet.getValue().toString());
    }
    return values;
  }

  /**
   * Is there an import for a particular namespace in a schema?
   *
   * @param schema
   * @param namespaceUri
   */
  public static boolean schemaImportsNamespace(XmlSchema schema, String namespaceUri) {
    List<XmlSchemaExternal> externals = schema.getExternals();
    for (XmlSchemaExternal what : externals) {
      if (what instanceof XmlSchemaImport) {
        XmlSchemaImport imp = (XmlSchemaImport) what;
        // already there.
        if (namespaceUri.equals(imp.getNamespace())) {
          return true;
        }
      }
    }
    return false;
  }

  /**
   * Assist in managing the required <import namespace='uri'> for imports of peer schemas.
   *
   * @param schema
   * @param namespaceUri
   */
  public static void addImportIfNeeded(XmlSchema schema, String namespaceUri) {
    // no need to import nothing or the XSD schema, or the schema we are fixing.
    if ("".equals(namespaceUri)
        || Constants.URI_2001_SCHEMA_XSD.equals(namespaceUri)
        || schema.getTargetNamespace().equals(namespaceUri)) {
      return;
    }

    List<XmlSchemaExternal> externals = schema.getExternals();
    for (XmlSchemaExternal what : externals) {
      if (what instanceof XmlSchemaImport) {
        XmlSchemaImport imp = (XmlSchemaImport) what;
        // already there.
        if (namespaceUri.equals(imp.getNamespace())) {
          return;
        }
      }
    }
    XmlSchemaImport imp = new XmlSchemaImport(schema);
    imp.setNamespace(namespaceUri);
  }

  /**
   * For convenience, start from a qname, and add the import if it is non-null and has a namespace.
   *
   * @see #addImportIfNeeded(XmlSchema, String)
   * @param schema
   * @param qname
   */
  public static void addImportIfNeeded(XmlSchema schema, QName qname) {
    if (qname == null) {
      return;
    }
    if (qname.getNamespaceURI() == null) {
      return;
    }
    addImportIfNeeded(schema, qname.getNamespaceURI());
  }

  /**
   * This copes with an observed phenomenon in the schema built by the ReflectionServiceFactoryBean.
   * It is creating element such that: (a) the type is not set. (b) the refName is set. (c) the
   * namespaceURI in the refName is set empty. This apparently indicates 'same Schema' to everyone
   * else, so thus function implements that convention here. It is unclear if that is a correct
   * structure, and it if changes, we can simplify or eliminate this function.
   *
   * @param name
   * @param referencingURI
   */
  public static XmlSchemaElement findElementByRefName(
      SchemaCollection xmlSchemaCollection, QName name, String referencingURI) {
    String uri = name.getNamespaceURI();
    if ("".equals(uri)) {
      uri = referencingURI;
    }
    QName copyName = new QName(uri, name.getLocalPart());
    XmlSchemaElement target = xmlSchemaCollection.getElementByQName(copyName);
    assert target != null;
    return target;
  }

  public static QName getBaseType(XmlSchemaComplexType type) {
    XmlSchemaContentModel model = type.getContentModel();
    if (model == null) {
      return null;
    }
    XmlSchemaContent content = model.getContent();
    if (content == null) {
      return null;
    }

    if (!(content instanceof XmlSchemaComplexContentExtension)) {
      return null;
    }

    XmlSchemaComplexContentExtension ext = (XmlSchemaComplexContentExtension) content;
    return ext.getBaseTypeName();
  }

  public static List<XmlSchemaAttributeOrGroupRef> getContentAttributes(XmlSchemaComplexType type) {
    XmlSchemaContentModel model = type.getContentModel();
    if (model == null) {
      return null;
    }
    XmlSchemaContent content = model.getContent();
    if (content == null) {
      return null;
    }
    if (!(content instanceof XmlSchemaComplexContentExtension)) {
      return null;
    }

    // TODO: the anyAttribute case.
    XmlSchemaComplexContentExtension ext = (XmlSchemaComplexContentExtension) content;
    return ext.getAttributes();
  }

  public static List<XmlSchemaAnnotated> getContentAttributes(
      XmlSchemaComplexType type, SchemaCollection collection) {
    List<XmlSchemaAnnotated> results = new ArrayList<XmlSchemaAnnotated>();
    QName baseTypeName = getBaseType(type);
    if (baseTypeName != null) {
      XmlSchemaComplexType baseType =
          (XmlSchemaComplexType) collection.getTypeByQName(baseTypeName);
      // recurse onto the base type ...
      results.addAll(getContentAttributes(baseType, collection));
      // and now process our sequence.
      List<XmlSchemaAttributeOrGroupRef> extAttrs = getContentAttributes(type);
      results.addAll(extAttrs);
      return results;
    } else {
      // no base type, the simple case.
      List<XmlSchemaAttributeOrGroupRef> attrs = type.getAttributes();
      results.addAll(attrs);
      return results;
    }
  }

  /**
   * By convention, an element that is named in its schema's TNS can have a 'name' but no QName.
   * This can get inconvenient for consumers who want to think about qualified names. Unfortunately,
   * XmlSchema elements, unlike types, don't store a reference to their containing schema.
   *
   * @param element
   * @param schema
   */
  public static QName getElementQualifiedName(XmlSchemaElement element, XmlSchema schema) {
    if (element.getQName() != null) {
      return element.getQName();
    } else if (element.getName() != null) {
      return new QName(schema.getTargetNamespace(), element.getName());
    } else {
      return null;
    }
  }

  public static boolean isAttributeNameQualified(XmlSchemaAttribute attribute, XmlSchema schema) {
    if (attribute.isRef()) {
      throw new RuntimeException("isElementNameQualified on element with ref=");
    }
    if (attribute.getForm().equals(XmlSchemaForm.QUALIFIED)) {
      return true;
    }
    if (attribute.getForm().equals(XmlSchemaForm.UNQUALIFIED)) {
      return false;
    }
    return schema.getAttributeFormDefault().equals(XmlSchemaForm.QUALIFIED);
  }

  /**
   * due to a bug, feature, or just plain oddity of JAXB, it isn't good enough to just check the
   * form of an element and of its schema. If schema 'a' (default unqualified) has a complex type
   * with an element with a ref= to schema (b) (default unqualified), JAXB seems to expect to see a
   * qualifier, anyway. <br>
   * So, if the element is local to a complex type, all we care about is the default element form of
   * the schema and the local form of the element. <br>
   * If, on the other hand, the element is global, we might need to compare namespaces. <br>
   *
   * @param attribute the attribute
   * @param global if this element is a global element (complex type ref= to it, or in a part)
   * @param localSchema the schema of the complex type containing the reference, only used for the
   *     'odd case'.
   * @param attributeSchema the schema for the element.
   * @return if the element needs to be qualified.
   */
  public static boolean isAttributeQualified(
      XmlSchemaAttribute attribute,
      boolean global,
      XmlSchema localSchema,
      XmlSchema attributeSchema) {
    if (attribute.getQName() == null) {
      throw new RuntimeException("getSchemaQualifier on anonymous element.");
    }
    if (attribute.isRef()) {
      throw new RuntimeException("getSchemaQualified on the 'from' side of ref=.");
    }

    if (global) {
      return isAttributeNameQualified(attribute, attributeSchema)
          || (localSchema != null
              && !(attribute
                  .getQName()
                  .getNamespaceURI()
                  .equals(localSchema.getTargetNamespace())));
    } else {
      return isAttributeNameQualified(attribute, attributeSchema);
    }
  }

  public static boolean isElementNameQualified(XmlSchemaElement element, XmlSchema schema) {
    if (element.isRef()) {
      throw new RuntimeException("isElementNameQualified on element with ref=");
    }
    if (element.getForm().equals(XmlSchemaForm.QUALIFIED)) {
      return true;
    }
    if (element.getForm().equals(XmlSchemaForm.UNQUALIFIED)) {
      return false;
    }
    return schema.getElementFormDefault().equals(XmlSchemaForm.QUALIFIED);
  }

  /**
   * due to a bug, feature, or just plain oddity of JAXB, it isn't good enough to just check the
   * form of an element and of its schema. If schema 'a' (default unqualified) has a complex type
   * with an element with a ref= to schema (b) (default unqualified), JAXB seems to expect to see a
   * qualifier, anyway. <br>
   * So, if the element is local to a complex type, all we care about is the default element form of
   * the schema and the local form of the element. <br>
   * If, on the other hand, the element is global, we might need to compare namespaces. <br>
   *
   * @param element the element.
   * @param global if this element is a global element (complex type ref= to it, or in a part)
   * @param localSchema the schema of the complex type containing the reference, only used for the
   *     'odd case'.
   * @param elementSchema the schema for the element.
   * @return if the element needs to be qualified.
   */
  public static boolean isElementQualified(
      XmlSchemaElement element, boolean global, XmlSchema localSchema, XmlSchema elementSchema) {
    QName qn = getElementQualifiedName(element, localSchema);
    if (qn == null) {
      throw new RuntimeException("isElementQualified on anonymous element.");
    }
    if (element.isRef()) {
      throw new RuntimeException("isElementQualified on the 'from' side of ref=.");
    }

    if (global) {
      return isElementNameQualified(element, elementSchema)
          || (localSchema != null
              && !(qn.getNamespaceURI().equals(localSchema.getTargetNamespace())));
    } else {
      return isElementNameQualified(element, elementSchema);
    }
  }

  public static boolean isParticleArray(XmlSchemaParticle particle) {
    return particle.getMaxOccurs() > 1;
  }

  public static boolean isParticleOptional(XmlSchemaParticle particle) {
    return particle.getMinOccurs() == 0 && particle.getMaxOccurs() == 1;
  }
}