예제 #1
0
  /**
   * Decodes value from a given byte array to the object according to the flags given.
   *
   * @param flags Flags.
   * @param bytes Byte array to decode.
   * @return Decoded value.
   * @throws GridException If deserialization failed.
   */
  private Object decodeObj(short flags, byte[] bytes) throws GridException {
    assert bytes != null;

    if ((flags & SERIALIZED_FLAG) != 0)
      return jdkMarshaller.unmarshal(new ByteArrayInputStream(bytes), null);

    int masked = flags & 0xff00;

    switch (masked) {
      case BOOLEAN_FLAG:
        return bytes[0] == '1';
      case INT_FLAG:
        return U.bytesToInt(bytes, 0);
      case LONG_FLAG:
        return U.bytesToLong(bytes, 0);
      case DATE_FLAG:
        return new Date(U.bytesToLong(bytes, 0));
      case BYTE_FLAG:
        return bytes[0];
      case FLOAT_FLAG:
        return Float.intBitsToFloat(U.bytesToInt(bytes, 0));
      case DOUBLE_FLAG:
        return Double.longBitsToDouble(U.bytesToLong(bytes, 0));
      case BYTE_ARR_FLAG:
        return bytes;
      default:
        return new String(bytes);
    }
  }
예제 #2
0
  /**
   * Encodes given object to a byte array and returns flags that describe the type of serialized
   * object.
   *
   * @param obj Object to serialize.
   * @param out Output stream to which object should be written.
   * @return Serialization flags.
   * @throws GridException If JDK serialization failed.
   */
  private int encodeObj(Object obj, ByteArrayOutputStream out) throws GridException {
    int flags = 0;

    byte[] data = null;

    if (obj instanceof String) {
      data = ((String) obj).getBytes();
    } else if (obj instanceof Boolean) {
      data = new byte[] {(byte) ((Boolean) obj ? '1' : '0')};

      flags |= BOOLEAN_FLAG;
    } else if (obj instanceof Integer) {
      data = U.intToBytes((Integer) obj);

      flags |= INT_FLAG;
    } else if (obj instanceof Long) {
      data = U.longToBytes((Long) obj);

      flags |= LONG_FLAG;
    } else if (obj instanceof Date) {
      data = U.longToBytes(((Date) obj).getTime());

      flags |= DATE_FLAG;
    } else if (obj instanceof Byte) {
      data = new byte[] {(Byte) obj};

      flags |= BYTE_FLAG;
    } else if (obj instanceof Float) {
      data = U.intToBytes(Float.floatToIntBits((Float) obj));

      flags |= FLOAT_FLAG;
    } else if (obj instanceof Double) {
      data = U.longToBytes(Double.doubleToLongBits((Double) obj));

      flags |= DOUBLE_FLAG;
    } else if (obj instanceof byte[]) {
      data = (byte[]) obj;

      flags |= BYTE_ARR_FLAG;
    } else {
      jdkMarshaller.marshal(obj, out);

      flags |= SERIALIZED_FLAG;
    }

    if (data != null) out.write(data, 0, data.length);

    return flags;
  }
예제 #3
0
  /**
   * Resolves host for REST TCP server using grid configuration.
   *
   * @param cfg Grid configuration.
   * @return REST host.
   * @throws IOException If failed to resolve REST host.
   */
  private InetAddress resolveRestTcpHost(GridConfiguration cfg) throws IOException {
    String host = cfg.getRestTcpHost();

    if (host == null) host = cfg.getLocalHost();

    return U.resolveLocalHost(host);
  }
예제 #4
0
  /**
   * Parses custom packet serialized by hessian marshaller.
   *
   * @param ses Session.
   * @param buf Buffer containing not parsed bytes.
   * @param state Parser state.
   * @return Parsed message.
   * @throws IOException If packet parsing or deserialization failed.
   */
  @Nullable
  private GridClientMessage parseCustomPacket(GridNioSession ses, ByteBuffer buf, ParserState state)
      throws IOException {
    assert state.packetType() == PacketType.GRIDGAIN;
    assert state.packet() == null;

    ByteArrayOutputStream tmp = state.buffer();

    int len = state.index();

    while (buf.remaining() > 0) {
      byte b = buf.get();

      if (len == 0) {
        tmp.write(b);

        if (tmp.size() == 4) {
          len = U.bytesToInt(tmp.toByteArray(), 0);

          tmp.reset();

          if (len == 0) return PING_MESSAGE;
          else if (len < 0)
            throw new IOException(
                "Failed to parse incoming packet (invalid packet length) [ses="
                    + ses
                    + ", len="
                    + len
                    + ']');

          state.index(len);
        }
      } else {
        tmp.write(b);

        if (tmp.size() == len) return marshaller.unmarshal(tmp.toByteArray());
      }
    }

    return null;
  }
