Example #1
0
/**
 * Implementation of Query
 *
 * @param <T> The type we will be querying for, and returning.
 * @author Scott Hernandez
 */
@SuppressWarnings("deprecation")
public class QueryImpl<T> extends CriteriaContainerImpl implements Query<T> {
  private static final Logger LOG = MorphiaLoggerFactory.get(QueryImpl.class);
  private final org.mongodb.morphia.DatastoreImpl ds;
  private final DBCollection dbColl;
  private final Class<T> clazz;
  private EntityCache cache;
  private boolean validateName = true;
  private boolean validateType = true;
  private Boolean includeFields;
  private BasicDBObject baseQuery;
  private FindOptions options;

  FindOptions getOptions() {
    if (options == null) {
      options = new FindOptions();
    }
    return options;
  }

  /**
   * Creates a Query for the given type and collection
   *
   * @param clazz the type to return
   * @param coll the collection to query
   * @param ds the Datastore to use
   */
  public QueryImpl(final Class<T> clazz, final DBCollection coll, final Datastore ds) {
    super(CriteriaJoin.AND);

    setQuery(this);
    this.clazz = clazz;
    this.ds = ((org.mongodb.morphia.DatastoreImpl) ds);
    dbColl = coll;
    cache = this.ds.getMapper().createEntityCache();

    final MappedClass mc = this.ds.getMapper().getMappedClass(clazz);
    final Entity entAn = mc == null ? null : mc.getEntityAnnotation();
    if (entAn != null) {
      getOptions()
          .readPreference(
              this.ds.getMapper().getMappedClass(clazz).getEntityAnnotation().queryNonPrimary()
                  ? ReadPreference.secondaryPreferred()
                  : null);
    }
  }

  /**
   * Parses the string and validates each part
   *
   * @param str the String to parse
   * @param clazz the class to use when validating
   * @param mapper the Mapper to use
   * @param validate true if the results should be validated
   * @return the DBObject
   */
  public static BasicDBObject parseFieldsString(
      final String str, final Class clazz, final Mapper mapper, final boolean validate) {
    BasicDBObject ret = new BasicDBObject();
    final String[] parts = str.split(",");
    for (String s : parts) {
      s = s.trim();
      int dir = 1;

      if (s.startsWith("-")) {
        dir = -1;
        s = s.substring(1).trim();
      }

      if (validate) {
        final StringBuilder sb = new StringBuilder(s);
        validateQuery(clazz, mapper, sb, FilterOperator.IN, "", true, false);
        s = sb.toString();
      }
      ret.put(s, dir);
    }
    return ret;
  }

  @Override
  public List<Key<T>> asKeyList() {
    return asKeyList(getOptions());
  }

  @Override
  public List<Key<T>> asKeyList(final FindOptions options) {
    final List<Key<T>> results = new ArrayList<Key<T>>();
    MorphiaKeyIterator<T> keys = fetchKeys(options);
    try {
      for (final Key<T> key : keys) {
        results.add(key);
      }
    } finally {
      keys.close();
    }
    return results;
  }

  @Override
  public List<T> asList() {
    return asList(getOptions());
  }

  @Override
  public List<T> asList(final FindOptions options) {
    final List<T> results = new ArrayList<T>();
    final MorphiaIterator<T, T> iter = fetch(options);
    try {
      for (final T ent : iter) {
        results.add(ent);
      }
    } finally {
      iter.close();
    }

    if (LOG.isTraceEnabled()) {
      LOG.trace(
          format(
              "asList: %s \t %d entities, iterator time: driver %d ms, mapper %d ms %n\t cache: %s %n\t for %s",
              dbColl.getName(),
              results.size(),
              iter.getDriverTime(),
              iter.getMapperTime(),
              cache.stats(),
              getQueryObject()));
    }

    return results;
  }

  @Override
  @Deprecated
  public long countAll() {
    final DBObject query = getQueryObject();
    if (LOG.isTraceEnabled()) {
      LOG.trace("Executing count(" + dbColl.getName() + ") for query: " + query);
    }
    return dbColl.getCount(query);
  }

  @Override
  public long count() {
    return dbColl.getCount(getQueryObject());
  }

  @Override
  public long count(final CountOptions options) {
    return dbColl.getCount(getQueryObject(), options.getOptions());
  }

  @Override
  public MorphiaIterator<T, T> fetch() {
    return fetch(getOptions());
  }

  @Override
  public MorphiaIterator<T, T> fetch(final FindOptions options) {
    final DBCursor cursor = prepareCursor(options);
    if (LOG.isTraceEnabled()) {
      LOG.trace("Getting cursor(" + dbColl.getName() + ")  for query:" + cursor.getQuery());
    }

    return new MorphiaIterator<T, T>(ds, cursor, ds.getMapper(), clazz, dbColl.getName(), cache);
  }

  @Override
  public MorphiaIterator<T, T> fetchEmptyEntities() {
    return fetchEmptyEntities(getOptions());
  }

  @Override
  public MorphiaIterator<T, T> fetchEmptyEntities(final FindOptions options) {
    QueryImpl<T> cloned = cloneQuery();
    cloned.getOptions().projection(new BasicDBObject(Mapper.ID_KEY, 1));
    cloned.includeFields = true;
    return cloned.fetch();
  }

  @Override
  public MorphiaKeyIterator<T> fetchKeys() {
    return fetchKeys(getOptions());
  }

  @Override
  public MorphiaKeyIterator<T> fetchKeys(final FindOptions options) {
    QueryImpl<T> cloned = cloneQuery();
    cloned.getOptions().projection(new BasicDBObject(Mapper.ID_KEY, 1));
    cloned.includeFields = true;

    return new MorphiaKeyIterator<T>(
        ds, cloned.prepareCursor(options), ds.getMapper(), clazz, dbColl.getName());
  }

  @Override
  public T get() {
    return get(getOptions());
  }

  @Override
  public T get(final FindOptions options) {
    final MorphiaIterator<T, T> it = fetch(options.copy().limit(1));
    try {
      return (it.hasNext()) ? it.next() : null;
    } finally {
      it.close();
    }
  }

