Example #1
0
/** @author Stéphane Épardaud <*****@*****.**> */
public class ServiceRegistry {
  private static final Logger logger = Logger.getLogger(ServiceRegistry.class);

  private static final long serialVersionUID = -1985015444704126795L;

  private ResourceMethodRegistry registry;

  private ResteasyProviderFactory providerFactory;

  private ServiceRegistry parent;

  private ArrayList<MethodMetaData> methods;

  private ArrayList<ServiceRegistry> locators;

  private ResourceLocator locator;

  private String uri;

  private String functionPrefix;

  public ServiceRegistry(
      ServiceRegistry parent,
      ResourceMethodRegistry registry,
      ResteasyProviderFactory providerFactory,
      ResourceLocator locator) {
    this.parent = parent;
    this.registry = registry;
    this.providerFactory = providerFactory;
    this.locator = locator;
    if (locator != null) {
      Method method = locator.getMethod();
      Path methodPath = method.getAnnotation(Path.class);
      Class<?> declaringClass = method.getDeclaringClass();
      Path classPath = declaringClass.getAnnotation(Path.class);
      this.uri = MethodMetaData.appendURIFragments(parent, classPath, methodPath);
      if (parent.isRoot())
        this.functionPrefix = declaringClass.getSimpleName() + "." + method.getName();
      else this.functionPrefix = parent.getFunctionPrefix() + "." + method.getName();
    }
    scanRegistry();
  }

  private void scanRegistry() {
    methods = new ArrayList<MethodMetaData>();
    locators = new ArrayList<ServiceRegistry>();
    for (Entry<String, List<ResourceInvoker>> entry : registry.getRoot().getBounded().entrySet()) {
      List<ResourceInvoker> invokers = entry.getValue();
      for (ResourceInvoker invoker : invokers) {
        if (invoker instanceof ResourceMethod) {
          methods.add(new MethodMetaData(this, (ResourceMethod) invoker));
        } else if (invoker instanceof ResourceLocator) {
          ResourceLocator locator = (ResourceLocator) invoker;
          Method method = locator.getMethod();
          Class<?> locatorType = method.getReturnType();
          Class<?> locatorResourceType = GetRestful.getSubResourceClass(locatorType);
          if (locatorResourceType == null) {
            // FIXME: we could generate an error for the client, which would be more informative
            // than
            // just logging this
            if (logger.isWarnEnabled()) {
              logger.warn(
                  "Impossible to generate JSAPI for subresource returned by method "
                      + method.getDeclaringClass().getName()
                      + "."
                      + method.getName()
                      + " since return type is not a static JAXRS resource type");
            }
            // skip this
            continue;
          }
          ResourceMethodRegistry locatorRegistry = new ResourceMethodRegistry(providerFactory);
          locatorRegistry.addResourceFactory(null, null, locatorResourceType);
          locators.add(new ServiceRegistry(this, locatorRegistry, providerFactory, locator));
        }
      }
    }
  }

  public List<MethodMetaData> getMethodMetaData() {
    return methods;
  }

  public List<ServiceRegistry> getLocators() {
    return locators;
  }

  public String getUri() {
    return uri;
  }

  public boolean isRoot() {
    return parent == null;
  }

  public String getFunctionPrefix() {
    return functionPrefix;
  }

  public void collectResourceMethodsUntilRoot(List<Method> methods) {
    if (isRoot()) return;
    methods.add(locator.getMethod());
    parent.collectResourceMethodsUntilRoot(methods);
  }
}
/**
 * @author <a href="mailto:[email protected]">Bill Burke</a>
 * @version $Revision: 1 $
 */
public class ResourceMethodInvoker implements ResourceInvoker, JaxrsInterceptorRegistryListener {
  static final Logger logger = Logger.getLogger(ResourceMethodInvoker.class);

  protected MethodInjector methodInjector;
  protected InjectorFactory injector;
  protected ResourceFactory resource;
  protected ResteasyProviderFactory parentProviderFactory;
  protected ResteasyProviderFactory resourceMethodProviderFactory;
  protected ResourceMethod method;
  protected ContainerRequestFilter[] requestFilters;
  protected ContainerResponseFilter[] responseFilters;
  protected WriterInterceptor[] writerInterceptors;
  protected ConcurrentHashMap<String, AtomicLong> stats =
      new ConcurrentHashMap<String, AtomicLong>();
  protected GeneralValidator validator;
  protected boolean isValidatable;
  protected boolean methodIsValidatable;
  protected ResourceInfo resourceInfo;

  protected boolean expectsBody;

  public ResourceMethodInvoker(
      ResourceMethod method,
      InjectorFactory injector,
      ResourceFactory resource,
      ResteasyProviderFactory providerFactory) {
    this.injector = injector;
    this.resource = resource;
    this.parentProviderFactory = providerFactory;
    this.method = method;

    resourceInfo =
        new ResourceInfo() {
          @Override
          public Method getResourceMethod() {
            return ResourceMethodInvoker.this.method.getAnnotatedMethod();
          }

          @Override
          public Class<?> getResourceClass() {
            return ResourceMethodInvoker.this.method.getResourceClass().getClazz();
          }
        };

    this.resourceMethodProviderFactory = new ResteasyProviderFactory(providerFactory);
    for (DynamicFeature feature : providerFactory.getServerDynamicFeatures()) {
      feature.configure(resourceInfo, new FeatureContextDelegate(resourceMethodProviderFactory));
    }

    this.methodInjector = injector.createMethodInjector(method, resourceMethodProviderFactory);

    // hack for when message contentType == null
    // and @Consumes is on the class
    expectsBody = this.methodInjector.expectsBody();

    requestFilters =
        resourceMethodProviderFactory
            .getContainerRequestFilterRegistry()
            .postMatch(method.getResourceClass().getClazz(), method.getAnnotatedMethod());
    responseFilters =
        resourceMethodProviderFactory
            .getContainerResponseFilterRegistry()
            .postMatch(method.getResourceClass().getClazz(), method.getAnnotatedMethod());
    writerInterceptors =
        resourceMethodProviderFactory
            .getServerWriterInterceptorRegistry()
            .postMatch(method.getResourceClass().getClazz(), method.getAnnotatedMethod());

    // we register with parent to lisen for redeploy evens
    providerFactory.getContainerRequestFilterRegistry().getListeners().add(this);
    providerFactory.getContainerResponseFilterRegistry().getListeners().add(this);
    providerFactory.getServerWriterInterceptorRegistry().getListeners().add(this);
    ContextResolver<GeneralValidator> resolver =
        providerFactory.getContextResolver(GeneralValidator.class, MediaType.WILDCARD_TYPE);
    if (resolver != null) {
      validator =
          providerFactory
              .getContextResolver(GeneralValidator.class, MediaType.WILDCARD_TYPE)
              .getContext(null);
    }
    if (validator != null) {
      isValidatable = validator.isValidatable(getMethod().getDeclaringClass());
      methodIsValidatable = validator.isMethodValidatable(getMethod());
    }
  }

  public void cleanup() {
    parentProviderFactory.getContainerRequestFilterRegistry().getListeners().remove(this);
    parentProviderFactory.getContainerResponseFilterRegistry().getListeners().remove(this);
    parentProviderFactory.getServerWriterInterceptorRegistry().getListeners().remove(this);
    for (ValueInjector param : methodInjector.getParams()) {
      if (param instanceof MessageBodyParameterInjector) {
        parentProviderFactory.getServerReaderInterceptorRegistry().getListeners().remove(param);
      }
    }
  }

  public void registryUpdated(JaxrsInterceptorRegistry registry) {
    this.resourceMethodProviderFactory = new ResteasyProviderFactory(parentProviderFactory);
    for (DynamicFeature feature : parentProviderFactory.getServerDynamicFeatures()) {
      feature.configure(resourceInfo, new FeatureContextDelegate(resourceMethodProviderFactory));
    }
    if (registry.getIntf().equals(WriterInterceptor.class)) {
      writerInterceptors =
          resourceMethodProviderFactory
              .getServerWriterInterceptorRegistry()
              .postMatch(method.getResourceClass().getClazz(), method.getAnnotatedMethod());
    } else if (registry.getIntf().equals(ContainerRequestFilter.class)) {
      requestFilters =
          resourceMethodProviderFactory
              .getContainerRequestFilterRegistry()
              .postMatch(method.getResourceClass().getClazz(), method.getAnnotatedMethod());
    } else if (registry.getIntf().equals(ContainerResponseFilter.class)) {
      responseFilters =
          resourceMethodProviderFactory
              .getContainerResponseFilterRegistry()
              .postMatch(method.getResourceClass().getClazz(), method.getAnnotatedMethod());
    }
  }

  protected void incrementMethodCount(String httpMethod) {
    AtomicLong stat = stats.get(httpMethod);
    if (stat == null) {
      stat = new AtomicLong();
      AtomicLong old = stats.putIfAbsent(httpMethod, stat);
      if (old != null) stat = old;
    }
    stat.incrementAndGet();
  }

