Example #1
0
  /**
   * Signs the message.
   *
   * @param soapMessage SOAPMessage that needs to be signed.
   * @param profile Security profile that needs to be used for signing.
   * @param assertion Security Assertion
   * @return SOAPMessage signed SOAPMessage.
   */
  private SOAPMessage signMessage(
      SOAPMessage soapMessage, String profile, SecurityAssertion assertion)
      throws SOAPBindingException {
    try {
      SOAPHeader soapHeader = soapMessage.getSOAPPart().getEnvelope().getHeader();
      if (soapHeader == null) {
        soapMessage.getSOAPPart().getEnvelope().addHeader();
      }
      SOAPBody soapBody = soapMessage.getSOAPPart().getEnvelope().getBody();
      if (soapBody == null) {
        throw new SOAPBindingException(WSSUtils.bundle.getString("nullSOAPBody"));
      }

      String bodyId = SAMLUtils.generateID();
      soapBody.setAttributeNS(WSSEConstants.NS_WSU_WSF11, WSSEConstants.WSU_ID, bodyId);
      List ids = new ArrayList();
      ids.add(bodyId);
      if (correlationId != null) {
        ids.add(correlationId);
      }

      Certificate cert = null;
      Element sigElem = null;
      ByteArrayInputStream bin = null;
      ByteArrayOutputStream bop = new ByteArrayOutputStream();
      Document doc = null;
      if (profile == null
          || profile.equals(Message.NULL_X509)
          || profile.equals(Message.TLS_X509)
          || profile.equals(Message.CLIENT_TLS_X509)
          || profile.equals(Message.NULL_X509_WSF11)
          || profile.equals(Message.TLS_X509_WSF11)
          || profile.equals(Message.CLIENT_TLS_X509_WSF11)) {

        BinarySecurityToken binaryToken = addBinaryToken(soapMessage);
        cert = SecurityUtils.getCertificate(binaryToken);
        soapMessage.writeTo(bop);
        bin = new ByteArrayInputStream(bop.toByteArray());
        doc = XMLUtils.toDOMDocument(bin, WSSUtils.debug);
        sigElem =
            SecurityUtils.getSignatureManager()
                .signWithWSSX509TokenProfile(
                    doc, cert, "", ids, SOAPBindingConstants.WSF_11_VERSION);

      } else if (profile.equals(Message.NULL_SAML)
          || profile.equals(Message.TLS_SAML)
          || profile.equals(Message.CLIENT_TLS_SAML)
          || profile.equals(Message.NULL_SAML_WSF11)
          || profile.equals(Message.TLS_SAML_WSF11)
          || profile.equals(Message.CLIENT_TLS_SAML_WSF11)) {

        cert = SecurityUtils.getCertificate(assertion);
        soapMessage.writeTo(bop);
        new ByteArrayInputStream(bop.toByteArray());
        bin = new ByteArrayInputStream(bop.toByteArray());
        doc = XMLUtils.toDOMDocument(bin, WSSUtils.debug);
        sigElem =
            SecurityUtils.getSignatureManager()
                .signWithWSSSAMLTokenProfile(
                    doc,
                    cert,
                    assertion.getAssertionID(),
                    "",
                    ids,
                    SOAPBindingConstants.WSF_11_VERSION);
      }

      if (sigElem == null) {
        WSSUtils.debug.error("MessageProcessor.signMessage: " + "SigElement is null");
        throw new SOAPBindingException(WSSUtils.bundle.getString("cannotSignMessage"));
      }

      Element securityHeader = getSecurityHeader(soapMessage);
      securityHeader.appendChild(securityHeader.getOwnerDocument().importNode(sigElem, true));

      return Utils.DocumentToSOAPMessage(sigElem.getOwnerDocument());

    } catch (Exception ex) {
      WSSUtils.debug.error("MessageProcessor.signMessage: " + "Signing failed.", ex);
      throw new SOAPBindingException(WSSUtils.bundle.getString("cannotSignMessage"));
    }
  }
  /**
   * Initiates <code>SAML</code> web browser POST profile. This method takes in a TARGET in the
   * request, creates a SAMLResponse, then redirects user to the destination site.
   *
   * @param request <code>HttpServletRequest</code> instance
   * @param response <code>HttpServletResponse</code> instance
   * @throws ServletException if there is an error.
   * @throws IOException if there is an error.
   */
  public void doGet(HttpServletRequest request, HttpServletResponse response)
      throws ServletException, IOException {
    if ((request == null) || (response == null)) {
      String[] data = {SAMLUtils.bundle.getString("nullInputParameter")};
      LogUtils.error(java.util.logging.Level.INFO, LogUtils.NULL_PARAMETER, data);
      SAMLUtils.sendError(
          request,
          response,
          HttpServletResponse.SC_INTERNAL_SERVER_ERROR,
          "nullInputParameter",
          SAMLUtils.bundle.getString("nullInputParameter"));
      return;
    }

    SAMLUtils.checkHTTPContentLength(request);

    // get Session
    Object token = getSession(request);
    if (token == null) {
      response.sendRedirect(SAMLUtils.getLoginRedirectURL(request));
      return;
    }

    // obtain TARGET
    String target = request.getParameter(SAMLConstants.POST_TARGET_PARAM);
    if (target == null || target.length() == 0) {
      String[] data = {SAMLUtils.bundle.getString("missingTargetSite")};
      LogUtils.error(java.util.logging.Level.INFO, LogUtils.MISSING_TARGET, data, token);
      SAMLUtils.sendError(
          request,
          response,
          HttpServletResponse.SC_BAD_REQUEST,
          "missingTargetSite",
          SAMLUtils.bundle.getString("missingTargetSite"));
      return;
    }

    // Get the Destination site Entry
    // find the destSite POST URL, which is the Receipient
    SAMLServiceManager.SiteEntry destSite = getDestSite(target);
    String destSiteUrl = null;
    if ((destSite == null) || ((destSiteUrl = destSite.getPOSTUrl()) == null)) {
      String[] data = {SAMLUtils.bundle.getString("targetForbidden"), target};
      LogUtils.error(java.util.logging.Level.INFO, LogUtils.TARGET_FORBIDDEN, data, token);
      SAMLUtils.sendError(
          request,
          response,
          response.SC_BAD_REQUEST,
          "targetForbidden",
          SAMLUtils.bundle.getString("targetForbidden") + " " + target);
      return;
    }

    Response samlResponse = null;
    try {
      String version = destSite.getVersion();
      int majorVersion = SAMLConstants.PROTOCOL_MAJOR_VERSION;
      int minorVersion = SAMLConstants.PROTOCOL_MINOR_VERSION;
      if (version != null) {
        StringTokenizer st = new StringTokenizer(version, ".");
        if (st.countTokens() == 2) {
          majorVersion = Integer.parseInt(st.nextToken().trim());
          minorVersion = Integer.parseInt(st.nextToken().trim());
        }
      }
      // create assertion
      AssertionManager am = AssertionManager.getInstance();
      SessionProvider sessionProvider = SessionManager.getProvider();
      Assertion assertion =
          am.createSSOAssertion(
              sessionProvider.getSessionID(token),
              null,
              request,
              response,
              destSite.getSourceID(),
              target,
              majorVersion + "." + minorVersion);

      // create SAMLResponse
      StatusCode statusCode = new StatusCode(SAMLConstants.STATUS_CODE_SUCCESS);
      Status status = new Status(statusCode);
      List contents = new ArrayList();
      contents.add(assertion);
      samlResponse = new Response(null, status, destSiteUrl, contents);
      samlResponse.setMajorVersion(majorVersion);
      samlResponse.setMinorVersion(minorVersion);
    } catch (SessionException sse) {
      SAMLUtils.debug.error(
          "SAMLPOSTProfileServlet.doGet: Exception " + "Couldn't get SessionProvider:", sse);
      SAMLUtils.sendError(
          request,
          response,
          HttpServletResponse.SC_INTERNAL_SERVER_ERROR,
          "couldNotCreateResponse",
          sse.getMessage());
      return;
    } catch (NumberFormatException ne) {
      SAMLUtils.debug.error(
          "SAMLPOSTProfileServlet.doGet: Exception " + "when creating Response: ", ne);
      SAMLUtils.sendError(
          request,
          response,
          HttpServletResponse.SC_INTERNAL_SERVER_ERROR,
          "couldNotCreateResponse",
          ne.getMessage());
      return;
    } catch (SAMLException se) {
      SAMLUtils.debug.error(
          "SAMLPOSTProfileServlet.doGet: Exception " + "when creating Response: ", se);
      SAMLUtils.sendError(
          request,
          response,
          HttpServletResponse.SC_INTERNAL_SERVER_ERROR,
          "couldNotCreateResponse",
          se.getMessage());
      return;
    }

    // sign the samlResponse
    byte signedBytes[] = null;
    try {
      samlResponse.signXML();
      if (SAMLUtils.debug.messageEnabled()) {
        SAMLUtils.debug.message(
            "SAMLPOSTProfileServlet.doGet: "
                + "signed samlResponse is"
                + samlResponse.toString(true, true, true));
      }
      signedBytes = SAMLUtils.getResponseBytes(samlResponse);
    } catch (Exception e) {
      SAMLUtils.debug.error(
          "SAMLPOSTProfileServlet.doGet: Exception " + "when signing the response:", e);
      SAMLUtils.sendError(
          request,
          response,
          HttpServletResponse.SC_INTERNAL_SERVER_ERROR,
          "errorSigningResponse",
          SAMLUtils.bundle.getString("errorSigningResponse"));
      return;
    }

    // base64 encode the signed samlResponse
    String encodedResponse = null;
    try {
      encodedResponse = Base64.encode(signedBytes, true).trim();
    } catch (Exception e) {
      SAMLUtils.debug.error(
          "SAMLPOSTProfileServlet.doGet: Exception " + "when encoding the response:", e);
      SAMLUtils.sendError(
          request,
          response,
          HttpServletResponse.SC_INTERNAL_SERVER_ERROR,
          "errorEncodeResponse",
          SAMLUtils.bundle.getString("errorEncodeResponse"));
      return;
    }

    if (LogUtils.isAccessLoggable(java.util.logging.Level.FINE)) {
      String[] data = {
        SAMLUtils.bundle.getString("redirectTo"),
        target,
        destSiteUrl,
        new String(signedBytes, "UTF-8")
      };
      LogUtils.access(java.util.logging.Level.FINE, LogUtils.REDIRECT_TO_URL, data, token);
    } else {
      String[] data = {SAMLUtils.bundle.getString("redirectTo"), target, destSiteUrl};
      LogUtils.access(java.util.logging.Level.INFO, LogUtils.REDIRECT_TO_URL, data, token);
    }
    response.setContentType("text/html; charset=UTF-8");
    PrintWriter out = response.getWriter();
    out.println("<HTML>");
    out.println("<BODY Onload=\"document.forms[0].submit()\">");
    out.println("<FORM METHOD=\"POST\" ACTION=\"" + destSiteUrl + "\">");
    out.println("<INPUT TYPE=\"HIDDEN\" NAME=\"" + SAMLConstants.POST_SAML_RESPONSE_PARAM + "\" ");
    out.println("VALUE=\"" + encodedResponse + "\">");
    out.println(
        "<INPUT TYPE=\"HIDDEN\" NAME=\""
            + SAMLConstants.POST_TARGET_PARAM
            + "\" VALUE=\""
            + target
            + "\"> </FORM>");
    out.println("</BODY></HTML>");
    out.close();
  }
  /**
   * Signs the entity descriptor root element by the following rules:
   *
   * <ul>
   *   <li>Hosted Entity
   *       <ul>
   *         <li>If there is a signature already on the EntityDescriptor, removes it, then signs the
   *             EntityDescriptor.
   *         <li>Simply signs the EntityDescriptor otherwise.
   *       </ul>
   *   <li>Remote Entity
   *       <ul>
   *         <li>If there is a signature already on the EntityDescriptor, then does not change it,
   *             but returns the Document with the original signature.
   *         <li>Simply signs the EntityDescriptor otherwise
   *       </ul>
   * </ul>
   *
   * If there is no extended metadata for the entity, the entity is considered as remote.
   *
   * @param realm The realm where the EntityDescriptor belongs to.
   * @param descriptor The entity descriptor.
   * @return Signed <code>Document</code> for the entity descriptor or null if no metadata signing
   *     key is found in the configuration.
   * @throws SAML2MetaException if unable to sign the entity descriptor.
   * @throws JAXBException if the entity descriptor is invalid.
   */
  public static Document sign(String realm, EntityDescriptorElement descriptor)
      throws JAXBException, SAML2MetaException {
    if (descriptor == null) {
      throw new SAML2MetaException("Unable to sign null descriptor");
    }

    SAML2MetaManager metaManager = new SAML2MetaManager();
    EntityConfigElement cfgElem = metaManager.getEntityConfig(realm, descriptor.getEntityID());
    boolean isHosted;
    if (cfgElem == null) {
      // if there is no EntityConfig, this is considered as a remote entity
      isHosted = false;
    } else {
      isHosted = cfgElem.isHosted();
    }

    String signingCert = getRealmSetting(METADATA_SIGNING_KEY, realm);
    if (signingCert == null) {
      return null;
    }

    initializeKeyStore();

    String xmlstr = SAML2MetaUtils.convertJAXBToString(descriptor);
    xmlstr = formatBase64BinaryElement(xmlstr);

    Document doc = XMLUtils.toDOMDocument(xmlstr, debug);
    NodeList childNodes = doc.getDocumentElement().getChildNodes();
    for (int i = 0; i < childNodes.getLength(); i++) {
      Node node = childNodes.item(i);
      if (node.getLocalName() != null
          && node.getLocalName().equals("Signature")
          && node.getNamespaceURI().equals(NS_XMLSIG)) {
        if (isHosted) {
          node.getParentNode().removeChild(node);
          break;
        } else {
          // one signature found for this remote entity on the root element,
          // in this case returning the entry with the original signature
          // as that may be judged more accurately
          return doc;
        }
      }
    }

    // we need to sign or re-sign the document, let's generate a new ID
    String descriptorId = SAMLUtils.generateID();
    doc.getDocumentElement().setAttribute(ATTR_ID, descriptorId);

    XMLSignatureManager sigManager = XMLSignatureManager.getInstance();
    try {
      String xpath =
          "//*[local-name()=\""
              + TAG_ENTITY_DESCRIPTOR
              + "\" and namespace-uri()=\""
              + NS_META
              + "\"]/*[1]";
      sigManager.signXMLUsingKeyPass(
          doc,
          signingCert,
          getRealmSetting(METADATA_SIGNING_KEY_PASS, realm),
          null,
          SAML2Constants.ID,
          descriptorId,
          true,
          xpath);
    } catch (XMLSignatureException xmlse) {
      if (debug.messageEnabled()) {
        debug.message("SAML2MetaSecurityUtils.sign:", xmlse);
      }
    }

    return doc;
  }
  /**
   * This method processes TARGET and SAMLResponse info from the request, validates the
   * response/assertion(s), then redirects user to the TARGET resource if all are valid.
   *
   * @param request <code>HttpServletRequest</code> instance
   * @param response <code>HttpServletResponse</code> instance
   * @throws ServletException if there is an error.
   * @throws IOException if there is an error.
   */
  public void doPost(HttpServletRequest request, HttpServletResponse response)
      throws ServletException, IOException {
    response.setContentType("text/html; charset=UTF-8");

    if ((request == null) || (response == null)) {
      String[] data = {SAMLUtils.bundle.getString("nullInputParameter")};
      LogUtils.error(java.util.logging.Level.INFO, LogUtils.NULL_PARAMETER, data);
      SAMLUtils.sendError(
          request,
          response,
          HttpServletResponse.SC_BAD_REQUEST,
          "nullInputParameter",
          SAMLUtils.bundle.getString("nullInputParameter"));
      return;
    }

    SAMLUtils.checkHTTPContentLength(request);

    // obtain TARGET
    String target = request.getParameter(SAMLConstants.POST_TARGET_PARAM);
    if (target == null || target.length() == 0) {
      String[] data = {SAMLUtils.bundle.getString("missingTargetSite")};
      LogUtils.error(java.util.logging.Level.INFO, LogUtils.MISSING_TARGET, data);
      SAMLUtils.sendError(
          request,
          response,
          HttpServletResponse.SC_BAD_REQUEST,
          "missingTargetSite",
          SAMLUtils.bundle.getString("missingTargetSite"));
      return;
    }

    // obtain SAMLResponse
    String samlResponse = request.getParameter(SAMLConstants.POST_SAML_RESPONSE_PARAM);
    if (samlResponse == null) {
      String[] data = {SAMLUtils.bundle.getString("missingSAMLResponse")};
      LogUtils.error(java.util.logging.Level.INFO, LogUtils.MISSING_RESPONSE, data);
      SAMLUtils.sendError(
          request,
          response,
          HttpServletResponse.SC_BAD_REQUEST,
          "missingSAMLResponse",
          SAMLUtils.bundle.getString("missingSAMLResponse"));
      return;
    }

    // decode the Response
    byte raw[] = null;
    try {
      raw = Base64.decode(samlResponse);
    } catch (Exception e) {
      SAMLUtils.debug.error(
          "SAMLPOSTProfileServlet.doPost: Exception " + "when decoding SAMLResponse:", e);
      SAMLUtils.sendError(
          request,
          response,
          response.SC_INTERNAL_SERVER_ERROR,
          "errorDecodeResponse",
          SAMLUtils.bundle.getString("errorDecodeResponse"));
      return;
    }

    // Get Response back
    Response sResponse = SAMLUtils.getResponse(raw);
    if (sResponse == null) {
      String[] data = {SAMLUtils.bundle.getString("errorObtainResponse")};
      LogUtils.error(java.util.logging.Level.INFO, LogUtils.RESPONSE_MESSAGE_ERROR, data);
      SAMLUtils.sendError(
          request,
          response,
          HttpServletResponse.SC_BAD_REQUEST,
          "errorObtainResponse",
          SAMLUtils.bundle.getString("errorObtainResponse"));
      return;
    }

    if (SAMLUtils.debug.messageEnabled()) {
      SAMLUtils.debug.message("SAMLPOSTProfileServlet.doPost: Received " + sResponse.toString());
    }

    // verify that Response is correct
    StringBuffer requestUrl = request.getRequestURL();
    if (SAMLUtils.debug.messageEnabled()) {
      SAMLUtils.debug.message("SAMLPOSTProfileServlet.doPost: " + "requestUrl=" + requestUrl);
    }
    boolean valid = SAMLUtils.verifyResponse(sResponse, requestUrl.toString(), request);
    if (!valid) {
      String[] data = {SAMLUtils.bundle.getString("invalidResponse")};
      LogUtils.error(java.util.logging.Level.INFO, LogUtils.INVALID_RESPONSE, data);
      SAMLUtils.sendError(
          request,
          response,
          HttpServletResponse.SC_BAD_REQUEST,
          "invalidResponse",
          SAMLUtils.bundle.getString("invalidResponse"));
      return;
    }

    Map attrMap = null;
    List assertions = null;
    javax.security.auth.Subject authSubject = null;
    try {
      Map sessionAttr = SAMLUtils.processResponse(sResponse, target);
      Object token = SAMLUtils.generateSession(request, response, sessionAttr);
    } catch (Exception ex) {
      SAMLUtils.debug.error("generateSession: ", ex);
      String[] data = {SAMLUtils.bundle.getString("failedCreateSSOToken")};
      LogUtils.error(java.util.logging.Level.INFO, LogUtils.FAILED_TO_CREATE_SSO_TOKEN, data);
      SAMLUtils.sendError(
          request,
          response,
          HttpServletResponse.SC_INTERNAL_SERVER_ERROR,
          "failedCreateSSOToken",
          ex.getMessage());
      ;
      return;
    }

    if (LogUtils.isAccessLoggable(java.util.logging.Level.FINE)) {
      String[] data = {SAMLUtils.bundle.getString("accessGranted"), new String(raw, "UTF-8")};
      LogUtils.access(java.util.logging.Level.FINE, LogUtils.ACCESS_GRANTED, data);
    } else {
      String[] data = {SAMLUtils.bundle.getString("accessGranted")};
      LogUtils.access(java.util.logging.Level.INFO, LogUtils.ACCESS_GRANTED, data);
    }
    if (SAMLUtils.postYN(target)) {
      if (SAMLUtils.debug.messageEnabled()) {
        SAMLUtils.debug.message("POST to target:" + target);
      }
      SAMLUtils.postToTarget(response, assertions, target, attrMap);
    } else {
      response.setHeader("Location", target);
      response.sendRedirect(target);
    }
  }