  @Override
  public Key<T> getKey() {
    return getKey(getOptions());
  }

  @Override
  public Key<T> getKey(final FindOptions options) {
    final MorphiaIterator<T, Key<T>> it = fetchKeys(options.copy().limit(1));
    Key<T> key = (it.hasNext()) ? it.next() : null;
    it.close();
    return key;
  }

  @Override
  @Deprecated
  public MorphiaIterator<T, T> tail() {
    return tail(true);
  }

  @Override
  @Deprecated
  public MorphiaIterator<T, T> tail(final boolean awaitData) {
    return fetch(getOptions().copy().cursorType(awaitData ? TailableAwait : Tailable));
  }

  @Override
  @Deprecated
  public Query<T> batchSize(final int value) {
    getOptions().batchSize(value);
    return this;
  }

  @Override
  public QueryImpl<T> cloneQuery() {
    final QueryImpl<T> n = new QueryImpl<T>(clazz, dbColl, ds);
    n.cache = ds.getMapper().createEntityCache(); // fresh cache
    n.includeFields = includeFields;
    n.setQuery(n); // feels weird, correct?
    n.validateName = validateName;
    n.validateType = validateType;
    n.baseQuery = copy(baseQuery);
    n.options = options != null ? options.copy() : null;

    // fields from superclass
    n.setAttachedTo(getAttachedTo());
    n.setChildren(getChildren() == null ? null : new ArrayList<Criteria>(getChildren()));
    return n;
  }

  protected BasicDBObject copy(final DBObject dbObject) {
    return dbObject == null ? null : new BasicDBObject(dbObject.toMap());
  }

  @Override
  @Deprecated
  public Query<T> comment(final String comment) {
    getOptions().modifier("$comment", comment);
    return this;
  }

  @Override
  public FieldEnd<? extends CriteriaContainerImpl> criteria(final String field) {
    final CriteriaContainerImpl container = new CriteriaContainerImpl(this, CriteriaJoin.AND);
    add(container);

    return new FieldEndImpl<CriteriaContainerImpl>(this, field, container);
  }

  @Override
  @Deprecated
  public Query<T> disableCursorTimeout() {
    getOptions().noCursorTimeout(true);
    return this;
  }

  @Override
  @Deprecated
  public Query<T> disableSnapshotMode() {
    getOptions().getModifiers().removeField("$snapshot");

    return this;
  }

  @Override
  public Query<T> disableValidation() {
    validateName = false;
    validateType = false;
    return this;
  }

  @Override
  @Deprecated
  public Query<T> enableCursorTimeout() {
    getOptions().noCursorTimeout(false);
    return this;
  }

  @Override
  @Deprecated
  public Query<T> enableSnapshotMode() {
    getOptions().modifier("$snapshot", true);
    return this;
  }

  @Override
  public Query<T> enableValidation() {
    validateName = true;
    validateType = true;
    return this;
  }

  @Override
  @SuppressWarnings("unchecked")
  public Map<String, Object> explain() {
    return explain(getOptions());
  }

  @Override
  @SuppressWarnings("unchecked")
  public Map<String, Object> explain(final FindOptions options) {
    return prepareCursor(options).explain().toMap();
  }

  @Override
  public FieldEnd<? extends Query<T>> field(final String name) {
    return new FieldEndImpl<QueryImpl<T>>(this, name, this);
  }

  @Override
  public Query<T> filter(final String condition, final Object value) {
    final String[] parts = condition.trim().split(" ");
    if (parts.length < 1 || parts.length > 6) {
      throw new IllegalArgumentException("'" + condition + "' is not a legal filter condition");
    }

    final String prop = parts[0].trim();
    final FilterOperator op = (parts.length == 2) ? translate(parts[1]) : FilterOperator.EQUAL;

    add(new FieldCriteria(this, prop, op, value));

    return this;
  }

  @Override
  @Deprecated
  public int getBatchSize() {
    return getOptions().getBatchSize();
  }

  @Override
  @Deprecated
  public DBCollection getCollection() {
    return dbColl;
  }

  @Override
  public Class<T> getEntityClass() {
    return clazz;
  }

  @Override
  @Deprecated
  public DBObject getFieldsObject() {
    DBObject projection = getOptions().getProjection();
    if (projection == null || projection.keySet().size() == 0) {
      return null;
    }

    final MappedClass mc = ds.getMapper().getMappedClass(clazz);

    Entity entityAnnotation = mc.getEntityAnnotation();
    final BasicDBObject fieldsFilter = copy(projection);

    if (includeFields && entityAnnotation != null && !entityAnnotation.noClassnameStored()) {
      fieldsFilter.put(Mapper.CLASS_NAME_FIELDNAME, 1);
    }

    return fieldsFilter;
  }

  @Override
  @Deprecated
  public int getLimit() {
    return getOptions().getLimit();
  }

  @Override
  @Deprecated
  public int getOffset() {
    return getOptions().getSkip();
  }

  @Override
  @Deprecated
  public DBObject getQueryObject() {
    final DBObject obj = new BasicDBObject();

    if (baseQuery != null) {
      obj.putAll((BSONObject) baseQuery);
    }

    addTo(obj);

    return obj;
  }

  /**
   * Sets query structure directly
   *
   * @param query the DBObject containing the query
   */
  public void setQueryObject(final DBObject query) {
    baseQuery = new BasicDBObject(query.toMap());
  }

  @Override
  @Deprecated
  public DBObject getSortObject() {
    DBObject sort = getOptions().getSortDBObject();
    return (sort == null) ? null : new BasicDBObject(sort.toMap());
  }

  @Override
  @Deprecated
  public Query<T> hintIndex(final String idxName) {
    getOptions().modifier("$hint", idxName);
    return this;
  }

  @Override
  @Deprecated
  public Query<T> limit(final int value) {
    getOptions().limit(value);
    return this;
  }

  @Override
  @Deprecated
  @SuppressWarnings("unchecked")
  public Query<T> lowerIndexBound(final DBObject lowerBound) {
    if (lowerBound != null) {
      getOptions().modifier("$min", new Document(lowerBound.toMap()));
    }
    return this;
  }