  /**
   * Key is httpMethod called
   *
   * @return
   */
  public Map<String, AtomicLong> getStats() {
    return stats;
  }

  public ContainerRequestFilter[] getRequestFilters() {
    return requestFilters;
  }

  public ContainerResponseFilter[] getResponseFilters() {
    return responseFilters;
  }

  public WriterInterceptor[] getWriterInterceptors() {
    return writerInterceptors;
  }

  public Type getGenericReturnType() {
    return method.getGenericReturnType();
  }

  public Class<?> getResourceClass() {
    return method.getResourceClass().getClazz();
  }

  public Annotation[] getMethodAnnotations() {
    return method.getAnnotatedMethod().getAnnotations();
  }

  @Override
  public Method getMethod() {
    return method.getMethod();
  }

  public BuiltResponse invoke(HttpRequest request, HttpResponse response) {
    Object target = resource.createResource(request, response, resourceMethodProviderFactory);
    return invoke(request, response, target);
  }

  public BuiltResponse invoke(HttpRequest request, HttpResponse response, Object target) {
    request.setAttribute(ResourceMethodInvoker.class.getName(), this);
    incrementMethodCount(request.getHttpMethod());
    ResteasyUriInfo uriInfo = (ResteasyUriInfo) request.getUri();
    if (method.getPath() != null) {
      uriInfo.pushMatchedURI(uriInfo.getMatchingPath());
    }
    uriInfo.pushCurrentResource(target);
    BuiltResponse rtn = invokeOnTarget(request, response, target);
    return rtn;
  }

  protected BuiltResponse invokeOnTarget(
      HttpRequest request, HttpResponse response, Object target) {
    ResteasyProviderFactory.pushContext(
        ResourceInfo.class, resourceInfo); // we don't pop so writer interceptors can get at this

    PostMatchContainerRequestContext requestContext =
        new PostMatchContainerRequestContext(request, this);
    for (ContainerRequestFilter filter : requestFilters) {
      try {
        filter.filter(requestContext);
      } catch (IOException e) {
        throw new ApplicationException(e);
      }
      BuiltResponse serverResponse = (BuiltResponse) requestContext.getResponseAbortedWith();
      if (serverResponse != null) {
        return serverResponse;
      }
    }

    if (validator != null) {
      if (isValidatable) {
        validator.validate(request, target);
      }
      if (methodIsValidatable) {
        request.setAttribute(GeneralValidator.class.getName(), validator);
      } else if (isValidatable) {
        validator.checkViolations(request);
      }
    }

    Object rtn = null;
    try {
      rtn = methodInjector.invoke(request, response, target);
    } catch (RuntimeException ex) {
      if (request.getAsyncContext().isSuspended()) {
        try {
          request.getAsyncContext().getAsyncResponse().resume(ex);
        } catch (Exception e) {
          logger.error("Error resuming failed async operation", e);
        }
        return null;
      } else {
        throw ex;
      }
    }

    if (request.getAsyncContext().isSuspended() || request.wasForwarded()) {
      return null;
    }
    if (rtn == null || method.getReturnType().equals(void.class)) {
      BuiltResponse build = (BuiltResponse) Response.noContent().build();
      build.addMethodAnnotations(method.getAnnotatedMethod());
      return build;
    }
    if (Response.class.isAssignableFrom(method.getReturnType()) || rtn instanceof Response) {
      if (!(rtn instanceof BuiltResponse)) {
        Response r = (Response) rtn;
        Headers<Object> metadata = new Headers<Object>();
        metadata.putAll(r.getMetadata());
        rtn = new BuiltResponse(r.getStatus(), metadata, r.getEntity(), null);
      }
      BuiltResponse rtn1 = (BuiltResponse) rtn;
      rtn1.addMethodAnnotations(method.getAnnotatedMethod());
      if (rtn1.getGenericType() == null) {
        if (getMethod().getReturnType().equals(Response.class)) {
          rtn1.setGenericType(rtn1.getEntityClass());
        } else {
          rtn1.setGenericType(method.getGenericReturnType());
        }
      }
      return rtn1;
    }

    Response.ResponseBuilder builder = Response.ok(rtn);
    BuiltResponse jaxrsResponse = (BuiltResponse) builder.build();
    if (jaxrsResponse.getGenericType() == null) {
      if (getMethod().getReturnType().equals(Response.class)) {
        jaxrsResponse.setGenericType(jaxrsResponse.getEntityClass());
      } else {
        jaxrsResponse.setGenericType(method.getGenericReturnType());
      }
    }
    jaxrsResponse.addMethodAnnotations(method.getAnnotatedMethod());
    return jaxrsResponse;
  }

  public void initializeAsync(ResteasyAsynchronousResponse asyncResponse) {
    asyncResponse.setAnnotations(method.getAnnotatedMethod().getAnnotations());
    asyncResponse.setWriterInterceptors(writerInterceptors);
    asyncResponse.setResponseFilters(responseFilters);
    asyncResponse.setMethod(this);
  }

  public boolean doesProduce(List<? extends MediaType> accepts) {
    if (accepts == null || accepts.size() == 0) {
      // System.out.println("**** no accepts " +" method: " + method);
      return true;
    }
    if (method.getProduces().length == 0) {
      // System.out.println("**** no produces " +" method: " + method);
      return true;
    }

    for (MediaType accept : accepts) {
      for (MediaType type : method.getProduces()) {
        if (type.isCompatible(accept)) {
          return true;
        }
      }
    }
    return false;
  }

  public boolean doesConsume(MediaType contentType) {
    boolean matches = false;
    if (method.getConsumes().length == 0 || (contentType == null && !expectsBody)) return true;

    if (contentType == null) {
      contentType = MediaType.APPLICATION_OCTET_STREAM_TYPE;
    }
    for (MediaType type : method.getConsumes()) {
      if (type.isCompatible(contentType)) {
        matches = true;
        break;
      }
    }
    return matches;
  }

  public MediaType resolveContentType(HttpRequest in, Object entity) {
    MediaType chosen = (MediaType) in.getAttribute(SegmentNode.RESTEASY_CHOSEN_ACCEPT);
    if (chosen != null && !chosen.equals(MediaType.WILDCARD_TYPE)) {
      return chosen;
    }

    List<MediaType> accepts = in.getHttpHeaders().getAcceptableMediaTypes();

    if (accepts == null || accepts.size() == 0) {
      if (method.getProduces().length == 0) return MediaType.WILDCARD_TYPE;
      else return method.getProduces()[0];
    }

    if (method.getProduces().length == 0) {
      return resolveContentTypeByAccept(accepts, entity);
    }

    for (MediaType accept : accepts) {
      for (MediaType type : method.getProduces()) {
        if (type.isCompatible(accept)) return type;
      }
    }
    return MediaType.WILDCARD_TYPE;
  }

  protected MediaType resolveContentTypeByAccept(List<MediaType> accepts, Object entity) {
    if (accepts == null || accepts.size() == 0 || entity == null) {
      return MediaType.WILDCARD_TYPE;
    }
    Class clazz = entity.getClass();
    Type type = this.method.getGenericReturnType();
    if (entity instanceof GenericEntity) {
      GenericEntity gen = (GenericEntity) entity;
      clazz = gen.getRawType();
      type = gen.getType();
    }
    for (MediaType accept : accepts) {
      if (resourceMethodProviderFactory.getMessageBodyWriter(
              clazz, type, method.getAnnotatedMethod().getAnnotations(), accept)
          != null) {
        return accept;
      }
    }
    return MediaType.WILDCARD_TYPE;
  }

  public Set<String> getHttpMethods() {
    return method.getHttpMethods();
  }

  public MediaType[] getProduces() {
    return method.getProduces();
  }

  public MediaType[] getConsumes() {
    return method.getConsumes();
  }
}
public class SrampServiceImpl implements SrampService {

  private Logger log = Logger.getLogger(SrampServiceImpl.class);

  private SrampAtomApiClient client = null;

  private long artifactCounter = 0;

  public SrampServiceImpl(SrampConfiguration config)
      throws SrampClientException, SrampAtomException {

    // create connection to S-RAMP
    this.client =
        new SrampAtomApiClient(
            config.getSrampServerURL(), config.getSrampUsername(), config.getSrampPassword(), true);
  }

  /*
   * (non-Javadoc)
   *
   * @see org.jboss.arquillian.container.sramp.SrampService#getClient()
   */
  public SrampAtomApiClient getClient() {
    return client;
  }

