private void getPredigestedSaltedPassword(NameCallback nameCallback) throws SaslException { String passwordType; switch (getMechanismName()) { case SaslMechanismInformation.Names.SCRAM_SHA_1: case SaslMechanismInformation.Names.SCRAM_SHA_1_PLUS: { passwordType = ScramDigestPassword.ALGORITHM_SCRAM_SHA_1; break; } case SaslMechanismInformation.Names.SCRAM_SHA_256: case SaslMechanismInformation.Names.SCRAM_SHA_256_PLUS: { passwordType = ScramDigestPassword.ALGORITHM_SCRAM_SHA_256; break; } case SaslMechanismInformation.Names.SCRAM_SHA_384: case SaslMechanismInformation.Names.SCRAM_SHA_384_PLUS: { passwordType = ScramDigestPassword.ALGORITHM_SCRAM_SHA_384; break; } case SaslMechanismInformation.Names.SCRAM_SHA_512: case SaslMechanismInformation.Names.SCRAM_SHA_512_PLUS: { passwordType = ScramDigestPassword.ALGORITHM_SCRAM_SHA_512; break; } default: throw Assert.impossibleSwitchCase(getMechanismName()); } CredentialCallback credentialCallback = new CredentialCallback(singletonMap(ScramDigestPassword.class, singleton(passwordType))); try { tryHandleCallbacks(nameCallback, credentialCallback); } catch (UnsupportedCallbackException e) { final Callback callback = e.getCallback(); if (callback == nameCallback) { throw log.saslCallbackHandlerDoesNotSupportUserName(getMechanismName(), e); } else if (callback == credentialCallback) { return; // pre digested not supported } else { throw log.saslCallbackHandlerFailedForUnknownReason(getMechanismName(), e); } } Password password = (Password) credentialCallback.getCredential(); if (password instanceof ScramDigestPassword) { // got a scram password final ScramDigestPassword scramDigestPassword = (ScramDigestPassword) password; if (!passwordType.equals(scramDigestPassword.getAlgorithm())) { return; } iterationCount = scramDigestPassword.getIterationCount(); salt = scramDigestPassword.getSalt(); if (iterationCount < minimumIterationCount) { throw log.saslIterationCountIsTooLow( getMechanismName(), iterationCount, minimumIterationCount); } else if (iterationCount > maximumIterationCount) { throw log.saslIterationCountIsTooHigh( getMechanismName(), iterationCount, maximumIterationCount); } if (salt == null) { throw log.saslSaltMustBeSpecified(getMechanismName()); } saltedPassword = scramDigestPassword.getDigest(); } }
protected byte[] evaluateMessage(final int state, final byte[] response) throws SaslException { switch (state) { case ST_CHALLENGE: { final CodePointIterator cpi = CodePointIterator.ofUtf8Bytes(response); final CodePointIterator di = cpi.delimitedBy(0); authorizationID = di.hasNext() ? di.drainToString() : null; cpi.next(); // Skip delimiter userName = di.drainToString(); validateUserName(userName); if ((authorizationID == null) || (authorizationID.isEmpty())) { authorizationID = userName; } validateAuthorizationId(authorizationID); // Construct an OTP extended challenge, where: // OTP extended challenge = <standard OTP challenge> ext[,<extension set id>[, ...]] // standard OTP challenge = otp-<algorithm identifier> <sequence integer> <seed> nameCallback = new NameCallback("Remote authentication name", userName); final CredentialCallback credentialCallback = CredentialCallback.builder() .addSupportedCredentialType( PasswordCredential.class, OneTimePassword.ALGORITHM_OTP_SHA1, OneTimePassword.ALGORITHM_OTP_MD5) .build(); final TimeoutCallback timeoutCallback = new TimeoutCallback(); handleCallbacks(nameCallback, credentialCallback, timeoutCallback); final PasswordCredential credential = (PasswordCredential) credentialCallback.getCredential(); final OneTimePassword previousPassword = (OneTimePassword) credential.getPassword(); if (previousPassword == null) { throw log.mechUnableToRetrievePassword(getMechanismName(), userName).toSaslException(); } previousAlgorithm = previousPassword.getAlgorithm(); validateAlgorithm(previousAlgorithm); previousSeed = new String(previousPassword.getSeed(), StandardCharsets.US_ASCII); validateSeed(previousSeed); previousSequenceNumber = previousPassword.getSequenceNumber(); validateSequenceNumber(previousSequenceNumber); previousHash = previousPassword.getHash(); // Prevent a user from starting multiple simultaneous authentication sessions using the // timeout approach described in https://tools.ietf.org/html/rfc2289#section-9.0 long timeout = timeoutCallback.getTimeout(); time = Instant.now().getEpochSecond(); if (time < timeout) { // An authentication attempt is already in progress for this user throw log.mechMultipleSimultaneousOTPAuthenticationsNotAllowed().toSaslException(); } else { updateTimeout(time + LOCK_TIMEOUT); locked = true; } final ByteStringBuilder challenge = new ByteStringBuilder(); challenge.append(previousAlgorithm); challenge.append(' '); challenge.appendNumber(previousSequenceNumber - 1); challenge.append(' '); challenge.append(previousSeed); challenge.append(' '); challenge.append(EXT); setNegotiationState(ST_PROCESS_RESPONSE); return challenge.toArray(); } case ST_PROCESS_RESPONSE: { if (Instant.now().getEpochSecond() > (time + LOCK_TIMEOUT)) { throw log.mechServerTimedOut(getMechanismName()).toSaslException(); } final CodePointIterator cpi = CodePointIterator.ofUtf8Bytes(response); final CodePointIterator di = cpi.delimitedBy(':'); final String responseType = di.drainToString().toLowerCase(Locale.ENGLISH); final byte[] currentHash; OneTimePasswordSpec passwordSpec; String algorithm; skipDelims(di, cpi, ':'); switch (responseType) { case HEX_RESPONSE: case WORD_RESPONSE: { if (responseType.equals(HEX_RESPONSE)) { currentHash = convertFromHex(di.drainToString()); } else { currentHash = convertFromWords(di.drainToString(), previousAlgorithm); } passwordSpec = new OneTimePasswordSpec( currentHash, previousSeed.getBytes(StandardCharsets.US_ASCII), previousSequenceNumber - 1); algorithm = previousAlgorithm; break; } case INIT_HEX_RESPONSE: case INIT_WORD_RESPONSE: { if (responseType.equals(INIT_HEX_RESPONSE)) { currentHash = convertFromHex(di.drainToString()); } else { currentHash = convertFromWords(di.drainToString(), previousAlgorithm); } try { // Attempt to parse the new params and new OTP skipDelims(di, cpi, ':'); final CodePointIterator si = di.delimitedBy(' '); String newAlgorithm = OTP_PREFIX + si.drainToString(); validateAlgorithm(newAlgorithm); skipDelims(si, di, ' '); int newSequenceNumber = Integer.parseInt(si.drainToString()); validateSequenceNumber(newSequenceNumber); skipDelims(si, di, ' '); String newSeed = si.drainToString(); validateSeed(newSeed); skipDelims(di, cpi, ':'); final byte[] newHash; if (responseType.equals(INIT_HEX_RESPONSE)) { newHash = convertFromHex(di.drainToString()); } else { newHash = convertFromWords(di.drainToString(), newAlgorithm); } passwordSpec = new OneTimePasswordSpec( newHash, newSeed.getBytes(StandardCharsets.US_ASCII), newSequenceNumber); algorithm = newAlgorithm; } catch (SaslException e) { // If the new params or new OTP could not be processed for any reason, the // sequence // number should be decremented if a valid current OTP is provided passwordSpec = new OneTimePasswordSpec( currentHash, previousSeed.getBytes(StandardCharsets.US_ASCII), previousSequenceNumber - 1); algorithm = previousAlgorithm; verifyAndUpdateCredential(currentHash, algorithm, passwordSpec); throw log.mechOTPReinitializationFailed(e).toSaslException(); } break; } default: throw log.mechInvalidOTPResponseType().toSaslException(); } if (cpi.hasNext()) { throw log.mechInvalidMessageReceived(getMechanismName()).toSaslException(); } verifyAndUpdateCredential(currentHash, algorithm, passwordSpec); // Check the authorization id if (authorizationID == null) { authorizationID = userName; } final AuthorizeCallback authorizeCallback = new AuthorizeCallback(userName, authorizationID); handleCallbacks(authorizeCallback); if (!authorizeCallback.isAuthorized()) { throw log.mechAuthorizationFailed(getMechanismName(), userName, authorizationID) .toSaslException(); } negotiationComplete(); return null; } case COMPLETE_STATE: { if (response != null && response.length != 0) { throw log.mechMessageAfterComplete(getMechanismName()).toSaslException(); } return null; } default: throw Assert.impossibleSwitchCase(state); } }
protected byte[] evaluateMessage(final int state, final byte[] response) throws SaslException { boolean trace = log.isTraceEnabled(); boolean ok = false; try { switch (state) { case S_NO_MESSAGE: { if (response == null || response.length == 0) { setNegotiationState(S_FIRST_MESSAGE); // initial challenge ok = true; return NO_BYTES; } // fall through } case S_FIRST_MESSAGE: { if (response == null || response.length == 0) { throw log.saslClientRefusesToInitiateAuthentication(getMechanismName()); } if (trace) log.tracef( "[S] Client first message: %s%n", ByteIterator.ofBytes(response).hexEncode().drainToString()); final ByteStringBuilder b = new ByteStringBuilder(); int c; ByteIterator bi = ByteIterator.ofBytes(response); ByteIterator di = bi.delimitedBy(','); CodePointIterator cpi = di.asUtf8String(); // == parse message == // binding type cbindFlag = bi.next(); if (cbindFlag == 'p' && plus) { assert bindingType != null; // because {@code plus} is true assert bindingData != null; if (bi.next() != '=') { throw log.saslInvalidClientMessage(getMechanismName()); } if (!bindingType.equals(cpi.drainToString())) { // nope, auth must fail because we cannot acquire the same binding throw log.saslChannelBindingTypeMismatch(getMechanismName()); } bi.next(); // skip delimiter } else if ((cbindFlag == 'y' || cbindFlag == 'n') && !plus) { if (bi.next() != ',') { throw log.saslInvalidClientMessage(getMechanismName()); } } else { throw log.saslInvalidClientMessage(getMechanismName()); } // authorization ID c = bi.next(); if (c == 'a') { if (bi.next() != '=') { throw log.saslInvalidClientMessage(getMechanismName()); } authorizationID = cpi.drainToString(); bi.next(); // skip delimiter } else if (c != ',') { throw log.saslInvalidClientMessage(getMechanismName()); } clientFirstMessageBareStart = bi.offset(); // user name if (bi.next() == 'n') { if (bi.next() != '=') { throw log.saslInvalidClientMessage(getMechanismName()); } ByteStringBuilder bsb = new ByteStringBuilder(); StringPrep.encode( cpi.drainToString(), bsb, StringPrep.PROFILE_SASL_QUERY | StringPrep.UNMAP_SCRAM_LOGIN_CHARS); userName = new String(bsb.toArray(), StandardCharsets.UTF_8); bi.next(); // skip delimiter } else { throw log.saslInvalidClientMessage(getMechanismName()); } // random nonce if (bi.next() != 'r' || bi.next() != '=') { throw log.saslInvalidClientMessage(getMechanismName()); } byte[] nonce = di.drain(); if (trace) log.tracef( "[S] Client nonce: %s%n", ByteIterator.ofBytes(nonce).hexEncode().drainToString()); if (bi.hasNext()) { throw log.saslInvalidClientMessage(getMechanismName()); } clientFirstMessage = response; // == send first challenge == // get salted password final NameCallback nameCallback = new NameCallback("Remote authentication name", userName); saltedPassword = null; getPredigestedSaltedPassword(nameCallback); if (saltedPassword == null) { getSaltedPasswordFromTwoWay(nameCallback, b); } if (saltedPassword == null) { getSaltedPasswordFromPasswordCallback(nameCallback, b); } if (saltedPassword == null) { throw log.saslCallbackHandlerDoesNotSupportCredentialAcquisition( getMechanismName(), null); } if (trace) log.tracef("[S] Salt: %s%n", ByteIterator.ofBytes(salt).hexEncode().drainToString()); if (trace) log.tracef( "[S] Salted password: %s%n", ByteIterator.ofBytes(saltedPassword).hexEncode().drainToString()); // nonce (client + server nonce) b.append('r').append('='); b.append(nonce); b.append(ScramUtil.generateNonce(28, getRandom())); b.append(','); // salt b.append('s').append('='); b.appendLatin1(ByteIterator.ofBytes(salt).base64Encode()); b.append(','); b.append('i').append('='); b.append(Integer.toString(iterationCount)); setNegotiationState(S_FINAL_MESSAGE); ok = true; return serverFirstMessage = b.toArray(); } case S_FINAL_MESSAGE: { final ByteStringBuilder b = new ByteStringBuilder(); ByteIterator bi = ByteIterator.ofBytes(response); ByteIterator di = bi.delimitedBy(','); // == parse message == // first comes the channel binding if (bi.next() != 'c' || bi.next() != '=') { throw log.saslInvalidClientMessage(getMechanismName()); } final ByteIterator bindingIterator = di.base64Decode(); // -- sub-parse of binding data -- if (bindingIterator.next() != cbindFlag) { throw log.saslInvalidClientMessage(getMechanismName()); } switch (cbindFlag) { case 'n': case 'y': { // n,[a=authzid], if (plus) throw log.saslChannelBindingNotProvided(getMechanismName()); parseAuthorizationId(bindingIterator); if (bindingIterator.hasNext()) { // require end throw log.saslInvalidClientMessage(getMechanismName()); } break; } case 'p': { // p=bindingType,[a=authzid],bindingData if (!plus) { throw log.saslChannelBindingNotSupported(getMechanismName()); } if (bindingIterator.next() != '=') { throw log.saslInvalidClientMessage(getMechanismName()); } if (!bindingType.equals( bindingIterator.delimitedBy(',').asUtf8String().drainToString())) { throw log.saslChannelBindingTypeMismatch(getMechanismName()); } parseAuthorizationId(bindingIterator); // following is the raw channel binding data if (!bindingIterator.contentEquals(ByteIterator.ofBytes(bindingData))) { throw log.saslChannelBindingTypeMismatch(getMechanismName()); } if (bindingIterator.hasNext()) { // require end throw log.saslInvalidClientMessage(getMechanismName()); } break; } } bi.next(); // skip delimiter // nonce if (bi.next() != 'r' || bi.next() != '=') { throw log.saslInvalidClientMessage(getMechanismName()); } while (di.hasNext()) { di.next(); } // proof final int proofOffset = bi.offset(); bi.next(); // skip delimiter if (bi.next() != 'p' || bi.next() != '=') { throw log.saslInvalidClientMessage(getMechanismName()); } byte[] recoveredClientProofEncoded = di.drain(); if (bi.hasNext()) { throw log.saslInvalidClientMessage(getMechanismName()); } // == verify proof == // client key byte[] clientKey; mac.reset(); mac.init(new SecretKeySpec(saltedPassword, mac.getAlgorithm())); mac.update(Scram.CLIENT_KEY_BYTES); clientKey = mac.doFinal(); if (trace) log.tracef( "[S] Client key: %s%n", ByteIterator.ofBytes(clientKey).hexEncode().drainToString()); // stored key byte[] storedKey; messageDigest.reset(); messageDigest.update(clientKey); storedKey = messageDigest.digest(); if (trace) log.tracef( "[S] Stored key: %s%n", ByteIterator.ofBytes(storedKey).hexEncode().drainToString()); // client signature mac.reset(); mac.init(new SecretKeySpec(storedKey, mac.getAlgorithm())); mac.update( clientFirstMessage, clientFirstMessageBareStart, clientFirstMessage.length - clientFirstMessageBareStart); if (trace) log.tracef( "[S] Using client first message: %s%n", ByteIterator.ofBytes( copyOfRange( clientFirstMessage, clientFirstMessageBareStart, clientFirstMessage.length)) .hexEncode() .drainToString()); mac.update((byte) ','); mac.update(serverFirstMessage); if (trace) log.tracef( "[S] Using server first message: %s%n", ByteIterator.ofBytes(serverFirstMessage).hexEncode().drainToString()); mac.update((byte) ','); mac.update(response, 0, proofOffset); // client-final-message-without-proof if (trace) log.tracef( "[S] Using client final message without proof: %s%n", ByteIterator.ofBytes(copyOfRange(response, 0, proofOffset)) .hexEncode() .drainToString()); byte[] clientSignature = mac.doFinal(); if (trace) log.tracef( "[S] Client signature: %s%n", ByteIterator.ofBytes(clientSignature).hexEncode().drainToString()); // server key byte[] serverKey; mac.reset(); mac.init(new SecretKeySpec(saltedPassword, mac.getAlgorithm())); mac.update(Scram.SERVER_KEY_BYTES); serverKey = mac.doFinal(); if (trace) log.tracef( "[S] Server key: %s%n", ByteIterator.ofBytes(serverKey).hexEncode().drainToString()); // server signature byte[] serverSignature; mac.reset(); mac.init(new SecretKeySpec(serverKey, mac.getAlgorithm())); mac.update( clientFirstMessage, clientFirstMessageBareStart, clientFirstMessage.length - clientFirstMessageBareStart); mac.update((byte) ','); mac.update(serverFirstMessage); mac.update((byte) ','); mac.update(response, 0, proofOffset); // client-final-message-without-proof serverSignature = mac.doFinal(); if (trace) log.tracef( "[S] Server signature: %s%n", ByteIterator.ofBytes(serverSignature).hexEncode().drainToString()); if (trace) log.tracef( "[S] Client proof string: %s%n", CodePointIterator.ofUtf8Bytes(recoveredClientProofEncoded).drainToString()); b.setLength(0); byte[] recoveredClientProof = ByteIterator.ofBytes(recoveredClientProofEncoded).base64Decode().drain(); if (trace) log.tracef( "[S] Client proof: %s%n", ByteIterator.ofBytes(recoveredClientProof).hexEncode().drainToString()); // now check the proof byte[] recoveredClientKey = clientSignature.clone(); ScramUtil.xor(recoveredClientKey, recoveredClientProof); if (trace) log.tracef( "[S] Recovered client key: %s%n", ByteIterator.ofBytes(recoveredClientKey).hexEncode().drainToString()); if (!Arrays.equals(recoveredClientKey, clientKey)) { // bad auth, send error if (sendErrors) { b.setLength(0); b.append("e=invalid-proof"); setNegotiationState(FAILED_STATE); return b.toArray(); } throw log.saslAuthenticationRejectedInvalidProof(getMechanismName()); } if (authorizationID == null) { authorizationID = userName; } else { ByteStringBuilder bsb = new ByteStringBuilder(); StringPrep.encode( authorizationID, bsb, StringPrep.PROFILE_SASL_QUERY | StringPrep.UNMAP_SCRAM_LOGIN_CHARS); authorizationID = new String(bsb.toArray(), StandardCharsets.UTF_8); } final AuthorizeCallback authorizeCallback = new AuthorizeCallback(userName, authorizationID); try { tryHandleCallbacks(authorizeCallback); } catch (UnsupportedCallbackException e) { throw log.saslAuthorizationUnsupported(getMechanismName(), e); } if (!authorizeCallback.isAuthorized()) { throw log.saslAuthorizationFailed(getMechanismName(), userName, authorizationID); } // == send response == b.setLength(0); b.append('v').append('='); b.appendUtf8(ByteIterator.ofBytes(serverSignature).base64Encode()); setNegotiationState(COMPLETE_STATE); ok = true; return b.toArray(); } case COMPLETE_STATE: { if (response != null && response.length != 0) { throw log.saslClientSentExtraMessage(getMechanismName()); } ok = true; return null; } case FAILED_STATE: { throw log.saslAuthenticationFailed(getMechanismName()); } } throw Assert.impossibleSwitchCase(state); } catch (ArrayIndexOutOfBoundsException | InvalidKeyException ignored) { throw log.saslInvalidClientMessage(getMechanismName()); } finally { if (!ok) { setNegotiationState(FAILED_STATE); } } }