  @Override
  @Deprecated
  public Query<T> maxScan(final int value) {
    if (value > 0) {
      getOptions().modifier("$maxScan", value);
    }
    return this;
  }

  @Override
  @Deprecated
  public Query<T> maxTime(final long value, final TimeUnit unit) {
    getOptions().modifier("$maxTimeMS", MILLISECONDS.convert(value, unit));
    return this;
  }

  long getMaxTime(final TimeUnit unit) {
    Long maxTime = (Long) getOptions().getModifiers().get("$maxTimeMS");
    return unit.convert(maxTime != null ? maxTime : 0, MILLISECONDS);
  }

  @Override
  @Deprecated
  public Query<T> offset(final int value) {
    getOptions().skip(value);
    return this;
  }

  @Override
  public Query<T> order(final String sort) {
    getOptions().sort(parseFieldsString(sort, clazz, ds.getMapper(), validateName));
    return this;
  }

  @Override
  public Query<T> order(final Meta sort) {
    validateQuery(
        clazz,
        ds.getMapper(),
        new StringBuilder(sort.getField()),
        FilterOperator.IN,
        "",
        false,
        false);

    getOptions().sort(sort.toDatabase());

    return this;
  }

  @Override
  public Query<T> order(final Sort... sorts) {
    BasicDBObject sortList = new BasicDBObject();
    for (Sort sort : sorts) {
      String s = sort.getField();
      if (validateName) {
        final StringBuilder sb = new StringBuilder(s);
        validateQuery(clazz, ds.getMapper(), sb, FilterOperator.IN, "", true, false);
        s = sb.toString();
      }
      sortList.put(s, sort.getOrder());
    }
    getOptions().sort(sortList);
    return this;
  }

  @Override
  @Deprecated
  public Query<T> queryNonPrimary() {
    getOptions().readPreference(ReadPreference.secondaryPreferred());
    return this;
  }

  @Override
  @Deprecated
  public Query<T> queryPrimaryOnly() {
    getOptions().readPreference(ReadPreference.primary());
    return this;
  }

  @Override
  public Query<T> retrieveKnownFields() {
    final MappedClass mc = ds.getMapper().getMappedClass(clazz);
    final List<String> fields = new ArrayList<String>(mc.getPersistenceFields().size() + 1);
    for (final MappedField mf : mc.getPersistenceFields()) {
      fields.add(mf.getNameToStore());
    }
    retrievedFields(true, fields.toArray(new String[fields.size()]));
    return this;
  }

  @Override
  public Query<T> project(final String field, final boolean include) {
    final StringBuilder sb = new StringBuilder(field);
    validateQuery(clazz, ds.getMapper(), sb, FilterOperator.EQUAL, null, validateName, false);
    String fieldName = sb.toString();
    validateProjections(fieldName, include);
    project(fieldName, include ? 1 : 0);
    return this;
  }

  private void project(final String fieldName, final Object value) {
    DBObject projection = getOptions().getProjection();
    if (projection == null) {
      projection = new BasicDBObject();
      getOptions().projection(projection);
    }
    projection.put(fieldName, value);
  }

  private void project(final DBObject value) {
    DBObject projection = getOptions().getProjection();
    if (projection == null) {
      projection = new BasicDBObject();
      getOptions().projection(projection);
    }
    projection.putAll(value);
  }

  @Override
  public Query<T> project(final String field, final ArraySlice slice) {
    final StringBuilder sb = new StringBuilder(field);
    validateQuery(clazz, ds.getMapper(), sb, FilterOperator.EQUAL, null, validateName, false);
    String fieldName = sb.toString();
    validateProjections(fieldName, true);
    project(fieldName, slice.toDatabase());
    return this;
  }

  @Override
  public Query<T> project(final Meta meta) {
    final StringBuilder sb = new StringBuilder(meta.getField());
    validateQuery(clazz, ds.getMapper(), sb, FilterOperator.EQUAL, null, false, false);
    String fieldName = sb.toString();
    validateProjections(fieldName, true);
    project(meta.toDatabase());
    return this;
  }

  private void validateProjections(final String field, final boolean include) {
    if (includeFields != null && include != includeFields) {
      if (!includeFields || !"_id".equals(field)) {
        throw new ValidationException("You cannot mix included and excluded fields together");
      }
    }
    if (includeFields == null) {
      includeFields = include;
    }
  }

  @Override
  @Deprecated
  public Query<T> retrievedFields(final boolean include, final String... list) {
    if (includeFields != null && include != includeFields) {
      throw new IllegalStateException("You cannot mix included and excluded fields together");
    }
    for (String field : list) {
      project(field, include);
    }
    return this;
  }

  @Override
  @Deprecated
  public Query<T> returnKey() {
    getOptions().getModifiers().put("$returnKey", true);
    return this;
  }

  @Override
  public Query<T> search(final String search) {

    final BasicDBObject op = new BasicDBObject("$search", search);

    this.criteria("$text").equal(op);

    return this;
  }

  @Override
  public Query<T> search(final String search, final String language) {

    final BasicDBObject op = new BasicDBObject("$search", search).append("$language", language);

    this.criteria("$text").equal(op);

    return this;
  }

  @Override
  @Deprecated
  public Query<T> upperIndexBound(final DBObject upperBound) {
    if (upperBound != null) {
      getOptions().getModifiers().put("$max", new BasicDBObject(upperBound.toMap()));
    }

    return this;
  }

  @Override
  @Deprecated
  public Query<T> useReadPreference(final ReadPreference readPref) {
    getOptions().readPreference(readPref);
    return this;
  }

  @Override
  public Query<T> where(final String js) {
    add(new WhereCriteria(js));
    return this;
  }

  @Override
  public Query<T> where(final CodeWScope js) {
    add(new WhereCriteria(js));
    return this;
  }

  @Override
  public String getFieldName() {
    return null;
  }

  /**
   * @return the Datastore
   * @deprecated this is an internal method that exposes an internal type and will likely go away
   *     soon
   */
  @Deprecated
  public org.mongodb.morphia.DatastoreImpl getDatastore() {
    return ds;
  }

  /** @return true if field names are being validated */
  public boolean isValidatingNames() {
    return validateName;
  }