  /*
   * (non-Javadoc)
   *
   * @see
   * org.jboss.arquillian.container.sramp.SrampService#deployArchive(java.
   * lang.String, java.lang.String, java.io.InputStream)
   */
  public BaseArtifactType deployArchive(
      String archiveId, String archiveName, String artifactTypeArg, InputStream content) {

    assert content != null;

    ZipToSrampArchive expander = null;
    SrampArchive archive = null;
    BaseArtifactType artifact = null;
    File tempResourceFile = null;
    try {
      // internal integrity check
      artifactCounter = client.query("/s-ramp").getTotalResults();

      // First, stash the content in a temp file - we may need it multiple
      // times.
      tempResourceFile = stashResourceContent(content);
      content = FileUtils.openInputStream(tempResourceFile);

      ArtifactType artifactType = ArtifactType.valueOf(artifactTypeArg);
      if (artifactType.isExtendedType()) {
        artifactType = ArtifactType.ExtendedDocument(artifactType.getExtendedType());
      }

      artifact = client.uploadArtifact(artifactType, content, archiveName);
      IOUtils.closeQuietly(content);

      // for all uploaded files add custom property
      SrampModelUtils.setCustomProperty(artifact, "arquillian-archive-id", archiveId);
      client.updateArtifactMetaData(artifact);

      content = FileUtils.openInputStream(tempResourceFile);

      // Now also add "expanded" content to the s-ramp repository
      expander = ZipToSrampArchiveRegistry.createExpander(artifactType, content);

      if (expander != null) {
        expander.setContextParam(DefaultMetaDataFactory.PARENT_UUID, artifact.getUuid());
        archive = expander.createSrampArchive();
        client.uploadBatch(archive);
      }
    } catch (Exception e) {
      log.error("Upload failure:", e);
      IOUtils.closeQuietly(content);
    } finally {
      SrampArchive.closeQuietly(archive);
      ZipToSrampArchive.closeQuietly(expander);
      FileUtils.deleteQuietly(tempResourceFile);
    }

    return artifact;
  }

  /*
   * (non-Javadoc)
   *
   * @see org.jboss.arquillian.container.sramp.SrampService#undeployArchives()
   */
  public void undeployArchives(String archiveId) throws SrampClientException, SrampAtomException {

    log.debug("Deleting expanded artifacts");

    // Delete expanded artifacts
    QueryResultSet rset =
        client
            .buildQuery("/s-ramp[expandedFromDocument[@arquillian-archive-id = ?]]")
            .parameter(archiveId)
            .query();

    for (ArtifactSummary artifactSummary : rset) {
      log.debug("Deleting: " + artifactSummary.getName());
      client.deleteArtifact(artifactSummary.getUuid(), artifactSummary.getType());
    }

    // Delete (un)deployment information
    rset =
        client
            .buildQuery("/s-ramp[describesDeployment[@arquillian-archive-id = ?]]")
            .parameter(archiveId)
            .query();
    for (ArtifactSummary artifactSummary : rset) {
      log.debug("Deleting: " + artifactSummary.getName());
      client.deleteArtifact(artifactSummary.getUuid(), artifactSummary.getType());
    }

    // Delete main archive
    // Related are deleted along with the primary
    rset = client.buildQuery("/s-ramp[@arquillian-archive-id = ?]").parameter(archiveId).query();

    ArtifactSummary archiveArtifact = rset.get(0);

    log.debug("Deleting: " + archiveArtifact.getName());

    client.deleteArtifact(archiveArtifact.getUuid(), archiveArtifact.getType());

    // Internal consistency check whether the number of artifacts before
    // deploy and after deploy match
    long artifactCounterTemp = client.query("/s-ramp").getTotalResults();

    if (artifactCounter != artifactCounterTemp) {
      log.warn("Artifact counts does not match!");
      log.warn(
          "Artifacts before deploy: "
              + artifactCounter
              + ". Artifacts after undeploy: "
              + artifactCounterTemp);
      artifactCounter = artifactCounterTemp;
    }
  }

  /**
   * Make a temporary copy of the resource by saving the content to a temp file.
   *
   * @param resourceInputStream
   * @throws IOException
   */
  private File stashResourceContent(InputStream resourceInputStream) throws IOException {
    File resourceTempFile = null;
    OutputStream oStream = null;
    try {
      resourceTempFile = File.createTempFile("s-ramp-resource", ".tmp");
      oStream = FileUtils.openOutputStream(resourceTempFile);
    } finally {
      IOUtils.copy(resourceInputStream, oStream);
      IOUtils.closeQuietly(resourceInputStream);
      IOUtils.closeQuietly(oStream);
    }
    return resourceTempFile;
  }
}
/**
 * @author <a href="mailto:[email protected]">Ron Sigal</a>
 * @version $Revision: 1.1 $ Created Feb 1, 2012
 */
public class ExternalEntityUnmarshaller implements Unmarshaller {
  static final Logger log = Logger.getLogger(ExternalEntityUnmarshaller.class);

  private Unmarshaller delegate;

  public ExternalEntityUnmarshaller(Unmarshaller delegate) {
    this.delegate = delegate;
  }

  @SuppressWarnings("unchecked")
  public <A extends XmlAdapter> A getAdapter(Class<A> type) {
    return delegate.getAdapter(type);
  }

  public AttachmentUnmarshaller getAttachmentUnmarshaller() {
    return delegate.getAttachmentUnmarshaller();
  }

  public ValidationEventHandler getEventHandler() throws JAXBException {
    return delegate.getEventHandler();
  }

  public Listener getListener() {
    return delegate.getListener();
  }

  public Object getProperty(String name) throws PropertyException {
    return delegate.getProperty(name);
  }

  public Schema getSchema() {
    return delegate.getSchema();
  }

  public UnmarshallerHandler getUnmarshallerHandler() {
    return delegate.getUnmarshallerHandler();
  }

  /** @deprecated since 2.0 */
  @Deprecated
  public boolean isValidating() throws JAXBException {
    return delegate.isValidating();
  }

  /** @deprecated since 2.0 */
  @Deprecated
  @SuppressWarnings("unchecked")
  public void setAdapter(XmlAdapter adapter) {
    delegate.setAdapter(adapter);
  }

  @SuppressWarnings("unchecked")
  public <A extends XmlAdapter> void setAdapter(Class<A> type, A adapter) {
    delegate.setAdapter(adapter);
  }

  public void setAttachmentUnmarshaller(AttachmentUnmarshaller au) {
    delegate.setAttachmentUnmarshaller(au);
  }

  public void setEventHandler(ValidationEventHandler handler) throws JAXBException {
    delegate.setEventHandler(handler);
  }

  public void setListener(Listener listener) {
    delegate.setListener(listener);
  }

  public void setProperty(String name, Object value) throws PropertyException {
    delegate.setProperty(name, value);
  }

  public void setSchema(Schema schema) {
    delegate.setSchema(schema);
  }

  /** @deprecated since 2.0 */
  @Deprecated
  public void setValidating(boolean validating) throws JAXBException {
    delegate.setValidating(validating);
  }

  public Object unmarshal(File f) throws JAXBException {
    throw new UnsupportedOperationException(errorMessage("File"));
  }

  /** Turns off expansion of external entities. */
  public Object unmarshal(InputStream is) throws JAXBException {
    return unmarshal(new InputSource(is));
  }

  public Object unmarshal(Reader reader) throws JAXBException {
    throw new UnsupportedOperationException(errorMessage("Reader"));
  }

  public Object unmarshal(URL url) throws JAXBException {
    throw new UnsupportedOperationException(errorMessage("URL"));
  }

  /** Turns off expansion of external entities. */
  public Object unmarshal(InputSource source) throws JAXBException {
    try {
      SAXParserFactory spf = SAXParserFactory.newInstance();
      SAXParser sp = spf.newSAXParser();
      XMLReader xmlReader = sp.getXMLReader();
      xmlReader.setFeature("http://xml.org/sax/features/validation", false);
      xmlReader.setFeature("http://xml.org/sax/features/external-general-entities", false);
      SAXSource saxSource = new SAXSource(xmlReader, source);
      return delegate.unmarshal(saxSource);
    } catch (SAXException e) {
      throw new JAXBException(e);
    } catch (ParserConfigurationException e) {
      throw new JAXBException(e);
    }
  }

  public Object unmarshal(Node node) throws JAXBException {
    return delegate.unmarshal(node);
  }

  public Object unmarshal(Source source) throws JAXBException {
    throw new UnsupportedOperationException(errorMessage("Source"));
  }

  public Object unmarshal(XMLStreamReader reader) throws JAXBException {
    throw new UnsupportedOperationException(errorMessage("XMLStreamReader"));
  }

  public Object unmarshal(XMLEventReader reader) throws JAXBException {
    throw new UnsupportedOperationException(errorMessage("XMLEventReader"));
  }

  public <T> JAXBElement<T> unmarshal(Node node, Class<T> declaredType) throws JAXBException {
    throw new UnsupportedOperationException(errorMessage("Node, Class<T>"));
  }