예제 #5
0
  /** {@inheritDoc} */
  @Override
  public ByteBuffer encode(GridNioSession ses, GridClientMessage msg)
      throws IOException, GridException {
    assert msg != null;

    if (msg instanceof GridTcpRestPacket) return encodeMemcache((GridTcpRestPacket) msg);
    else if (msg == PING_MESSAGE) return ByteBuffer.wrap(PING_PACKET);
    else {
      byte[] data = marshaller.marshal(msg);

      assert data.length > 0;

      ByteBuffer res = ByteBuffer.allocate(data.length + 5);

      res.put(GRIDGAIN_REQ_FLAG);
      res.put(U.intToBytes(data.length));
      res.put(data);

      res.flip();

      return res;
    }
  }
예제 #6
0
  /**
   * Validates incoming packet and deserializes all fields that need to be deserialized.
   *
   * @param ses Session on which packet is being parsed.
   * @param req Raw packet.
   * @return Same packet with fields deserialized.
   * @throws IOException If parsing failed.
   * @throws GridException If deserialization failed.
   */
  private GridClientMessage assemble(GridNioSession ses, GridTcpRestPacket req)
      throws IOException, GridException {
    byte[] extras = req.extras();

    // First, decode key and value, if any
    if (req.key() != null || req.value() != null) {
      short keyFlags = 0;
      short valFlags = 0;

      if (req.hasFlags()) {
        if (extras == null || extras.length < FLAGS_LENGTH)
          throw new IOException(
              "Failed to parse incoming packet (flags required for command) [ses="
                  + ses
                  + ", opCode="
                  + Integer.toHexString(req.operationCode() & 0xFF)
                  + ']');

        keyFlags = U.bytesToShort(extras, 0);
        valFlags = U.bytesToShort(extras, 2);
      }

      if (req.key() != null) {
        assert req.key() instanceof byte[];

        byte[] rawKey = (byte[]) req.key();

        // Only values can be hessian-encoded.
        req.key(decodeObj(keyFlags, rawKey));
      }

      if (req.value() != null) {
        assert req.value() instanceof byte[];

        byte[] rawVal = (byte[]) req.value();

        req.value(decodeObj(valFlags, rawVal));
      }
    }

    if (req.hasExpiration()) {
      if (extras == null || extras.length < 8)
        throw new IOException(
            "Failed to parse incoming packet (expiration value required for command) [ses="
                + ses
                + ", opCode="
                + Integer.toHexString(req.operationCode() & 0xFF)
                + ']');

      req.expiration(U.bytesToInt(extras, 4) & 0xFFFFFFFFL);
    }

    if (req.hasInitial()) {
      if (extras == null || extras.length < 16)
        throw new IOException(
            "Failed to parse incoming packet (initial value required for command) [ses="
                + ses
                + ", opCode="
                + Integer.toHexString(req.operationCode() & 0xFF)
                + ']');

      req.initial(U.bytesToLong(extras, 8));
    }

    if (req.hasDelta()) {
      if (extras == null || extras.length < 8)
        throw new IOException(
            "Failed to parse incoming packet (delta value required for command) [ses="
                + ses
                + ", opCode="
                + Integer.toHexString(req.operationCode() & 0xFF)
                + ']');

      req.delta(U.bytesToLong(extras, 0));
    }

    if (extras != null) {
      // Clients that include cache name must always include flags.
      int length = 4;

      if (req.hasExpiration()) length += 4;

      if (req.hasDelta()) length += 8;

      if (req.hasInitial()) length += 8;

      if (extras.length - length > 0) {
        byte[] cacheName = new byte[extras.length - length];

        System.arraycopy(extras, length, cacheName, 0, extras.length - length);

        req.cacheName(new String(cacheName));
      }
    }

    return req;
  }