  /** @return true if query parameter value types are being validated against the field types */
  public boolean isValidatingTypes() {
    return validateType;
  }

  @Override
  public MorphiaIterator<T, T> iterator() {
    return fetch();
  }

  /**
   * Prepares cursor for iteration
   *
   * @return the cursor
   * @deprecated this is an internal method. no replacement is planned.
   */
  @Deprecated
  public DBCursor prepareCursor() {
    return prepareCursor(getOptions());
  }

  private DBCursor prepareCursor(final FindOptions findOptions) {
    final DBObject query = getQueryObject();

    if (LOG.isTraceEnabled()) {
      LOG.trace(
          String.format(
              "Running query(%s) : %s, options: %s,", dbColl.getName(), query, findOptions));
    }

    if (findOptions.isSnapshot()
        && (findOptions.getSortDBObject() != null || findOptions.hasHint())) {
      LOG.warning("Snapshotted query should not have hint/sort.");
    }

    if (findOptions.getCursorType() != NonTailable && (findOptions.getSortDBObject() != null)) {
      LOG.warning("Sorting on tail is not allowed.");
    }

    return dbColl
        .find(
            query,
            findOptions.getOptions().copy().sort(getSortObject()).projection(getFieldsObject()))
        .setDecoderFactory(ds.getDecoderFact());
  }

  @Override
  public String toString() {
    return String.format(
        "{ query: %s %s }",
        getQueryObject(),
        getOptions().getProjection() == null ? "" : ", projection: " + getFieldsObject());
  }

  /**
   * Converts the textual operator (">", "<=", etc) into a FilterOperator. Forgiving about the
   * syntax; != and <> are NOT_EQUAL, = and == are EQUAL.
   */
  protected FilterOperator translate(final String operator) {
    return FilterOperator.fromString(operator);
  }

  @Override
  public boolean equals(final Object o) {
    if (this == o) {
      return true;
    }
    if (!(o instanceof QueryImpl)) {
      return false;
    }

    final QueryImpl<?> query = (QueryImpl<?>) o;

    if (validateName != query.validateName) {
      return false;
    }
    if (validateType != query.validateType) {
      return false;
    }
    if (!dbColl.equals(query.dbColl)) {
      return false;
    }
    if (!clazz.equals(query.clazz)) {
      return false;
    }
    if (includeFields != null
        ? !includeFields.equals(query.includeFields)
        : query.includeFields != null) {
      return false;
    }
    if (baseQuery != null ? !baseQuery.equals(query.baseQuery) : query.baseQuery != null) {
      return false;
    }
    return compare(options, query.options);
  }

  private boolean compare(final FindOptions these, final FindOptions those) {
    if (these == null && those != null || these != null && those == null) {
      return false;
    }
    if (these == null) {
      return true;
    }

    DBCollectionFindOptions dbOptions = these.getOptions();
    DBCollectionFindOptions that = those.getOptions();

    if (dbOptions.getBatchSize() != that.getBatchSize()) {
      return false;
    }
    if (dbOptions.getLimit() != that.getLimit()) {
      return false;
    }
    if (dbOptions.getMaxTime(MILLISECONDS) != that.getMaxTime(MILLISECONDS)) {
      return false;
    }
    if (dbOptions.getMaxAwaitTime(MILLISECONDS) != that.getMaxAwaitTime(MILLISECONDS)) {
      return false;
    }
    if (dbOptions.getSkip() != that.getSkip()) {
      return false;
    }
    if (dbOptions.isNoCursorTimeout() != that.isNoCursorTimeout()) {
      return false;
    }
    if (dbOptions.isOplogReplay() != that.isOplogReplay()) {
      return false;
    }
    if (dbOptions.isPartial() != that.isPartial()) {
      return false;
    }
    if (dbOptions.getModifiers() != null
        ? !dbOptions.getModifiers().equals(that.getModifiers())
        : that.getModifiers() != null) {
      return false;
    }
    if (dbOptions.getProjection() != null
        ? !dbOptions.getProjection().equals(that.getProjection())
        : that.getProjection() != null) {
      return false;
    }
    if (dbOptions.getSort() != null
        ? !dbOptions.getSort().equals(that.getSort())
        : that.getSort() != null) {
      return false;
    }
    if (dbOptions.getCursorType() != that.getCursorType()) {
      return false;
    }
    if (dbOptions.getReadPreference() != null
        ? !dbOptions.getReadPreference().equals(that.getReadPreference())
        : that.getReadPreference() != null) {
      return false;
    }
    if (dbOptions.getReadConcern() != null
        ? !dbOptions.getReadConcern().equals(that.getReadConcern())
        : that.getReadConcern() != null) {
      return false;
    }
    return dbOptions.getCollation() != null
        ? dbOptions.getCollation().equals(that.getCollation())
        : that.getCollation() == null;
  }

  private int hash(final FindOptions options) {
    if (options == null) {
      return 0;
    }
    int result = options.getBatchSize();
    result = 31 * result + getLimit();
    result = 31 * result + (options.getModifiers() != null ? options.getModifiers().hashCode() : 0);
    result =
        31 * result + (options.getProjection() != null ? options.getProjection().hashCode() : 0);
    result =
        31 * result
            + (int) (options.getMaxTime(MILLISECONDS) ^ options.getMaxTime(MILLISECONDS) >>> 32);
    result =
        31 * result
            + (int)
                (options.getMaxAwaitTime(MILLISECONDS)
                    ^ options.getMaxAwaitTime(MILLISECONDS) >>> 32);
    result = 31 * result + options.getSkip();
    result =
        31 * result
            + (options.getSortDBObject() != null ? options.getSortDBObject().hashCode() : 0);
    result =
        31 * result + (options.getCursorType() != null ? options.getCursorType().hashCode() : 0);
    result = 31 * result + (options.isNoCursorTimeout() ? 1 : 0);
    result = 31 * result + (options.isOplogReplay() ? 1 : 0);
    result = 31 * result + (options.isPartial() ? 1 : 0);
    result =
        31 * result
            + (options.getReadPreference() != null ? options.getReadPreference().hashCode() : 0);
    result =
        31 * result + (options.getReadConcern() != null ? options.getReadConcern().hashCode() : 0);
    result = 31 * result + (options.getCollation() != null ? options.getCollation().hashCode() : 0);
    return result;
  }