  public <T> JAXBElement<T> unmarshal(Source source, Class<T> declaredType) throws JAXBException {
    if (source instanceof SAXSource) {
      try {
        SAXParserFactory spf = SAXParserFactory.newInstance();
        SAXParser sp = spf.newSAXParser();
        XMLReader xmlReader = sp.getXMLReader();
        xmlReader.setFeature("http://xml.org/sax/features/validation", false);
        xmlReader.setFeature("http://xml.org/sax/features/external-general-entities", false);
        ((SAXSource) source).setXMLReader(xmlReader);
        return delegate.unmarshal(source, declaredType);
      } catch (SAXException e) {
        throw new JAXBException(e);
      } catch (ParserConfigurationException e) {
        throw new JAXBException(e);
      }
    }

    throw new UnsupportedOperationException(errorMessage("Source, Class<T>"));
  }

  public <T> JAXBElement<T> unmarshal(XMLStreamReader reader, Class<T> declaredType)
      throws JAXBException {
    throw new UnsupportedOperationException(errorMessage("XMLStreamReader, Class<T>"));
  }

  public <T> JAXBElement<T> unmarshal(XMLEventReader reader, Class<T> declaredType)
      throws JAXBException {
    throw new UnsupportedOperationException(errorMessage("XMLEventReader, Class<T>"));
  }

  public Unmarshaller getDelegate() {
    return delegate;
  }

  public void setDelegate(Unmarshaller delegate) {
    this.delegate = delegate;
  }

  private String errorMessage(String s) {
    return "ExternalEntityUnmarshallerWrapper: unexpected use of unmarshal(" + s + ")";
  }
}
/**
 * @author <a href="mailto:[email protected]">Solomon Duskis</a>
 * @version $Revision: 1 $
 */
public class ResteasyHandlerMapping implements HandlerMapping, Ordered, InitializingBean {
  private static Logger logger = Logger.getLogger(ResteasyHandlerMapping.class);

  private int order = Integer.MAX_VALUE;
  private SynchronousDispatcher dispatcher;

  private String prefix = "";
  private HandlerInterceptor[] interceptors;
  private boolean throwNotFound = false;

  public ResteasyHandlerMapping(ResteasyDeployment deployment) {
    super();
    this.dispatcher = (SynchronousDispatcher) deployment.getDispatcher();
  }

  public SynchronousDispatcher getDispatcher() {
    return dispatcher;
  }

  public void setOrder(int order) {
    this.order = order;
  }

  public HandlerInterceptor[] getInterceptors() {
    return interceptors;
  }

  public void setInterceptors(HandlerInterceptor[] interceptors) {
    this.interceptors = interceptors;
  }

  public HandlerExecutionChain getHandler(HttpServletRequest request) throws Exception {
    ResteasyRequestWrapper requestWrapper =
        RequestUtil.getRequestWrapper(request, request.getMethod(), prefix);
    try {
      // NOTE: if invoker isn't found, RESTEasy throw NoReourceFoundFailure
      HttpRequest httpRequest = requestWrapper.getHttpRequest();
      if (!httpRequest.isInitial()) {
        String message =
            httpRequest.getUri().getPath()
                + " is not initial request.  Its suspended and retried.  Aborting.";
        logger.error(message);
        requestWrapper.setError(500, message);
      } else {
        Response response = dispatcher.preprocess(httpRequest);
        if (response != null) {
          requestWrapper.setAbortedResponse(response);
        } else {
          requestWrapper.setInvoker(getInvoker(httpRequest));
        }
      }
      return new HandlerExecutionChain(requestWrapper, interceptors);
    } catch (NotFoundException e) {
      if (throwNotFound) {
        throw e;
      }
      logger.error("Resource Not Found: " + e.getMessage(), e);
    } catch (Failure e) {
      logger.error("ResourceFailure: " + e.getMessage(), e);
      throw e;
    }
    return null;
  }

  private ResourceInvoker getInvoker(HttpRequest httpRequest) {
    if (dispatcher != null) return dispatcher.getInvoker(httpRequest);
    return null;
  }

  public int getOrder() {
    return order;
  }

  public boolean isThrowNotFound() {
    return throwNotFound;
  }

  public void setThrowNotFound(boolean throwNotFound) {
    this.throwNotFound = throwNotFound;
  }

  public String getPrefix() {
    return prefix;
  }

  public void setPrefix(String prefix) {
    this.prefix = prefix;
  }

  public void afterPropertiesSet() throws Exception {
    if (!throwNotFound && order == Integer.MAX_VALUE) {
      logger.info(
          "ResteasyHandlerMapping has the default order and throwNotFound settings.  Consider adding explicit ordering to your HandlerMappings, with ResteasyHandlerMapping being last, and set throwNotFound = true.");
    }
  }
}
/**
 * @author <a href="mailto:[email protected]">Bill Burke</a>
 * @version $Revision: 1 $
 */
@SuppressWarnings("unchecked")
public class SynchronousDispatcher implements Dispatcher {
  protected ResteasyProviderFactory providerFactory;
  protected Registry registry;
  protected List<HttpRequestPreprocessor> requestPreprocessors =
      new ArrayList<HttpRequestPreprocessor>();
  protected Map<Class, Object> defaultContextObjects = new HashMap<Class, Object>();
  protected Set<String> unwrappedExceptions = new HashSet<String>();

  private static final Logger logger = Logger.getLogger(SynchronousDispatcher.class);

  public SynchronousDispatcher(ResteasyProviderFactory providerFactory) {
    this.providerFactory = providerFactory;
    this.registry = new ResourceMethodRegistry(providerFactory);
    defaultContextObjects.put(Providers.class, providerFactory);
    defaultContextObjects.put(Registry.class, registry);
    defaultContextObjects.put(Dispatcher.class, this);
    defaultContextObjects.put(InternalDispatcher.class, InternalDispatcher.getInstance());
  }

  public ResteasyProviderFactory getProviderFactory() {
    return providerFactory;
  }

  public Registry getRegistry() {
    return registry;
  }

  public Map<Class, Object> getDefaultContextObjects() {
    return defaultContextObjects;
  }

  public Set<String> getUnwrappedExceptions() {
    return unwrappedExceptions;
  }

  /**
   * Call pre-process ContainerRequestFilters
   *
   * @param in
   * @return
   */
  public Response preprocess(HttpRequest in) {
    for (HttpRequestPreprocessor preprocessor : this.requestPreprocessors) {
      preprocessor.preProcess(in);
    }
    ContainerRequestFilter[] requestFilters =
        providerFactory.getContainerRequestFilterRegistry().preMatch();
    PreMatchContainerRequestContext requestContext = new PreMatchContainerRequestContext(in);
    for (ContainerRequestFilter filter : requestFilters) {
      try {
        filter.filter(requestContext);
      } catch (IOException e) {
        throw new RuntimeException(e);
      }
      Response response = requestContext.getResponseAbortedWith();
      if (response != null) return response;
    }
    return null;
  }

  public void invoke(HttpRequest request, HttpResponse response) {
    try {
      pushContextObjects(request, response);
      Response aborted = preprocess(request);
      if (aborted != null) {
        try {
          writeJaxrsResponse(request, response, aborted);
        } catch (Exception e) {
          handleWriteResponseException(request, response, e);
          return;
        }
        return;
      }
      ResourceInvoker invoker = getInvoker(request);
      invoke(request, response, invoker);
    } catch (Failure e) {
      handleException(request, response, e);
      return;
    } finally {
      clearContextData();
    }
  }

  /**
   * Propagate NotFoundException. This is used for Filters
   *
   * @param request
   * @param response
   */
  public void invokePropagateNotFound(HttpRequest request, HttpResponse response)
      throws NotFoundException {
    try {
      pushContextObjects(request, response);
      try {
        Response aborted = preprocess(request);
        if (aborted != null) {
          try {
            writeJaxrsResponse(request, response, aborted);
          } catch (Exception e) {
            handleWriteResponseException(request, response, e);
            return;
          }
          return;
        }
      } catch (Exception e) {
        handleException(request, response, e);
        return;
      }
      ResourceInvoker invoker = null;
      try {
        invoker = getInvoker(request);
      } catch (Exception failure) {
        if (failure instanceof NotFoundException) {
          throw ((NotFoundException) failure);
        } else {
          handleException(request, response, failure);
          return;
        }
      }
      try {
        invoke(request, response, invoker);
      } catch (Failure e) {
        handleException(request, response, e);
        return;
      }
    } finally {
      clearContextData();
    }
  }

  public ResourceInvoker getInvoker(HttpRequest request) throws Failure {
    logger.debug("PathInfo: " + request.getUri().getPath());
    if (!request.isInitial()) {
      throw new InternalServerErrorException(
          request.getUri().getPath()
              + " is not initial request.  Its suspended and retried.  Aborting.");
    }
    ResourceInvoker invoker = registry.getResourceInvoker(request);
    if (invoker == null) {
      throw new NotFoundException(
          "Unable to find JAX-RS resource associated with path: " + request.getUri().getPath());
    }
    return invoker;
  }

  /**
   * Called if method invoke was unsuccessful
   *
   * @param request
   * @param response
   * @param e
   */
  public void handleInvokerException(HttpRequest request, HttpResponse response, Exception e) {
    handleException(request, response, e);
  }

