Esempio n. 1
0
 private void getPredigestedSaltedPassword(NameCallback nameCallback) throws SaslException {
   String passwordType;
   switch (getMechanismName()) {
     case SaslMechanismInformation.Names.SCRAM_SHA_1:
     case SaslMechanismInformation.Names.SCRAM_SHA_1_PLUS:
       {
         passwordType = ScramDigestPassword.ALGORITHM_SCRAM_SHA_1;
         break;
       }
     case SaslMechanismInformation.Names.SCRAM_SHA_256:
     case SaslMechanismInformation.Names.SCRAM_SHA_256_PLUS:
       {
         passwordType = ScramDigestPassword.ALGORITHM_SCRAM_SHA_256;
         break;
       }
     case SaslMechanismInformation.Names.SCRAM_SHA_384:
     case SaslMechanismInformation.Names.SCRAM_SHA_384_PLUS:
       {
         passwordType = ScramDigestPassword.ALGORITHM_SCRAM_SHA_384;
         break;
       }
     case SaslMechanismInformation.Names.SCRAM_SHA_512:
     case SaslMechanismInformation.Names.SCRAM_SHA_512_PLUS:
       {
         passwordType = ScramDigestPassword.ALGORITHM_SCRAM_SHA_512;
         break;
       }
     default:
       throw Assert.impossibleSwitchCase(getMechanismName());
   }
   CredentialCallback credentialCallback =
       new CredentialCallback(singletonMap(ScramDigestPassword.class, singleton(passwordType)));
   try {
     tryHandleCallbacks(nameCallback, credentialCallback);
   } catch (UnsupportedCallbackException e) {
     final Callback callback = e.getCallback();
     if (callback == nameCallback) {
       throw log.saslCallbackHandlerDoesNotSupportUserName(getMechanismName(), e);
     } else if (callback == credentialCallback) {
       return; // pre digested not supported
     } else {
       throw log.saslCallbackHandlerFailedForUnknownReason(getMechanismName(), e);
     }
   }
   Password password = (Password) credentialCallback.getCredential();
   if (password instanceof ScramDigestPassword) {
     // got a scram password
     final ScramDigestPassword scramDigestPassword = (ScramDigestPassword) password;
     if (!passwordType.equals(scramDigestPassword.getAlgorithm())) {
       return;
     }
     iterationCount = scramDigestPassword.getIterationCount();
     salt = scramDigestPassword.getSalt();
     if (iterationCount < minimumIterationCount) {
       throw log.saslIterationCountIsTooLow(
           getMechanismName(), iterationCount, minimumIterationCount);
     } else if (iterationCount > maximumIterationCount) {
       throw log.saslIterationCountIsTooHigh(
           getMechanismName(), iterationCount, maximumIterationCount);
     }
     if (salt == null) {
       throw log.saslSaltMustBeSpecified(getMechanismName());
     }
     saltedPassword = scramDigestPassword.getDigest();
   }
 }
  protected byte[] evaluateMessage(final int state, final byte[] response) throws SaslException {
    switch (state) {
      case ST_CHALLENGE:
        {
          final CodePointIterator cpi = CodePointIterator.ofUtf8Bytes(response);
          final CodePointIterator di = cpi.delimitedBy(0);

          authorizationID = di.hasNext() ? di.drainToString() : null;
          cpi.next(); // Skip delimiter
          userName = di.drainToString();
          validateUserName(userName);
          if ((authorizationID == null) || (authorizationID.isEmpty())) {
            authorizationID = userName;
          }
          validateAuthorizationId(authorizationID);

          // Construct an OTP extended challenge, where:
          // OTP extended challenge = <standard OTP challenge> ext[,<extension set id>[, ...]]
          // standard OTP challenge = otp-<algorithm identifier> <sequence integer> <seed>
          nameCallback = new NameCallback("Remote authentication name", userName);
          final CredentialCallback credentialCallback =
              CredentialCallback.builder()
                  .addSupportedCredentialType(
                      PasswordCredential.class,
                      OneTimePassword.ALGORITHM_OTP_SHA1,
                      OneTimePassword.ALGORITHM_OTP_MD5)
                  .build();

          final TimeoutCallback timeoutCallback = new TimeoutCallback();
          handleCallbacks(nameCallback, credentialCallback, timeoutCallback);
          final PasswordCredential credential =
              (PasswordCredential) credentialCallback.getCredential();
          final OneTimePassword previousPassword = (OneTimePassword) credential.getPassword();
          if (previousPassword == null) {
            throw log.mechUnableToRetrievePassword(getMechanismName(), userName).toSaslException();
          }
          previousAlgorithm = previousPassword.getAlgorithm();
          validateAlgorithm(previousAlgorithm);
          previousSeed = new String(previousPassword.getSeed(), StandardCharsets.US_ASCII);
          validateSeed(previousSeed);
          previousSequenceNumber = previousPassword.getSequenceNumber();
          validateSequenceNumber(previousSequenceNumber);
          previousHash = previousPassword.getHash();

          // Prevent a user from starting multiple simultaneous authentication sessions using the
          // timeout approach described in https://tools.ietf.org/html/rfc2289#section-9.0
          long timeout = timeoutCallback.getTimeout();
          time = Instant.now().getEpochSecond();
          if (time < timeout) {
            // An authentication attempt is already in progress for this user
            throw log.mechMultipleSimultaneousOTPAuthenticationsNotAllowed().toSaslException();
          } else {
            updateTimeout(time + LOCK_TIMEOUT);
            locked = true;
          }

          final ByteStringBuilder challenge = new ByteStringBuilder();
          challenge.append(previousAlgorithm);
          challenge.append(' ');
          challenge.appendNumber(previousSequenceNumber - 1);
          challenge.append(' ');
          challenge.append(previousSeed);
          challenge.append(' ');
          challenge.append(EXT);
          setNegotiationState(ST_PROCESS_RESPONSE);
          return challenge.toArray();
        }
      case ST_PROCESS_RESPONSE:
        {
          if (Instant.now().getEpochSecond() > (time + LOCK_TIMEOUT)) {
            throw log.mechServerTimedOut(getMechanismName()).toSaslException();
          }
          final CodePointIterator cpi = CodePointIterator.ofUtf8Bytes(response);
          final CodePointIterator di = cpi.delimitedBy(':');
          final String responseType = di.drainToString().toLowerCase(Locale.ENGLISH);
          final byte[] currentHash;
          OneTimePasswordSpec passwordSpec;
          String algorithm;
          skipDelims(di, cpi, ':');
          switch (responseType) {
            case HEX_RESPONSE:
            case WORD_RESPONSE:
              {
                if (responseType.equals(HEX_RESPONSE)) {
                  currentHash = convertFromHex(di.drainToString());
                } else {
                  currentHash = convertFromWords(di.drainToString(), previousAlgorithm);
                }
                passwordSpec =
                    new OneTimePasswordSpec(
                        currentHash,
                        previousSeed.getBytes(StandardCharsets.US_ASCII),
                        previousSequenceNumber - 1);
                algorithm = previousAlgorithm;
                break;
              }
            case INIT_HEX_RESPONSE:
            case INIT_WORD_RESPONSE:
              {
                if (responseType.equals(INIT_HEX_RESPONSE)) {
                  currentHash = convertFromHex(di.drainToString());
                } else {
                  currentHash = convertFromWords(di.drainToString(), previousAlgorithm);
                }
                try {
                  // Attempt to parse the new params and new OTP
                  skipDelims(di, cpi, ':');
                  final CodePointIterator si = di.delimitedBy(' ');
                  String newAlgorithm = OTP_PREFIX + si.drainToString();
                  validateAlgorithm(newAlgorithm);
                  skipDelims(si, di, ' ');
                  int newSequenceNumber = Integer.parseInt(si.drainToString());
                  validateSequenceNumber(newSequenceNumber);
                  skipDelims(si, di, ' ');
                  String newSeed = si.drainToString();
                  validateSeed(newSeed);
                  skipDelims(di, cpi, ':');
                  final byte[] newHash;
                  if (responseType.equals(INIT_HEX_RESPONSE)) {
                    newHash = convertFromHex(di.drainToString());
                  } else {
                    newHash = convertFromWords(di.drainToString(), newAlgorithm);
                  }
                  passwordSpec =
                      new OneTimePasswordSpec(
                          newHash, newSeed.getBytes(StandardCharsets.US_ASCII), newSequenceNumber);
                  algorithm = newAlgorithm;
                } catch (SaslException e) {
                  // If the new params or new OTP could not be processed for any reason, the
                  // sequence
                  // number should be decremented if a valid current OTP is provided
                  passwordSpec =
                      new OneTimePasswordSpec(
                          currentHash,
                          previousSeed.getBytes(StandardCharsets.US_ASCII),
                          previousSequenceNumber - 1);
                  algorithm = previousAlgorithm;
                  verifyAndUpdateCredential(currentHash, algorithm, passwordSpec);
                  throw log.mechOTPReinitializationFailed(e).toSaslException();
                }
                break;
              }
            default:
              throw log.mechInvalidOTPResponseType().toSaslException();
          }
          if (cpi.hasNext()) {
            throw log.mechInvalidMessageReceived(getMechanismName()).toSaslException();
          }
          verifyAndUpdateCredential(currentHash, algorithm, passwordSpec);

          // Check the authorization id
          if (authorizationID == null) {
            authorizationID = userName;
          }
          final AuthorizeCallback authorizeCallback =
              new AuthorizeCallback(userName, authorizationID);
          handleCallbacks(authorizeCallback);
          if (!authorizeCallback.isAuthorized()) {
            throw log.mechAuthorizationFailed(getMechanismName(), userName, authorizationID)
                .toSaslException();
          }
          negotiationComplete();
          return null;
        }
      case COMPLETE_STATE:
        {
          if (response != null && response.length != 0) {
            throw log.mechMessageAfterComplete(getMechanismName()).toSaslException();
          }
          return null;
        }
      default:
        throw Assert.impossibleSwitchCase(state);
    }
  }