  @Override
  public int hashCode() {
    int result = dbColl.hashCode();
    result = 31 * result + clazz.hashCode();
    result = 31 * result + (validateName ? 1 : 0);
    result = 31 * result + (validateType ? 1 : 0);
    result = 31 * result + (includeFields != null ? includeFields.hashCode() : 0);
    result = 31 * result + (baseQuery != null ? baseQuery.hashCode() : 0);
    result = 31 * result + hash(options);
    return result;
  }
}
Example #2
0
final class QueryValidator {
  private static final Logger LOG = MorphiaLoggerFactory.get(QueryValidator.class);

  private QueryValidator() {}

  /*package*/
  static boolean isCompatibleForOperator(
      final MappedClass mappedClass,
      final MappedField mappedField,
      final Class<?> type,
      final FilterOperator op,
      final Object value,
      final List<ValidationFailure> validationFailures) {
    // TODO: it's really OK to have null values?  I think this is to prevent null pointers further
    // down,
    // but I want to move the null check into the operations that care whether they allow nulls or
    // not.
    if (value == null || type == null) {
      return true;
    }

    boolean validationApplied =
        ExistsOperationValidator.getInstance().apply(mappedField, op, value, validationFailures)
            || SizeOperationValidator.getInstance()
                .apply(mappedField, op, value, validationFailures)
            || InOperationValidator.getInstance().apply(mappedField, op, value, validationFailures)
            || NotInOperationValidator.getInstance()
                .apply(mappedField, op, value, validationFailures)
            || ModOperationValidator.getInstance().apply(mappedField, op, value, validationFailures)
            || GeoWithinOperationValidator.getInstance()
                .apply(mappedField, op, value, validationFailures)
            || AllOperationValidator.getInstance().apply(mappedField, op, value, validationFailures)
            || KeyValueTypeValidator.getInstance().apply(type, value, validationFailures)
            || IntegerTypeValidator.getInstance().apply(type, value, validationFailures)
            || LongTypeValidator.getInstance().apply(type, value, validationFailures)
            || DoubleTypeValidator.getInstance().apply(type, value, validationFailures)
            || PatternValueValidator.getInstance().apply(type, value, validationFailures)
            || EntityAnnotatedValueValidator.getInstance().apply(type, value, validationFailures)
            || ListValueValidator.getInstance().apply(type, value, validationFailures)
            || EntityTypeAndIdValueValidator.getInstance()
                .apply(mappedClass, mappedField, value, validationFailures)
            || DefaultTypeValidator.getInstance().apply(type, value, validationFailures);

    return validationApplied && validationFailures.size() == 0;
  }

  /** Validate the path, and value type, returning the mapped field for the field at the path */
  static MappedField validateQuery(
      final Class clazz,
      final Mapper mapper,
      final StringBuilder origProp,
      final FilterOperator op,
      final Object val,
      final boolean validateNames,
      final boolean validateTypes) {
    // TODO: cache validations (in static?).

    MappedField mf = null;
    final String prop = origProp.toString();
    boolean hasTranslations = false;

    if (validateNames) {
      final String[] parts = prop.split("\\.");
      if (clazz == null) {
        return null;
      }

      MappedClass mc = mapper.getMappedClass(clazz);
      // CHECKSTYLE:OFF
      for (int i = 0; ; ) {
        // CHECKSTYLE:ON
        final String part = parts[i];
        boolean fieldIsArrayOperator = part.equals("$");

        mf = mc.getMappedField(part);

        // translate from java field name to stored field name
        if (mf == null && !fieldIsArrayOperator) {
          mf = mc.getMappedFieldByJavaField(part);
          if (mf == null) {
            throw new ValidationException(
                format(
                    "The field '%s' could not be found in '%s' while validating - %s; if "
                        + "you wish to continue please disable validation.",
                    part, clazz.getName(), prop));
          }
          hasTranslations = true;
          parts[i] = mf.getNameToStore();
        }

        i++;
        if (mf != null && mf.isMap()) {
          // skip the map key validation, and move to the next part
          i++;
        }

        if (i >= parts.length) {
          break;
        }

        if (!fieldIsArrayOperator) {
          // catch people trying to search/update into @Reference/@Serialized fields
          if (!canQueryPast(mf)) {
            throw new ValidationException(
                format(
                    "Can not use dot-notation past '%s' could not be found in '%s' while"
                        + " validating - %s",
                    part, clazz.getName(), prop));
          }

          // get the next MappedClass for the next field validation
          mc = mapper.getMappedClass((mf.isSingleValue()) ? mf.getType() : mf.getSubClass());
        }
      }

      // record new property string if there has been a translation to any part
      if (hasTranslations) {
        origProp.setLength(0); // clear existing content
        origProp.append(parts[0]);
        for (int i = 1; i < parts.length; i++) {
          origProp.append('.');
          origProp.append(parts[i]);
        }
      }

      if (validateTypes && mf != null) {
        List<ValidationFailure> typeValidationFailures = new ArrayList<ValidationFailure>();
        boolean compatibleForType =
            isCompatibleForOperator(mc, mf, mf.getType(), op, val, typeValidationFailures);
        List<ValidationFailure> subclassValidationFailures = new ArrayList<ValidationFailure>();
        boolean compatibleForSubclass =
            isCompatibleForOperator(mc, mf, mf.getSubClass(), op, val, subclassValidationFailures);

        if ((mf.isSingleValue() && !compatibleForType)
            || mf.isMultipleValues() && !(compatibleForSubclass || compatibleForType)) {

          if (LOG.isWarningEnabled()) {
            LOG.warning(
                format(
                    "The type(s) for the query/update may be inconsistent; using an instance of type '%s' "
                        + "for the field '%s.%s' which is declared as '%s'",
                    val.getClass().getName(),
                    mf.getDeclaringClass().getName(),
                    mf.getJavaFieldName(),
                    mf.getType().getName()));
            typeValidationFailures.addAll(subclassValidationFailures);
            LOG.warning("Validation warnings: \n" + typeValidationFailures);
          }
        }
      }
    }
    return mf;
  }

