@Test public void testAuthenticate() throws Exception { SAML2LoginAPIAuthenticatorCmd cmd = Mockito.spy(new SAML2LoginAPIAuthenticatorCmd()); Field apiServerField = SAML2LoginAPIAuthenticatorCmd.class.getDeclaredField("_apiServer"); apiServerField.setAccessible(true); apiServerField.set(cmd, apiServer); Field managerField = SAML2LoginAPIAuthenticatorCmd.class.getDeclaredField("_samlAuthManager"); managerField.setAccessible(true); managerField.set(cmd, samlAuthManager); Field accountServiceField = BaseCmd.class.getDeclaredField("_accountService"); accountServiceField.setAccessible(true); accountServiceField.set(cmd, accountService); Field domainMgrField = SAML2LoginAPIAuthenticatorCmd.class.getDeclaredField("_domainMgr"); domainMgrField.setAccessible(true); domainMgrField.set(cmd, domainMgr); Field userAccountDaoField = SAML2LoginAPIAuthenticatorCmd.class.getDeclaredField("_userAccountDao"); userAccountDaoField.setAccessible(true); userAccountDaoField.set(cmd, userAccountDao); String spId = "someSPID"; String url = "someUrl"; KeyPair kp = SAMLUtils.generateRandomKeyPair(); X509Certificate cert = SAMLUtils.generateRandomX509Certificate(kp); SAMLProviderMetadata providerMetadata = new SAMLProviderMetadata(); providerMetadata.setEntityId("random"); providerMetadata.setSigningCertificate(cert); providerMetadata.setEncryptionCertificate(cert); providerMetadata.setKeyPair(kp); providerMetadata.setSsoUrl("http://test.local"); providerMetadata.setSloUrl("http://test.local"); Mockito.when(session.getAttribute(Mockito.anyString())).thenReturn(null); Mockito.when(domain.getId()).thenReturn(1L); Mockito.when(domainMgr.getDomain(Mockito.anyString())).thenReturn(domain); UserAccountVO user = new UserAccountVO(); user.setId(1000L); Mockito.when(userAccountDao.getUserAccount(Mockito.anyString(), Mockito.anyLong())) .thenReturn(user); Mockito.when(apiServer.verifyUser(Mockito.anyLong())).thenReturn(false); Mockito.when(samlAuthManager.getSPMetadata()).thenReturn(providerMetadata); Mockito.when(samlAuthManager.getIdPMetadata(Mockito.anyString())).thenReturn(providerMetadata); Map<String, Object[]> params = new HashMap<String, Object[]>(); // SSO redirection test cmd.authenticate( "command", params, session, InetAddress.getByName("127.0.0.1"), HttpUtils.RESPONSE_TYPE_JSON, new StringBuilder(), req, resp); Mockito.verify(resp, Mockito.times(1)).sendRedirect(Mockito.anyString()); // SSO SAMLResponse verification test, this should throw ServerApiException for auth failure params.put(SAMLPluginConstants.SAML_RESPONSE, new String[] {"Some String"}); Mockito.stub(cmd.processSAMLResponse(Mockito.anyString())).toReturn(buildMockResponse()); try { cmd.authenticate( "command", params, session, InetAddress.getByName("127.0.0.1"), HttpUtils.RESPONSE_TYPE_JSON, new StringBuilder(), req, resp); } catch (ServerApiException ignored) { } Mockito.verify(userAccountDao, Mockito.times(0)) .getUserAccount(Mockito.anyString(), Mockito.anyLong()); Mockito.verify(apiServer, Mockito.times(0)).verifyUser(Mockito.anyLong()); }
@Override public String authenticate( String command, Map<String, Object[]> params, HttpSession session, String remoteAddress, String responseType, StringBuilder auditTrailSb, final HttpServletResponse resp) throws ServerApiException { // FIXME: ported from ApiServlet, refactor and cleanup final String[] username = (String[]) params.get(ApiConstants.USERNAME); final String[] password = (String[]) params.get(ApiConstants.PASSWORD); String[] domainIdArr = (String[]) params.get(ApiConstants.DOMAIN_ID); if (domainIdArr == null) { domainIdArr = (String[]) params.get(ApiConstants.DOMAIN__ID); } final String[] domainName = (String[]) params.get(ApiConstants.DOMAIN); Long domainId = null; if ((domainIdArr != null) && (domainIdArr.length > 0)) { try { // check if UUID is passed in for domain domainId = _apiServer.fetchDomainId(domainIdArr[0]); if (domainId == null) { domainId = Long.parseLong(domainIdArr[0]); } auditTrailSb.append(" domainid=" + domainId); // building the params for POST call } catch (final NumberFormatException e) { s_logger.warn("Invalid domain id entered by user"); auditTrailSb.append( " " + HttpServletResponse.SC_UNAUTHORIZED + " " + "Invalid domain id entered, please enter a valid one"); throw new ServerApiException( ApiErrorCode.UNAUTHORIZED, _apiServer.getSerializedApiError( HttpServletResponse.SC_UNAUTHORIZED, "Invalid domain id entered, please enter a valid one", params, responseType)); } } String domain = null; if (domainName != null) { domain = domainName[0]; auditTrailSb.append(" domain=" + domain); if (domain != null) { // ensure domain starts with '/' and ends with '/' if (!domain.endsWith("/")) { domain += '/'; } if (!domain.startsWith("/")) { domain = "/" + domain; } } } String serializedResponse = null; if (username != null) { final String pwd = ((password == null) ? null : password[0]); try { return ApiResponseSerializer.toSerializedString( _apiServer.loginUser( session, username[0], pwd, domainId, domain, remoteAddress, params), responseType); } catch (final CloudAuthenticationException ex) { // TODO: fall through to API key, or just fail here w/ auth error? (HTTP 401) try { session.invalidate(); } catch (final IllegalStateException ise) { } auditTrailSb.append( " " + ApiErrorCode.ACCOUNT_ERROR + " " + ex.getMessage() != null ? ex.getMessage() : "failed to authenticate user, check if username/password are correct"); serializedResponse = _apiServer.getSerializedApiError( ApiErrorCode.ACCOUNT_ERROR.getHttpCode(), ex.getMessage() != null ? ex.getMessage() : "failed to authenticate user, check if username/password are correct", params, responseType); } } // We should not reach here and if we do we throw an exception throw new ServerApiException(ApiErrorCode.ACCOUNT_ERROR, serializedResponse); }