  /**
   * Called if method invoke was successful, but writing the Response after was not.
   *
   * @param request
   * @param response
   * @param e
   */
  public void handleWriteResponseException(
      HttpRequest request, HttpResponse response, Exception e) {
    handleException(request, response, e);
  }

  public void handleException(HttpRequest request, HttpResponse response, Throwable e) {
    // See if there is an ExceptionMapper for the exact class of the exception instance being thrown
    if (executeExactExceptionMapper(request, response, e)) return;

    // These are wrapper exceptions so they need to be processed first as they map e.getCause()
    if (e instanceof ApplicationException) {
      handleApplicationException(request, response, (ApplicationException) e);
      return;
    } else if (e instanceof WriterException) {
      handleWriterException(request, response, (WriterException) e);
      return;
    } else if (e instanceof ReaderException) {
      handleReaderException(request, response, (ReaderException) e);
      return;
    }

    // First try and handle it with a mapper
    if (executeExceptionMapper(request, response, e)) {
      return;
    }
    // Otherwise do specific things
    else if (e instanceof WebApplicationException) {
      handleWebApplicationException(request, response, (WebApplicationException) e);
    } else if (e instanceof Failure) {
      handleFailure(request, response, (Failure) e);
    } else {
      logger.error(
          "Unknown exception while executing "
              + request.getHttpMethod()
              + " "
              + request.getUri().getPath(),
          e);
      throw new UnhandledException(e);
    }
  }

  protected void handleFailure(HttpRequest request, HttpResponse response, Failure failure) {
    if (failure.isLoggable())
      logger.error(
          "Failed executing " + request.getHttpMethod() + " " + request.getUri().getPath(),
          failure);
    else
      logger.debug(
          "Failed executing " + request.getHttpMethod() + " " + request.getUri().getPath(),
          failure);

    if (failure.getResponse() != null) {
      writeFailure(request, response, failure.getResponse());
    } else {
      try {
        if (failure.getMessage() != null) {
          response.sendError(failure.getErrorCode(), failure.getMessage());
        } else {
          response.sendError(failure.getErrorCode());
        }
      } catch (IOException e1) {
        throw new UnhandledException(e1);
      }
    }
  }

  /**
   * If there exists an Exception mapper for exception, execute it, otherwise, do NOT recurse up
   * class hierarchy of exception.
   *
   * @param request
   * @param response
   * @param exception
   * @return
   */
  public boolean executeExactExceptionMapper(
      HttpRequest request, HttpResponse response, Throwable exception) {
    ExceptionMapper mapper = providerFactory.getExceptionMapper(exception.getClass());
    if (mapper == null) return false;
    writeFailure(request, response, mapper.toResponse(exception));
    return true;
  }

  public boolean executeExceptionMapperForClass(
      HttpRequest request, HttpResponse response, Throwable exception, Class clazz) {
    ExceptionMapper mapper = providerFactory.getExceptionMapper(clazz);
    if (mapper == null) return false;
    writeFailure(request, response, mapper.toResponse(exception));
    return true;
  }

  /**
   * Execute an ExceptionMapper if one exists for the given exception. Recurse to base class if not
   * found
   *
   * @param response
   * @param exception
   * @return true if an ExceptionMapper was found and executed
   */
  public boolean executeExceptionMapper(
      HttpRequest request, HttpResponse response, Throwable exception) {
    ExceptionMapper mapper = null;

    Class causeClass = exception.getClass();
    while (mapper == null) {
      if (causeClass == null) break;
      mapper = providerFactory.getExceptionMapper(causeClass);
      if (mapper == null) causeClass = causeClass.getSuperclass();
    }
    if (mapper != null) {
      Response jaxrsResponse = mapper.toResponse(exception);
      if (jaxrsResponse == null) {
        jaxrsResponse = Response.status(204).build();
      }
      writeFailure(request, response, jaxrsResponse);
      return true;
    }
    return false;
  }

  protected void handleApplicationException(
      HttpRequest request, HttpResponse response, ApplicationException e) {
    // See if there is a mapper for ApplicationException
    if (executeExceptionMapperForClass(request, response, e, ApplicationException.class)) {
      return;
    }
    Throwable unhandled = unwrapException(request, response, e);
    if (unhandled != null) {
      throw new UnhandledException(unhandled);
    }
  }

  protected Throwable unwrapException(HttpRequest request, HttpResponse response, Throwable e) {
    Throwable unwrappedException = e.getCause();

    if (executeExceptionMapper(request, response, unwrappedException)) {
      return null;
    }
    if (unwrappedException instanceof WebApplicationException) {
      handleWebApplicationException(
          request, response, (WebApplicationException) unwrappedException);
      return null;
    } else if (unwrappedException instanceof Failure) {
      handleFailure(request, response, (Failure) unwrappedException);
      return null;
    } else {
      if (unwrappedExceptions.contains(unwrappedException.getClass().getName())
          && unwrappedException.getCause() != null) {
        return unwrapException(request, response, unwrappedException);
      } else {
        return unwrappedException;
      }
    }
  }

  protected void handleWriterException(
      HttpRequest request, HttpResponse response, WriterException e) {
    // See if there is a general mapper for WriterException
    if (executeExceptionMapperForClass(request, response, e, WriterException.class)) {
      return;
    }
    if (e.getResponse() != null || e.getErrorCode() > -1) {
      handleFailure(request, response, e);
      return;
    } else if (e.getCause() != null) {
      if (unwrapException(request, response, e) == null) return;
    }
    e.setErrorCode(HttpResponseCodes.SC_INTERNAL_SERVER_ERROR);
    handleFailure(request, response, e);
  }

  protected void handleReaderException(
      HttpRequest request, HttpResponse response, ReaderException e) {
    // See if there is a general mapper for ReaderException
    if (executeExceptionMapperForClass(request, response, e, ReaderException.class)) {
      return;
    }
    // If a response or error code set, use that, otherwise look at cause.
    if (e.getResponse() != null || e.getErrorCode() > -1) {
      handleFailure(request, response, e);
      return;
    } else if (e.getCause() != null) {
      if (unwrapException(request, response, e) == null) return;
    }
    e.setErrorCode(HttpResponseCodes.SC_BAD_REQUEST);
    handleFailure(request, response, e);
  }

  protected void writeFailure(HttpRequest request, HttpResponse response, Response jaxrsResponse) {
    response.reset();
    try {
      writeJaxrsResponse(request, response, jaxrsResponse);
    } catch (WebApplicationException ex) {
      if (response.isCommitted())
        throw new UnhandledException("Request was committed couldn't handle exception", ex);
      // don't think I want to call writeJaxrsResponse infinately! so we'll just write the status
      response.reset();
      response.setStatus(ex.getResponse().getStatus());

    } catch (Exception e1) {
      throw new UnhandledException(e1); // we're screwed, can't handle the exception
    }
  }

  protected void handleWebApplicationException(
      HttpRequest request, HttpResponse response, WebApplicationException wae) {
    if (!(wae instanceof NoLogWebApplicationException)) logger.error("failed to execute", wae);
    if (response.isCommitted())
      throw new UnhandledException("Request was committed couldn't handle exception", wae);

    writeFailure(request, response, wae.getResponse());
  }

  public void pushContextObjects(HttpRequest request, HttpResponse response) {
    Map contextDataMap = ResteasyProviderFactory.getContextDataMap();
    contextDataMap.put(HttpRequest.class, request);
    contextDataMap.put(HttpResponse.class, response);
    contextDataMap.put(HttpHeaders.class, request.getHttpHeaders());
    contextDataMap.put(UriInfo.class, request.getUri());
    contextDataMap.put(Request.class, new RequestImpl(request));
    contextDataMap.put(ResteasyAsynchronousContext.class, request.getAsyncContext());

    contextDataMap.putAll(defaultContextObjects);
  }

  public Response internalInvocation(HttpRequest request, HttpResponse response, Object entity) {
    // be extra careful in the clean up process. Only pop if there was an
    // equivalent push.
    ResteasyProviderFactory.addContextDataLevel();
    boolean pushedBody = false;
    try {
      MessageBodyParameterInjector.pushBody(entity);
      pushedBody = true;
      ResourceInvoker invoker = getInvoker(request);
      if (invoker != null) {
        pushContextObjects(request, response);
        return getResponse(request, response, invoker);
      }

      // this should never happen, since getInvoker should throw an exception
      // if invoker is null
      return null;
    } finally {
      ResteasyProviderFactory.removeContextDataLevel();
      if (pushedBody) {
        MessageBodyParameterInjector.popBody();
      }
    }
  }

  public void clearContextData() {
    ResteasyProviderFactory.clearContextData();
    // just in case there were internalDispatches that need to be cleaned up
    MessageBodyParameterInjector.clearBodies();
  }

  public void invoke(HttpRequest request, HttpResponse response, ResourceInvoker invoker) {
    Response jaxrsResponse = getResponse(request, response, invoker);

    try {
      if (jaxrsResponse != null) writeJaxrsResponse(request, response, jaxrsResponse);
    } catch (Exception e) {
      handleWriteResponseException(request, response, e);
    }
  }