  private static boolean canQueryPast(final MappedField mf) {
    return !(mf.isReference() || mf.hasAnnotation(Serialized.class));
  }
}
public class AggregationPipelineImpl<T, U> implements AggregationPipeline<T, U> {
  private static final Logger LOG = MorphiaLoggerFactory.get(AggregationPipelineImpl.class);

  private final DBCollection collection;
  private final Class<T> source;
  private final List<DBObject> stages = new ArrayList<DBObject>();
  private final Mapper mapper;
  private final DatastoreImpl datastore;
  private boolean firstStage = false;

  public AggregationPipelineImpl(final DatastoreImpl datastore, final Class<T> source) {
    this.datastore = datastore;
    this.collection = datastore.getCollection(source);
    mapper = datastore.getMapper();
    this.source = source;
  }

  @SuppressWarnings("unchecked")
  public DBObject toDBObject(final Projection projection) {
    String sourceFieldName;
    if (firstStage) {
      MappedField field = mapper.getMappedClass(source).getMappedField(projection.getSourceField());
      sourceFieldName = field.getNameToStore();
    } else {
      sourceFieldName = projection.getSourceField();
    }

    if (projection.getProjections() != null) {
      List<Projection> list = projection.getProjections();
      DBObject projections = new BasicDBObject();
      for (Projection subProjection : list) {
        projections.putAll(toDBObject(subProjection));
      }
      return new BasicDBObject(sourceFieldName, projections);
    } else if (projection.getProjectedField() != null) {
      return new BasicDBObject(sourceFieldName, projection.getProjectedField());
    } else {
      return new BasicDBObject(sourceFieldName, projection.isSuppressed() ? 0 : 1);
    }
  }

  public AggregationPipeline<T, U> project(final Projection... projections) {
    firstStage = stages.isEmpty();
    DBObject dbObject = new BasicDBObject();
    for (Projection projection : projections) {
      dbObject.putAll(toDBObject(projection));
    }
    stages.add(new BasicDBObject("$project", dbObject));
    return this;
  }

  public AggregationPipeline<T, U> group(final String id, final Group... groupings) {
    DBObject group = new BasicDBObject("_id", "$" + id);
    for (Group grouping : groupings) {
      Accumulator accumulator = grouping.getAccumulator();
      group.put(
          grouping.getName(),
          new BasicDBObject(accumulator.getOperation(), accumulator.getField()));
    }

    stages.add(new BasicDBObject("$group", group));
    return this;
  }

  public AggregationPipeline<T, U> group(final List<Group> id, final Group... groupings) {
    DBObject idGroup = new BasicDBObject();
    for (Group group : id) {
      idGroup.put(group.getName(), group.getSourceField());
    }
    DBObject group = new BasicDBObject("_id", idGroup);
    for (Group grouping : groupings) {
      Accumulator accumulator = grouping.getAccumulator();
      group.put(
          grouping.getName(),
          new BasicDBObject(accumulator.getOperation(), accumulator.getField()));
    }

    stages.add(new BasicDBObject("$group", group));
    return this;
  }

  public AggregationPipeline<T, U> match(final Query query) {
    stages.add(new BasicDBObject("$match", query.getQueryObject()));
    return this;
  }

  public AggregationPipeline<T, U> sort(final Sort... sorts) {
    DBObject sortList = new BasicDBObject();
    for (Sort sort : sorts) {
      sortList.put(sort.getField(), sort.getDirection());
    }

    stages.add(new BasicDBObject("$sort", sortList));
    return this;
  }

  public AggregationPipeline<T, U> limit(final int count) {
    stages.add(new BasicDBObject("$limit", count));
    return this;
  }

  public AggregationPipeline<T, U> skip(final int count) {
    stages.add(new BasicDBObject("$skip", count));
    return this;
  }

  public AggregationPipeline<T, U> unwind(final String field) {
    stages.add(new BasicDBObject("$unwind", "$" + field));
    return this;
  }

  public AggregationPipeline<T, U> geoNear(final GeoNear geoNear) {
    DBObject geo = new BasicDBObject();
    putIfNull(geo, "near", geoNear.getNear());
    putIfNull(geo, "distanceField", geoNear.getDistanceField());
    putIfNull(geo, "limit", geoNear.getLimit());
    putIfNull(geo, "num", geoNear.getMaxDocuments());
    putIfNull(geo, "maxDistance", geoNear.getMaxDistance());
    if (geoNear.getQuery() != null) {
      geo.put("query", geoNear.getQuery().getQueryObject());
    }
    putIfNull(geo, "spherical", geoNear.getSpherical());
    putIfNull(geo, "distanceMultiplier", geoNear.getDistanceMultiplier());
    putIfNull(geo, "includeLocs", geoNear.getIncludeLocations());
    putIfNull(geo, "uniqueDocs", geoNear.getUniqueDocuments());
    stages.add(new BasicDBObject("$geoNear", geo));

    return this;
  }

  private void putIfNull(final DBObject dbObject, final String name, final Object value) {
    if (value != null) {
      dbObject.put(name, value);
    }
  }

  @Override
  public MorphiaIterator<U, U> out(final Class<U> target) {
    return out(datastore.getCollection(target).getName(), target);
  }

  @Override
  public MorphiaIterator<U, U> out(final Class<U> target, final AggregationOptions options) {
    return out(datastore.getCollection(target).getName(), target, options);
  }

  @Override
  public MorphiaIterator<U, U> out(final String collectionName, final Class<U> target) {
    return out(collectionName, target, AggregationOptions.builder().build());
  }

  @Override
  public MorphiaIterator<U, U> out(
      final String collectionName, final Class<U> target, final AggregationOptions options) {
    stages.add(new BasicDBObject("$out", collectionName));
    return aggregate(target, options);
  }

  @Override
  public MorphiaIterator<U, U> aggregate(final Class<U> target) {
    return aggregate(target, AggregationOptions.builder().build(), collection.getReadPreference());
  }

  @Override
  public MorphiaIterator<U, U> aggregate(final Class<U> target, final AggregationOptions options) {
    return aggregate(target, options, collection.getReadPreference());
  }