Esempio n. 3
0
  protected byte[] evaluateMessage(final int state, final byte[] response) throws SaslException {
    boolean trace = log.isTraceEnabled();
    boolean ok = false;
    try {
      switch (state) {
        case S_NO_MESSAGE:
          {
            if (response == null || response.length == 0) {
              setNegotiationState(S_FIRST_MESSAGE);
              // initial challenge
              ok = true;
              return NO_BYTES;
            }
            // fall through
          }
        case S_FIRST_MESSAGE:
          {
            if (response == null || response.length == 0) {
              throw log.saslClientRefusesToInitiateAuthentication(getMechanismName());
            }
            if (trace)
              log.tracef(
                  "[S] Client first message: %s%n",
                  ByteIterator.ofBytes(response).hexEncode().drainToString());

            final ByteStringBuilder b = new ByteStringBuilder();
            int c;
            ByteIterator bi = ByteIterator.ofBytes(response);
            ByteIterator di = bi.delimitedBy(',');
            CodePointIterator cpi = di.asUtf8String();

            // == parse message ==

            // binding type
            cbindFlag = bi.next();
            if (cbindFlag == 'p' && plus) {
              assert bindingType != null; // because {@code plus} is true
              assert bindingData != null;
              if (bi.next() != '=') {
                throw log.saslInvalidClientMessage(getMechanismName());
              }
              if (!bindingType.equals(cpi.drainToString())) {
                // nope, auth must fail because we cannot acquire the same binding
                throw log.saslChannelBindingTypeMismatch(getMechanismName());
              }
              bi.next(); // skip delimiter
            } else if ((cbindFlag == 'y' || cbindFlag == 'n') && !plus) {
              if (bi.next() != ',') {
                throw log.saslInvalidClientMessage(getMechanismName());
              }
            } else {
              throw log.saslInvalidClientMessage(getMechanismName());
            }

            // authorization ID
            c = bi.next();
            if (c == 'a') {
              if (bi.next() != '=') {
                throw log.saslInvalidClientMessage(getMechanismName());
              }
              authorizationID = cpi.drainToString();
              bi.next(); // skip delimiter
            } else if (c != ',') {
              throw log.saslInvalidClientMessage(getMechanismName());
            }

            clientFirstMessageBareStart = bi.offset();

            // user name
            if (bi.next() == 'n') {
              if (bi.next() != '=') {
                throw log.saslInvalidClientMessage(getMechanismName());
              }
              ByteStringBuilder bsb = new ByteStringBuilder();
              StringPrep.encode(
                  cpi.drainToString(),
                  bsb,
                  StringPrep.PROFILE_SASL_QUERY | StringPrep.UNMAP_SCRAM_LOGIN_CHARS);
              userName = new String(bsb.toArray(), StandardCharsets.UTF_8);
              bi.next(); // skip delimiter
            } else {
              throw log.saslInvalidClientMessage(getMechanismName());
            }

            // random nonce
            if (bi.next() != 'r' || bi.next() != '=') {
              throw log.saslInvalidClientMessage(getMechanismName());
            }
            byte[] nonce = di.drain();
            if (trace)
              log.tracef(
                  "[S] Client nonce: %s%n",
                  ByteIterator.ofBytes(nonce).hexEncode().drainToString());

            if (bi.hasNext()) {
              throw log.saslInvalidClientMessage(getMechanismName());
            }

            clientFirstMessage = response;

            // == send first challenge ==

            // get salted password
            final NameCallback nameCallback =
                new NameCallback("Remote authentication name", userName);
            saltedPassword = null;
            getPredigestedSaltedPassword(nameCallback);
            if (saltedPassword == null) {
              getSaltedPasswordFromTwoWay(nameCallback, b);
            }
            if (saltedPassword == null) {
              getSaltedPasswordFromPasswordCallback(nameCallback, b);
            }
            if (saltedPassword == null) {
              throw log.saslCallbackHandlerDoesNotSupportCredentialAcquisition(
                  getMechanismName(), null);
            }
            if (trace)
              log.tracef("[S] Salt: %s%n", ByteIterator.ofBytes(salt).hexEncode().drainToString());
            if (trace)
              log.tracef(
                  "[S] Salted password: %s%n",
                  ByteIterator.ofBytes(saltedPassword).hexEncode().drainToString());

            // nonce (client + server nonce)
            b.append('r').append('=');
            b.append(nonce);
            b.append(ScramUtil.generateNonce(28, getRandom()));
            b.append(',');

            // salt
            b.append('s').append('=');
            b.appendLatin1(ByteIterator.ofBytes(salt).base64Encode());
            b.append(',');
            b.append('i').append('=');
            b.append(Integer.toString(iterationCount));

            setNegotiationState(S_FINAL_MESSAGE);
            ok = true;
            return serverFirstMessage = b.toArray();
          }
        case S_FINAL_MESSAGE:
          {
            final ByteStringBuilder b = new ByteStringBuilder();

            ByteIterator bi = ByteIterator.ofBytes(response);
            ByteIterator di = bi.delimitedBy(',');

            // == parse message ==

            // first comes the channel binding
            if (bi.next() != 'c' || bi.next() != '=') {
              throw log.saslInvalidClientMessage(getMechanismName());
            }

            final ByteIterator bindingIterator = di.base64Decode();

            // -- sub-parse of binding data --
            if (bindingIterator.next() != cbindFlag) {
              throw log.saslInvalidClientMessage(getMechanismName());
            }
            switch (cbindFlag) {
              case 'n':
              case 'y':
                { // n,[a=authzid],
                  if (plus) throw log.saslChannelBindingNotProvided(getMechanismName());

                  parseAuthorizationId(bindingIterator);

                  if (bindingIterator.hasNext()) { // require end
                    throw log.saslInvalidClientMessage(getMechanismName());
                  }
                  break;
                }
              case 'p':
                { // p=bindingType,[a=authzid],bindingData
                  if (!plus) {
                    throw log.saslChannelBindingNotSupported(getMechanismName());
                  }
                  if (bindingIterator.next() != '=') {
                    throw log.saslInvalidClientMessage(getMechanismName());
                  }
                  if (!bindingType.equals(
                      bindingIterator.delimitedBy(',').asUtf8String().drainToString())) {
                    throw log.saslChannelBindingTypeMismatch(getMechanismName());
                  }
                  parseAuthorizationId(bindingIterator);

                  // following is the raw channel binding data
                  if (!bindingIterator.contentEquals(ByteIterator.ofBytes(bindingData))) {
                    throw log.saslChannelBindingTypeMismatch(getMechanismName());
                  }
                  if (bindingIterator.hasNext()) { // require end
                    throw log.saslInvalidClientMessage(getMechanismName());
                  }
                  break;
                }
            }
            bi.next(); // skip delimiter

            // nonce
            if (bi.next() != 'r' || bi.next() != '=') {
              throw log.saslInvalidClientMessage(getMechanismName());
            }
            while (di.hasNext()) {
              di.next();
            }

            // proof
            final int proofOffset = bi.offset();
            bi.next(); // skip delimiter
            if (bi.next() != 'p' || bi.next() != '=') {
              throw log.saslInvalidClientMessage(getMechanismName());
            }
            byte[] recoveredClientProofEncoded = di.drain();
            if (bi.hasNext()) {
              throw log.saslInvalidClientMessage(getMechanismName());
            }

            // == verify proof ==

            // client key
            byte[] clientKey;
            mac.reset();
            mac.init(new SecretKeySpec(saltedPassword, mac.getAlgorithm()));
            mac.update(Scram.CLIENT_KEY_BYTES);
            clientKey = mac.doFinal();
            if (trace)
              log.tracef(
                  "[S] Client key: %s%n",
                  ByteIterator.ofBytes(clientKey).hexEncode().drainToString());

            // stored key
            byte[] storedKey;
            messageDigest.reset();
            messageDigest.update(clientKey);
            storedKey = messageDigest.digest();
            if (trace)
              log.tracef(
                  "[S] Stored key: %s%n",
                  ByteIterator.ofBytes(storedKey).hexEncode().drainToString());

            // client signature
            mac.reset();
            mac.init(new SecretKeySpec(storedKey, mac.getAlgorithm()));
            mac.update(
                clientFirstMessage,
                clientFirstMessageBareStart,
                clientFirstMessage.length - clientFirstMessageBareStart);
            if (trace)
              log.tracef(
                  "[S] Using client first message: %s%n",
                  ByteIterator.ofBytes(
                          copyOfRange(
                              clientFirstMessage,
                              clientFirstMessageBareStart,
                              clientFirstMessage.length))
                      .hexEncode()
                      .drainToString());
            mac.update((byte) ',');
            mac.update(serverFirstMessage);
            if (trace)
              log.tracef(
                  "[S] Using server first message: %s%n",
                  ByteIterator.ofBytes(serverFirstMessage).hexEncode().drainToString());
            mac.update((byte) ',');
            mac.update(response, 0, proofOffset); // client-final-message-without-proof
            if (trace)
              log.tracef(
                  "[S] Using client final message without proof: %s%n",
                  ByteIterator.ofBytes(copyOfRange(response, 0, proofOffset))
                      .hexEncode()
                      .drainToString());
            byte[] clientSignature = mac.doFinal();
            if (trace)
              log.tracef(
                  "[S] Client signature: %s%n",
                  ByteIterator.ofBytes(clientSignature).hexEncode().drainToString());

            // server key
            byte[] serverKey;
            mac.reset();
            mac.init(new SecretKeySpec(saltedPassword, mac.getAlgorithm()));
            mac.update(Scram.SERVER_KEY_BYTES);
            serverKey = mac.doFinal();
            if (trace)
              log.tracef(
                  "[S] Server key: %s%n",
                  ByteIterator.ofBytes(serverKey).hexEncode().drainToString());

            // server signature
            byte[] serverSignature;
            mac.reset();
            mac.init(new SecretKeySpec(serverKey, mac.getAlgorithm()));
            mac.update(
                clientFirstMessage,
                clientFirstMessageBareStart,
                clientFirstMessage.length - clientFirstMessageBareStart);
            mac.update((byte) ',');
            mac.update(serverFirstMessage);
            mac.update((byte) ',');
            mac.update(response, 0, proofOffset); // client-final-message-without-proof
            serverSignature = mac.doFinal();
            if (trace)
              log.tracef(
                  "[S] Server signature: %s%n",
                  ByteIterator.ofBytes(serverSignature).hexEncode().drainToString());

            if (trace)
              log.tracef(
                  "[S] Client proof string: %s%n",
                  CodePointIterator.ofUtf8Bytes(recoveredClientProofEncoded).drainToString());
            b.setLength(0);
            byte[] recoveredClientProof =
                ByteIterator.ofBytes(recoveredClientProofEncoded).base64Decode().drain();
            if (trace)
              log.tracef(
                  "[S] Client proof: %s%n",
                  ByteIterator.ofBytes(recoveredClientProof).hexEncode().drainToString());

            // now check the proof
            byte[] recoveredClientKey = clientSignature.clone();
            ScramUtil.xor(recoveredClientKey, recoveredClientProof);
            if (trace)
              log.tracef(
                  "[S] Recovered client key: %s%n",
                  ByteIterator.ofBytes(recoveredClientKey).hexEncode().drainToString());
            if (!Arrays.equals(recoveredClientKey, clientKey)) {
              // bad auth, send error
              if (sendErrors) {
                b.setLength(0);
                b.append("e=invalid-proof");
                setNegotiationState(FAILED_STATE);
                return b.toArray();
              }
              throw log.saslAuthenticationRejectedInvalidProof(getMechanismName());
            }

            if (authorizationID == null) {
              authorizationID = userName;
            } else {
              ByteStringBuilder bsb = new ByteStringBuilder();
              StringPrep.encode(
                  authorizationID,
                  bsb,
                  StringPrep.PROFILE_SASL_QUERY | StringPrep.UNMAP_SCRAM_LOGIN_CHARS);
              authorizationID = new String(bsb.toArray(), StandardCharsets.UTF_8);
            }
            final AuthorizeCallback authorizeCallback =
                new AuthorizeCallback(userName, authorizationID);
            try {
              tryHandleCallbacks(authorizeCallback);
            } catch (UnsupportedCallbackException e) {
              throw log.saslAuthorizationUnsupported(getMechanismName(), e);
            }
            if (!authorizeCallback.isAuthorized()) {
              throw log.saslAuthorizationFailed(getMechanismName(), userName, authorizationID);
            }

            // == send response ==
            b.setLength(0);
            b.append('v').append('=');
            b.appendUtf8(ByteIterator.ofBytes(serverSignature).base64Encode());

            setNegotiationState(COMPLETE_STATE);
            ok = true;
            return b.toArray();
          }
        case COMPLETE_STATE:
          {
            if (response != null && response.length != 0) {
              throw log.saslClientSentExtraMessage(getMechanismName());
            }
            ok = true;
            return null;
          }
        case FAILED_STATE:
          {
            throw log.saslAuthenticationFailed(getMechanismName());
          }
      }
      throw Assert.impossibleSwitchCase(state);
    } catch (ArrayIndexOutOfBoundsException | InvalidKeyException ignored) {
      throw log.saslInvalidClientMessage(getMechanismName());
    } finally {
      if (!ok) {
        setNegotiationState(FAILED_STATE);
      }
    }
  }