예제 #1
0
  public static JTable createErrorTable(PredictionCenter center) {
    JTable errorTable = null;
    DefaultTableModel errorTableModel = null;
    StatisticalPackageRunner runner = DBSeerGUI.runner;

    Object[] maeList = (Object[]) runner.getVariableCell("meanAbsError");
    Object[] mreList = (Object[]) runner.getVariableCell("meanRelError");
    Object[] headers = (Object[]) runner.getVariableCell("errorHeader");

    if (maeList.length > 0 || mreList.length > 0) {
      errorTableModel =
          new DefaultTableModel() {
            @Override
            public boolean isCellEditable(int row, int column) {
              return false;
            }
          };
      // errorTable = new JTable();
      errorTableModel.addColumn(null, new String[] {"", "MAE", "MRE"}); // first empty column

      final java.util.List<String> transactionNames =
          center.getTrainConfig().getDataset(0).getTransactionTypeNames();
      for (int i = 0; i < maeList.length; ++i) {
        Object maeObj = maeList[i];
        Object mreObj = mreList[i];

        double[] mae = (double[]) maeObj;
        double[] mre = (double[]) mreObj;

        String header = (String) headers[i];
        for (int j = 0; j < transactionNames.size(); ++j) {
          if (header.contains("Type " + (j + 1))) {
            headers[i] = header.replace("Type " + (j + 1), transactionNames.get(j));
            break;
          }
        }

        errorTableModel.addColumn(
            null,
            new Object[] {
              headers[i], String.format("%.3f", mae[0]), String.format("%.3f", mre[0])
            });
      }
      errorTable = new JTable(errorTableModel);
    }

    return errorTable;
  }
예제 #2
0
  public static JFreeChart createPredictionBarChart(PredictionCenter center) {
    StatisticalPackageRunner runner = DBSeerGUI.runner;

    String title = runner.getVariableString("title");
    Object[] legends = (Object[]) runner.getVariableCell("legends");
    Object[] xCellArray = (Object[]) runner.getVariableCell("Xdata");
    Object[] yCellArray = (Object[]) runner.getVariableCell("Ydata");
    String xLabel = runner.getVariableString("Xlabel");
    String yLabel = runner.getVariableString("Ylabel");

    DefaultCategoryDataset dataset = new DefaultCategoryDataset();

    int numLegends = legends.length;
    int numXCellArray = xCellArray.length;
    int numYCellArray = yCellArray.length;
    int dataCount = 0;

    final java.util.List<String> transactionNames =
        center.getTrainConfig().getDataset(0).getTransactionTypeNames();
    for (int i = 0; i < numLegends; ++i) {
      String legend = (String) legends[i];
      for (int j = 0; j < transactionNames.size(); ++j) {
        if (legend.contains("Type " + (j + 1))) {
          legends[i] = legend.replace("Type " + (j + 1), transactionNames.get(j));
          break;
        }
      }
    }
    for (int j = 0; j < transactionNames.size(); ++j) {
      if (xLabel.contains("Type " + (j + 1))) {
        xLabel = xLabel.replace("Type " + (j + 1), transactionNames.get(j));
        break;
      }
    }
    for (int j = 0; j < transactionNames.size(); ++j) {
      if (yLabel.contains("Type " + (j + 1))) {
        yLabel = yLabel.replace("Type " + (j + 1), transactionNames.get(j));
        break;
      }
    }

    for (int i = 0; i < numYCellArray; ++i) {
      runner.eval("yArraySize = size(Ydata{" + (i + 1) + "});");
      runner.eval("yArray = Ydata{" + (i + 1) + "};");
      double[] yArraySize = runner.getVariableDouble("yArraySize");
      double[] yArray = runner.getVariableDouble("yArray");

      int row = (int) yArraySize[0];
      int col = (int) yArraySize[1];

      for (int c = 0; c < col; ++c) {
        String category = "";
        int legendIdx = (dataCount >= numLegends) ? numLegends - 1 : dataCount;
        String legend = (String) legends[legendIdx];
        if (numLegends == 0) {
          category = "Data " + dataCount + 1;
        } else if (dataCount >= numLegends) {
          category = legend + (dataCount + 1);
        } else {
          category = legend;
        }

        for (int r = 0; r < row; ++r) {
          double yValue = yArray[r + c * row];
          // remove negatives.
          if (yValue < 0
              || yValue == Double.NaN
              || yValue == Double.POSITIVE_INFINITY
              || yValue == Double.NEGATIVE_INFINITY) {
            yValue = 0.0;
          }

          dataset.addValue(yValue, category, "");
        }
        ++dataCount;
      }
    }

    JFreeChart chart = ChartFactory.createBarChart(title, xLabel, yLabel, dataset);

    return chart;
  }
