final class PositionsCheckNode extends Node { @Children private final PositionCheckNode[] positionsCheck; private final ElementAccessMode mode; private final BranchProfile errorBranch = BranchProfile.create(); private final VectorLengthProfile selectedPositionsCountProfile = VectorLengthProfile.create(); private final VectorLengthProfile maxOutOfBoundsProfile = VectorLengthProfile.create(); private final ConditionProfile containsNAProfile = ConditionProfile.createBinaryProfile(); private final BranchProfile unsupportedProfile = BranchProfile.create(); private final boolean replace; private final int positionsLength; PositionsCheckNode( ElementAccessMode mode, RType containerType, Object[] positions, boolean exact, boolean replace, boolean recursive) { this.mode = mode; this.replace = replace; this.positionsCheck = new PositionCheckNode[positions.length]; for (int i = 0; i < positions.length; i++) { positionsCheck[i] = PositionCheckNode.createNode( mode, containerType, positions[i], i, positions.length, exact, replace, recursive); } this.positionsLength = positions.length; } public PositionCheckNode getPositionCheckAt(int index) { return positionsCheck[index]; } @ExplodeLoop public boolean isSupported(Object[] positions) { if (positionsCheck.length != positions.length) { unsupportedProfile.enter(); return false; } for (int i = 0; i < positionsCheck.length; i++) { if (!positionsCheck[i].isSupported(positions[i])) { unsupportedProfile.enter(); return false; } } return true; } public int getDimensions() { return positionsCheck.length; } public boolean isSingleDimension() { return positionsCheck.length == 1; } public boolean isMultiDimension() { return positionsCheck.length > 1; } @ExplodeLoop public PositionProfile[] executeCheck( RAbstractContainer vector, int[] vectorDimensions, int vectorLength, Object[] positions) { assert isSupported(positions); verifyDimensions(vectorDimensions); PositionProfile[] statistics = new PositionProfile[positionsLength]; for (int i = 0; i < positionsLength; i++) { Object position = positions[i]; PositionProfile profile = new PositionProfile(); positions[i] = positionsCheck[i].execute(profile, vector, vectorDimensions, vectorLength, position); statistics[i] = profile; } return statistics; } @TruffleBoundary private void print() { System.out.println(positionsCheck.length); } private void verifyDimensions(int[] vectorDimensions) { if (vectorDimensions == null) { if (isMultiDimension()) { errorBranch.enter(); throw dimensionsError(); } } else { if (getDimensions() > vectorDimensions.length || getDimensions() < vectorDimensions.length) { errorBranch.enter(); throw dimensionsError(); } } } private RError dimensionsError() { if (replace) { if (mode.isSubset()) { if (getDimensions() == 2) { return RError.error(this, RError.Message.INCORRECT_SUBSCRIPTS_MATRIX); } else { return RError.error(this, RError.Message.INCORRECT_SUBSCRIPTS); } } else { return RError.error(this, RError.Message.IMPROPER_SUBSCRIPT); } } else { return RError.error(this, RError.Message.INCORRECT_DIMENSIONS); } } @ExplodeLoop public int getSelectedPositionsCount(PositionProfile[] profiles) { if (positionsCheck.length == 1) { return selectedPositionsCountProfile.profile(profiles[0].selectedPositionsCount); } else { int newSize = 1; for (int i = 0; i < positionsCheck.length; i++) { newSize *= profiles[i].selectedPositionsCount; } return selectedPositionsCountProfile.profile(newSize); } } public int getCachedSelectedPositionsCount() { return selectedPositionsCountProfile.getCachedLength(); } @ExplodeLoop public boolean getContainsNA(PositionProfile[] profiles) { if (positionsCheck.length == 1) { return containsNAProfile.profile(profiles[0].containsNA); } else { boolean containsNA = false; for (int i = 0; i < positionsCheck.length; i++) { containsNA |= profiles[i].containsNA; } return containsNAProfile.profile(containsNA); } } @ExplodeLoop public int getMaxOutOfBounds(PositionProfile[] replacementStatistics) { if (positionsCheck.length == 1) { return maxOutOfBoundsProfile.profile(replacementStatistics[0].maxOutOfBoundsIndex); } else { // impossible to be relevant as position check will throw an error in this case. return 0; } } public boolean isMissing() { return positionsCheck.length == 1 && // (positionsCheck[0].getPositionClass() == RMissing.class || positionsCheck[0].getPositionClass() == REmpty.class || // positionsCheck[0].getPositionClass() == RSymbol.class); } final class PositionProfile { int selectedPositionsCount; int maxOutOfBoundsIndex; boolean containsNA; } }
abstract class PositionCheckNode extends Node { protected final Class<?> positionClass; protected final int dimensionIndex; protected final int numDimensions; protected final VectorLengthProfile positionLengthProfile = VectorLengthProfile.create(); protected final BranchProfile error = BranchProfile.create(); protected final boolean replace; protected final RType containerType; @Child private PositionCastNode castNode; @Child private RLengthNode positionLengthNode = RLengthNode.create(); @Child private PositionCharacterLookupNode characterLookup; PositionCheckNode( ElementAccessMode mode, RType containerType, Object positionValue, int dimensionIndex, int numDimensions, boolean exact, boolean replace) { this.positionClass = positionValue.getClass(); this.dimensionIndex = dimensionIndex; this.numDimensions = numDimensions; this.replace = replace; this.containerType = containerType; this.castNode = PositionCastNode.create(mode, replace); if (positionValue instanceof String || positionValue instanceof RAbstractStringVector) { boolean useNAForNotFound = !replace && isListLike(containerType) && mode.isSubscript(); characterLookup = new PositionCharacterLookupNode( mode, numDimensions, dimensionIndex, useNAForNotFound, exact); } } protected static boolean isListLike(RType type) { switch (type) { case Language: case DataFrame: case Expression: case PairList: case List: return true; } return false; } public boolean isIgnoreDimension() { return positionClass == RMissing.class; } public Class<?> getPositionClass() { return positionClass; } public final boolean isSupported(Object object) { return object.getClass() == positionClass; } public static PositionCheckNode createNode( ElementAccessMode mode, RType containerType, Object position, int positionIndex, int numDimensions, boolean exact, boolean replace, boolean recursive) { if (mode.isSubset()) { return PositionCheckSubsetNodeGen.create( mode, containerType, position, positionIndex, numDimensions, exact, replace); } else { return PositionCheckSubscriptNodeGen.create( mode, containerType, position, positionIndex, numDimensions, exact, replace, recursive); } } protected boolean isMultiDimension() { return numDimensions > 1; } public final Object execute( PositionProfile profile, RAbstractContainer vector, int[] vectorDimensions, int vectorLength, Object position) { Object castPosition = castNode.execute(positionClass.cast(position)); int dimensionLength; if (numDimensions == 1) { dimensionLength = vectorLength; } else { assert vectorDimensions != null; assert vectorDimensions.length == numDimensions; dimensionLength = vectorDimensions[dimensionIndex]; } if (characterLookup != null) { castPosition = characterLookup.execute(vector, (RAbstractStringVector) castPosition, dimensionLength); } RTypedValue positionVector = (RTypedValue) profilePosition(castPosition); int positionLength; if (positionVector instanceof RMissing) { positionLength = -1; } else { positionLength = positionLengthProfile.profile(((RAbstractVector) positionVector).getLength()); } assert isValidCastedType(positionVector) : "result type of a position cast node must be integer or logical"; return execute(profile, dimensionLength, positionVector, positionLength); } private final ValueProfile castedValue = ValueProfile.createClassProfile(); Object profilePosition(Object positionVector) { return castedValue.profile(positionVector); } private static boolean isValidCastedType(RTypedValue positionVector) { RType type = positionVector.getRType(); return type == RType.Integer || type == RType.Logical || type == RType.Character || type == RType.Double || type == RType.Null; } public abstract Object execute( PositionProfile statistics, int dimensionLength, Object position, int positionLength); }