  @Override
  public MorphiaIterator<U, U> aggregate(
      final Class<U> target,
      final AggregationOptions options,
      final ReadPreference readPreference) {
    return aggregate(datastore.getCollection(target).getName(), target, options, readPreference);
  }

  @Override
  public MorphiaIterator<U, U> aggregate(
      final String collectionName,
      final Class<U> target,
      final AggregationOptions options,
      final ReadPreference readPreference) {
    LOG.debug("stages = " + stages);

    Cursor cursor = collection.aggregate(stages, options, readPreference);
    return new MorphiaIterator<U, U>(
        cursor, mapper, target, collection.getName(), mapper.createEntityCache());
  }
}
Example #4
0
final class IndexHelper {
  private static final Logger LOG = MorphiaLoggerFactory.get(IndexHelper.class);
  private static final EncoderContext ENCODER_CONTEXT = EncoderContext.builder().build();

  private final Mapper mapper;
  private final MongoDatabase database;

  IndexHelper(final Mapper mapper, final MongoDatabase database) {
    this.mapper = mapper;
    this.database = database;
  }

  private static String join(final List<String> path, final char delimiter) {
    StringBuilder builder = new StringBuilder();
    for (String element : path) {
      if (builder.length() != 0) {
        builder.append(delimiter);
      }
      builder.append(element);
    }
    return builder.toString();
  }

  private void calculateWeights(
      final Index index, final com.mongodb.client.model.IndexOptions indexOptions) {
    Document weights = new Document();
    for (Field field : index.fields()) {
      if (field.weight() != -1) {
        if (field.type() != IndexType.TEXT) {
          throw new MappingException(
              "Weight values only apply to text indexes: " + Arrays.toString(index.fields()));
        }
        weights.put(field.value(), field.weight());
      }
    }
    if (!weights.isEmpty()) {
      indexOptions.weights(weights);
    }
  }

  Index convert(final Text text, final String nameToStore) {
    return new IndexBuilder()
        .options(text.options())
        .fields(
            Collections.<Field>singletonList(
                new FieldBuilder().value(nameToStore).type(IndexType.TEXT).weight(text.value())));
  }

  @SuppressWarnings("deprecation")
  Index convert(final Indexed indexed, final String nameToStore) {
    if (indexed.dropDups() || indexed.options().dropDups()) {
      LOG.warning(
          "dropDups value is no longer supported by the server.  Please set this value to false and "
              + "validate your system behaves as expected.");
    }
    final Map<String, Object> newOptions = extractOptions(indexed.options());
    if (!extractOptions(indexed).isEmpty() && !newOptions.isEmpty()) {
      throw new MappingException(
          "Mixed usage of deprecated @Indexed values with the new @IndexOption values is not "
              + "allowed.  Please migrate all settings to @IndexOptions");
    }

    List<Field> fields =
        Collections.<Field>singletonList(
            new FieldBuilder().value(nameToStore).type(fromValue(indexed.value().toIndexValue())));
    return newOptions.isEmpty()
        ? new IndexBuilder().options(new IndexOptionsBuilder().migrate(indexed)).fields(fields)
        : new IndexBuilder().options(indexed.options()).fields(fields);
  }

  @SuppressWarnings("deprecation")
  private List<Index> collectFieldIndexes(final MappedClass mc) {
    List<Index> list = new ArrayList<Index>();
    for (final MappedField mf : mc.getPersistenceFields()) {
      if (mf.hasAnnotation(Indexed.class)) {
        final Indexed indexed = mf.getAnnotation(Indexed.class);
        list.add(convert(indexed, mf.getNameToStore()));
      } else if (mf.hasAnnotation(Text.class)) {
        final Text text = mf.getAnnotation(Text.class);
        list.add(convert(text, mf.getNameToStore()));
      }
    }
    return list;
  }

  private List<Index> collectIndexes(final MappedClass mc, final List<MappedClass> parentMCs) {
    if (parentMCs.contains(mc) || mc.getEmbeddedAnnotation() != null && parentMCs.isEmpty()) {
      return emptyList();
    }

    List<Index> indexes = collectTopLevelIndexes(mc);
    indexes.addAll(collectFieldIndexes(mc));
    indexes.addAll(collectNestedIndexes(mc, parentMCs));

    return indexes;
  }

  private List<Index> collectNestedIndexes(
      final MappedClass mc, final List<MappedClass> parentMCs) {
    List<Index> list = new ArrayList<Index>();
    for (final MappedField mf : mc.getPersistenceFields()) {
      if (!mf.isTypeMongoCompatible()
          && !mf.hasAnnotation(Reference.class)
          && !mf.hasAnnotation(Serialized.class)
          && !mf.hasAnnotation(NotSaved.class)
          && !mf.isTransient()) {

        final List<MappedClass> parents = new ArrayList<MappedClass>(parentMCs);
        parents.add(mc);

        List<MappedClass> classes = new ArrayList<MappedClass>();
        MappedClass mappedClass =
            mapper.getMappedClass(mf.isSingleValue() ? mf.getType() : mf.getSubClass());
        classes.add(mappedClass);
        classes.addAll(mapper.getSubTypes(mappedClass));
        for (MappedClass aClass : classes) {
          for (Index index : collectIndexes(aClass, parents)) {
            List<Field> fields = new ArrayList<Field>();
            for (Field field : index.fields()) {
              fields.add(
                  new FieldBuilder()
                      .value(
                          field.value().equals("$**")
                              ? field.value()
                              : mf.getNameToStore() + "." + field.value())
                      .type(field.type())
                      .weight(field.weight()));
            }
            list.add(new IndexBuilder(index).fields(fields));
          }
        }
      }
    }

    return list;
  }