예제 #7
0
  /**
   * Parses memcache protocol message.
   *
   * @param ses Session.
   * @param buf Buffer containing not parsed bytes.
   * @param state Current parser state.
   * @return Parsed packet.s
   * @throws IOException If packet cannot be parsed.
   * @throws GridException If deserialization error occurred.
   */
  @Nullable
  private GridClientMessage parseMemcachePacket(
      GridNioSession ses, ByteBuffer buf, ParserState state) throws IOException, GridException {
    assert state.packetType() == PacketType.MEMCACHE;
    assert state.packet() != null;
    assert state.packet() instanceof GridTcpRestPacket;

    GridTcpRestPacket req = (GridTcpRestPacket) state.packet();
    ByteArrayOutputStream tmp = state.buffer();
    int i = state.index();

    while (buf.remaining() > 0) {
      byte b = buf.get();

      if (i == 0) req.requestFlag(b);
      else if (i == 1) req.operationCode(b);
      else if (i == 2 || i == 3) {
        tmp.write(b);

        if (i == 3) {
          req.keyLength(U.bytesToShort(tmp.toByteArray(), 0));

          tmp.reset();
        }
      } else if (i == 4) req.extrasLength(b);
      else if (i >= 8 && i <= 11) {
        tmp.write(b);

        if (i == 11) {
          req.totalLength(U.bytesToInt(tmp.toByteArray(), 0));

          tmp.reset();
        }
      } else if (i >= 12 && i <= 15) {
        tmp.write(b);

        if (i == 15) {
          req.opaque(tmp.toByteArray());

          tmp.reset();
        }
      } else if (i >= HDR_LEN && i < HDR_LEN + req.extrasLength()) {
        tmp.write(b);

        if (i == HDR_LEN + req.extrasLength() - 1) {
          req.extras(tmp.toByteArray());

          tmp.reset();
        }
      } else if (i >= HDR_LEN + req.extrasLength()
          && i < HDR_LEN + req.extrasLength() + req.keyLength()) {
        tmp.write(b);

        if (i == HDR_LEN + req.extrasLength() + req.keyLength() - 1) {
          req.key(tmp.toByteArray());

          tmp.reset();
        }
      } else if (i >= HDR_LEN + req.extrasLength() + req.keyLength()
          && i < HDR_LEN + req.totalLength()) {
        tmp.write(b);

        if (i == HDR_LEN + req.totalLength() - 1) {
          req.value(tmp.toByteArray());

          tmp.reset();
        }
      }

      if (i == HDR_LEN + req.totalLength() - 1)
        // Assembled the packet.
        return assemble(ses, req);

      i++;
    }

    state.index(i);

    return null;
  }
예제 #8
0
  /** {@inheritDoc} */
  @SuppressWarnings("BusyWait")
  @Override
  public void start(final GridRestProtocolHandler hnd) throws GridException {
    assert hnd != null;

    GridConfiguration cfg = ctx.config();

    GridNioServerListener<GridClientMessage> lsnr = new GridTcpRestNioListener(log, hnd);

    GridNioParser parser = new GridTcpRestParser(log);

    try {
      host = resolveRestTcpHost(cfg);

      SSLContext sslCtx = null;

      if (cfg.isRestTcpSslEnabled()) {
        GridSslContextFactory factory = cfg.getRestTcpSslContextFactory();

        if (factory == null)
          // Thrown SSL exception instead of GridException for writing correct warning message into
          // log.
          throw new SSLException("SSL is enabled, but SSL context factory is not specified.");

        sslCtx = factory.createSslContext();
      }

      int lastPort = cfg.getRestTcpPort() + cfg.getRestPortRange() - 1;

      for (port = cfg.getRestTcpPort(); port <= lastPort; port++) {
        if (startTcpServer(host, port, lsnr, parser, sslCtx, cfg)) {
          if (log.isInfoEnabled()) log.info(startInfo());

          return;
        }
      }

      U.warn(
          log,
          "Failed to start TCP binary REST server (possibly all ports in range are in use) "
              + "[firstPort="
              + cfg.getRestTcpPort()
              + ", lastPort="
              + lastPort
              + ", host="
              + host
              + ']');
    } catch (SSLException e) {
      U.warn(
          log,
          "Failed to start " + name() + " protocol on port " + port + ": " + e.getMessage(),
          "Failed to start "
              + name()
              + " protocol on port "
              + port
              + ". Check if SSL context factory is "
              + "properly configured.");
    } catch (IOException e) {
      U.warn(
          log,
          "Failed to start " + name() + " protocol on port " + port + ": " + e.getMessage(),
          "Failed to start "
              + name()
              + " protocol on port "
              + port
              + ". "
              + "Check restTcpHost configuration property.");
    }
  }