예제 #3
0
  public static JFreeChart createXYLinePredictionChart(PredictionCenter center) {
    StatisticalPackageRunner runner = DBSeerGUI.runner;

    String title = runner.getVariableString("title");
    Object[] legends = (Object[]) runner.getVariableCell("legends");
    Object[] xCellArray = (Object[]) runner.getVariableCell("Xdata");
    Object[] yCellArray = (Object[]) runner.getVariableCell("Ydata");
    String xLabel = runner.getVariableString("Xlabel");
    String yLabel = runner.getVariableString("Ylabel");

    XYSeriesCollection dataSet = new XYSeriesCollection();

    int numLegends = legends.length;
    int numXCellArray = xCellArray.length;
    int numYCellArray = yCellArray.length;
    int dataCount = 0;

    if (numXCellArray != numYCellArray) {
      JOptionPane.showMessageDialog(
          null,
          "The number of X dataset and Y dataset does not match.",
          "The number of X dataset and Y dataset does not match.",
          JOptionPane.ERROR_MESSAGE);
      System.out.println(numXCellArray + " : " + numYCellArray);
      return null;
    }

    final java.util.List<String> transactionNames =
        center.getTrainConfig().getDataset(0).getTransactionTypeNames();
    for (int i = 0; i < numLegends; ++i) {
      String legend = (String) legends[i];
      for (int j = 0; j < transactionNames.size(); ++j) {
        if (legend.contains("Type " + (j + 1))) {
          legends[i] = legend.replace("Type " + (j + 1), transactionNames.get(j));
          break;
        }
      }
    }
    for (int j = 0; j < transactionNames.size(); ++j) {
      if (xLabel.contains("Type " + (j + 1))) {
        xLabel = xLabel.replace("Type " + (j + 1), transactionNames.get(j));
        break;
      }
    }
    for (int j = 0; j < transactionNames.size(); ++j) {
      if (yLabel.contains("Type " + (j + 1))) {
        yLabel = yLabel.replace("Type " + (j + 1), transactionNames.get(j));
        break;
      }
    }

    for (int i = 0; i < numYCellArray; ++i) {
      double[] xArray = (double[]) xCellArray[i];
      runner.eval("yArraySize = size(Ydata{" + (i + 1) + "});");
      runner.eval("yArray = Ydata{" + (i + 1) + "};");
      double[] yArraySize = runner.getVariableDouble("yArraySize");
      double[] yArray = runner.getVariableDouble("yArray");

      int xLength = xArray.length;
      int row = (int) yArraySize[0];
      int col = (int) yArraySize[1];

      for (int c = 0; c < col; ++c) {
        XYSeries series;
        int legendIdx = (dataCount >= numLegends) ? numLegends - 1 : dataCount;
        String legend = (String) legends[legendIdx];
        if (numLegends == 0) {
          series = new XYSeries("Data " + dataCount + 1);
        } else if (dataCount >= numLegends) {
          series = new XYSeries(legend + (dataCount + 1));
        } else {
          series = new XYSeries(legend);
        }

        for (int r = 0; r < row; ++r) {
          int xRow = (r >= xLength) ? xLength - 1 : r;
          double yValue = yArray[r + c * row];
          // remove negatives & NaN & infs.
          if (yValue < 0
              || yValue == Double.NaN
              || yValue == Double.POSITIVE_INFINITY
              || yValue == Double.NEGATIVE_INFINITY) {
            yValue = 0.0;
          }
          series.add(xArray[xRow], yValue);
        }
        dataSet.addSeries(series);
        ++dataCount;
      }
    }

    JFreeChart chart = ChartFactory.createXYLineChart(title, xLabel, yLabel, dataSet);

    // change 'predicted' data to have dotted lines.
    BasicStroke dashStroke = toStroke(STYLE_DASH);
    BasicStroke dotStroke = toStroke(STYLE_DOT);
    BasicStroke lineStroke = toStroke(STYLE_LINE);
    for (int i = 0; i < dataSet.getSeriesCount(); ++i) {
      String legend = (String) dataSet.getSeriesKey(i);
      XYPlot plot = chart.getXYPlot();
      XYItemRenderer renderer = plot.getRenderer();
      if (legend.contains("predicted") || legend.contains("Predicted")) {
        renderer.setSeriesStroke(i, dotStroke);
      } else {
        renderer.setSeriesStroke(i, lineStroke);
      }
    }

    return chart;
  }