@Nullable private static InferenceSession startTopLevelInference( final PsiCall topLevelCall, final ParameterTypeInferencePolicy policy) { final JavaResolveResult result = topLevelCall.resolveMethodGenerics(); if (result instanceof MethodCandidateInfo) { final PsiMethod method = ((MethodCandidateInfo) result).getElement(); final PsiParameter[] topLevelParameters = method.getParameterList().getParameters(); final PsiExpressionList topLevelCallArgumentList = topLevelCall.getArgumentList(); LOG.assertTrue(topLevelCallArgumentList != null, topLevelCall); final PsiExpression[] topLevelArguments = topLevelCallArgumentList.getExpressions(); return PsiResolveHelper.ourGraphGuard.doPreventingRecursion( topLevelCall, true, new Computable<InferenceSession>() { @Override public InferenceSession compute() { final InferenceSession topLevelSession = new InferenceSession( method.getTypeParameters(), ((MethodCandidateInfo) result).getSiteSubstitutor(), topLevelCall.getManager(), topLevelCall, policy); topLevelSession.initExpressionConstraints( topLevelParameters, topLevelArguments, topLevelCall, method, ((MethodCandidateInfo) result).isVarargs()); topLevelSession.infer( topLevelParameters, topLevelArguments, topLevelCall, ((MethodCandidateInfo) result).createProperties()); return topLevelSession; } }); } return null; }
private static PsiSubstitutor inferNested( final PsiTypeParameter[] typeParameters, @NotNull final PsiParameter[] parameters, @NotNull final PsiExpression[] arguments, final PsiSubstitutor partialSubstitutor, @NotNull final PsiCall parent, @NotNull final ParameterTypeInferencePolicy policy, final MethodCandidateInfo.CurrentCandidateProperties properties, final InferenceSession parentSession) { final CompoundInitialState compoundInitialState = createState(parentSession); InitialInferenceState initialInferenceState = compoundInitialState.getInitialState(parent); if (initialInferenceState != null) { final InferenceSession childSession = new InferenceSession(initialInferenceState); final List<String> errorMessages = parentSession.getIncompatibleErrorMessages(); if (errorMessages != null) { return childSession.prepareSubstitution(); } return childSession.collectAdditionalAndInfer( parameters, arguments, properties, compoundInitialState.getInitialSubstitutor()); } // we do not investigate lambda return expressions when lambda's return type is already inferred // (proper) // this way all calls from lambda's return expressions won't appear in nested sessions else { PsiElement gParent = PsiUtil.skipParenthesizedExprUp(parent.getParent()); // find the nearest parent which appears in the map and start inference with a provided target // type for a nested lambda while (true) { if (gParent instanceof PsiReturnStatement) { // process code block lambda final PsiElement returnContainer = gParent.getParent(); if (returnContainer instanceof PsiCodeBlock) { gParent = returnContainer.getParent(); } } if (gParent instanceof PsiLambdaExpression) { final PsiCall call = PsiTreeUtil.getParentOfType(gParent, PsiCall.class); if (call != null) { initialInferenceState = compoundInitialState.getInitialState(call); if (initialInferenceState != null) { final int idx = LambdaUtil.getLambdaIdx(call.getArgumentList(), gParent); final PsiMethod method = call.resolveMethod(); if (method != null && idx > -1) { final PsiType parameterType = PsiTypesUtil.getParameterType( method.getParameterList().getParameters(), idx, true); final PsiType parameterTypeInTermsOfSession = initialInferenceState.getInferenceSubstitutor().substitute(parameterType); final PsiType lambdaTargetType = compoundInitialState .getInitialSubstitutor() .substitute(parameterTypeInTermsOfSession); return LambdaUtil.performWithLambdaTargetType( (PsiLambdaExpression) gParent, lambdaTargetType, new Producer<PsiSubstitutor>() { @Nullable @Override public PsiSubstitutor produce() { if (call.equals(PsiTreeUtil.getParentOfType(parent, PsiCall.class, true))) { // parent was mentioned in the top inference session // just proceed with the target type final InferenceSession inferenceSession = new InferenceSession( typeParameters, partialSubstitutor, parent.getManager(), parent, policy); inferenceSession.initExpressionConstraints(parameters, arguments, parent); return inferenceSession.infer(parameters, arguments, parent); } // one of the grand parents were found in the top inference session // start from it as it is the top level call final InferenceSession sessionInsideLambda = startTopLevelInference(call, policy); return inferNested( typeParameters, parameters, arguments, partialSubstitutor, parent, policy, properties, sessionInsideLambda); } }); } } else { gParent = PsiUtil.skipParenthesizedExprUp(call.getParent()); continue; } } } break; } } return null; }
@Nullable public static PsiType getFunctionalInterfaceType( PsiElement expression, final boolean tryToSubstitute, int paramIdx) { PsiElement parent = expression.getParent(); PsiElement element = expression; while (parent instanceof PsiParenthesizedExpression || parent instanceof PsiConditionalExpression) { if (parent instanceof PsiConditionalExpression && ((PsiConditionalExpression) parent).getThenExpression() != element && ((PsiConditionalExpression) parent).getElseExpression() != element) break; element = parent; parent = parent.getParent(); } if (parent instanceof PsiArrayInitializerExpression) { final PsiType psiType = ((PsiArrayInitializerExpression) parent).getType(); if (psiType instanceof PsiArrayType) { return ((PsiArrayType) psiType).getComponentType(); } } else if (parent instanceof PsiTypeCastExpression) { final PsiType castType = ((PsiTypeCastExpression) parent).getType(); if (castType instanceof PsiIntersectionType) { for (PsiType conjunctType : ((PsiIntersectionType) castType).getConjuncts()) { if (getFunctionalInterfaceMethod(conjunctType) != null) return conjunctType; } } return castType; } else if (parent instanceof PsiVariable) { return ((PsiVariable) parent).getType(); } else if (parent instanceof PsiAssignmentExpression && expression instanceof PsiExpression && !PsiUtil.isOnAssignmentLeftHand((PsiExpression) expression)) { final PsiExpression lExpression = ((PsiAssignmentExpression) parent).getLExpression(); return lExpression.getType(); } else if (parent instanceof PsiExpressionList) { final PsiExpressionList expressionList = (PsiExpressionList) parent; final int lambdaIdx = getLambdaIdx(expressionList, expression); if (lambdaIdx > -1) { PsiType cachedType = null; final Pair<PsiMethod, PsiSubstitutor> method = MethodCandidateInfo.getCurrentMethod(parent); if (method != null) { final PsiParameter[] parameters = method.first.getParameterList().getParameters(); cachedType = lambdaIdx < parameters.length ? method.second.substitute( getNormalizedType( parameters[adjustLambdaIdx(lambdaIdx, method.first, parameters)])) : null; if (!tryToSubstitute) return cachedType; } PsiElement gParent = expressionList.getParent(); if (gParent instanceof PsiAnonymousClass) { gParent = gParent.getParent(); } if (gParent instanceof PsiCall) { final PsiCall contextCall = (PsiCall) gParent; final JavaResolveResult resolveResult = contextCall.resolveMethodGenerics(); final PsiElement resolve = resolveResult.getElement(); if (resolve instanceof PsiMethod) { final PsiParameter[] parameters = ((PsiMethod) resolve).getParameterList().getParameters(); final int finalLambdaIdx = adjustLambdaIdx(lambdaIdx, (PsiMethod) resolve, parameters); if (finalLambdaIdx < parameters.length) { if (!tryToSubstitute) return getNormalizedType(parameters[finalLambdaIdx]); if (cachedType != null && paramIdx > -1) { final PsiMethod interfaceMethod = getFunctionalInterfaceMethod(cachedType); if (interfaceMethod != null) { final PsiClassType.ClassResolveResult cachedResult = PsiUtil.resolveGenericsClassInType(cachedType); final PsiType interfaceMethodParameterType = interfaceMethod.getParameterList().getParameters()[paramIdx].getType(); if (!dependsOnTypeParams( cachedResult.getSubstitutor().substitute(interfaceMethodParameterType), cachedType, expression)) { return cachedType; } } } return PsiResolveHelper.ourGuard.doPreventingRecursion( expression, true, new Computable<PsiType>() { @Override public PsiType compute() { return resolveResult .getSubstitutor() .substitute(getNormalizedType(parameters[finalLambdaIdx])); } }); } } return null; } } } else if (parent instanceof PsiReturnStatement) { final PsiLambdaExpression gParent = PsiTreeUtil.getParentOfType(parent, PsiLambdaExpression.class); if (gParent != null) { return getFunctionalInterfaceTypeByContainingLambda(gParent); } else { final PsiMethod method = PsiTreeUtil.getParentOfType(parent, PsiMethod.class); if (method != null) { return method.getReturnType(); } } } else if (parent instanceof PsiLambdaExpression) { return getFunctionalInterfaceTypeByContainingLambda((PsiLambdaExpression) parent); } return null; }