/** * @param request The request from which to extract parameters and perform the authentication * @return The authenticated user token, or null if authentication is incomplete. */ protected Authentication handleAuthorizationCodeResponse( HttpServletRequest request, HttpServletResponse response) { String authorizationCode = request.getParameter("code"); HttpSession session = request.getSession(); // check for state, if it doesn't match we bail early String storedState = getStoredState(session); if (!Strings.isNullOrEmpty(storedState)) { String state = request.getParameter("state"); if (!storedState.equals(state)) { throw new AuthenticationServiceException( "State parameter mismatch on return. Expected " + storedState + " got " + state); } } // look up the issuer that we set out to talk to String issuer = getStoredSessionString(session, ISSUER_SESSION_VARIABLE); // pull the configurations based on that issuer ServerConfiguration serverConfig = servers.getServerConfiguration(issuer); final RegisteredClient clientConfig = clients.getClientConfiguration(serverConfig); MultiValueMap<String, String> form = new LinkedMultiValueMap<>(); form.add("grant_type", "authorization_code"); form.add("code", authorizationCode); form.setAll(authOptions.getTokenOptions(serverConfig, clientConfig, request)); String redirectUri = getStoredSessionString(session, REDIRECT_URI_SESION_VARIABLE); if (redirectUri != null) { form.add("redirect_uri", redirectUri); } // Handle Token Endpoint interaction HttpClient httpClient = HttpClientBuilder.create() .useSystemProperties() .setDefaultRequestConfig( RequestConfig.custom().setSocketTimeout(httpSocketTimeout).build()) .build(); HttpComponentsClientHttpRequestFactory factory = new HttpComponentsClientHttpRequestFactory(httpClient); RestTemplate restTemplate; if (SECRET_BASIC.equals(clientConfig.getTokenEndpointAuthMethod())) { // use BASIC auth if configured to do so restTemplate = new RestTemplate(factory) { @Override protected ClientHttpRequest createRequest(URI url, HttpMethod method) throws IOException { ClientHttpRequest httpRequest = super.createRequest(url, method); httpRequest .getHeaders() .add( "Authorization", String.format( "Basic %s", Base64.encode( String.format( "%s:%s", UriUtils.encodePathSegment(clientConfig.getClientId(), "UTF-8"), UriUtils.encodePathSegment( clientConfig.getClientSecret(), "UTF-8"))))); return httpRequest; } }; } else { // we're not doing basic auth, figure out what other flavor we have restTemplate = new RestTemplate(factory); if (SECRET_JWT.equals(clientConfig.getTokenEndpointAuthMethod()) || PRIVATE_KEY.equals(clientConfig.getTokenEndpointAuthMethod())) { // do a symmetric secret signed JWT for auth JWTSigningAndValidationService signer = null; JWSAlgorithm alg = clientConfig.getTokenEndpointAuthSigningAlg(); if (SECRET_JWT.equals(clientConfig.getTokenEndpointAuthMethod()) && (alg.equals(JWSAlgorithm.HS256) || alg.equals(JWSAlgorithm.HS384) || alg.equals(JWSAlgorithm.HS512))) { // generate one based on client secret signer = symmetricCacheService.getSymmetricValidtor(clientConfig.getClient()); } else if (PRIVATE_KEY.equals(clientConfig.getTokenEndpointAuthMethod())) { // needs to be wired in to the bean signer = authenticationSignerService; if (alg == null) { alg = authenticationSignerService.getDefaultSigningAlgorithm(); } } if (signer == null) { throw new AuthenticationServiceException( "Couldn't find required signer service for use with private key auth."); } JWTClaimsSet claimsSet = new JWTClaimsSet(); claimsSet.setIssuer(clientConfig.getClientId()); claimsSet.setSubject(clientConfig.getClientId()); claimsSet.setAudience(Lists.newArrayList(serverConfig.getTokenEndpointUri())); claimsSet.setJWTID(UUID.randomUUID().toString()); // TODO: make this configurable Date exp = new Date(System.currentTimeMillis() + (60 * 1000)); // auth good for 60 seconds claimsSet.setExpirationTime(exp); Date now = new Date(System.currentTimeMillis()); claimsSet.setIssueTime(now); claimsSet.setNotBeforeTime(now); JWSHeader header = new JWSHeader( alg, null, null, null, null, null, null, null, null, null, signer.getDefaultSignerKeyId(), null, null); SignedJWT jwt = new SignedJWT(header, claimsSet); signer.signJwt(jwt, alg); form.add("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"); form.add("client_assertion", jwt.serialize()); } else { // Alternatively use form based auth form.add("client_id", clientConfig.getClientId()); form.add("client_secret", clientConfig.getClientSecret()); } } logger.debug("tokenEndpointURI = " + serverConfig.getTokenEndpointUri()); logger.debug("form = " + form); String jsonString = null; try { jsonString = restTemplate.postForObject(serverConfig.getTokenEndpointUri(), form, String.class); } catch (RestClientException e) { // Handle error logger.error("Token Endpoint error response: " + e.getMessage()); throw new AuthenticationServiceException("Unable to obtain Access Token: " + e.getMessage()); } logger.debug("from TokenEndpoint jsonString = " + jsonString); JsonElement jsonRoot = new JsonParser().parse(jsonString); if (!jsonRoot.isJsonObject()) { throw new AuthenticationServiceException( "Token Endpoint did not return a JSON object: " + jsonRoot); } JsonObject tokenResponse = jsonRoot.getAsJsonObject(); if (tokenResponse.get("error") != null) { // Handle error String error = tokenResponse.get("error").getAsString(); logger.error("Token Endpoint returned: " + error); throw new AuthenticationServiceException( "Unable to obtain Access Token. Token Endpoint returned: " + error); } else { // Extract the id_token to insert into the // OIDCAuthenticationToken // get out all the token strings String accessTokenValue = null; String idTokenValue = null; String refreshTokenValue = null; if (tokenResponse.has("access_token")) { accessTokenValue = tokenResponse.get("access_token").getAsString(); } else { throw new AuthenticationServiceException( "Token Endpoint did not return an access_token: " + jsonString); } if (tokenResponse.has("id_token")) { idTokenValue = tokenResponse.get("id_token").getAsString(); } else { logger.error("Token Endpoint did not return an id_token"); throw new AuthenticationServiceException("Token Endpoint did not return an id_token"); } if (tokenResponse.has("refresh_token")) { refreshTokenValue = tokenResponse.get("refresh_token").getAsString(); } try { JWT idToken = JWTParser.parse(idTokenValue); // validate our ID Token over a number of tests ReadOnlyJWTClaimsSet idClaims = idToken.getJWTClaimsSet(); // check the signature JWTSigningAndValidationService jwtValidator = null; Algorithm tokenAlg = idToken.getHeader().getAlgorithm(); Algorithm clientAlg = clientConfig.getIdTokenSignedResponseAlg(); if (clientAlg != null) { if (!clientAlg.equals(tokenAlg)) { throw new AuthenticationServiceException( "Token algorithm " + tokenAlg + " does not match expected algorithm " + clientAlg); } } if (idToken instanceof PlainJWT) { if (clientAlg == null) { throw new AuthenticationServiceException( "Unsigned ID tokens can only be used if explicitly configured in client."); } if (tokenAlg != null && !tokenAlg.equals(Algorithm.NONE)) { throw new AuthenticationServiceException( "Unsigned token received, expected signature with " + tokenAlg); } } else if (idToken instanceof SignedJWT) { SignedJWT signedIdToken = (SignedJWT) idToken; if (tokenAlg.equals(JWSAlgorithm.HS256) || tokenAlg.equals(JWSAlgorithm.HS384) || tokenAlg.equals(JWSAlgorithm.HS512)) { // generate one based on client secret jwtValidator = symmetricCacheService.getSymmetricValidtor(clientConfig.getClient()); } else { // otherwise load from the server's public key jwtValidator = validationServices.getValidator(serverConfig.getJwksUri()); } if (jwtValidator != null) { if (!jwtValidator.validateSignature(signedIdToken)) { throw new AuthenticationServiceException("Signature validation failed"); } } else { logger.error("No validation service found. Skipping signature validation"); throw new AuthenticationServiceException( "Unable to find an appropriate signature validator for ID Token."); } } // TODO: encrypted id tokens // check the issuer if (idClaims.getIssuer() == null) { throw new AuthenticationServiceException("Id Token Issuer is null"); } else if (!idClaims.getIssuer().equals(serverConfig.getIssuer())) { throw new AuthenticationServiceException( "Issuers do not match, expected " + serverConfig.getIssuer() + " got " + idClaims.getIssuer()); } // check expiration if (idClaims.getExpirationTime() == null) { throw new AuthenticationServiceException( "Id Token does not have required expiration claim"); } else { // it's not null, see if it's expired Date now = new Date(System.currentTimeMillis() - (timeSkewAllowance * 1000)); if (now.after(idClaims.getExpirationTime())) { throw new AuthenticationServiceException( "Id Token is expired: " + idClaims.getExpirationTime()); } } // check not before if (idClaims.getNotBeforeTime() != null) { Date now = new Date(System.currentTimeMillis() + (timeSkewAllowance * 1000)); if (now.before(idClaims.getNotBeforeTime())) { throw new AuthenticationServiceException( "Id Token not valid untill: " + idClaims.getNotBeforeTime()); } } // check issued at if (idClaims.getIssueTime() == null) { throw new AuthenticationServiceException( "Id Token does not have required issued-at claim"); } else { // since it's not null, see if it was issued in the future Date now = new Date(System.currentTimeMillis() + (timeSkewAllowance * 1000)); if (now.before(idClaims.getIssueTime())) { throw new AuthenticationServiceException( "Id Token was issued in the future: " + idClaims.getIssueTime()); } } // check audience if (idClaims.getAudience() == null) { throw new AuthenticationServiceException("Id token audience is null"); } else if (!idClaims.getAudience().contains(clientConfig.getClientId())) { throw new AuthenticationServiceException( "Audience does not match, expected " + clientConfig.getClientId() + " got " + idClaims.getAudience()); } // compare the nonce to our stored claim String nonce = idClaims.getStringClaim("nonce"); if (Strings.isNullOrEmpty(nonce)) { logger.error("ID token did not contain a nonce claim."); throw new AuthenticationServiceException("ID token did not contain a nonce claim."); } String storedNonce = getStoredNonce(session); if (!nonce.equals(storedNonce)) { logger.error( "Possible replay attack detected! The comparison of the nonce in the returned " + "ID Token to the session " + NONCE_SESSION_VARIABLE + " failed. Expected " + storedNonce + " got " + nonce + "."); throw new AuthenticationServiceException( "Possible replay attack detected! The comparison of the nonce in the returned " + "ID Token to the session " + NONCE_SESSION_VARIABLE + " failed. Expected " + storedNonce + " got " + nonce + "."); } // construct an PendingOIDCAuthenticationToken and return a Authentication object w/the // userId and the idToken PendingOIDCAuthenticationToken token = new PendingOIDCAuthenticationToken( idClaims.getSubject(), idClaims.getIssuer(), serverConfig, idToken, accessTokenValue, refreshTokenValue); Authentication authentication = this.getAuthenticationManager().authenticate(token); return authentication; } catch (ParseException e) { throw new AuthenticationServiceException("Couldn't parse idToken: ", e); } } }
@Override public OAuth2AccessTokenEntity createIdToken( ClientDetailsEntity client, OAuth2Request request, Date issueTime, String sub, OAuth2AccessTokenEntity accessToken) { JWSAlgorithm signingAlg = jwtService.getDefaultSigningAlgorithm(); if (client.getIdTokenSignedResponseAlg() != null) { signingAlg = client.getIdTokenSignedResponseAlg(); } OAuth2AccessTokenEntity idTokenEntity = new OAuth2AccessTokenEntity(); JWTClaimsSet.Builder idClaims = new JWTClaimsSet.Builder(); // if the auth time claim was explicitly requested OR if the client always wants the auth time, // put it in if (request.getExtensions().containsKey("max_age") || (request .getExtensions() .containsKey( "idtoken")) // TODO: parse the ID Token claims (#473) -- for now assume it could be // in there || (client.getRequireAuthTime() != null && client.getRequireAuthTime())) { if (request.getExtensions().get(AuthenticationTimeStamper.AUTH_TIMESTAMP) != null) { Long authTimestamp = Long.parseLong( (String) request.getExtensions().get(AuthenticationTimeStamper.AUTH_TIMESTAMP)); if (authTimestamp != null) { idClaims.claim("auth_time", authTimestamp / 1000L); } } else { // we couldn't find the timestamp! logger.warn( "Unable to find authentication timestamp! There is likely something wrong with the configuration."); } } idClaims.issueTime(issueTime); if (client.getIdTokenValiditySeconds() != null) { Date expiration = new Date(System.currentTimeMillis() + (client.getIdTokenValiditySeconds() * 1000L)); idClaims.expirationTime(expiration); idTokenEntity.setExpiration(expiration); } idClaims.issuer(configBean.getIssuer()); idClaims.subject(sub); idClaims.audience(Lists.newArrayList(client.getClientId())); idClaims.jwtID(UUID.randomUUID().toString()); // set a random NONCE in the middle of it String nonce = (String) request.getExtensions().get("nonce"); if (!Strings.isNullOrEmpty(nonce)) { idClaims.claim("nonce", nonce); } Set<String> responseTypes = request.getResponseTypes(); if (responseTypes.contains("token")) { // calculate the token hash Base64URL at_hash = IdTokenHashUtils.getAccessTokenHash(signingAlg, accessToken); idClaims.claim("at_hash", at_hash); } if (client.getIdTokenEncryptedResponseAlg() != null && !client.getIdTokenEncryptedResponseAlg().equals(Algorithm.NONE) && client.getIdTokenEncryptedResponseEnc() != null && !client.getIdTokenEncryptedResponseEnc().equals(Algorithm.NONE) && (!Strings.isNullOrEmpty(client.getJwksUri()) || client.getJwks() != null)) { JWTEncryptionAndDecryptionService encrypter = encrypters.getEncrypter(client); if (encrypter != null) { EncryptedJWT idToken = new EncryptedJWT( new JWEHeader( client.getIdTokenEncryptedResponseAlg(), client.getIdTokenEncryptedResponseEnc()), idClaims.build()); encrypter.encryptJwt(idToken); idTokenEntity.setJwt(idToken); } else { logger.error("Couldn't find encrypter for client: " + client.getClientId()); } } else { JWT idToken; if (signingAlg.equals(Algorithm.NONE)) { // unsigned ID token idToken = new PlainJWT(idClaims.build()); } else { // signed ID token if (signingAlg.equals(JWSAlgorithm.HS256) || signingAlg.equals(JWSAlgorithm.HS384) || signingAlg.equals(JWSAlgorithm.HS512)) { JWSHeader header = new JWSHeader( signingAlg, null, null, null, null, null, null, null, null, null, jwtService.getDefaultSignerKeyId(), null, null); idToken = new SignedJWT(header, idClaims.build()); JWTSigningAndValidationService signer = symmetricCacheService.getSymmetricValidtor(client); // sign it with the client's secret signer.signJwt((SignedJWT) idToken); } else { idClaims.claim("kid", jwtService.getDefaultSignerKeyId()); JWSHeader header = new JWSHeader( signingAlg, null, null, null, null, null, null, null, null, null, jwtService.getDefaultSignerKeyId(), null, null); idToken = new SignedJWT(header, idClaims.build()); // sign it with the server's key jwtService.signJwt((SignedJWT) idToken); } } idTokenEntity.setJwt(idToken); } idTokenEntity.setAuthenticationHolder(accessToken.getAuthenticationHolder()); // create a scope set with just the special "id-token" scope // Set<String> idScopes = new HashSet<String>(token.getScope()); // this would copy the original // token's scopes in, we don't really want that Set<String> idScopes = Sets.newHashSet(SystemScopeService.ID_TOKEN_SCOPE); idTokenEntity.setScope(idScopes); idTokenEntity.setClient(accessToken.getClient()); return idTokenEntity; }