  protected Response getResponse(
      HttpRequest request, HttpResponse response, ResourceInvoker invoker) {
    Response jaxrsResponse = null;
    try {
      jaxrsResponse = invoker.invoke(request, response);
      if (request.getAsyncContext().isSuspended()) {
        /**
         * Callback by the initial calling thread. This callback will probably do nothing in an
         * asynchronous environment but will be used to simulate AsynchronousResponse in vanilla
         * Servlet containers that do not support asychronous HTTP.
         */
        request.getAsyncContext().getAsyncResponse().initialRequestThreadFinished();
        jaxrsResponse = null; // we're handing response asynchronously
      }
    } catch (Exception e) {
      handleInvokerException(request, response, e);
    }
    return jaxrsResponse;
  }

  public void asynchronousDelivery(
      HttpRequest request, HttpResponse response, Response jaxrsResponse) {
    try {
      pushContextObjects(request, response);
      try {
        if (jaxrsResponse != null) writeJaxrsResponse(request, response, jaxrsResponse);
      } catch (Exception e) {
        handleWriteResponseException(request, response, e);
      }
    } finally {
      clearContextData();
    }
  }

  protected void writeJaxrsResponse(
      HttpRequest request, HttpResponse response, Response jaxrsResponse) throws WriterException {
    Object type = jaxrsResponse.getMetadata().getFirst(HttpHeaderNames.CONTENT_TYPE);
    if (type == null && jaxrsResponse.getEntity() != null) {
      ResourceMethod method = (ResourceMethod) request.getAttribute(ResourceMethod.class.getName());
      if (method != null) {
        jaxrsResponse
            .getMetadata()
            .putSingle(
                HttpHeaderNames.CONTENT_TYPE,
                method.resolveContentType(request, jaxrsResponse.getEntity()));
      } else {
        MediaType contentType =
            resolveContentTypeByAccept(
                request.getHttpHeaders().getAcceptableMediaTypes(), jaxrsResponse.getEntity());
        jaxrsResponse.getMetadata().putSingle(HttpHeaderNames.CONTENT_TYPE, contentType);
      }
    }

    ServerResponseWriter.writeResponse(
        (BuiltResponse) jaxrsResponse, request, response, providerFactory);
  }

  protected MediaType resolveContentTypeByAccept(List<MediaType> accepts, Object entity) {
    if (accepts == null || accepts.size() == 0 || entity == null) {
      return MediaType.WILDCARD_TYPE;
    }
    Class clazz = entity.getClass();
    Type type = null;
    if (entity instanceof GenericEntity) {
      GenericEntity gen = (GenericEntity) entity;
      clazz = gen.getRawType();
      type = gen.getType();
    }
    for (MediaType accept : accepts) {
      if (providerFactory.getMessageBodyWriter(clazz, type, null, accept) != null) {
        return accept;
      }
    }
    return MediaType.WILDCARD_TYPE;
  }

  public void addHttpPreprocessor(HttpRequestPreprocessor httpPreprocessor) {
    requestPreprocessors.add(httpPreprocessor);
  }
}
Example #7
0
/**
 * @author <a href="mailto:[email protected]">Bill Burke</a>
 * @version $Revision: 1 $
 */
public class RealmsAdminResource {
  protected static final Logger logger = Logger.getLogger(RealmsAdminResource.class);
  protected UserModel admin;

  public RealmsAdminResource(UserModel admin) {
    this.admin = admin;
  }

  public static final CacheControl noCache = new CacheControl();

  static {
    noCache.setNoCache(true);
  }

  @Context protected ResourceContext resourceContext;

  @Context protected KeycloakSession session;

  @GET
  @NoCache
  @Produces("application/json")
  public List<RealmRepresentation> getRealms() {
    logger.debug(("getRealms()"));
    RealmManager realmManager = new RealmManager(session);
    List<RealmModel> realms = session.getRealms(admin);
    List<RealmRepresentation> reps = new ArrayList<RealmRepresentation>();
    for (RealmModel realm : realms) {
      reps.add(realmManager.toRepresentation(realm));
    }
    return reps;
  }

  public static UriBuilder realmUrl(UriInfo uriInfo) {
    return realmsUrl(uriInfo).path("{id}");
  }

  public static UriBuilder realmsUrl(UriInfo uriInfo) {
    return uriInfo
        .getBaseUriBuilder()
        .path(SaasService.class)
        .path(SaasService.class, "getRealmsAdmin");
  }

  @POST
  @Consumes("application/json")
  public Response importRealm(@Context final UriInfo uriInfo, final RealmRepresentation rep) {
    logger.debug("importRealm: {0}", rep.getRealm());
    RealmManager realmManager = new RealmManager(session);
    if (realmManager.getRealm(rep.getRealm()) != null) {
      return Flows.errors().exists("Realm " + rep.getRealm() + " already exists");
    }

    RealmModel realm = realmManager.importRealm(rep, admin);
    URI location = realmUrl(uriInfo).build(realm.getId());
    logger.debug("imported realm success, sending back: {0}", location.toString());
    return Response.created(location).build();
  }

  @POST
  @Consumes(MediaType.MULTIPART_FORM_DATA)
  public Response uploadRealm(MultipartFormDataInput input) throws IOException {
    Map<String, List<InputPart>> uploadForm = input.getFormDataMap();
    List<InputPart> inputParts = uploadForm.get("file");

    RealmManager realmManager = new RealmManager(session);
    for (InputPart inputPart : inputParts) {
      inputPart.setMediaType(MediaType.APPLICATION_JSON_TYPE);
      RealmRepresentation rep = inputPart.getBody(new GenericType<RealmRepresentation>() {});
      realmManager.importRealm(rep, admin);
    }
    return Response.noContent().build();
  }

  @Path("{id}")
  public RealmAdminResource getRealmAdmin(
      @Context final HttpHeaders headers, @PathParam("id") final String id) {
    RealmManager realmManager = new RealmManager(session);
    RealmModel realm = realmManager.getRealm(id);
    if (realm == null) throw new NotFoundException();

    RealmAdminResource adminResource = new RealmAdminResource(admin, realm);
    resourceContext.initResource(adminResource);
    return adminResource;
  }
}
/** @author vonnagyi */
public class WebSocketChannelInitializer {

  private static final Logger logger = Logger.getLogger(WebSocketChannelInitializer.class);

  public static ChannelFuture handshake(
      final ChannelHandlerContext ctx,
      final HttpRequest request,
      final String websocketPath,
      final ChannelHandler handler) {

    final String connHead = request.getHttpHeaders().getHeaderString(HttpHeaders.Names.CONNECTION);
    final String upHead = request.getHttpHeaders().getHeaderString(HttpHeaders.Names.UPGRADE);
    final String sockHead =
        request.getHttpHeaders().getHeaderString(HttpHeaders.Names.SEC_WEBSOCKET_VERSION);
    final String keyHead =
        request.getHttpHeaders().getHeaderString(HttpHeaders.Names.SEC_WEBSOCKET_KEY);

    try {
      DefaultHttpRequest req =
          new DefaultHttpRequest(
              HttpVersion.HTTP_1_0, HttpMethod.GET, request.getUri().getAbsolutePath().toString());
      req.setHeader(HttpHeaders.Names.SEC_WEBSOCKET_VERSION, sockHead);
      req.setHeader(HttpHeaders.Names.SEC_WEBSOCKET_KEY, keyHead);

      final Channel channel = ctx.channel();

      final String location = getWebSocketLocation(channel.pipeline(), request, websocketPath);
      final WebSocketServerHandshakerFactory wsFactory =
          new WebSocketServerHandshakerFactory(location, null, false);

      final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req);

      if (handshaker == null) {
        WebSocketServerHandshakerFactory.sendUnsupportedWebSocketVersionResponse(channel);
        return null;

      } else if (!connHead.toLowerCase().contains(HttpHeaders.Values.UPGRADE.toLowerCase())
          || !upHead.toLowerCase().contains(HttpHeaders.Values.WEBSOCKET.toLowerCase())) {
        // Not a valid socket open request
        logger.info("Invalid request: " + request.getUri());
        return null;

      } else {
        // We need to remove the RESTEasy stuff otherwise the Netty logic to write the handshake to
        // the channel
        // will never make it back to the client
        channel.pipeline().remove("resteasyEncoder");
        channel.pipeline().remove("resteasyDecoder");

        final ChannelFuture handshakeFuture = handshaker.handshake(channel, req);

        handshakeFuture.addListener(
            new ChannelFutureListener() {
              @Override
              public void operationComplete(ChannelFuture future) throws Exception {
                if (!future.isSuccess()) {
                  ctx.fireExceptionCaught(future.cause());
                } else {
                  final ChannelPipeline pipeline = future.channel().pipeline();
                  pipeline.replace(pipeline.last(), "customSocketHandler", handler);
                  pipeline.addBefore(
                      "customSocketHandler", "socketHandler", new WebSocketProtocolHandler());
                }
              }
            });

        WebSocketProtocolHandler.setHandshaker(ctx, handshaker);
        return handshakeFuture;

        // channel.pipeline().addBefore("timeout", "WS403Responder",
        //    WebSocketProtocolHandler.forbiddenHttpRequestResponder());
      }

    } catch (Exception e) {
      logger.error("Error trying to upgrade the channel to a socket", e);
    }

