/**
   * Sets the covariance for the a<->b edge to the given covariance, if within range. Otherwise does
   * nothing.
   *
   * @param a a <-> b
   * @param b a <-> b
   * @param covar The covariance of a <-> b.
   * @return true if the coefficent was set (i.e. was within range), false if not.
   */
  public boolean setErrorCovariance(Node a, Node b, final double covar) {
    Edge edge = Edges.bidirectedEdge(semGraph.getExogenous(a), semGraph.getExogenous(b));

    if (edgeParameters.get(edge) == null) {
      throw new IllegalArgumentException("Not a covariance parameter in this model: " + edge);
    }

    if (editingEdge == null || !edge.equals(editingEdge)) {
      range = getParameterRange(edge);
      editingEdge = edge;
    }

    if (covar > range.getLow() && covar < range.getHigh()) {
      edgeParameters.put(edge, covar);
      return true;
    } else {
      return false;
    }

    //        if (!paramInBounds(edge, coef)) {
    //            edgeParameters.put(edge, d);
    //            return false;
    //        }
    //
    //        edgeParameters.put(edge, coef);
    //        return true;

    //        if (!paramInBounds(edge, covar)) {
    //            edgeParameters.put(edge, d);
    //            return false;
    //        }
    //
    //        edgeParameters.put(edge, covar);
    //        return true;
  }
  public boolean containsParameter(Edge edge) {
    if (Edges.isBidirectedEdge(edge)) {
      edge =
          Edges.bidirectedEdge(
              semGraph.getExogenous(edge.getNode1()), semGraph.getExogenous(edge.getNode2()));
    }

    return edgeParameters.keySet().contains(edge);
  }
  /**
   * @param a a->b
   * @param b a->b
   * @return The coefficient for a->b.
   */
  public double getErrorCovariance(Node a, Node b) {
    Edge edge = Edges.bidirectedEdge(semGraph.getExogenous(a), semGraph.getExogenous(b));
    Double d = edgeParameters.get(edge);

    if (d == null) {
      throw new IllegalArgumentException("Not a covariance parameter in this model: " + edge);
    }

    return d;
  }
  /**
   * @return Returns the error covariance matrix of the model. i.e. [a][b] is the covariance of E_a
   *     and E_b, with [a][a] of course being the variance of E_a. THESE ARE NOT PARAMETERS OF THE
   *     MODEL; THEY ARE CALCULATED. Note that elements of this matrix may be Double.NaN; this
   *     indicates that these elements cannot be calculated.
   */
  private TetradMatrix errCovar(Map<Node, Double> errorVariances) {
    List<Node> variableNodes = getVariableNodes();
    List<Node> errorNodes = new ArrayList<Node>();

    for (Node node : variableNodes) {
      errorNodes.add(semGraph.getExogenous(node));
    }

    TetradMatrix errorCovar = new TetradMatrix(errorVariances.size(), errorVariances.size());

    for (int index = 0; index < errorNodes.size(); index++) {
      Node error = errorNodes.get(index);
      double variance = getErrorVariance(error);
      errorCovar.set(index, index, variance);
    }

    for (int index1 = 0; index1 < errorNodes.size(); index1++) {
      for (int index2 = 0; index2 < errorNodes.size(); index2++) {
        Node error1 = errorNodes.get(index1);
        Node error2 = errorNodes.get(index2);
        Edge edge = semGraph.getEdge(error1, error2);

        if (edge != null && Edges.isBidirectedEdge(edge)) {
          double covariance = getErrorCovariance(error1, error2);
          errorCovar.set(index1, index2, covariance);
        }
      }
    }

    return errorCovar;
  }
  public List<Node> getErrorNodes() {
    List<Node> errorNodes = new ArrayList<Node>();

    for (Node node : getVariableNodes()) {
      errorNodes.add(semGraph.getExogenous(node));
    }

    return errorNodes;
  }
  private boolean paramInBounds(Edge edge, double newValue) {
    edgeParameters.put(edge, newValue);
    Map<Node, Double> errorVariances = new HashMap<Node, Double>();
    for (Node node : semPm.getVariableNodes()) {
      Node error = semGraph.getExogenous(node);
      double d2 = calculateErrorVarianceFromParams(error);
      if (Double.isNaN(d2)) {
        return false;
      }

      errorVariances.put(error, d2);
    }

    if (!MatrixUtils.isPositiveDefinite(errCovar(errorVariances))) {
      return false;
    }

    return true;
  }
  @Test
  public void test1() {
    GeneralizedSemPm pm = makeTypicalPm();

    print(pm);

    Node x1 = pm.getNode("X1");
    Node x2 = pm.getNode("X2");
    Node x3 = pm.getNode("X3");
    Node x4 = pm.getNode("X4");
    Node x5 = pm.getNode("X5");

    SemGraph graph = pm.getGraph();

    List<Node> variablesNodes = pm.getVariableNodes();
    print(variablesNodes);

    List<Node> errorNodes = pm.getErrorNodes();
    print(errorNodes);

    try {
      pm.setNodeExpression(x1, "cos(B1) + E_X1");
      print(pm);

      String b1 = "B1";
      String b2 = "B2";
      String b3 = "B3";

      Set<Node> nodes = pm.getReferencingNodes(b1);

      assertTrue(nodes.contains(x1));
      assertTrue(!nodes.contains(x2) && !nodes.contains(x2));

      Set<String> referencedParameters = pm.getReferencedParameters(x3);

      print("Parameters referenced by X3 are: " + referencedParameters);

      assertTrue(referencedParameters.contains(b1) && referencedParameters.contains(b2));
      assertTrue(!(referencedParameters.contains(b1) && referencedParameters.contains(b3)));

      Node e_x3 = pm.getNode("E_X3");
      //
      for (Node node : pm.getNodes()) {
        Set<Node> referencingNodes = pm.getReferencingNodes(node);
        print("Nodes referencing " + node + " are: " + referencingNodes);
      }

      for (Node node : pm.getVariableNodes()) {
        Set<Node> referencingNodes = pm.getReferencedNodes(node);
        print("Nodes referenced by " + node + " are: " + referencingNodes);
      }

      Set<Node> referencingX3 = pm.getReferencingNodes(x3);
      assertTrue(referencingX3.contains(x4));
      assertTrue(!referencingX3.contains(x5));

      Set<Node> referencedByX3 = pm.getReferencedNodes(x3);
      assertTrue(
          referencedByX3.contains(x1)
              && referencedByX3.contains(x2)
              && referencedByX3.contains(e_x3)
              && !referencedByX3.contains(x4));

      pm.setNodeExpression(x5, "a * E^X2 + X4 + E_X5");

      Node e_x5 = pm.getErrorNode(x5);

      graph.setShowErrorTerms(true);
      assertTrue(e_x5.equals(graph.getExogenous(x5)));

      pm.setNodeExpression(e_x5, "Beta(3, 5)");

      print(pm);

      assertEquals("Split(-1.5,-.5,.5,1.5)", pm.getParameterExpressionString(b1));
      pm.setParameterExpression(b1, "N(0, 2)");
      assertEquals("N(0, 2)", pm.getParameterExpressionString(b1));

      GeneralizedSemIm im = new GeneralizedSemIm(pm);

      print(im);

      DataSet dataSet = im.simulateDataAvoidInfinity(10, false);

      print(dataSet);

    } catch (ParseException e) {
      e.printStackTrace();
    }
  }
  @Test
  public void test3() {
    RandomUtil.getInstance().setSeed(49293843L);

    List<Node> variableNodes = new ArrayList<>();
    ContinuousVariable x1 = new ContinuousVariable("X1");
    ContinuousVariable x2 = new ContinuousVariable("X2");
    ContinuousVariable x3 = new ContinuousVariable("X3");
    ContinuousVariable x4 = new ContinuousVariable("X4");
    ContinuousVariable x5 = new ContinuousVariable("X5");

    variableNodes.add(x1);
    variableNodes.add(x2);
    variableNodes.add(x3);
    variableNodes.add(x4);
    variableNodes.add(x5);

    Graph _graph = new EdgeListGraph(variableNodes);
    SemGraph graph = new SemGraph(_graph);
    graph.setShowErrorTerms(true);

    Node e1 = graph.getExogenous(x1);
    Node e2 = graph.getExogenous(x2);
    Node e3 = graph.getExogenous(x3);
    Node e4 = graph.getExogenous(x4);
    Node e5 = graph.getExogenous(x5);

    graph.addDirectedEdge(x1, x3);
    graph.addDirectedEdge(x1, x2);
    graph.addDirectedEdge(x2, x3);
    graph.addDirectedEdge(x3, x4);
    graph.addDirectedEdge(x2, x4);
    graph.addDirectedEdge(x4, x5);
    graph.addDirectedEdge(x2, x5);
    graph.addDirectedEdge(x5, x1);

    GeneralizedSemPm pm = new GeneralizedSemPm(graph);

    List<Node> variablesNodes = pm.getVariableNodes();
    print(variablesNodes);

    List<Node> errorNodes = pm.getErrorNodes();
    print(errorNodes);

    try {
      pm.setNodeExpression(x1, "cos(b1) + a1 * X5 + E_X1");
      pm.setNodeExpression(x2, "a2 * X1 + E_X2");
      pm.setNodeExpression(x3, "tan(a3*X2 + a4*X1) + E_X3");
      pm.setNodeExpression(x4, "0.1 * E^X2 + X3 + E_X4");
      pm.setNodeExpression(x5, "0.1 * E^X4 + a6* X2 + E_X5");
      pm.setNodeExpression(e1, "U(0, 1)");
      pm.setNodeExpression(e2, "U(0, 1)");
      pm.setNodeExpression(e3, "U(0, 1)");
      pm.setNodeExpression(e4, "U(0, 1)");
      pm.setNodeExpression(e5, "U(0, 1)");

      GeneralizedSemIm im = new GeneralizedSemIm(pm);

      print(im);

      DataSet dataSet = im.simulateDataNSteps(1000, false);

      //            System.out.println(dataSet);

      double[] d1 = dataSet.getDoubleData().getColumn(0).toArray();
      double[] d2 = dataSet.getDoubleData().getColumn(1).toArray();

      double cov = StatUtils.covariance(d1, d2);

      assertEquals(-0.002, cov, 0.001);
    } catch (ParseException e) {
      e.printStackTrace();
    }
  }
  /**
   * @param edge a->b or a<->b.
   * @return the range of the covariance parameter for a->b or a<->b.
   */
  public ParameterRange getParameterRange(Edge edge) {
    if (Edges.isBidirectedEdge(edge)) {
      edge =
          Edges.bidirectedEdge(
              semGraph.getExogenous(edge.getNode1()), semGraph.getExogenous(edge.getNode2()));
    }

    if (!(edgeParameters.keySet().contains(edge))) {
      throw new IllegalArgumentException("Not an edge in this model: " + edge);
    }

    double initial = edgeParameters.get(edge);

    if (initial == Double.NEGATIVE_INFINITY) {
      initial = Double.MIN_VALUE;
    } else if (initial == Double.POSITIVE_INFINITY) {
      initial = Double.MAX_VALUE;
    }

    double value = initial;

    // look upward for a point that fails.
    double high = value + 1;

    while (paramInBounds(edge, high)) {
      high = value + 2 * (high - value);

      if (high == Double.POSITIVE_INFINITY) {
        break;
      }
    }

    // find the boundary using binary search.
    double rangeHigh;

    if (high == Double.POSITIVE_INFINITY) {
      rangeHigh = high;
    } else {
      double low = value;

      while (high - low > 1e-10) {
        double midpoint = (high + low) / 2.0;

        if (paramInBounds(edge, midpoint)) {
          low = midpoint;
        } else {
          high = midpoint;
        }
      }

      rangeHigh = (high + low) / 2.0;
    }

    // look downard for a point that fails.
    double low = value - 1;

    while (paramInBounds(edge, low)) {
      low = value - 2 * (value - low);

      if (low == Double.NEGATIVE_INFINITY) {
        break;
      }
    }

    double rangeLow;

    if (low == Double.NEGATIVE_INFINITY) {
      rangeLow = low;
    } else {

      // find the boundary using binary search.
      high = value;

      while (high - low > 1e-10) {
        double midpoint = (high + low) / 2.0;

        if (paramInBounds(edge, midpoint)) {
          high = midpoint;
        } else {
          low = midpoint;
        }
      }

      rangeLow = (high + low) / 2.0;
    }

    if (Edges.isDirectedEdge(edge)) {
      edgeParameters.put(edge, initial);
    } else if (Edges.isBidirectedEdge(edge)) {
      edgeParameters.put(edge, initial);
    }

    return new ParameterRange(edge, value, rangeLow, rangeHigh);
  }
 public ParameterRange getCovarianceRange(Node a, Node b) {
   return getParameterRange(
       Edges.bidirectedEdge(semGraph.getExogenous(a), semGraph.getExogenous(b)));
 }
  /**
   * Constructs a new standardized SEM IM from the freeParameters in the given SEM IM.
   *
   * @param im Stop asking me for these things! The given SEM IM!!!
   * @param initialization CALCULATE_FROM_SEM if the initial values will be calculated from the
   *     given SEM IM; INITIALIZE_FROM_DATA if data will be simulated from the given SEM,
   *     standardized, and estimated.
   */
  public StandardizedSemIm(SemIm im, Initialization initialization) {
    this.semPm = new SemPm(im.getSemPm());
    this.semGraph = new SemGraph(semPm.getGraph());
    semGraph.setShowErrorTerms(true);

    if (semGraph.existsDirectedCycle()) {
      throw new IllegalArgumentException("The cyclic case is not handled.");
    }

    if (initialization == Initialization.CALCULATE_FROM_SEM) {
      //         This code calculates the new coefficients directly from the old ones.
      edgeParameters = new HashMap<Edge, Double>();

      List<Node> nodes = im.getVariableNodes();
      TetradMatrix impliedCovar = im.getImplCovar(true);

      for (Parameter parameter : im.getSemPm().getParameters()) {
        if (parameter.getType() == ParamType.COEF) {
          Node a = parameter.getNodeA();
          Node b = parameter.getNodeB();
          int aindex = nodes.indexOf(a);
          int bindex = nodes.indexOf(b);
          double vara = impliedCovar.get(aindex, aindex);
          double stda = Math.sqrt(vara);
          double varb = impliedCovar.get(bindex, bindex);
          double stdb = Math.sqrt(varb);
          double oldCoef = im.getEdgeCoef(a, b);
          double newCoef = (stda / stdb) * oldCoef;
          edgeParameters.put(Edges.directedEdge(a, b), newCoef);
        } else if (parameter.getType() == ParamType.COVAR) {
          Node a = parameter.getNodeA();
          Node b = parameter.getNodeB();
          Node exoa = semGraph.getExogenous(a);
          Node exob = semGraph.getExogenous(b);
          double covar = im.getErrCovar(a, b) / Math.sqrt(im.getErrVar(a) * im.getErrVar(b));
          edgeParameters.put(Edges.bidirectedEdge(exoa, exob), covar);
        }
      }
    } else {

      // This code estimates the new coefficients from simulated data from the old model.
      DataSet dataSet = im.simulateData(1000, false);
      TetradMatrix _dataSet = dataSet.getDoubleData();
      _dataSet = DataUtils.standardizeData(_dataSet);
      DataSet dataSetStandardized = ColtDataSet.makeData(dataSet.getVariables(), _dataSet);

      SemEstimator estimator = new SemEstimator(dataSetStandardized, im.getSemPm());
      SemIm imStandardized = estimator.estimate();

      edgeParameters = new HashMap<Edge, Double>();

      for (Parameter parameter : imStandardized.getSemPm().getParameters()) {
        if (parameter.getType() == ParamType.COEF) {
          Node a = parameter.getNodeA();
          Node b = parameter.getNodeB();
          double coef = imStandardized.getEdgeCoef(a, b);
          edgeParameters.put(Edges.directedEdge(a, b), coef);
        } else if (parameter.getType() == ParamType.COVAR) {
          Node a = parameter.getNodeA();
          Node b = parameter.getNodeB();
          Node exoa = semGraph.getExogenous(a);
          Node exob = semGraph.getExogenous(b);
          double covar = -im.getErrCovar(a, b) / Math.sqrt(im.getErrVar(a) * im.getErrVar(b));
          edgeParameters.put(Edges.bidirectedEdge(exoa, exob), covar);
        }
      }
    }

    this.measuredNodes = Collections.unmodifiableList(semPm.getMeasuredNodes());
  }