/*
  * (non-Javadoc)
  * @see
  * edu.brown.costmodel.AbstractCostModel#estimateCost(org.voltdb.catalog
  * .Database, edu.brown.workload.TransactionTrace,
  * edu.brown.workload.AbstractWorkload.Filter)
  */
 @Override
 public double estimateTransactionCost(
     CatalogContext catalogContext, Workload workload, Filter filter, TransactionTrace xact)
     throws Exception {
   assert (workload != null) : "The workload handle is null";
   // First figure out the time interval of this
   int interval = workload.getTimeInterval(xact, this.cost_models.length);
   return (this.cost_models[interval].estimateTransactionCost(
       catalogContext, workload, filter, xact));
 }
  @Override
  protected double estimateWorkloadCostImpl(
      final CatalogContext catalogContext,
      final Workload workload,
      final Filter filter,
      final Double upper_bound)
      throws Exception {

    if (debug.val)
      LOG.debug(
          "Calculating workload execution cost across "
              + num_intervals
              + " intervals for "
              + num_partitions
              + " partitions");

    // (1) Grab the costs at the different time intervals
    //     Also create the ratios that we will use to weight the interval costs
    final AtomicLong total_txns = new AtomicLong(0);

    // final HashSet<Long> trace_ids[] = new HashSet[num_intervals];
    for (int i = 0; i < num_intervals; i++) {
      total_interval_txns[i] = 0;
      total_interval_queries[i] = 0;
      singlepartition_ctrs[i] = 0;
      singlepartition_with_partitions_ctrs[i] = 0;
      multipartition_ctrs[i] = 0;
      partitions_touched[i] = 0;
      incomplete_txn_ctrs[i] = 0;
      exec_mismatch_ctrs[i] = 0;
      incomplete_txn_histogram[i].clear();
      missing_txn_histogram[i].clear();
      exec_histogram[i].clear();
    } // FOR

    // (2) Now go through the workload and estimate the partitions that each txn
    //     will touch for the given catalog setups
    if (trace.val) {
      LOG.trace("Total # of Txns in Workload: " + workload.getTransactionCount());
      if (filter != null)
        LOG.trace(
            "Workload Filter Chain:       " + StringUtil.join("   ", "\n", filter.getFilters()));
    }

    // QUEUING THREAD
    tmp_consumers.clear();
    Producer<TransactionTrace, Pair<TransactionTrace, Integer>> producer =
        new Producer<TransactionTrace, Pair<TransactionTrace, Integer>>(
            CollectionUtil.iterable(workload.iterator(filter))) {
          @Override
          public Pair<Consumer<Pair<TransactionTrace, Integer>>, Pair<TransactionTrace, Integer>>
              transform(TransactionTrace txn_trace) {
            int i = workload.getTimeInterval(txn_trace, num_intervals);
            assert (i >= 0)
                : "Invalid time interval '" + i + "'\n" + txn_trace.debug(catalogContext.database);
            assert (i < num_intervals)
                : "Invalid interval: " + i + "\n" + txn_trace.debug(catalogContext.database);
            total_txns.incrementAndGet();
            Pair<TransactionTrace, Integer> p = Pair.of(txn_trace, i);
            return (Pair.of(tmp_consumers.get(i), p));
          }
        };

    // PROCESSING THREADS
    final int num_threads = ThreadUtil.getMaxGlobalThreads();
    int interval_ctr = 0;
    for (int thread = 0; thread < num_threads; thread++) {
      // First create a new IntervalProcessor/Consumer
      IntervalProcessor ip = new IntervalProcessor(catalogContext, workload, filter);

      // Then assign it to some number of intervals
      for (int i = 0, cnt = (int) Math.ceil(num_intervals / (double) num_threads); i < cnt; i++) {
        if (interval_ctr > num_intervals) break;
        tmp_consumers.put(interval_ctr++, ip);
        if (trace.val)
          LOG.trace(
              String.format("Interval #%02d => IntervalProcessor #%02d", interval_ctr - 1, thread));
      } // FOR

      // And make sure that we queue it up too
      producer.addConsumer(ip);
    } // FOR (threads)

    ThreadUtil.runGlobalPool(producer.getRunnablesList()); // BLOCKING
    if (debug.val) {
      int processed = 0;
      for (Consumer<?> c : producer.getConsumers()) {
        processed += c.getProcessedCounter();
      } // FOR
      assert (total_txns.get() == processed)
          : String.format("Expected[%d] != Processed[%d]", total_txns.get(), processed);
    }

    // We have to convert all of the costs into the range of [0.0, 1.0]
    // For each interval, divide the number of partitions touched by the total number
    // of partitions that the interval could have touched (worst case scenario)
    final double execution_costs[] = new double[num_intervals];
    StringBuilder sb = (this.isDebugEnabled() || debug.get() ? new StringBuilder() : null);
    Map<String, Object> debug_m = null;
    if (sb != null) {
      debug_m = new LinkedHashMap<String, Object>();
    }

    if (debug.val)
      LOG.debug("Calculating execution cost for " + this.num_intervals + " intervals...");
    long total_multipartition_txns = 0;
    for (int i = 0; i < this.num_intervals; i++) {
      interval_weights[i] = total_interval_txns[i] / (double) total_txns.get();
      long total_txns_in_interval = (long) total_interval_txns[i];
      long total_queries_in_interval = (long) total_interval_queries[i];
      long num_txns = this.cost_models[i].txn_ctr.get();
      long potential_txn_touches = (total_txns_in_interval * num_partitions); // TXNS
      double penalty = 0.0d;
      total_multipartition_txns += multipartition_ctrs[i];

      // Divide the total number of partitions touched by...
      // This is the total number of partitions that we could have touched
      // in this interval
      // And this is the total number of partitions that we did actually touch
      if (multipartition_ctrs[i] > 0) {
        assert (partitions_touched[i] > 0) : "No touched partitions for interval " + i;
        double cost = (partitions_touched[i] / (double) potential_txn_touches);

        if (this.use_multitpartition_penalty) {
          penalty =
              this.multipartition_penalty
                  * (1.0d + (multipartition_ctrs[i] / (double) total_txns_in_interval));
          assert (penalty >= 1.0) : "The multipartition penalty is less than one: " + penalty;
          cost *= penalty;
        }
        execution_costs[i] = Math.min(cost, (double) potential_txn_touches);
      }

      // For each txn that wasn't even evaluated, add all of the
      // partitions to the incomplete histogram
      if (num_txns < total_txns_in_interval) {
        if (trace.val)
          LOG.trace(
              "Adding "
                  + (total_txns_in_interval - num_txns)
                  + " entries to the incomplete histogram for interval #"
                  + i);
        for (long ii = num_txns; ii < total_txns_in_interval; ii++) {
          missing_txn_histogram[i].put(all_partitions);
        } // WHILE
      }

      if (sb != null) {
        tmp_penalties.add(penalty);
        tmp_total.add(total_txns_in_interval);
        tmp_touched.add(partitions_touched[i]);
        tmp_potential.add(potential_txn_touches);

        Map<String, Object> inner = new LinkedHashMap<String, Object>();
        inner.put("Partitions Touched", partitions_touched[i]);
        inner.put("Potential Touched", potential_txn_touches);
        inner.put("Multi-Partition Txns", multipartition_ctrs[i]);
        inner.put("Total Txns", total_txns_in_interval);
        inner.put("Total Queries", total_queries_in_interval);
        inner.put("Missing Txns", (total_txns_in_interval - num_txns));
        inner.put("Cost", String.format("%.05f", execution_costs[i]));
        inner.put("Exec Txns", exec_histogram[i].getSampleCount());
        debug_m.put("Interval #" + i, inner);
      }
    } // FOR

    if (sb != null) {
      Map<String, Object> m0 = new LinkedHashMap<String, Object>();
      m0.put("SinglePartition Txns", (total_txns.get() - total_multipartition_txns));
      m0.put("MultiPartition Txns", total_multipartition_txns);
      m0.put(
          "Total Txns",
          String.format(
              "%d [%.06f]",
              total_txns.get(), (1.0d - (total_multipartition_txns / (double) total_txns.get()))));

      Map<String, Object> m1 = new LinkedHashMap<String, Object>();
      m1.put("Touched Partitions", tmp_touched);
      m1.put("Potential Partitions", tmp_potential);
      m1.put("Total Partitions", tmp_total);
      m1.put("Penalties", tmp_penalties);

      sb.append(StringUtil.formatMaps(debug_m, m0, m1));
      if (debug.val) LOG.debug("**** Execution Cost ****\n" + sb);
      this.appendDebugMessage(sb);
    }

    // LOG.debug("Execution By Intervals:\n" + sb.toString());

    // (3) We then need to go through and grab the histograms of partitions were accessed
    if (sb != null) {
      if (debug.val)
        LOG.debug("Calculating skew factor for " + this.num_intervals + " intervals...");
      debug_histograms.clear();
      sb = new StringBuilder();
    }
    for (int i = 0; i < this.num_intervals; i++) {
      ObjectHistogram<Integer> histogram_txn = this.cost_models[i].getTxnPartitionAccessHistogram();
      ObjectHistogram<Integer> histogram_query =
          this.cost_models[i].getQueryPartitionAccessHistogram();
      this.histogram_query_partitions.put(histogram_query);
      long num_queries = this.cost_models[i].query_ctr.get();
      this.query_ctr.addAndGet(num_queries);

      // DEBUG
      SingleSitedCostModel inner_costModel = (SingleSitedCostModel) this.cost_models[i];
      boolean is_valid =
          (partitions_touched[i] + singlepartition_with_partitions_ctrs[i])
              == (this.cost_models[i].getTxnPartitionAccessHistogram().getSampleCount()
                  + exec_mismatch_ctrs[i]);
      if (!is_valid) {
        LOG.error("Transaction Entries: " + inner_costModel.getTransactionCacheEntries().size());
        ObjectHistogram<Integer> check = new ObjectHistogram<Integer>();
        for (TransactionCacheEntry tce : inner_costModel.getTransactionCacheEntries()) {
          check.put(tce.getTouchedPartitions());
          // LOG.error(tce.debug() + "\n");
        }
        LOG.error(
            "Check Touched Partitions: sample="
                + check.getSampleCount()
                + ", values="
                + check.getValueCount());
        LOG.error(
            "Cache Touched Partitions: sample="
                + this.cost_models[i].getTxnPartitionAccessHistogram().getSampleCount()
                + ", values="
                + this.cost_models[i].getTxnPartitionAccessHistogram().getValueCount());

        int qtotal = inner_costModel.getAllQueryCacheEntries().size();
        int ctr = 0;
        int multip = 0;
        for (QueryCacheEntry qce : inner_costModel.getAllQueryCacheEntries()) {
          ctr += (qce.getAllPartitions().isEmpty() ? 0 : 1);
          multip += (qce.getAllPartitions().size() > 1 ? 1 : 0);
        } // FOR
        LOG.error("# of QueryCacheEntries with Touched Partitions: " + ctr + " / " + qtotal);
        LOG.error("# of MultiP QueryCacheEntries: " + multip);
      }
      assert (is_valid)
          : String.format(
              "Partitions Touched by Txns Mismatch in Interval #%d\n"
                  + "(partitions_touched[%d] + singlepartition_with_partitions_ctrs[%d]) != "
                  + "(histogram_txn[%d] + exec_mismatch_ctrs[%d])",
              i,
              partitions_touched[i],
              singlepartition_with_partitions_ctrs[i],
              this.cost_models[i].getTxnPartitionAccessHistogram().getSampleCount(),
              exec_mismatch_ctrs[i]);

      this.histogram_java_partitions.put(this.cost_models[i].getJavaExecutionHistogram());
      this.histogram_txn_partitions.put(histogram_txn);
      long num_txns = this.cost_models[i].txn_ctr.get();
      assert (num_txns >= 0) : "The transaction counter at interval #" + i + " is " + num_txns;
      this.txn_ctr.addAndGet(num_txns);

      // Calculate the skew factor at this time interval
      // XXX: Should the number of txns be the total number of unique txns
      //      that were executed or the total number of times a txn touched the partitions?
      // XXX: What do we do when the number of elements that we are examining is zero?
      //      I guess the cost just needs to be zero?
      // XXX: What histogram do we want to use?
      target_histogram.clear();
      target_histogram.put(histogram_txn);

      // For each txn that we haven't gotten an estimate for at this interval,
      // we're going mark it as being broadcast to all partitions. That way the access
      // histogram will look uniform. Then as more information is added, we will
      // This is an attempt to make sure that the skew cost never decreases but only increases
      long total_txns_in_interval = (long) total_interval_txns[i];
      if (sb != null) {
        debug_histograms.put("Incomplete Txns", incomplete_txn_histogram[i]);
        debug_histograms.put("Missing Txns", missing_txn_histogram[i]);
        debug_histograms.put(
            "Target Partitions (BEFORE)", new ObjectHistogram<Integer>(target_histogram));
        debug_histograms.put("Target Partitions (AFTER)", target_histogram);
      }

      // Merge the values from incomplete histogram into the target
      // histogram
      target_histogram.put(incomplete_txn_histogram[i]);
      target_histogram.put(missing_txn_histogram[i]);
      exec_histogram[i].put(missing_txn_histogram[i]);

      long num_elements = target_histogram.getSampleCount();

      // The number of partition touches should never be greater than our
      // potential touches
      assert (num_elements <= (total_txns_in_interval * num_partitions))
          : "New Partitions Touched Sample Count ["
              + num_elements
              + "] < "
              + "Maximum Potential Touched Count ["
              + (total_txns_in_interval * num_partitions)
              + "]";

      if (sb != null) {
        Map<String, Object> m = new LinkedHashMap<String, Object>();
        for (String key : debug_histograms.keySet()) {
          ObjectHistogram<?> h = debug_histograms.get(key);
          m.put(
              key,
              String.format("[Sample=%d, Value=%d]\n%s", h.getSampleCount(), h.getValueCount(), h));
        } // FOR
        sb.append(
            String.format(
                "INTERVAL #%d [total_txns_in_interval=%d, num_txns=%d, incomplete_txns=%d]\n%s",
                i,
                total_txns_in_interval,
                num_txns,
                incomplete_txn_ctrs[i],
                StringUtil.formatMaps(m)));
      }

      // Txn Skew
      if (num_elements == 0) {
        txn_skews[i] = 0.0d;
      } else {
        txn_skews[i] = SkewFactorUtil.calculateSkew(num_partitions, num_elements, target_histogram);
      }

      // Exec Skew
      if (exec_histogram[i].getSampleCount() == 0) {
        exec_skews[i] = 0.0d;
      } else {
        exec_skews[i] =
            SkewFactorUtil.calculateSkew(
                num_partitions, exec_histogram[i].getSampleCount(), exec_histogram[i]);
      }
      total_skews[i] = (0.5 * exec_skews[i]) + (0.5 * txn_skews[i]);

      if (sb != null) {
        sb.append("Txn Skew   = " + MathUtil.roundToDecimals(txn_skews[i], 6) + "\n");
        sb.append("Exec Skew  = " + MathUtil.roundToDecimals(exec_skews[i], 6) + "\n");
        sb.append("Total Skew = " + MathUtil.roundToDecimals(total_skews[i], 6) + "\n");
        sb.append(StringUtil.DOUBLE_LINE);
      }
    } // FOR
    if (sb != null && sb.length() > 0) {
      if (debug.val) LOG.debug("**** Skew Factor ****\n" + sb);
      this.appendDebugMessage(sb);
    }
    if (trace.val) {
      for (int i = 0; i < num_intervals; i++) {
        LOG.trace(
            "Time Interval #"
                + i
                + "\n"
                + "Total # of Txns: "
                + this.cost_models[i].txn_ctr.get()
                + "\n"
                + "Multi-Partition Txns: "
                + multipartition_ctrs[i]
                + "\n"
                + "Execution Cost: "
                + execution_costs[i]
                + "\n"
                + "ProcHistogram:\n"
                + this.cost_models[i].getProcedureHistogram().toString()
                + "\n"
                +
                // "TransactionsPerPartitionHistogram:\n" +
                // this.cost_models[i].getTxnPartitionAccessHistogram()
                // + "\n" +
                StringUtil.SINGLE_LINE);
      }
    }

    // (3) We can now calculate the final total estimate cost of this workload as the following
    //     Just take the simple ratio of mp txns / all txns
    this.last_execution_cost =
        MathUtil.weightedMean(
            execution_costs,
            total_interval_txns); // MathUtil.roundToDecimals(MathUtil.geometricMean(execution_costs,
    // MathUtil.GEOMETRIC_MEAN_ZERO),
    // 10);

    // The final skew cost needs to be weighted by the percentage of txns running in that interval
    // This will cause the partitions with few txns
    this.last_skew_cost =
        MathUtil.weightedMean(
            total_skews, total_interval_txns); // roundToDecimals(MathUtil.geometricMean(entropies,
    // MathUtil.GEOMETRIC_MEAN_ZERO),
    // 10);
    double new_final_cost =
        (this.use_execution ? (this.execution_weight * this.last_execution_cost) : 0)
            + (this.use_skew ? (this.skew_weight * this.last_skew_cost) : 0);

    if (sb != null) {
      Map<String, Object> m = new LinkedHashMap<String, Object>();
      m.put("Total Txns", total_txns.get());
      m.put("Interval Txns", Arrays.toString(total_interval_txns));
      m.put("Execution Costs", Arrays.toString(execution_costs));
      m.put("Skew Factors", Arrays.toString(total_skews));
      m.put("Txn Skew", Arrays.toString(txn_skews));
      m.put("Exec Skew", Arrays.toString(exec_skews));
      m.put("Interval Weights", Arrays.toString(interval_weights));
      m.put(
          "Final Cost",
          String.format(
              "%f = %f + %f", new_final_cost, this.last_execution_cost, this.last_skew_cost));
      if (debug.val) LOG.debug(StringUtil.formatMaps(m));
      this.appendDebugMessage(StringUtil.formatMaps(m));
    }

    this.last_final_cost = new_final_cost;
    return (MathUtil.roundToDecimals(this.last_final_cost, 5));
  }