    return null;
  }

  private static String getWebSocketLocation(ChannelPipeline cp, HttpRequest req, String path) {
    String protocol = "ws";
    if (cp.get(SslHandler.class) != null) {
      // SSL in use so use Secure WebSockets
      protocol = "wss";
    }
    return protocol + "://" + req.getHttpHeaders().getHeaderString(HttpHeaders.Names.HOST) + path;
  }
}
/**
 * @author <a href="mailto:[email protected]">Bill Burke</a>
 * @version $Revision: 1 $
 */
public class RealmAdminResource {
  protected static final Logger logger = Logger.getLogger(RealmAdminResource.class);
  protected RealmAuth auth;
  protected RealmModel realm;
  private TokenManager tokenManager;

  @Context protected ResourceContext resourceContext;

  @Context protected KeycloakSession session;

  @Context protected ProviderSession providers;

  public RealmAdminResource(RealmAuth auth, RealmModel realm, TokenManager tokenManager) {
    this.auth = auth;
    this.realm = realm;
    this.tokenManager = tokenManager;

    auth.init(RealmAuth.Resource.REALM);
  }

  @Path("applications")
  public ApplicationsResource getApplications() {
    ApplicationsResource applicationsResource = new ApplicationsResource(realm, auth);
    resourceContext.initResource(applicationsResource);
    return applicationsResource;
  }

  @Path("oauth-clients")
  public OAuthClientsResource getOAuthClients() {
    OAuthClientsResource oauth = new OAuthClientsResource(realm, auth, session);
    resourceContext.initResource(oauth);
    return oauth;
  }

  @Path("roles")
  public RoleContainerResource getRoleContainerResource() {
    return new RoleContainerResource(realm, auth, realm);
  }

  @GET
  @NoCache
  @Produces("application/json")
  public RealmRepresentation getRealm() {
    if (auth.hasView()) {
      return ModelToRepresentation.toRepresentation(realm);
    } else {
      auth.requireAny();

      RealmRepresentation rep = new RealmRepresentation();
      rep.setRealm(realm.getName());

      return rep;
    }
  }

  @PUT
  @Consumes("application/json")
  public void updateRealm(final RealmRepresentation rep) {
    auth.requireManage();

    logger.debug("updating realm: " + realm.getName());
    new RealmManager(session).updateRealm(rep, realm);
  }

  @DELETE
  public void deleteRealm() {
    auth.requireManage();

    if (!new RealmManager(session).removeRealm(realm)) {
      throw new NotFoundException();
    }
  }

  @Path("users")
  public UsersResource users() {
    UsersResource users = new UsersResource(realm, auth, tokenManager);
    resourceContext.initResource(users);
    return users;
  }

  @Path("roles-by-id")
  public RoleByIdResource rolesById() {
    RoleByIdResource resource = new RoleByIdResource(realm, auth);
    resourceContext.initResource(resource);
    return resource;
  }

  @Path("push-revocation")
  @POST
  public void pushRevocation() {
    auth.requireManage();
    new ResourceAdminManager().pushRealmRevocationPolicy(realm);
  }

  @Path("logout-all")
  @POST
  public void logoutAll() {
    auth.requireManage();
    new ResourceAdminManager().logoutAll(realm);
  }

  @Path("session-stats")
  @GET
  @NoCache
  @Produces(MediaType.APPLICATION_JSON)
  public Map<String, SessionStats> getSessionStats() {
    logger.info("session-stats");
    auth.requireView();
    Map<String, SessionStats> stats = new HashMap<String, SessionStats>();
    for (ApplicationModel applicationModel : realm.getApplications()) {
      if (applicationModel.getManagementUrl() == null) continue;
      SessionStats appStats =
          new ResourceAdminManager().getSessionStats(realm, applicationModel, false);
      stats.put(applicationModel.getName(), appStats);
    }
    return stats;
  }

  @Path("audit")
  @GET
  @NoCache
  @Produces(MediaType.APPLICATION_JSON)
  public List<Event> getAudit(
      @QueryParam("client") String client,
      @QueryParam("event") String event,
      @QueryParam("user") String user,
      @QueryParam("ipAddress") String ipAddress,
      @QueryParam("first") Integer firstResult,
      @QueryParam("max") Integer maxResults) {
    auth.init(RealmAuth.Resource.AUDIT).requireView();

    AuditProvider audit = providers.getProvider(AuditProvider.class);

    EventQuery query = audit.createQuery().realm(realm.getId());
    if (client != null) {
      query.client(client);
    }
    if (event != null) {
      query.event(event);
    }
    if (user != null) {
      query.user(user);
    }
    if (ipAddress != null) {
      query.ipAddress(ipAddress);
    }
    if (firstResult != null) {
      query.firstResult(firstResult);
    }
    if (maxResults != null) {
      query.maxResults(maxResults);
    }

    return query.getResultList();
  }
}
/**
 * A single thread will log failures. This is so that we can avoid concurrent writes as we want an
 * accurate failure count
 *
 * @author <a href="mailto:[email protected]">Bill Burke</a>
 * @version $Revision: 1 $
 */
public class BruteForceProtector implements Runnable {
  protected static Logger logger = Logger.getLogger(BruteForceProtector.class);

  protected int maxFailureWaitSeconds = 900;
  protected int minimumQuickLoginWaitSeconds = 60;
  protected int waitIncrementSeconds = 60;
  protected long quickLoginCheckMilliSeconds = 1000;
  protected int maxDeltaTime = 60 * 60 * 24 * 1000;
  protected int failureFactor = 10;
  protected volatile boolean run = true;
  protected KeycloakSessionFactory factory;
  protected CountDownLatch shutdownLatch = new CountDownLatch(1);

  protected volatile long failures;
  protected volatile long lastFailure;
  protected volatile long totalTime;

  protected LinkedBlockingQueue<LoginEvent> queue = new LinkedBlockingQueue<LoginEvent>();
  public static final int TRANSACTION_SIZE = 20;

  protected abstract class LoginEvent implements Comparable<LoginEvent> {
    protected final String realmId;
    protected final String username;
    protected final String ip;

    protected LoginEvent(String realmId, String username, String ip) {
      this.realmId = realmId;
      this.username = username;
      this.ip = ip;
    }

    @Override
    public int compareTo(LoginEvent o) {
      return username.compareTo(o.username);
    }
  }

  protected class SuccessfulLogin extends LoginEvent {
    public SuccessfulLogin(String realmId, String userId, String ip) {
      super(realmId, userId, ip);
    }
  }

  protected class FailedLogin extends LoginEvent {
    protected final CountDownLatch latch = new CountDownLatch(1);

    public FailedLogin(String realmId, String username, String ip) {
      super(realmId, username, ip);
    }
  }

  public BruteForceProtector(KeycloakSessionFactory factory) {
    this.factory = factory;
  }

  public void failure(KeycloakSession session, LoginEvent event) {
    UsernameLoginFailureModel user = getUserModel(session, event);
    if (user == null) return;
    user.setLastIPFailure(event.ip);
    long currentTime = System.currentTimeMillis();
    long last = user.getLastFailure();
    long deltaTime = 0;
    if (last > 0) {
      deltaTime = currentTime - last;
    }
    user.setLastFailure(currentTime);
    if (deltaTime > 0) {
      // if last failure was more than MAX_DELTA clear failures
      if (deltaTime > maxDeltaTime) {
        user.clearFailures();
      }
    }
    user.incrementFailures();

    int waitSeconds = waitIncrementSeconds * (user.getNumFailures() / failureFactor);
    if (waitSeconds == 0) {
      if (deltaTime > quickLoginCheckMilliSeconds) {
        waitSeconds = minimumQuickLoginWaitSeconds;
      }
    }
    waitSeconds = Math.min(maxFailureWaitSeconds, waitSeconds);
    if (waitSeconds > 0) {
      user.setFailedLoginNotBefore((int) (currentTime / 1000) + waitSeconds);
    }
  }

  protected UsernameLoginFailureModel getUserModel(KeycloakSession session, LoginEvent event) {
    RealmModel realm = session.getRealm(event.realmId);
    if (realm == null) return null;
    UsernameLoginFailureModel user = realm.getUserLoginFailure(event.username);
    if (user == null) return null;
    return user;
  }

  public void start() {
    new Thread(this).start();
  }

  public void shutdown() {
    run = false;
    try {
      shutdownLatch.await(5, TimeUnit.SECONDS);
    } catch (InterruptedException e) {
      throw new RuntimeException(e);
    }
  }

