예제 #1
0
 private void parseAuthorizationId(ByteIterator bindingIterator) throws SaslException {
   if (bindingIterator.next() != ',') {
     throw log.saslInvalidClientMessage(getMechanismName());
   }
   switch (bindingIterator.next()) {
     case ',':
       if (authorizationID != null) {
         throw log.saslInvalidClientMessage(getMechanismName());
       }
       break;
     case 'a':
       {
         if (bindingIterator.next() != '=') {
           throw log.saslInvalidClientMessage(getMechanismName());
         }
         if (!bindingIterator
             .delimitedBy(',')
             .asUtf8String()
             .drainToString()
             .equals(authorizationID)) {
           throw log.saslInvalidClientMessage(getMechanismName());
         }
         if (bindingIterator.next() != ',') {
           throw log.saslInvalidClientMessage(getMechanismName());
         }
         break;
       }
     default:
       throw log.saslInvalidClientMessage(getMechanismName());
   }
 }
예제 #2
0
  protected byte[] evaluateMessage(final int state, final byte[] response) throws SaslException {
    boolean trace = log.isTraceEnabled();
    boolean ok = false;
    try {
      switch (state) {
        case S_NO_MESSAGE:
          {
            if (response == null || response.length == 0) {
              setNegotiationState(S_FIRST_MESSAGE);
              // initial challenge
              ok = true;
              return NO_BYTES;
            }
            // fall through
          }
        case S_FIRST_MESSAGE:
          {
            if (response == null || response.length == 0) {
              throw log.saslClientRefusesToInitiateAuthentication(getMechanismName());
            }
            if (trace)
              log.tracef(
                  "[S] Client first message: %s%n",
                  ByteIterator.ofBytes(response).hexEncode().drainToString());

            final ByteStringBuilder b = new ByteStringBuilder();
            int c;
            ByteIterator bi = ByteIterator.ofBytes(response);
            ByteIterator di = bi.delimitedBy(',');
            CodePointIterator cpi = di.asUtf8String();

            // == parse message ==

            // binding type
            cbindFlag = bi.next();
            if (cbindFlag == 'p' && plus) {
              assert bindingType != null; // because {@code plus} is true
              assert bindingData != null;
              if (bi.next() != '=') {
                throw log.saslInvalidClientMessage(getMechanismName());
              }
              if (!bindingType.equals(cpi.drainToString())) {
                // nope, auth must fail because we cannot acquire the same binding
                throw log.saslChannelBindingTypeMismatch(getMechanismName());
              }
              bi.next(); // skip delimiter
            } else if ((cbindFlag == 'y' || cbindFlag == 'n') && !plus) {
              if (bi.next() != ',') {
                throw log.saslInvalidClientMessage(getMechanismName());
              }
            } else {
              throw log.saslInvalidClientMessage(getMechanismName());
            }

            // authorization ID
            c = bi.next();
            if (c == 'a') {
              if (bi.next() != '=') {
                throw log.saslInvalidClientMessage(getMechanismName());
              }
              authorizationID = cpi.drainToString();
              bi.next(); // skip delimiter
            } else if (c != ',') {
              throw log.saslInvalidClientMessage(getMechanismName());
            }

            clientFirstMessageBareStart = bi.offset();

            // user name
            if (bi.next() == 'n') {
              if (bi.next() != '=') {
                throw log.saslInvalidClientMessage(getMechanismName());
              }
              ByteStringBuilder bsb = new ByteStringBuilder();
              StringPrep.encode(
                  cpi.drainToString(),
                  bsb,
                  StringPrep.PROFILE_SASL_QUERY | StringPrep.UNMAP_SCRAM_LOGIN_CHARS);
              userName = new String(bsb.toArray(), StandardCharsets.UTF_8);
              bi.next(); // skip delimiter
            } else {
              throw log.saslInvalidClientMessage(getMechanismName());
            }

            // random nonce
            if (bi.next() != 'r' || bi.next() != '=') {
              throw log.saslInvalidClientMessage(getMechanismName());
            }
            byte[] nonce = di.drain();
            if (trace)
              log.tracef(
                  "[S] Client nonce: %s%n",
                  ByteIterator.ofBytes(nonce).hexEncode().drainToString());

            if (bi.hasNext()) {
              throw log.saslInvalidClientMessage(getMechanismName());
            }

            clientFirstMessage = response;

            // == send first challenge ==

            // get salted password
            final NameCallback nameCallback =
                new NameCallback("Remote authentication name", userName);
            saltedPassword = null;
            getPredigestedSaltedPassword(nameCallback);
            if (saltedPassword == null) {
              getSaltedPasswordFromTwoWay(nameCallback, b);
            }
            if (saltedPassword == null) {
              getSaltedPasswordFromPasswordCallback(nameCallback, b);
            }
            if (saltedPassword == null) {
              throw log.saslCallbackHandlerDoesNotSupportCredentialAcquisition(
                  getMechanismName(), null);
            }
            if (trace)
              log.tracef("[S] Salt: %s%n", ByteIterator.ofBytes(salt).hexEncode().drainToString());
            if (trace)
              log.tracef(
                  "[S] Salted password: %s%n",
                  ByteIterator.ofBytes(saltedPassword).hexEncode().drainToString());

            // nonce (client + server nonce)
            b.append('r').append('=');
            b.append(nonce);
            b.append(ScramUtil.generateNonce(28, getRandom()));
            b.append(',');

            // salt
            b.append('s').append('=');
            b.appendLatin1(ByteIterator.ofBytes(salt).base64Encode());
            b.append(',');
            b.append('i').append('=');
            b.append(Integer.toString(iterationCount));

            setNegotiationState(S_FINAL_MESSAGE);
            ok = true;
            return serverFirstMessage = b.toArray();
          }
        case S_FINAL_MESSAGE:
          {
            final ByteStringBuilder b = new ByteStringBuilder();

            ByteIterator bi = ByteIterator.ofBytes(response);
            ByteIterator di = bi.delimitedBy(',');

            // == parse message ==

            // first comes the channel binding
            if (bi.next() != 'c' || bi.next() != '=') {
              throw log.saslInvalidClientMessage(getMechanismName());
            }

            final ByteIterator bindingIterator = di.base64Decode();

            // -- sub-parse of binding data --
            if (bindingIterator.next() != cbindFlag) {
              throw log.saslInvalidClientMessage(getMechanismName());
            }
            switch (cbindFlag) {
              case 'n':
              case 'y':
                { // n,[a=authzid],
                  if (plus) throw log.saslChannelBindingNotProvided(getMechanismName());

                  parseAuthorizationId(bindingIterator);

                  if (bindingIterator.hasNext()) { // require end
                    throw log.saslInvalidClientMessage(getMechanismName());
                  }
                  break;
                }
              case 'p':
                { // p=bindingType,[a=authzid],bindingData
                  if (!plus) {
                    throw log.saslChannelBindingNotSupported(getMechanismName());
                  }
                  if (bindingIterator.next() != '=') {
                    throw log.saslInvalidClientMessage(getMechanismName());
                  }
                  if (!bindingType.equals(
                      bindingIterator.delimitedBy(',').asUtf8String().drainToString())) {
                    throw log.saslChannelBindingTypeMismatch(getMechanismName());
                  }
                  parseAuthorizationId(bindingIterator);

                  // following is the raw channel binding data
                  if (!bindingIterator.contentEquals(ByteIterator.ofBytes(bindingData))) {
                    throw log.saslChannelBindingTypeMismatch(getMechanismName());
                  }
                  if (bindingIterator.hasNext()) { // require end
                    throw log.saslInvalidClientMessage(getMechanismName());
                  }
                  break;
                }
            }
            bi.next(); // skip delimiter

            // nonce
            if (bi.next() != 'r' || bi.next() != '=') {
              throw log.saslInvalidClientMessage(getMechanismName());
            }
            while (di.hasNext()) {
              di.next();
            }

            // proof
            final int proofOffset = bi.offset();
            bi.next(); // skip delimiter
            if (bi.next() != 'p' || bi.next() != '=') {
              throw log.saslInvalidClientMessage(getMechanismName());
            }
            byte[] recoveredClientProofEncoded = di.drain();
            if (bi.hasNext()) {
              throw log.saslInvalidClientMessage(getMechanismName());
            }

            // == verify proof ==

            // client key
            byte[] clientKey;
            mac.reset();
            mac.init(new SecretKeySpec(saltedPassword, mac.getAlgorithm()));
            mac.update(Scram.CLIENT_KEY_BYTES);
            clientKey = mac.doFinal();
            if (trace)
              log.tracef(
                  "[S] Client key: %s%n",
                  ByteIterator.ofBytes(clientKey).hexEncode().drainToString());

            // stored key
            byte[] storedKey;
            messageDigest.reset();
            messageDigest.update(clientKey);
            storedKey = messageDigest.digest();
            if (trace)
              log.tracef(
                  "[S] Stored key: %s%n",
                  ByteIterator.ofBytes(storedKey).hexEncode().drainToString());

            // client signature
            mac.reset();
            mac.init(new SecretKeySpec(storedKey, mac.getAlgorithm()));
            mac.update(
                clientFirstMessage,
                clientFirstMessageBareStart,
                clientFirstMessage.length - clientFirstMessageBareStart);
            if (trace)
              log.tracef(
                  "[S] Using client first message: %s%n",
                  ByteIterator.ofBytes(
                          copyOfRange(
                              clientFirstMessage,
                              clientFirstMessageBareStart,
                              clientFirstMessage.length))
                      .hexEncode()
                      .drainToString());
            mac.update((byte) ',');
            mac.update(serverFirstMessage);
            if (trace)
              log.tracef(
                  "[S] Using server first message: %s%n",
                  ByteIterator.ofBytes(serverFirstMessage).hexEncode().drainToString());
            mac.update((byte) ',');
            mac.update(response, 0, proofOffset); // client-final-message-without-proof
            if (trace)
              log.tracef(
                  "[S] Using client final message without proof: %s%n",
                  ByteIterator.ofBytes(copyOfRange(response, 0, proofOffset))
                      .hexEncode()
                      .drainToString());
            byte[] clientSignature = mac.doFinal();
            if (trace)
              log.tracef(
                  "[S] Client signature: %s%n",
                  ByteIterator.ofBytes(clientSignature).hexEncode().drainToString());

            // server key
            byte[] serverKey;
            mac.reset();
            mac.init(new SecretKeySpec(saltedPassword, mac.getAlgorithm()));
            mac.update(Scram.SERVER_KEY_BYTES);
            serverKey = mac.doFinal();
            if (trace)
              log.tracef(
                  "[S] Server key: %s%n",
                  ByteIterator.ofBytes(serverKey).hexEncode().drainToString());

            // server signature
            byte[] serverSignature;
            mac.reset();
            mac.init(new SecretKeySpec(serverKey, mac.getAlgorithm()));
            mac.update(
                clientFirstMessage,
                clientFirstMessageBareStart,
                clientFirstMessage.length - clientFirstMessageBareStart);
            mac.update((byte) ',');
            mac.update(serverFirstMessage);
            mac.update((byte) ',');
            mac.update(response, 0, proofOffset); // client-final-message-without-proof
            serverSignature = mac.doFinal();
            if (trace)
              log.tracef(
                  "[S] Server signature: %s%n",
                  ByteIterator.ofBytes(serverSignature).hexEncode().drainToString());

            if (trace)
              log.tracef(
                  "[S] Client proof string: %s%n",
                  CodePointIterator.ofUtf8Bytes(recoveredClientProofEncoded).drainToString());
            b.setLength(0);
            byte[] recoveredClientProof =
                ByteIterator.ofBytes(recoveredClientProofEncoded).base64Decode().drain();
            if (trace)
              log.tracef(
                  "[S] Client proof: %s%n",
                  ByteIterator.ofBytes(recoveredClientProof).hexEncode().drainToString());

            // now check the proof
            byte[] recoveredClientKey = clientSignature.clone();
            ScramUtil.xor(recoveredClientKey, recoveredClientProof);
            if (trace)
              log.tracef(
                  "[S] Recovered client key: %s%n",
                  ByteIterator.ofBytes(recoveredClientKey).hexEncode().drainToString());
            if (!Arrays.equals(recoveredClientKey, clientKey)) {
              // bad auth, send error
              if (sendErrors) {
                b.setLength(0);
                b.append("e=invalid-proof");
                setNegotiationState(FAILED_STATE);
                return b.toArray();
              }
              throw log.saslAuthenticationRejectedInvalidProof(getMechanismName());
            }

            if (authorizationID == null) {
              authorizationID = userName;
            } else {
              ByteStringBuilder bsb = new ByteStringBuilder();
              StringPrep.encode(
                  authorizationID,
                  bsb,
                  StringPrep.PROFILE_SASL_QUERY | StringPrep.UNMAP_SCRAM_LOGIN_CHARS);
              authorizationID = new String(bsb.toArray(), StandardCharsets.UTF_8);
            }
            final AuthorizeCallback authorizeCallback =
                new AuthorizeCallback(userName, authorizationID);
            try {
              tryHandleCallbacks(authorizeCallback);
            } catch (UnsupportedCallbackException e) {
              throw log.saslAuthorizationUnsupported(getMechanismName(), e);
            }
            if (!authorizeCallback.isAuthorized()) {
              throw log.saslAuthorizationFailed(getMechanismName(), userName, authorizationID);
            }

            // == send response ==
            b.setLength(0);
            b.append('v').append('=');
            b.appendUtf8(ByteIterator.ofBytes(serverSignature).base64Encode());

            setNegotiationState(COMPLETE_STATE);
            ok = true;
            return b.toArray();
          }
        case COMPLETE_STATE:
          {
            if (response != null && response.length != 0) {
              throw log.saslClientSentExtraMessage(getMechanismName());
            }
            ok = true;
            return null;
          }
        case FAILED_STATE:
          {
            throw log.saslAuthenticationFailed(getMechanismName());
          }
      }
      throw Assert.impossibleSwitchCase(state);
    } catch (ArrayIndexOutOfBoundsException | InvalidKeyException ignored) {
      throw log.saslInvalidClientMessage(getMechanismName());
    } finally {
      if (!ok) {
        setNegotiationState(FAILED_STATE);
      }
    }
  }