  private List<Index> collectTopLevelIndexes(final MappedClass mc) {
    List<Index> list = new ArrayList<Index>();
    final List<Indexes> annotations = mc.getAnnotations(Indexes.class);
    if (annotations != null) {
      for (final Indexes indexes : annotations) {
        for (final Index index : indexes.value()) {
          Index updated = index;
          if (index.fields().length == 0) {
            LOG.warning(
                format(
                    "This index on '%s' is using deprecated configuration options.  Please update to use the "
                        + "fields value on @Index: %s",
                    mc.getClazz().getName(), index.toString()));
            updated = new IndexBuilder().migrate(index);
          }
          List<Field> fields = new ArrayList<Field>();
          for (Field field : updated.fields()) {
            fields.add(
                new FieldBuilder()
                    .value(findField(mc, index.options(), asList(field.value().split("\\."))))
                    .type(field.type())
                    .weight(field.weight()));
          }

          list.add(replaceFields(updated, fields));
        }
      }
    }
    return list;
  }

  private Map<String, Object> extractOptions(final IndexOptions options) {
    return toMap(options);
  }

  private Map<String, Object> extractOptions(final Indexed indexed) {
    Map<String, Object> map = toMap(indexed);
    if (indexed.options().collation().locale().equals("")) {
      map.remove("options");
    }
    map.remove("value");
    return map;
  }

  private MappingException pathFail(final MappedClass mc, final List<String> path) {
    return new MappingException(
        format(
            "Could not resolve path '%s' against '%s'.", join(path, '.'), mc.getClazz().getName()));
  }

  private Index replaceFields(final Index original, final List<Field> list) {
    return new IndexBuilder(original).fields(list);
  }

  @SuppressWarnings("unchecked")
  private BsonDocument toBsonDocument(final String key, final Object value) {
    BsonDocumentWriter writer = new BsonDocumentWriter(new BsonDocument());
    writer.writeStartDocument();
    writer.writeName(key);
    ((Encoder) database.getCodecRegistry().get(value.getClass()))
        .encode(writer, value, ENCODER_CONTEXT);
    writer.writeEndDocument();
    return writer.getDocument();
  }

  BsonDocument calculateKeys(final MappedClass mc, final Index index) {
    BsonDocument keys = new BsonDocument();
    for (Field field : index.fields()) {
      String path;
      try {
        path =
            findField(
                mc, index.options(), new ArrayList<String>(asList(field.value().split("\\."))));
      } catch (Exception e) {
        path = field.value();
        String message =
            format(
                "The path '%s' can not be validated against '%s' and may represent an invalid index",
                path, mc.getClazz().getName());
        if (!index.options().disableValidation()) {
          throw new MappingException(message);
        }
        LOG.warning(message);
      }
      keys.putAll(toBsonDocument(path, field.type().toIndexValue()));
    }
    return keys;
  }

  @SuppressWarnings("deprecation")
  com.mongodb.client.model.IndexOptions convert(
      final IndexOptions options, final boolean background) {
    if (options.dropDups()) {
      LOG.warning(
          "dropDups value is no longer supported by the server.  Please set this value to false and "
              + "validate your system behaves as expected.");
    }
    com.mongodb.client.model.IndexOptions indexOptions =
        new com.mongodb.client.model.IndexOptions()
            .background(options.background() || background)
            .sparse(options.sparse())
            .unique(options.unique());

    if (!options.language().equals("")) {
      indexOptions.defaultLanguage(options.language());
    }
    if (!options.languageOverride().equals("")) {
      indexOptions.languageOverride(options.languageOverride());
    }
    if (!options.name().equals("")) {
      indexOptions.name(options.name());
    }
    if (options.expireAfterSeconds() != -1) {
      indexOptions.expireAfter((long) options.expireAfterSeconds(), TimeUnit.SECONDS);
    }
    if (!options.partialFilter().equals("")) {
      indexOptions.partialFilterExpression(Document.parse(options.partialFilter()));
    }
    if (!options.collation().locale().equals("")) {
      indexOptions.collation(convert(options.collation()));
    }

    return indexOptions;
  }

  com.mongodb.client.model.Collation convert(final Collation collation) {
    return com.mongodb.client.model.Collation.builder()
        .locale(collation.locale())
        .backwards(collation.backwards())
        .caseLevel(collation.caseLevel())
        .collationAlternate(collation.alternate())
        .collationCaseFirst(collation.caseFirst())
        .collationMaxVariable(collation.maxVariable())
        .collationStrength(collation.strength())
        .normalization(collation.normalization())
        .numericOrdering(collation.numericOrdering())
        .build();
  }

  String findField(final MappedClass mc, final IndexOptions options, final List<String> path) {
    String segment = path.get(0);
    if (segment.equals("$**")) {
      return segment;
    }

    MappedField mf = mc.getMappedField(segment);
    if (mf == null) {
      mf = mc.getMappedFieldByJavaField(segment);
    }
    if (mf == null && mc.isInterface()) {
      for (final MappedClass mappedClass : mapper.getSubTypes(mc)) {
        try {
          return findField(mappedClass, options, new ArrayList<String>(path));
        } catch (MappingException e) {
          // try the next one
        }
      }
    }
    String namePath;
    if (mf != null) {
      namePath = mf.getNameToStore();
    } else {
      if (!options.disableValidation()) {
        throw pathFail(mc, path);
      } else {
        return join(path, '.');
      }
    }
    if (path.size() > 1) {
      try {
        Class concreteType = !mf.isSingleValue() ? mf.getSubClass() : mf.getConcreteType();
        namePath +=
            "."
                + findField(
                    mapper.getMappedClass(concreteType), options, path.subList(1, path.size()));
      } catch (MappingException e) {
        if (!options.disableValidation()) {
          throw pathFail(mc, path);
        } else {
          return join(path, '.');
        }
      }
    }
    return namePath;
  }

  void createIndex(
      final MongoCollection collection, final MappedClass mc, final boolean background) {
    for (Index index : collectIndexes(mc, Collections.<MappedClass>emptyList())) {
      createIndex(collection, mc, index, background);
    }
  }

  void createIndex(
      final MongoCollection collection,
      final MappedClass mc,
      final Index index,
      final boolean background) {
    Index normalized = IndexBuilder.normalize(index);

    BsonDocument keys = calculateKeys(mc, normalized);
    com.mongodb.client.model.IndexOptions indexOptions = convert(normalized.options(), background);
    calculateWeights(normalized, indexOptions);

    collection.createIndex(keys, indexOptions);
  }
}