  public void run() {
    final ArrayList<LoginEvent> events = new ArrayList<LoginEvent>(TRANSACTION_SIZE + 1);
    while (run) {
      try {
        LoginEvent take = queue.poll(2, TimeUnit.SECONDS);
        if (take == null) {
          continue;
        }
        try {
          events.add(take);
          queue.drainTo(events, TRANSACTION_SIZE);
          for (LoginEvent event : events) {
            if (event instanceof FailedLogin) {
              logFailure(event);
            } else {
              logSuccess(event);
            }
          }

          Collections.sort(
              events); // we sort to avoid deadlock due to ordered updates.  Maybe I'm overthinking
          // this.
          KeycloakSession session = factory.createSession();
          try {
            for (LoginEvent event : events) {
              if (event instanceof FailedLogin) {
                failure(session, event);
              }
            }
            session.getTransaction().commit();
          } catch (Exception e) {
            session.getTransaction().rollback();
            throw e;
          } finally {
            for (LoginEvent event : events) {
              if (event instanceof FailedLogin) {
                ((FailedLogin) event).latch.countDown();
              }
            }
            events.clear();
            session.close();
          }
        } catch (Exception e) {
          logger.error("Failed processing event", e);
        }
      } catch (InterruptedException e) {
        break;
      } finally {
        shutdownLatch.countDown();
      }
    }
  }

  protected void logSuccess(LoginEvent event) {
    logger.warn("login success for user " + event.username + " from ip " + event.ip);
  }

  protected void logFailure(LoginEvent event) {
    logger.warn("login failure for user " + event.username + " from ip " + event.ip);
    failures++;
    long delta = 0;
    if (lastFailure > 0) {
      delta = System.currentTimeMillis() - lastFailure;
      if (delta > maxDeltaTime) {
        totalTime = 0;

      } else {
        totalTime += delta;
      }
    }
  }

  public void successfulLogin(
      RealmModel realm, String username, ClientConnection clientConnection) {
    logger.info(
        "successful login user: "******" from ip " + clientConnection.getRemoteAddr());
  }

  public void invalidUser(RealmModel realm, String username, ClientConnection clientConnection) {
    logger.warn("invalid user: "******" from ip " + clientConnection.getRemoteAddr());
    // todo more?
  }

  public void failedLogin(RealmModel realm, String username, ClientConnection clientConnection) {
    try {
      FailedLogin event =
          new FailedLogin(realm.getId(), username, clientConnection.getRemoteAddr());
      queue.offer(event);
      // wait a minimum of seconds for event to process so that a hacker
      // cannot flood with failed logins and overwhelm the queue and not have notBefore updated to
      // block next requests
      // todo failure HTTP responses should be queued via async HTTP
      event.latch.await(5, TimeUnit.SECONDS);

    } catch (InterruptedException e) {
    }
  }
}
/**
 * Keystone Access token protocol
 *
 * @author <a href="mailto:[email protected]">Bill Burke</a>
 * @version $Revision: 1 $
 */
public class SignedSkeletonKeyStoneLoginModule extends JBossWebAuthLoginModule {
  private static final Logger log = Logger.getLogger(SignedSkeletonKeyStoneLoginModule.class);
  private static final String SECURITY_DOMAIN = "securityDomain";
  protected String projectId;
  protected String skeletonKeyCertificateAlias;
  protected Access access;
  /** The SecurityDomain to obtain the KeyStore/TrustStore from */
  private Object domain = null;

  @Override
  public void initialize(
      Subject subject,
      CallbackHandler callbackHandler,
      Map<String, ?> sharedState,
      Map<String, ?> options) {
    super.initialize(subject, callbackHandler, sharedState, options);

    projectId = (String) options.get("projectId");
    skeletonKeyCertificateAlias = (String) options.get("skeleton.key.certificate.alias");
    // Get the security domain and default to "other"
    String sd = (String) options.get(SECURITY_DOMAIN);
    log.error("Security Domain: " + sd);
    sd = SecurityUtil.unprefixSecurityDomain(sd);
    if (sd == null) sd = "other";

    try {
      Object tempDomain = new InitialContext().lookup(SecurityConstants.JAAS_CONTEXT_ROOT + sd);
      if (tempDomain instanceof SecurityDomain) {
        domain = tempDomain;
      } else {
        tempDomain =
            new InitialContext().lookup(SecurityConstants.JAAS_CONTEXT_ROOT + sd + "/jsse");
        if (tempDomain instanceof JSSESecurityDomain) {
          domain = tempDomain;
        } else {
          log.error(
              "The JSSE security domain "
                  + sd
                  + " is not valid. All authentication using this login module will fail!");
        }
      }
    } catch (NamingException e) {
      log.error("Unable to find the securityDomain named: " + sd, e);
    }
  }

  @Override
  protected boolean login(Request request, HttpServletResponse response) throws LoginException {
    String tokenHeader = request.getHeader("X-Auth-Signed-Token");
    if (tokenHeader == null) return false; // throw new LoginException("No X-Auth-Signed-Token");
    // if we don't have a trust store, we'll just use the key store.
    KeyStore keyStore = null;
    if (domain != null) {
      if (domain instanceof SecurityDomain) {
        keyStore = ((SecurityDomain) domain).getKeyStore();
      } else if (domain instanceof JSSESecurityDomain) {
        keyStore = ((JSSESecurityDomain) domain).getKeyStore();
      }
    }
    if (keyStore == null) throw new LoginException("No trust store found");
    X509Certificate certificate = null;
    try {
      certificate = (X509Certificate) keyStore.getCertificate(skeletonKeyCertificateAlias);
    } catch (KeyStoreException e) {
      throw new LoginException("Could not get certificate from keyStore");
    }
    try {
      PKCS7SignatureInput input = new PKCS7SignatureInput(tokenHeader);
      if (input.verify(certificate) == false) throw new LoginException("Bad Signature");
      access = (Access) input.getEntity(Access.class, MediaType.APPLICATION_JSON_TYPE);

    } catch (LoginException le) {
      throw le;
    } catch (Exception e) {
      throw new LoginException("Bad Token");
    }

    if (access.getToken().expired()) {
      throw new LoginException("Token expired");
    }
    if (!projectId.equals(access.getToken().getProject().getId())) {
      throw new LoginException("Token project id doesn't match");
    }

    this.loginOk = true;
    return true;
  }

  @Override
  protected Principal getIdentity() {
    Principal principal = new UserPrincipal(access.getUser());
    return principal;
  }

  @Override
  protected Group[] getRoleSets() throws LoginException {
    SimpleGroup roles = new SimpleGroup("Roles");
    Group[] roleSets = {roles};
    for (Role role : access.getUser().getRoles()) {
      roles.addMember(new SimplePrincipal(role.getName()));
    }
    return roleSets;
  }
}
/**
 * @author <a href="mailto:[email protected]">Bill Burke</a>
 * @version $Revision: 1 $
 */
public class OAuthClientsResource {
  protected static final Logger logger = Logger.getLogger(RealmAdminResource.class);
  protected RealmModel realm;

  protected KeycloakSession session;

  @Context protected ResourceContext resourceContext;
  private RealmAuth auth;

  public OAuthClientsResource(RealmModel realm, RealmAuth auth, KeycloakSession session) {
    this.auth = auth;
    this.realm = realm;
    this.session = session;

    auth.init(RealmAuth.Resource.CLIENT);
  }

  @GET
  @Produces(MediaType.APPLICATION_JSON)
  @NoCache
  public List<OAuthClientRepresentation> getOAuthClients() {
    List<OAuthClientRepresentation> rep = new ArrayList<OAuthClientRepresentation>();
    List<OAuthClientModel> oauthModels = realm.getOAuthClients();

    boolean view = auth.hasView();
    for (OAuthClientModel oauth : oauthModels) {
      if (view) {
        rep.add(OAuthClientManager.toRepresentation(oauth));
      } else {
        OAuthClientRepresentation client = new OAuthClientRepresentation();
        client.setName(oauth.getClientId());
        rep.add(client);
      }
    }
    return rep;
  }

  @POST
  @Consumes(MediaType.APPLICATION_JSON)
  public Response createOAuthClient(
      final @Context UriInfo uriInfo, final OAuthClientRepresentation rep) {
    auth.requireManage();

    OAuthClientManager resourceManager = new OAuthClientManager(realm);
    OAuthClientModel oauth = resourceManager.create(rep);
    return Response.created(uriInfo.getAbsolutePathBuilder().path(oauth.getId()).build()).build();
  }

  @Path("{id}")
  public OAuthClientResource getOAuthClient(final @PathParam("id") String id) {
    auth.requireView();

    OAuthClientModel oauth = realm.getOAuthClientById(id);
    if (oauth == null) {
      throw new NotFoundException();
    }
    OAuthClientResource oAuthClientResource = new OAuthClientResource(realm, auth, oauth, session);
    resourceContext.initResource(oAuthClientResource);
    return oAuthClientResource;
  }
}