private static CompoundInitialState createState(InferenceSession topLevelSession) {
    final PsiSubstitutor topInferenceSubstitutor =
        replaceVariables(topLevelSession.getInferenceVariables());
    final Map<PsiElement, InitialInferenceState> nestedStates =
        new LinkedHashMap<PsiElement, InitialInferenceState>();

    final InferenceSessionContainer copy =
        new InferenceSessionContainer() {
          @Override
          public PsiSubstitutor findNestedSubstitutor(
              PsiElement arg, @Nullable PsiSubstitutor defaultSession) {
            // for the case foo(bar(a -> m())): top level inference won't touch lambda "a -> m()"
            // for the case foo(a -> bar(b -> m())): top level inference would go till nested lambda
            // "b -> m()" and the state from top level could be found here by "bar(b -> m())"
            // but proceeding with additional constraints from saved point would produce new
            // expression constraints with different inference variables (could be found in
            // myNestedSessions)
            // which won't be found in the system if we won't reject stored sessions in such cases
            final PsiSubstitutor substitutor = super.findNestedSubstitutor(arg, null);
            if (substitutor != null) {
              return substitutor;
            }

            final InitialInferenceState state =
                nestedStates.get(PsiTreeUtil.getParentOfType(arg, PsiCall.class));
            if (state != null) {
              return state.getInferenceSubstitutor();
            }
            return super.findNestedSubstitutor(arg, defaultSession);
          }
        };
    final Map<PsiElement, InferenceSession> nestedSessions =
        topLevelSession.getInferenceSessionContainer().myNestedSessions;
    for (Map.Entry<PsiElement, InferenceSession> entry : nestedSessions.entrySet()) {
      nestedStates.put(
          entry.getKey(),
          entry
              .getValue()
              .createInitialState(
                  copy, topLevelSession.getInferenceVariables(), topInferenceSubstitutor));
    }

    PsiSubstitutor substitutor = PsiSubstitutor.EMPTY;
    for (InferenceVariable variable : topLevelSession.getInferenceVariables()) {
      final PsiType instantiation = variable.getInstantiation();
      if (instantiation != PsiType.NULL) {
        final PsiClass psiClass =
            PsiUtil.resolveClassInClassTypeOnly(topInferenceSubstitutor.substitute(variable));
        if (psiClass instanceof InferenceVariable) {
          substitutor = substitutor.put((PsiTypeParameter) psiClass, instantiation);
        }
      }
    }

    return new CompoundInitialState(substitutor, nestedStates);
  }
  static PsiSubstitutor infer(
      @NotNull PsiTypeParameter[] typeParameters,
      @NotNull PsiParameter[] parameters,
      @NotNull PsiExpression[] arguments,
      @NotNull PsiSubstitutor partialSubstitutor,
      @NotNull final PsiElement parent,
      @NotNull final ParameterTypeInferencePolicy policy) {
    if (parent instanceof PsiCall) {
      final PsiExpressionList argumentList = ((PsiCall) parent).getArgumentList();
      final MethodCandidateInfo.CurrentCandidateProperties properties =
          MethodCandidateInfo.getCurrentMethod(argumentList);
      // overload resolution can't depend on outer call => should not traverse to top
      if (properties != null
          && !properties.isApplicabilityCheck()
          &&
          // in order to to avoid caching of candidates's errors on parent (!) , so check for
          // overload resolution is left here
          // But overload resolution can depend on type of lambda parameter. As it can't depend on
          // lambda body,
          // traversing down would stop at lambda level and won't take into account overloaded
          // method
          !MethodCandidateInfo.ourOverloadGuard.currentStack().contains(argumentList)) {
        final PsiCall topLevelCall =
            PsiResolveHelper.ourGraphGuard.doPreventingRecursion(
                parent,
                false,
                new Computable<PsiCall>() {
                  @Override
                  public PsiCall compute() {
                    if (parent instanceof PsiExpression
                        && !PsiPolyExpressionUtil.isPolyExpression((PsiExpression) parent)) {
                      return null;
                    }
                    return LambdaUtil.treeWalkUp(parent);
                  }
                });
        if (topLevelCall != null) {

          InferenceSession session;
          if (MethodCandidateInfo.isOverloadCheck()
              || !PsiDiamondType.ourDiamondGuard.currentStack().isEmpty()
              || LambdaUtil.isLambdaParameterCheck()) {
            session = startTopLevelInference(topLevelCall, policy);
          } else {
            session =
                CachedValuesManager.getCachedValue(
                    topLevelCall,
                    new CachedValueProvider<InferenceSession>() {
                      @Nullable
                      @Override
                      public Result<InferenceSession> compute() {
                        return new Result<InferenceSession>(
                            startTopLevelInference(topLevelCall, policy),
                            PsiModificationTracker.MODIFICATION_COUNT);
                      }
                    });

            if (session != null) {
              // reject cached top level session if it was based on wrong candidate: check nested
              // session if candidate (it's type parameters) are the same
              // such situations are avoided when overload resolution is performed
              // (MethodCandidateInfo.isOverloadCheck above)
              // but situations when client code iterates through
              // PsiResolveHelper.getReferencedMethodCandidates or similar are impossible to guess
              final Map<PsiElement, InferenceSession> sessions =
                  session.getInferenceSessionContainer().myNestedSessions;
              final InferenceSession childSession = sessions.get(parent);
              if (childSession != null) {
                for (PsiTypeParameter parameter : typeParameters) {
                  if (!childSession
                      .getInferenceSubstitution()
                      .getSubstitutionMap()
                      .containsKey(parameter)) {
                    session = startTopLevelInference(topLevelCall, policy);
                    break;
                  }
                }
              }
            }
          }

          if (session != null) {
            final PsiSubstitutor childSubstitutor =
                inferNested(
                    typeParameters,
                    parameters,
                    arguments,
                    partialSubstitutor,
                    (PsiCall) parent,
                    policy,
                    properties,
                    session);
            if (childSubstitutor != null) return childSubstitutor;
          } else if (topLevelCall instanceof PsiMethodCallExpression) {
            return new InferenceSession(
                    typeParameters, partialSubstitutor, parent.getManager(), parent, policy)
                .prepareSubstitution();
          }
        }
      }
    }

    final InferenceSession inferenceSession =
        new InferenceSession(
            typeParameters, partialSubstitutor, parent.getManager(), parent, policy);
    inferenceSession.initExpressionConstraints(parameters, arguments, parent);
    return inferenceSession.infer(parameters, arguments, parent);
  }
 public void registerNestedSession(InferenceSession session) {
   myNestedSessions.put(session.getContext(), session);
   myNestedSessions.putAll(session.getInferenceSessionContainer().myNestedSessions);
 }