Example #1
0
public final class ResumableTasks {
  private static final Logger log = Logger.get(ResumableTasks.class);

  private ResumableTasks() {}

  public static void submit(Executor executor, ResumableTask task) {
    AtomicReference<Runnable> runnableReference = new AtomicReference<>();
    Runnable runnable =
        () -> {
          ResumableTask.TaskStatus status = safeProcessTask(task);
          if (!status.isFinished()) {
            status.getContinuationFuture().thenRun(() -> executor.execute(runnableReference.get()));
          }
        };
    runnableReference.set(runnable);
    executor.execute(runnable);
  }

  private static ResumableTask.TaskStatus safeProcessTask(ResumableTask task) {
    try {
      return task.process();
    } catch (Throwable t) {
      log.warn(t, "ResumableTask completed exceptionally");
      return ResumableTask.TaskStatus.finished();
    }
  }
}
 private void dropTable(SchemaTableName table) {
   try {
     metastoreClient.dropTable(table.getSchemaName(), table.getTableName());
   } catch (RuntimeException e) {
     Logger.get(getClass()).warn(e, "Failed to drop table: %s", table);
   }
 }
Example #3
0
public class JmxAgent {
  private final int registryPort;
  private final int serverPort;

  private static final Logger log = Logger.get(JmxAgent.class);
  private final JMXServiceURL url;

  @Inject
  public JmxAgent(JmxConfig config) throws IOException {
    if (config.getRmiRegistryPort() == null) {
      registryPort = NetUtils.findUnusedPort();
    } else {
      registryPort = config.getRmiRegistryPort();
    }

    if (config.getRmiServerPort() == null) {
      serverPort = NetUtils.findUnusedPort();
    } else {
      serverPort = config.getRmiServerPort();
    }

    try {
      // This is how the jdk jmx agent constructs its url
      url = new JMXServiceURL("rmi", null, registryPort);
    } catch (MalformedURLException e) {
      // should not happen...
      throw new AssertionError(e);
    }
  }

  public JMXServiceURL getURL() {
    return url;
  }

  @PostConstruct
  public void start() throws IOException {
    // This is somewhat of a hack, but the jmx agent in Oracle/OpenJDK doesn't
    // have a programmatic API for starting it and controlling its parameters
    System.setProperty("com.sun.management.jmxremote", "true");
    System.setProperty("com.sun.management.jmxremote.port", Integer.toString(registryPort));
    System.setProperty("com.sun.management.jmxremote.rmi.port", Integer.toString(serverPort));
    System.setProperty("com.sun.management.jmxremote.authenticate", "false");
    System.setProperty("com.sun.management.jmxremote.ssl", "false");

    try {
      Agent.startAgent();
    } catch (Exception e) {
      throw Throwables.propagate(e);
    }

    log.info("JMX Agent listening on %s:%s", url.getHost(), url.getPort());
  }
}
 @Test
 public void testViewCreation() {
   try {
     verifyViewCreation();
   } finally {
     try {
       metadata.dropView(SESSION, temporaryCreateView);
     } catch (RuntimeException e) {
       Logger.get(getClass()).warn(e, "Failed to drop view: %s", temporaryCreateView);
     }
   }
 }
Example #5
0
  public static final class Utils {
    public static final Logger log = Logger.get(AliasDao.class);

    public static void createTables(AliasDao dao) {
      dao.createAliasTable();
    }

    public static void createTablesWithRetry(AliasDao dao) throws InterruptedException {
      Duration delay = new Duration(10, TimeUnit.SECONDS);
      while (true) {
        try {
          createTables(dao);
          return;
        } catch (UnableToObtainConnectionException e) {
          log.warn(
              "Failed to connect to database. Will retry again in %s. Exception: %s",
              delay, e.getMessage());
          Thread.sleep(delay.toMillis());
        }
      }
    }
  }
Example #6
0
@ThreadSafe
public class TaskExecutor {
  private static final Logger log = Logger.get(TaskExecutor.class);

  // each task is guaranteed a minimum number of tasks
  private static final int GUARANTEED_SPLITS_PER_TASK = 3;

  // each time we run a split, run it for this length before returning to the pool
  private static final Duration SPLIT_RUN_QUANTA = new Duration(1, TimeUnit.SECONDS);

  private static final AtomicLong NEXT_RUNNER_ID = new AtomicLong();
  private static final AtomicLong NEXT_WORKER_ID = new AtomicLong();

  private final ExecutorService executor;
  private final ThreadPoolExecutorMBean executorMBean;

  private final int runnerThreads;
  private final int minimumNumberOfTasks;

  private final Ticker ticker;

  @GuardedBy("this")
  private final List<TaskHandle> tasks;

  private final Set<PrioritizedSplitRunner> allSplits = new HashSet<>();
  private final PriorityBlockingQueue<PrioritizedSplitRunner> pendingSplits;
  private final Set<PrioritizedSplitRunner> runningSplits =
      Sets.newSetFromMap(new ConcurrentHashMap<PrioritizedSplitRunner, Boolean>());
  private final Set<PrioritizedSplitRunner> blockedSplits =
      Sets.newSetFromMap(new ConcurrentHashMap<PrioritizedSplitRunner, Boolean>());

  private final AtomicLongArray completedTasksPerLevel = new AtomicLongArray(5);

  private final DistributionStat queuedTime = new DistributionStat();
  private final DistributionStat wallTime = new DistributionStat();

  private volatile boolean closed;

  @Inject
  public TaskExecutor(TaskManagerConfig config) {
    this(checkNotNull(config, "config is null").getMaxShardProcessorThreads());
  }

  public TaskExecutor(int runnerThreads) {
    this(runnerThreads, Ticker.systemTicker());
  }

  @VisibleForTesting
  public TaskExecutor(int runnerThreads, Ticker ticker) {
    checkArgument(runnerThreads > 0, "runnerThreads must be at least 1");

    // we manages thread pool size directly, so create an unlimited pool
    this.executor = Executors.newCachedThreadPool(threadsNamed("task-processor-%d"));
    this.executorMBean = new ThreadPoolExecutorMBean((ThreadPoolExecutor) executor);
    this.runnerThreads = runnerThreads;

    this.ticker = checkNotNull(ticker, "ticker is null");

    // we assume we need at least two tasks per runner thread to keep the system busy
    this.minimumNumberOfTasks = 2 * this.runnerThreads;
    this.pendingSplits =
        new PriorityBlockingQueue<>(Runtime.getRuntime().availableProcessors() * 10);
    this.tasks = new LinkedList<>();
  }

  @PostConstruct
  public synchronized void start() {
    checkState(!closed, "TaskExecutor is closed");
    for (int i = 0; i < runnerThreads; i++) {
      addRunnerThread();
    }
  }

  @PreDestroy
  public synchronized void stop() {
    closed = true;
    executor.shutdownNow();
  }

  @Override
  public synchronized String toString() {
    return Objects.toStringHelper(this)
        .add("runnerThreads", runnerThreads)
        .add("allSplits", allSplits.size())
        .add("pendingSplits", pendingSplits.size())
        .add("runningSplits", runningSplits.size())
        .add("blockedSplits", blockedSplits.size())
        .toString();
  }

  private synchronized void addRunnerThread() {
    try {
      executor.execute(new Runner());
    } catch (RejectedExecutionException ignored) {
    }
  }

  public synchronized TaskHandle addTask(TaskId taskId) {
    TaskHandle taskHandle = new TaskHandle(checkNotNull(taskId, "taskId is null"));
    tasks.add(taskHandle);
    return taskHandle;
  }

  public synchronized void removeTask(TaskHandle taskHandle) {
    taskHandle.destroy();
    tasks.remove(taskHandle);

    // record completed stats
    long threadUsageNanos = taskHandle.getThreadUsageNanos();
    int priorityLevel = calculatePriorityLevel(threadUsageNanos);
    completedTasksPerLevel.incrementAndGet(priorityLevel);
  }

  public synchronized ListenableFuture<?> enqueueSplit(
      TaskHandle taskHandle, SplitRunner taskSplit) {
    PrioritizedSplitRunner prioritizedSplitRunner =
        new PrioritizedSplitRunner(taskHandle, taskSplit, ticker);
    taskHandle.addSplit(prioritizedSplitRunner);

    scheduleTaskIfNecessary(taskHandle);

    addNewEntrants();

    return prioritizedSplitRunner.getFinishedFuture();
  }

  public synchronized ListenableFuture<?> forceRunSplit(
      TaskHandle taskHandle, SplitRunner taskSplit) {
    PrioritizedSplitRunner prioritizedSplitRunner =
        new PrioritizedSplitRunner(taskHandle, taskSplit, ticker);

    // Note: we do not record queued time for forced splits

    startSplit(prioritizedSplitRunner);

    return prioritizedSplitRunner.getFinishedFuture();
  }

  private synchronized void splitFinished(PrioritizedSplitRunner split) {
    allSplits.remove(split);
    pendingSplits.remove(split);

    TaskHandle taskHandle = split.getTaskHandle();
    taskHandle.splitComplete(split);

    wallTime.add(System.nanoTime() - split.createdNanos);

    scheduleTaskIfNecessary(taskHandle);

    addNewEntrants();
  }

  private synchronized void scheduleTaskIfNecessary(TaskHandle taskHandle) {
    // if task has less than the minimum guaranteed splits running,
    // immediately schedule a new split for this task.  This assures
    // that a task gets its fair amount of consideration (you have to
    // have splits to be considered for running on a thread).
    if (taskHandle.getRunningSplits() < GUARANTEED_SPLITS_PER_TASK) {
      PrioritizedSplitRunner split = taskHandle.pollNextSplit();
      if (split != null) {
        startSplit(split);
        queuedTime.add(System.nanoTime() - split.createdNanos);
      }
    }
  }

  private synchronized void addNewEntrants() {
    int running = allSplits.size();
    for (int i = 0; i < minimumNumberOfTasks - running; i++) {
      PrioritizedSplitRunner split = pollNextSplitWorker();
      if (split == null) {
        break;
      }

      queuedTime.add(System.nanoTime() - split.createdNanos);
      startSplit(split);
    }
  }

  private synchronized void startSplit(PrioritizedSplitRunner split) {
    allSplits.add(split);
    pendingSplits.put(split);
  }

  private synchronized PrioritizedSplitRunner pollNextSplitWorker() {
    // todo find a better algorithm for this
    // find the first task that produces a split, then move that task to the
    // end of the task list, so we get round robin
    for (Iterator<TaskHandle> iterator = tasks.iterator(); iterator.hasNext(); ) {
      TaskHandle task = iterator.next();
      PrioritizedSplitRunner split = task.pollNextSplit();
      if (split != null) {
        // move task to end of list
        iterator.remove();

        // CAUTION: we are modifying the list in the loop which would normally
        // cause a ConcurrentModificationException but we exit immediately
        tasks.add(task);
        return split;
      }
    }
    return null;
  }

  @NotThreadSafe
  public static class TaskHandle {
    private final TaskId taskId;
    private final Queue<PrioritizedSplitRunner> queuedSplits = new ArrayDeque<>(10);
    private final List<PrioritizedSplitRunner> runningSplits = new ArrayList<>(10);
    private final AtomicLong taskThreadUsageNanos = new AtomicLong();

    private TaskHandle(TaskId taskId) {
      this.taskId = taskId;
    }

    private long addThreadUsageNanos(long durationNanos) {
      return taskThreadUsageNanos.addAndGet(durationNanos);
    }

    private TaskId getTaskId() {
      return taskId;
    }

    private void destroy() {
      for (PrioritizedSplitRunner runningSplit : runningSplits) {
        runningSplit.destroy();
      }
      runningSplits.clear();

      for (PrioritizedSplitRunner queuedSplit : queuedSplits) {
        queuedSplit.destroy();
      }
      queuedSplits.clear();
    }

    private void addSplit(PrioritizedSplitRunner split) {
      queuedSplits.add(split);
    }

    private int getRunningSplits() {
      return runningSplits.size();
    }

    private long getThreadUsageNanos() {
      return taskThreadUsageNanos.get();
    }

    private PrioritizedSplitRunner pollNextSplit() {
      PrioritizedSplitRunner split = queuedSplits.poll();
      if (split != null) {
        runningSplits.add(split);
      }
      return split;
    }

    private void splitComplete(PrioritizedSplitRunner split) {
      runningSplits.remove(split);
      split.destroy();
    }

    @Override
    public String toString() {
      return Objects.toStringHelper(this).add("taskId", taskId).toString();
    }
  }

  private static class PrioritizedSplitRunner implements Comparable<PrioritizedSplitRunner> {
    private final long createdNanos = System.nanoTime();

    private final TaskHandle taskHandle;
    private final long workerId;
    private final SplitRunner split;

    private final Ticker ticker;

    private final SettableFuture<?> finishedFuture = SettableFuture.create();

    private final AtomicBoolean initialized = new AtomicBoolean();
    private final AtomicBoolean destroyed = new AtomicBoolean();

    private final AtomicInteger priorityLevel = new AtomicInteger();
    private final AtomicLong threadUsageNanos = new AtomicLong();
    private final AtomicLong lastRun = new AtomicLong();

    private PrioritizedSplitRunner(TaskHandle taskHandle, SplitRunner split, Ticker ticker) {
      this.taskHandle = taskHandle;
      this.split = split;
      this.ticker = ticker;
      this.workerId = NEXT_WORKER_ID.getAndIncrement();
    }

    private TaskHandle getTaskHandle() {
      return taskHandle;
    }

    private SettableFuture<?> getFinishedFuture() {
      return finishedFuture;
    }

    public void initializeIfNecessary() {
      if (initialized.compareAndSet(false, true)) {
        split.initialize();
      }
    }

    public void destroy() {
      try {
        split.close();
      } catch (RuntimeException e) {
        log.error(e, "Error closing split for task %s", taskHandle.getTaskId());
      }
      destroyed.set(true);
    }

    public boolean isFinished() {
      boolean finished = split.isFinished();
      if (finished) {
        finishedFuture.set(null);
      }
      return finished || destroyed.get();
    }

    public ListenableFuture<?> process() throws Exception {
      try {
        long start = ticker.read();
        ListenableFuture<?> blocked = split.processFor(SPLIT_RUN_QUANTA);
        long endTime = ticker.read();

        // update priority level base on total thread usage of task
        long durationNanos = endTime - start;
        long threadUsageNanos = taskHandle.addThreadUsageNanos(durationNanos);
        this.threadUsageNanos.set(threadUsageNanos);
        priorityLevel.set(calculatePriorityLevel(threadUsageNanos));

        // record last run for prioritization within a level
        lastRun.set(endTime);

        return blocked;
      } catch (Throwable e) {
        finishedFuture.setException(e);
        throw e;
      }
    }

    public boolean updatePriorityLevel() {
      int newPriority = calculatePriorityLevel(taskHandle.getThreadUsageNanos());
      if (newPriority == priorityLevel.getAndSet(newPriority)) {
        return false;
      }

      // update thread usage while if level changed
      threadUsageNanos.set(taskHandle.getThreadUsageNanos());
      return true;
    }

    @Override
    public int compareTo(PrioritizedSplitRunner o) {
      int level = priorityLevel.get();

      int result = Ints.compare(level, o.priorityLevel.get());
      if (result != 0) {
        return result;
      }

      if (level < 4) {
        result = Long.compare(threadUsageNanos.get(), threadUsageNanos.get());
      } else {
        result = Long.compare(lastRun.get(), o.lastRun.get());
      }
      if (result != 0) {
        return result;
      }

      return Longs.compare(workerId, o.workerId);
    }

    @Override
    public String toString() {
      return String.format(
          "Split %-15s %s %s",
          taskHandle.getTaskId(),
          priorityLevel,
          new Duration(threadUsageNanos.get(), TimeUnit.NANOSECONDS)
              .convertToMostSuccinctTimeUnit());
    }
  }

  private static int calculatePriorityLevel(long threadUsageNanos) {
    long millis = TimeUnit.NANOSECONDS.toMillis(threadUsageNanos);

    int priorityLevel;
    if (millis < 1000) {
      priorityLevel = 0;
    } else if (millis < 10_000) {
      priorityLevel = 1;
    } else if (millis < 60_000) {
      priorityLevel = 2;
    } else if (millis < 300_000) {
      priorityLevel = 3;
    } else {
      priorityLevel = 4;
    }
    return priorityLevel;
  }

  private class Runner implements Runnable {
    private final long runnerId = NEXT_RUNNER_ID.getAndIncrement();

    @Override
    public void run() {
      try (SetThreadName runnerName = new SetThreadName("SplitRunner-%s", runnerId)) {
        while (!closed && !Thread.currentThread().isInterrupted()) {
          // select next worker
          final PrioritizedSplitRunner split;
          try {
            split = pendingSplits.take();
            if (split.updatePriorityLevel()) {
              // priority level changed, return split to queue for re-prioritization
              pendingSplits.put(split);
              continue;
            }
          } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            return;
          }

          try (SetThreadName splitName = new SetThreadName(split.toString())) {
            runningSplits.add(split);

            boolean finished;
            ListenableFuture<?> blocked;
            try {
              split.initializeIfNecessary();
              blocked = split.process();
              finished = split.isFinished();
            } finally {
              runningSplits.remove(split);
            }

            if (finished) {
              log.debug("%s is finished", split);
              splitFinished(split);
            } else {
              if (blocked.isDone()) {
                pendingSplits.put(split);
              } else {
                blockedSplits.add(split);
                blocked.addListener(
                    new Runnable() {
                      @Override
                      public void run() {
                        blockedSplits.remove(split);
                        split.updatePriorityLevel();
                        pendingSplits.put(split);
                      }
                    },
                    executor);
              }
            }
          } catch (Throwable t) {
            log.error(t, "Error processing %s", split);
            splitFinished(split);
          }
        }
      } finally {
        // unless we have been closed, we need to replace this thread
        if (!closed) {
          addRunnerThread();
        }
      }
    }
  }

  //
  // STATS
  //

  @Managed
  public int getTasks() {
    return tasks.size();
  }

  @Managed
  public int getRunnerThreads() {
    return runnerThreads;
  }

  @Managed
  public int getMinimumNumberOfTasks() {
    return minimumNumberOfTasks;
  }

  @Managed
  public int getTotalSplits() {
    return allSplits.size();
  }

  @Managed
  public int getPendingSplits() {
    return pendingSplits.size();
  }

  @Managed
  public int getRunningSplits() {
    return runningSplits.size();
  }

  @Managed
  public int getBlockedSplits() {
    return blockedSplits.size();
  }

  @Managed
  public long getCompletedTasksLevel0() {
    return completedTasksPerLevel.get(0);
  }

  @Managed
  public long getCompletedTasksLevel1() {
    return completedTasksPerLevel.get(1);
  }

  @Managed
  public long getCompletedTasksLevel2() {
    return completedTasksPerLevel.get(2);
  }

  @Managed
  public long getCompletedTasksLevel3() {
    return completedTasksPerLevel.get(3);
  }

  @Managed
  public long getCompletedTasksLevel4() {
    return completedTasksPerLevel.get(4);
  }

  @Managed
  public long getRunningTasksLevel0() {
    return calculateRunningTasksForLevel(0);
  }

  @Managed
  public long getRunningTasksLevel1() {
    return calculateRunningTasksForLevel(1);
  }

  @Managed
  public long getRunningTasksLevel2() {
    return calculateRunningTasksForLevel(2);
  }

  @Managed
  public long getRunningTasksLevel3() {
    return calculateRunningTasksForLevel(3);
  }

  @Managed
  public long getRunningTasksLevel4() {
    return calculateRunningTasksForLevel(4);
  }

  @Managed
  @Nested
  public DistributionStat getQueuedTime() {
    return queuedTime;
  }

  @Managed
  @Nested
  public DistributionStat getWallTime() {
    return wallTime;
  }

  private synchronized int calculateRunningTasksForLevel(int level) {
    int count = 0;
    for (TaskHandle task : tasks) {
      if (calculatePriorityLevel(task.getThreadUsageNanos()) == level) {
        count++;
      }
    }
    return count;
  }

  @Managed(description = "Task processor executor")
  @Nested
  public ThreadPoolExecutorMBean getProcessorExecutor() {
    return executorMBean;
  }
}
public class OrganizationJobFactory implements JobFactory {
  private static final Logger log = Logger.get(OrganizationJobFactory.class);

  private final MetadataDao metadataDao;
  private final ShardManager shardManager;
  private final ShardCompactor compactor;

  @Inject
  public OrganizationJobFactory(
      @ForMetadata IDBI dbi, ShardManager shardManager, ShardCompactor compactor) {
    this.metadataDao = onDemandDao(dbi, MetadataDao.class);
    this.shardManager = requireNonNull(shardManager, "shardManager is null");
    this.compactor = requireNonNull(compactor, "compactor is null");
  }

  @Override
  public Runnable create(OrganizationSet organizationSet) {
    return new OrganizationJob(organizationSet);
  }

  private class OrganizationJob implements Runnable {
    private final OrganizationSet organizationSet;

    public OrganizationJob(OrganizationSet organizationSet) {
      this.organizationSet = requireNonNull(organizationSet, "organizationSet is null");
    }

    @Override
    public void run() {
      try {
        runJob(
            organizationSet.getTableId(),
            organizationSet.getBucketNumber(),
            organizationSet.getShards());
      } catch (IOException e) {
        throw Throwables.propagate(e);
      }
    }

    private void runJob(long tableId, OptionalInt bucketNumber, Set<UUID> shardUuids)
        throws IOException {
      long transactionId = shardManager.beginTransaction();
      try {
        runJob(transactionId, bucketNumber, tableId, shardUuids);
      } catch (Throwable e) {
        shardManager.rollbackTransaction(transactionId);
        throw e;
      }
    }

    private void runJob(
        long transactionId, OptionalInt bucketNumber, long tableId, Set<UUID> shardUuids)
        throws IOException {
      TableMetadata metadata = getTableMetadata(tableId);

      // This job could be in the queue for quite some time, so before doing any expensive
      // operations,
      // filter out shards that no longer exist, reducing the possibility of failure
      shardUuids = shardManager.getExistingShardUuids(tableId, shardUuids);
      if (shardUuids.size() <= 1) {
        return;
      }

      List<ShardInfo> newShards =
          performCompaction(transactionId, bucketNumber, shardUuids, metadata);
      log.info(
          "Compacted shards %s into %s",
          shardUuids, newShards.stream().map(ShardInfo::getShardUuid).collect(toList()));
      shardManager.replaceShardUuids(
          transactionId,
          tableId,
          metadata.getColumns(),
          shardUuids,
          newShards,
          OptionalLong.empty());
    }

    private TableMetadata getTableMetadata(long tableId) {
      List<TableColumn> sortColumns = metadataDao.listSortColumns(tableId);

      List<Long> sortColumnIds =
          sortColumns.stream().map(TableColumn::getColumnId).collect(toList());

      List<ColumnInfo> columns =
          metadataDao
              .listTableColumns(tableId)
              .stream()
              .map(TableColumn::toColumnInfo)
              .collect(toList());
      return new TableMetadata(tableId, columns, sortColumnIds);
    }

    private List<ShardInfo> performCompaction(
        long transactionId,
        OptionalInt bucketNumber,
        Set<UUID> shardUuids,
        TableMetadata tableMetadata)
        throws IOException {
      if (tableMetadata.getSortColumnIds().isEmpty()) {
        return compactor.compact(
            transactionId, bucketNumber, shardUuids, tableMetadata.getColumns());
      }
      return compactor.compactSorted(
          transactionId,
          bucketNumber,
          shardUuids,
          tableMetadata.getColumns(),
          tableMetadata.getSortColumnIds(),
          nCopies(tableMetadata.getSortColumnIds().size(), ASC_NULLS_FIRST));
    }
  }
}
Example #8
0
public class BatchProcessor<T> {
  private static final Logger log = Logger.get(BatchProcessor.class);

  private final BatchHandler<T> handler;
  private final int maxBatchSize;
  private final BlockingQueue<T> queue;
  private final String name;

  private ExecutorService executor;
  private volatile Future<?> future;

  private final AtomicLong processedEntries = new AtomicLong();
  private final AtomicLong droppedEntries = new AtomicLong();
  private final AtomicLong errors = new AtomicLong();

  public BatchProcessor(String name, BatchHandler<T> handler, int maxBatchSize, int queueSize) {
    Preconditions.checkNotNull(name, "name is null");
    Preconditions.checkNotNull(handler, "handler is null");
    Preconditions.checkArgument(queueSize > 0, "queue size needs to be a positive integer");
    Preconditions.checkArgument(maxBatchSize > 0, "max batch size needs to be a positive integer");

    this.name = name;
    this.handler = handler;
    this.maxBatchSize = maxBatchSize;
    this.queue = new ArrayBlockingQueue<T>(queueSize);
  }

  @PostConstruct
  public synchronized void start() {
    if (future == null) {
      executor = newSingleThreadExecutor(threadsNamed("batch-processor-" + name + "-%d"));

      future =
          executor.submit(
              new Runnable() {
                public void run() {
                  while (!Thread.interrupted()) {
                    final List<T> entries = new ArrayList<T>(maxBatchSize);

                    try {
                      T first = queue.take();
                      entries.add(first);
                      queue.drainTo(entries, maxBatchSize - 1);

                      handler.processBatch(Collections.unmodifiableList(entries));

                      processedEntries.addAndGet(entries.size());
                    } catch (InterruptedException e) {
                      Thread.currentThread().interrupt();
                    } catch (Throwable t) {
                      errors.incrementAndGet();
                      log.warn(t, "Error handling batch");
                    }

                    // TODO: expose timestamp of last execution via jmx
                  }
                }
              });
    }
  }

  @Managed
  public long getProcessedEntries() {
    return processedEntries.get();
  }

  @Managed
  public long getDroppedEntries() {
    return droppedEntries.get();
  }

  @Managed
  public long getErrors() {
    return errors.get();
  }

  @Managed
  public long getQueueSize() {
    return queue.size();
  }

  @PreDestroy
  public synchronized void stop() {
    if (future != null) {
      future.cancel(true);
      executor.shutdownNow();

      future = null;
    }
  }

  public void put(T entry) {
    Preconditions.checkState(!future.isCancelled(), "Processor is not running");
    Preconditions.checkNotNull(entry, "entry is null");

    while (!queue.offer(entry)) {
      // throw away oldest and try again
      if (queue.poll() != null) {
        droppedEntries.incrementAndGet();
      }
    }
  }

  public static interface BatchHandler<T> {
    void processBatch(Collection<T> entries);
  }
}
public class RiakSplitManager implements ConnectorSplitManager {
  private static final Logger log = Logger.get(RiakSplitManager.class);
  private final String connectorId;
  private final RiakClient riakClient;
  private final RiakConfig riakConfig;
  private final DirectConnection directConnection;

  @Inject
  public RiakSplitManager(
      RiakConnectorId connectorId,
      RiakClient riakClient,
      RiakConfig config,
      DirectConnection directConnection) {
    this.connectorId = checkNotNull(connectorId, "connectorId is null").toString();
    this.riakClient = checkNotNull(riakClient, "client is null");
    this.riakConfig = checkNotNull(config);
    this.directConnection = checkNotNull(directConnection);
  }

  // TODO: get the right partitions right here
  @Override
  public ConnectorPartitionResult getPartitions(
      ConnectorTableHandle tableHandle, TupleDomain<ColumnHandle> tupleDomain) {
    checkArgument(
        tableHandle instanceof RiakTableHandle,
        "tableHandle is not an instance of RiakTableHandle");
    RiakTableHandle riakTableHandle = (RiakTableHandle) tableHandle;

    log.info("==========================tupleDomain=============================");
    log.info(tupleDomain.toString());

    try {
      String parentTable = PRSubTable.parentTableName(riakTableHandle.getTableName());
      SchemaTableName parentSchemaTable =
          new SchemaTableName(riakTableHandle.getSchemaName(), parentTable);
      PRTable table = riakClient.getTable(parentSchemaTable);
      List<String> indexedColumns = new LinkedList<String>();
      for (RiakColumn riakColumn : table.getColumns()) {
        if (riakColumn.getIndex()) {
          indexedColumns.add(riakColumn.getName());
        }
      }

      // Riak connector has only one partition
      List<ConnectorPartition> partitions =
          ImmutableList.<ConnectorPartition>of(
              new RiakPartition(
                  riakTableHandle.getSchemaName(),
                  riakTableHandle.getTableName(),
                  tupleDomain,
                  indexedColumns));

      // Riak connector does not do any additional processing/filtering with the TupleDomain, so
      // just return the whole TupleDomain
      return new ConnectorPartitionResult(partitions, tupleDomain);
    } catch (Exception e) {
      log.error("interrupted: %s", e.toString());
      throw new TableNotFoundException(riakTableHandle.toSchemaTableName());
    }
  }

  // TODO: return correct splits from partitions
  @Override
  public ConnectorSplitSource getPartitionSplits(
      ConnectorTableHandle tableHandle, List<ConnectorPartition> partitions) {
    checkNotNull(partitions, "partitions is null");
    checkArgument(partitions.size() == 1, "Expected one partition but got %s", partitions.size());
    ConnectorPartition partition = partitions.get(0);

    checkArgument(
        partition instanceof RiakPartition, "partition is not an instance of RiakPartition");
    // RiakPartition riakPartition = (RiakPartition) partition;

    RiakTableHandle riakTableHandle = (RiakTableHandle) tableHandle;

    try {
      String parentTable = PRSubTable.parentTableName(riakTableHandle.getTableName());
      SchemaTableName parentSchemaTable =
          new SchemaTableName(riakTableHandle.getSchemaName(), parentTable);
      PRTable table = riakClient.getTable(parentSchemaTable);

      log.debug("> %s", table.getColumns().toString());
      // add all nodes at the cluster here
      List<ConnectorSplit> splits = Lists.newArrayList();
      String hosts = riakClient.getHosts();
      log.debug(hosts);

      if (riakConfig.getLocalNode() != null) {
        // TODO: make coverageSplits here

        // try {
        DirectConnection conn = directConnection;
        // conn.connect(riak);
        // conn.ping();
        Coverage coverage = new Coverage(conn);
        coverage.plan();
        List<SplitTask> splitTasks = coverage.getSplits();

        log.debug("print coverage plan==============");
        log.debug(coverage.toString());

        for (SplitTask split : splitTasks) {
          log.info("============printing split data at " + split.getHost() + "===============");
          // log.debug(((OtpErlangObject)split.getTask()).toString());
          log.info(split.toString());

          CoverageSplit coverageSplit =
              new CoverageSplit(
                  riakTableHandle, // maybe toplevel or subtable
                  table, // toplevel PRTable
                  split.getHost(),
                  split.toString(),
                  partition.getTupleDomain());

          // log.info(new JsonCodecFactory().jsonCodec(CoverageSplit.class).toJson(coverageSplit));
          splits.add(coverageSplit);
        }
      } else {
        // TODO: in Riak connector, you only need single access point for each presto worker???
        log.error("localNode must be set and working");
        log.debug(hosts);
        // splits.add(new CoverageSplit(connectorId, riakTableHandle.getSchemaName(),
        //        riakTableHandle.getTableName(), hosts,
        //        partition.getTupleDomain(),
        //        ((RiakPartition) partition).getIndexedColumns()));

      }
      log.debug(
          "table %s.%s has %d splits.",
          riakTableHandle.getSchemaName(), riakTableHandle.getTableName(), splits.size());

      Collections.shuffle(splits);
      return new FixedSplitSource(connectorId, splits);

    } catch (Exception e) {
      throw new TableNotFoundException(riakTableHandle.toSchemaTableName());
    }
    // this can happen if table is removed during a query

  }
}
Example #10
0
final class ShardIterator extends AbstractIterator<BucketShards>
    implements ResultIterator<BucketShards> {
  private static final Logger log = Logger.get(ShardIterator.class);
  private final Map<Integer, String> nodeMap = new HashMap<>();

  private final boolean merged;
  private final Map<Integer, String> bucketToNode;
  private final ShardDao dao;
  private final Connection connection;
  private final PreparedStatement statement;
  private final ResultSet resultSet;
  private boolean first = true;

  public ShardIterator(
      long tableId,
      boolean merged,
      Optional<Map<Integer, String>> bucketToNode,
      TupleDomain<RaptorColumnHandle> effectivePredicate,
      IDBI dbi) {
    this.merged = merged;
    this.bucketToNode = bucketToNode.orElse(null);
    ShardPredicate predicate = ShardPredicate.create(effectivePredicate, bucketToNode.isPresent());

    String sql;
    if (bucketToNode.isPresent()) {
      sql = "SELECT shard_uuid, bucket_number FROM %s WHERE %s ORDER BY bucket_number";
    } else {
      sql = "SELECT shard_uuid, node_ids FROM %s WHERE %s";
    }
    sql = format(sql, shardIndexTable(tableId), predicate.getPredicate());

    dao = onDemandDao(dbi, ShardDao.class);
    fetchNodes();

    try {
      connection = dbi.open().getConnection();
      statement = connection.prepareStatement(sql);
      enableStreamingResults(statement);
      predicate.bind(statement);
      log.debug("Running query: %s", statement);
      resultSet = statement.executeQuery();
    } catch (SQLException e) {
      close();
      throw metadataError(e);
    }
  }

  @Override
  protected BucketShards computeNext() {
    try {
      return merged ? computeMerged() : compute();
    } catch (SQLException e) {
      throw metadataError(e);
    }
  }

  @SuppressWarnings({"UnusedDeclaration", "EmptyTryBlock"})
  @Override
  public void close() {
    // use try-with-resources to close everything properly
    try (Connection connection = this.connection;
        Statement statement = this.statement;
        ResultSet resultSet = this.resultSet) {
      // do nothing
    } catch (SQLException ignored) {
    }
  }

  /** Compute split-per-shard (separate split for each shard). */
  private BucketShards compute() throws SQLException {
    if (!resultSet.next()) {
      return endOfData();
    }

    UUID shardUuid = uuidFromBytes(resultSet.getBytes("shard_uuid"));
    Set<String> nodeIdentifiers;
    OptionalInt bucketNumber = OptionalInt.empty();

    if (bucketToNode != null) {
      int bucket = resultSet.getInt("bucket_number");
      bucketNumber = OptionalInt.of(bucket);
      nodeIdentifiers = ImmutableSet.of(getBucketNode(bucket));
    } else {
      List<Integer> nodeIds = intArrayFromBytes(resultSet.getBytes("node_ids"));
      nodeIdentifiers = getNodeIdentifiers(nodeIds, shardUuid);
    }

    ShardNodes shard = new ShardNodes(shardUuid, nodeIdentifiers);
    return new BucketShards(bucketNumber, ImmutableSet.of(shard));
  }

  /** Compute split-per-bucket (single split for all shards in a bucket). */
  private BucketShards computeMerged() throws SQLException {
    if (resultSet.isAfterLast()) {
      return endOfData();
    }
    if (first) {
      first = false;
      if (!resultSet.next()) {
        return endOfData();
      }
    }

    int bucketNumber = resultSet.getInt("bucket_number");
    ImmutableSet.Builder<ShardNodes> shards = ImmutableSet.builder();

    do {
      UUID shardUuid = uuidFromBytes(resultSet.getBytes("shard_uuid"));
      int bucket = resultSet.getInt("bucket_number");
      Set<String> nodeIdentifiers = ImmutableSet.of(getBucketNode(bucket));

      shards.add(new ShardNodes(shardUuid, nodeIdentifiers));
    } while (resultSet.next() && resultSet.getInt("bucket_number") == bucketNumber);

    return new BucketShards(OptionalInt.of(bucketNumber), shards.build());
  }

  private String getBucketNode(int bucket) {
    String node = bucketToNode.get(bucket);
    if (node == null) {
      throw new PrestoException(RAPTOR_ERROR, "No node mapping for bucket: " + bucket);
    }
    return node;
  }

  private Set<String> getNodeIdentifiers(List<Integer> nodeIds, UUID shardUuid) {
    Function<Integer, String> fetchNode = id -> fetchNode(id, shardUuid);
    return nodeIds.stream().map(id -> nodeMap.computeIfAbsent(id, fetchNode)).collect(toSet());
  }

  private String fetchNode(int id, UUID shardUuid) {
    String node = dao.getNodeIdentifier(id);
    if (node == null) {
      throw new PrestoException(
          RAPTOR_ERROR, format("Missing node ID [%s] for shard: %s", id, shardUuid));
    }
    return node;
  }

  private void fetchNodes() {
    for (RaptorNode node : dao.getNodes()) {
      nodeMap.put(node.getNodeId(), node.getNodeIdentifier());
    }
  }

  private static void enableStreamingResults(Statement statement) throws SQLException {
    if (statement.isWrapperFor(com.mysql.jdbc.Statement.class)) {
      statement.unwrap(com.mysql.jdbc.Statement.class).enableStreamingResults();
    }
  }
}
public class HttpServiceInventory implements ServiceInventory {
  private static final Logger log = Logger.get(HttpServiceInventory.class);
  private final Repository repository;
  private final JsonCodec<List<ServiceDescriptor>> descriptorsJsonCodec;
  private final Set<String> invalidServiceInventory =
      Collections.newSetFromMap(new ConcurrentHashMap<String, Boolean>());
  private final File cacheDir;

  @Inject
  public HttpServiceInventory(
      Repository repository,
      JsonCodec<List<ServiceDescriptor>> descriptorsJsonCodec,
      CoordinatorConfig config) {
    this(repository, descriptorsJsonCodec, new File(config.getServiceInventoryCacheDir()));
  }

  public HttpServiceInventory(
      Repository repository,
      JsonCodec<List<ServiceDescriptor>> descriptorsJsonCodec,
      File cacheDir) {
    Preconditions.checkNotNull(repository, "repository is null");
    Preconditions.checkNotNull(descriptorsJsonCodec, "descriptorsJsonCodec is null");

    this.repository = repository;
    this.descriptorsJsonCodec = descriptorsJsonCodec;
    this.cacheDir = cacheDir;
  }

  @Override
  public ImmutableList<ServiceDescriptor> getServiceInventory(Iterable<SlotStatus> allSlotStatus) {
    ImmutableList.Builder<ServiceDescriptor> newDescriptors = ImmutableList.builder();
    for (SlotStatus slotStatus : allSlotStatus) {
      // if the self reference is null, the slot is totally offline so skip for now
      if (slotStatus.getSelf() == null) {
        continue;
      }

      List<ServiceDescriptor> serviceDescriptors = getServiceInventory(slotStatus);
      if (serviceDescriptors == null) {
        continue;
      }
      for (ServiceDescriptor serviceDescriptor : serviceDescriptors) {
        newDescriptors.add(
            new ServiceDescriptor(
                null,
                slotStatus.getId().toString(),
                serviceDescriptor.getType(),
                serviceDescriptor.getPool(),
                slotStatus.getLocation(),
                slotStatus.getState() == SlotLifecycleState.RUNNING
                    ? ServiceState.RUNNING
                    : ServiceState.STOPPED,
                interpolateProperties(serviceDescriptor.getProperties(), slotStatus)));
      }
    }
    return newDescriptors.build();
  }

  private Map<String, String> interpolateProperties(
      Map<String, String> properties, SlotStatus slotStatus) {
    ImmutableMap.Builder<String, String> builder = ImmutableMap.builder();
    for (Entry<String, String> entry : properties.entrySet()) {
      String key = entry.getKey();
      String value = entry.getValue();
      value = value.replaceAll(Pattern.quote("${airship.host}"), slotStatus.getSelf().getHost());
      builder.put(key, value);
    }
    return builder.build();
  }

  private List<ServiceDescriptor> getServiceInventory(SlotStatus slotStatus) {
    Assignment assignment = slotStatus.getAssignment();
    if (assignment == null) {
      return null;
    }

    String config = assignment.getConfig();

    File cacheFile = getCacheFile(config);
    if (cacheFile.canRead()) {
      try {
        String json = CharStreams.toString(Files.newReaderSupplier(cacheFile, Charsets.UTF_8));
        List<ServiceDescriptor> descriptors = descriptorsJsonCodec.fromJson(json);
        invalidServiceInventory.remove(config);
        return descriptors;
      } catch (Exception ignored) {
        // delete the bad cache file
        cacheFile.delete();
      }
    }

    InputSupplier<? extends InputStream> configFile =
        ConfigUtils.newConfigEntrySupplier(repository, config, "airship-service-inventory.json");
    if (configFile == null) {
      return null;
    }

    try {
      String json;
      try {
        json = CharStreams.toString(CharStreams.newReaderSupplier(configFile, Charsets.UTF_8));
      } catch (FileNotFoundException e) {
        // no service inventory in the config, so replace with json null so caching works
        json = "null";
      }
      invalidServiceInventory.remove(config);

      // cache json
      cacheFile.getParentFile().mkdirs();
      Files.write(json, cacheFile, Charsets.UTF_8);

      List<ServiceDescriptor> descriptors = descriptorsJsonCodec.fromJson(json);
      return descriptors;
    } catch (Exception e) {
      if (invalidServiceInventory.add(config)) {
        log.error(e, "Unable to read service inventory for %s" + config);
      }
    }
    return null;
  }

  private File getCacheFile(String config) {
    String cacheName = config;
    if (cacheName.startsWith("@")) {
      cacheName = cacheName.substring(1);
    }
    cacheName = cacheName.replaceAll("[^a-zA-Z0-9_.-]", "_");

    cacheName = cacheName + "_" + DigestUtils.md5Hex(cacheName);
    return new File(cacheDir, cacheName).getAbsoluteFile();
  }
}
Example #12
0
public class PredicatePushDown extends PlanOptimizer {
  private static final Logger log = Logger.get(PredicatePushDown.class);

  private final Metadata metadata;
  private final SplitManager splitManager;
  private final boolean experimentalSyntaxEnabled;

  public PredicatePushDown(
      Metadata metadata, SplitManager splitManager, boolean experimentalSyntaxEnabled) {
    this.metadata = checkNotNull(metadata, "metadata is null");
    this.splitManager = checkNotNull(splitManager, "splitManager is null");
    this.experimentalSyntaxEnabled = experimentalSyntaxEnabled;
  }

  @Override
  public PlanNode optimize(
      PlanNode plan,
      Session session,
      Map<Symbol, Type> types,
      SymbolAllocator symbolAllocator,
      PlanNodeIdAllocator idAllocator) {
    checkNotNull(plan, "plan is null");
    checkNotNull(session, "session is null");
    checkNotNull(types, "types is null");
    checkNotNull(idAllocator, "idAllocator is null");

    return PlanRewriter.rewriteWith(
        new Rewriter(
            symbolAllocator,
            idAllocator,
            metadata,
            splitManager,
            session,
            experimentalSyntaxEnabled),
        plan,
        BooleanLiteral.TRUE_LITERAL);
  }

  private static class Rewriter extends PlanNodeRewriter<Expression> {
    private final SymbolAllocator symbolAllocator;
    private final PlanNodeIdAllocator idAllocator;
    private final Metadata metadata;
    private final SplitManager splitManager;
    private final Session session;
    private final boolean experimentalSyntaxEnabled;

    private Rewriter(
        SymbolAllocator symbolAllocator,
        PlanNodeIdAllocator idAllocator,
        Metadata metadata,
        SplitManager splitManager,
        Session session,
        boolean experimentalSyntaxEnabled) {
      this.symbolAllocator = checkNotNull(symbolAllocator, "symbolAllocator is null");
      this.idAllocator = checkNotNull(idAllocator, "idAllocator is null");
      this.metadata = checkNotNull(metadata, "metadata is null");
      this.splitManager = checkNotNull(splitManager, "splitManager is null");
      this.session = checkNotNull(session, "session is null");
      this.experimentalSyntaxEnabled = experimentalSyntaxEnabled;
    }

    @Override
    public PlanNode rewriteNode(
        PlanNode node, Expression inheritedPredicate, PlanRewriter<Expression> planRewriter) {
      PlanNode rewrittenNode = planRewriter.defaultRewrite(node, BooleanLiteral.TRUE_LITERAL);
      if (!inheritedPredicate.equals(BooleanLiteral.TRUE_LITERAL)) {
        // Drop in a FilterNode b/c we cannot push our predicate down any further
        rewrittenNode = new FilterNode(idAllocator.getNextId(), rewrittenNode, inheritedPredicate);
      }
      return rewrittenNode;
    }

    @Override
    public PlanNode rewriteProject(
        ProjectNode node, Expression inheritedPredicate, PlanRewriter<Expression> planRewriter) {
      Expression inlinedPredicate =
          ExpressionTreeRewriter.rewriteWith(
              new ExpressionSymbolInliner(node.getOutputMap()), inheritedPredicate);
      return planRewriter.defaultRewrite(node, inlinedPredicate);
    }

    @Override
    public PlanNode rewriteMarkDistinct(
        MarkDistinctNode node,
        Expression inheritedPredicate,
        PlanRewriter<Expression> planRewriter) {
      checkState(
          !DependencyExtractor.extractUnique(inheritedPredicate).contains(node.getMarkerSymbol()),
          "predicate depends on marker symbol");
      return planRewriter.defaultRewrite(node, inheritedPredicate);
    }

    @Override
    public PlanNode rewriteSort(
        SortNode node, Expression inheritedPredicate, PlanRewriter<Expression> planRewriter) {
      return planRewriter.defaultRewrite(node, inheritedPredicate);
    }

    @Override
    public PlanNode rewriteUnion(
        UnionNode node, Expression inheritedPredicate, PlanRewriter<Expression> planRewriter) {
      boolean modified = false;
      ImmutableList.Builder<PlanNode> builder = ImmutableList.builder();
      for (int i = 0; i < node.getSources().size(); i++) {
        Expression sourcePredicate =
            ExpressionTreeRewriter.rewriteWith(
                new ExpressionSymbolInliner(node.sourceSymbolMap(i)), inheritedPredicate);
        PlanNode source = node.getSources().get(i);
        PlanNode rewrittenSource = planRewriter.rewrite(source, sourcePredicate);
        if (rewrittenSource != source) {
          modified = true;
        }
        builder.add(rewrittenSource);
      }

      if (modified) {
        return new UnionNode(node.getId(), builder.build(), node.getSymbolMapping());
      }

      return node;
    }

    @Override
    public PlanNode rewriteFilter(
        FilterNode node, Expression inheritedPredicate, PlanRewriter<Expression> planRewriter) {
      return planRewriter.rewrite(
          node.getSource(), combineConjuncts(node.getPredicate(), inheritedPredicate));
    }

    @Override
    public PlanNode rewriteJoin(
        JoinNode node, Expression inheritedPredicate, PlanRewriter<Expression> planRewriter) {
      boolean isCrossJoin = (node.getType() == JoinNode.Type.CROSS);

      // See if we can rewrite outer joins in terms of a plain inner join
      node = tryNormalizeToInnerJoin(node, inheritedPredicate);

      Expression leftEffectivePredicate = EffectivePredicateExtractor.extract(node.getLeft());
      Expression rightEffectivePredicate = EffectivePredicateExtractor.extract(node.getRight());
      Expression joinPredicate = extractJoinPredicate(node);

      Expression leftPredicate;
      Expression rightPredicate;
      Expression postJoinPredicate;
      Expression newJoinPredicate;

      switch (node.getType()) {
        case INNER:
          InnerJoinPushDownResult innerJoinPushDownResult =
              processInnerJoin(
                  inheritedPredicate,
                  leftEffectivePredicate,
                  rightEffectivePredicate,
                  joinPredicate,
                  node.getLeft().getOutputSymbols());
          leftPredicate = innerJoinPushDownResult.getLeftPredicate();
          rightPredicate = innerJoinPushDownResult.getRightPredicate();
          postJoinPredicate = innerJoinPushDownResult.getPostJoinPredicate();
          newJoinPredicate = innerJoinPushDownResult.getJoinPredicate();
          break;
        case LEFT:
          OuterJoinPushDownResult leftOuterJoinPushDownResult =
              processOuterJoin(
                  inheritedPredicate,
                  leftEffectivePredicate,
                  rightEffectivePredicate,
                  joinPredicate,
                  node.getLeft().getOutputSymbols());
          leftPredicate = leftOuterJoinPushDownResult.getOuterJoinPredicate();
          rightPredicate = leftOuterJoinPushDownResult.getInnerJoinPredicate();
          postJoinPredicate = leftOuterJoinPushDownResult.getPostJoinPredicate();
          newJoinPredicate = joinPredicate; // Use the same as the original
          break;
        case RIGHT:
          OuterJoinPushDownResult rightOuterJoinPushDownResult =
              processOuterJoin(
                  inheritedPredicate,
                  rightEffectivePredicate,
                  leftEffectivePredicate,
                  joinPredicate,
                  node.getRight().getOutputSymbols());
          leftPredicate = rightOuterJoinPushDownResult.getInnerJoinPredicate();
          rightPredicate = rightOuterJoinPushDownResult.getOuterJoinPredicate();
          postJoinPredicate = rightOuterJoinPushDownResult.getPostJoinPredicate();
          newJoinPredicate = joinPredicate; // Use the same as the original
          break;
        default:
          throw new UnsupportedOperationException("Unsupported join type: " + node.getType());
      }

      PlanNode leftSource = planRewriter.rewrite(node.getLeft(), leftPredicate);
      PlanNode rightSource = planRewriter.rewrite(node.getRight(), rightPredicate);

      PlanNode output = node;
      if (leftSource != node.getLeft()
          || rightSource != node.getRight()
          || !newJoinPredicate.equals(joinPredicate)) {
        List<JoinNode.EquiJoinClause> criteria = node.getCriteria();

        // Rewrite criteria and add projections if there is a new join predicate

        if (!newJoinPredicate.equals(joinPredicate) || isCrossJoin) {
          // Create identity projections for all existing symbols
          ImmutableMap.Builder<Symbol, Expression> leftProjections = ImmutableMap.builder();
          leftProjections.putAll(
              IterableTransformer.<Symbol>on(node.getLeft().getOutputSymbols())
                  .toMap(symbolToQualifiedNameReference())
                  .map());
          ImmutableMap.Builder<Symbol, Expression> rightProjections = ImmutableMap.builder();
          rightProjections.putAll(
              IterableTransformer.<Symbol>on(node.getRight().getOutputSymbols())
                  .toMap(symbolToQualifiedNameReference())
                  .map());

          // HACK! we don't support cross joins right now, so put in a simple fake join predicate
          // instead if all of the join clauses got simplified out
          // TODO: remove this code when cross join support is added
          Iterable<Expression> simplifiedJoinConjuncts =
              transform(extractConjuncts(newJoinPredicate), simplifyExpressions());
          simplifiedJoinConjuncts =
              filter(
                  simplifiedJoinConjuncts,
                  not(Predicates.<Expression>equalTo(BooleanLiteral.TRUE_LITERAL)));
          if (Iterables.isEmpty(simplifiedJoinConjuncts)) {
            simplifiedJoinConjuncts =
                ImmutableList.<Expression>of(
                    new ComparisonExpression(
                        ComparisonExpression.Type.EQUAL,
                        new LongLiteral("0"),
                        new LongLiteral("0")));
          }

          // Create new projections for the new join clauses
          ImmutableList.Builder<JoinNode.EquiJoinClause> builder = ImmutableList.builder();
          for (Expression conjunct : simplifiedJoinConjuncts) {
            checkState(
                joinEqualityExpression(node.getLeft().getOutputSymbols()).apply(conjunct),
                "Expected join predicate to be a valid join equality");

            ComparisonExpression equality = (ComparisonExpression) conjunct;

            boolean alignedComparison =
                Iterables.all(
                    DependencyExtractor.extractUnique(equality.getLeft()),
                    in(node.getLeft().getOutputSymbols()));
            Expression leftExpression =
                (alignedComparison) ? equality.getLeft() : equality.getRight();
            Expression rightExpression =
                (alignedComparison) ? equality.getRight() : equality.getLeft();

            Symbol leftSymbol =
                symbolAllocator.newSymbol(leftExpression, extractType(leftExpression));
            leftProjections.put(leftSymbol, leftExpression);
            Symbol rightSymbol =
                symbolAllocator.newSymbol(rightExpression, extractType(rightExpression));
            rightProjections.put(rightSymbol, rightExpression);

            builder.add(new JoinNode.EquiJoinClause(leftSymbol, rightSymbol));
          }

          leftSource =
              new ProjectNode(idAllocator.getNextId(), leftSource, leftProjections.build());
          rightSource =
              new ProjectNode(idAllocator.getNextId(), rightSource, rightProjections.build());
          criteria = builder.build();
        }
        output = new JoinNode(node.getId(), node.getType(), leftSource, rightSource, criteria);
      }
      if (!postJoinPredicate.equals(BooleanLiteral.TRUE_LITERAL)) {
        output = new FilterNode(idAllocator.getNextId(), output, postJoinPredicate);
      }
      return output;
    }

    private OuterJoinPushDownResult processOuterJoin(
        Expression inheritedPredicate,
        Expression outerEffectivePredicate,
        Expression innerEffectivePredicate,
        Expression joinPredicate,
        Collection<Symbol> outerSymbols) {
      checkArgument(
          Iterables.all(
              DependencyExtractor.extractUnique(outerEffectivePredicate), in(outerSymbols)),
          "outerEffectivePredicate must only contain symbols from outerSymbols");
      checkArgument(
          Iterables.all(
              DependencyExtractor.extractUnique(innerEffectivePredicate), not(in(outerSymbols))),
          "innerEffectivePredicate must not contain symbols from outerSymbols");

      ImmutableList.Builder<Expression> outerPushdownConjuncts = ImmutableList.builder();
      ImmutableList.Builder<Expression> innerPushdownConjuncts = ImmutableList.builder();
      ImmutableList.Builder<Expression> postJoinConjuncts = ImmutableList.builder();

      // Strip out non-deterministic conjuncts
      postJoinConjuncts.addAll(filter(extractConjuncts(inheritedPredicate), not(deterministic())));
      inheritedPredicate = stripNonDeterministicConjuncts(inheritedPredicate);

      outerEffectivePredicate = stripNonDeterministicConjuncts(outerEffectivePredicate);
      innerEffectivePredicate = stripNonDeterministicConjuncts(innerEffectivePredicate);
      joinPredicate = stripNonDeterministicConjuncts(joinPredicate);

      // Generate equality inferences
      EqualityInference inheritedInference = createEqualityInference(inheritedPredicate);
      EqualityInference outerInference =
          createEqualityInference(inheritedPredicate, outerEffectivePredicate);

      EqualityInference.EqualityPartition equalityPartition =
          inheritedInference.generateEqualitiesPartitionedBy(in(outerSymbols));
      Expression outerOnlyInheritedEqualities =
          combineConjuncts(equalityPartition.getScopeEqualities());
      EqualityInference potentialNullSymbolInference =
          createEqualityInference(
              outerOnlyInheritedEqualities,
              outerEffectivePredicate,
              innerEffectivePredicate,
              joinPredicate);
      EqualityInference potentialNullSymbolInferenceWithoutInnerInferred =
          createEqualityInference(
              outerOnlyInheritedEqualities, outerEffectivePredicate, joinPredicate);

      // Sort through conjuncts in inheritedPredicate that were not used for inference
      for (Expression conjunct : EqualityInference.nonInferrableConjuncts(inheritedPredicate)) {
        Expression outerRewritten = outerInference.rewriteExpression(conjunct, in(outerSymbols));
        if (outerRewritten != null) {
          outerPushdownConjuncts.add(outerRewritten);

          // A conjunct can only be pushed down into an inner side if it can be rewritten in terms
          // of the outer side
          Expression innerRewritten =
              potentialNullSymbolInference.rewriteExpression(outerRewritten, not(in(outerSymbols)));
          if (innerRewritten != null) {
            innerPushdownConjuncts.add(innerRewritten);
          }
        } else {
          postJoinConjuncts.add(conjunct);
        }
      }

      // See if we can push down any outer or join predicates to the inner side
      for (Expression conjunct :
          EqualityInference.nonInferrableConjuncts(and(outerEffectivePredicate, joinPredicate))) {
        Expression rewritten =
            potentialNullSymbolInference.rewriteExpression(conjunct, not(in(outerSymbols)));
        if (rewritten != null) {
          innerPushdownConjuncts.add(rewritten);
        }
      }

      // TODO: consider adding join predicate optimizations to outer joins

      // Add the equalities from the inferences back in
      outerPushdownConjuncts.addAll(equalityPartition.getScopeEqualities());
      postJoinConjuncts.addAll(equalityPartition.getScopeComplementEqualities());
      postJoinConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities());
      innerPushdownConjuncts.addAll(
          potentialNullSymbolInferenceWithoutInnerInferred
              .generateEqualitiesPartitionedBy(not(in(outerSymbols)))
              .getScopeEqualities());

      return new OuterJoinPushDownResult(
          combineConjuncts(outerPushdownConjuncts.build()),
          combineConjuncts(innerPushdownConjuncts.build()),
          combineConjuncts(postJoinConjuncts.build()));
    }

    private static class OuterJoinPushDownResult {
      private final Expression outerJoinPredicate;
      private final Expression innerJoinPredicate;
      private final Expression postJoinPredicate;

      private OuterJoinPushDownResult(
          Expression outerJoinPredicate,
          Expression innerJoinPredicate,
          Expression postJoinPredicate) {
        this.outerJoinPredicate = outerJoinPredicate;
        this.innerJoinPredicate = innerJoinPredicate;
        this.postJoinPredicate = postJoinPredicate;
      }

      private Expression getOuterJoinPredicate() {
        return outerJoinPredicate;
      }

      private Expression getInnerJoinPredicate() {
        return innerJoinPredicate;
      }

      private Expression getPostJoinPredicate() {
        return postJoinPredicate;
      }
    }

    private InnerJoinPushDownResult processInnerJoin(
        Expression inheritedPredicate,
        Expression leftEffectivePredicate,
        Expression rightEffectivePredicate,
        Expression joinPredicate,
        Collection<Symbol> leftSymbols) {
      checkArgument(
          Iterables.all(DependencyExtractor.extractUnique(leftEffectivePredicate), in(leftSymbols)),
          "leftEffectivePredicate must only contain symbols from leftSymbols");
      checkArgument(
          Iterables.all(
              DependencyExtractor.extractUnique(rightEffectivePredicate), not(in(leftSymbols))),
          "rightEffectivePredicate must not contain symbols from leftSymbols");

      ImmutableList.Builder<Expression> leftPushDownConjuncts = ImmutableList.builder();
      ImmutableList.Builder<Expression> rightPushDownConjuncts = ImmutableList.builder();
      ImmutableList.Builder<Expression> joinConjuncts = ImmutableList.builder();

      // Strip out non-deterministic conjuncts
      joinConjuncts.addAll(filter(extractConjuncts(inheritedPredicate), not(deterministic())));
      inheritedPredicate = stripNonDeterministicConjuncts(inheritedPredicate);

      joinConjuncts.addAll(filter(extractConjuncts(joinPredicate), not(deterministic())));
      joinPredicate = stripNonDeterministicConjuncts(joinPredicate);

      leftEffectivePredicate = stripNonDeterministicConjuncts(leftEffectivePredicate);
      rightEffectivePredicate = stripNonDeterministicConjuncts(rightEffectivePredicate);

      // Generate equality inferences
      EqualityInference allInference =
          createEqualityInference(
              inheritedPredicate, leftEffectivePredicate, rightEffectivePredicate, joinPredicate);
      EqualityInference allInferenceWithoutLeftInferred =
          createEqualityInference(inheritedPredicate, rightEffectivePredicate, joinPredicate);
      EqualityInference allInferenceWithoutRightInferred =
          createEqualityInference(inheritedPredicate, leftEffectivePredicate, joinPredicate);

      // Sort through conjuncts in inheritedPredicate that were not used for inference
      for (Expression conjunct : EqualityInference.nonInferrableConjuncts(inheritedPredicate)) {
        Expression leftRewrittenConjunct =
            allInference.rewriteExpression(conjunct, in(leftSymbols));
        if (leftRewrittenConjunct != null) {
          leftPushDownConjuncts.add(leftRewrittenConjunct);
        }

        Expression rightRewrittenConjunct =
            allInference.rewriteExpression(conjunct, not(in(leftSymbols)));
        if (rightRewrittenConjunct != null) {
          rightPushDownConjuncts.add(rightRewrittenConjunct);
        }

        // Drop predicate after join only if unable to push down to either side
        if (leftRewrittenConjunct == null && rightRewrittenConjunct == null) {
          joinConjuncts.add(conjunct);
        }
      }

      // See if we can push the right effective predicate to the left side
      for (Expression conjunct :
          EqualityInference.nonInferrableConjuncts(rightEffectivePredicate)) {
        Expression rewritten = allInference.rewriteExpression(conjunct, in(leftSymbols));
        if (rewritten != null) {
          leftPushDownConjuncts.add(rewritten);
        }
      }

      // See if we can push the left effective predicate to the right side
      for (Expression conjunct : EqualityInference.nonInferrableConjuncts(leftEffectivePredicate)) {
        Expression rewritten = allInference.rewriteExpression(conjunct, not(in(leftSymbols)));
        if (rewritten != null) {
          rightPushDownConjuncts.add(rewritten);
        }
      }

      // See if we can push any parts of the join predicates to either side
      for (Expression conjunct : EqualityInference.nonInferrableConjuncts(joinPredicate)) {
        Expression leftRewritten = allInference.rewriteExpression(conjunct, in(leftSymbols));
        if (leftRewritten != null) {
          leftPushDownConjuncts.add(leftRewritten);
        }

        Expression rightRewritten = allInference.rewriteExpression(conjunct, not(in(leftSymbols)));
        if (rightRewritten != null) {
          rightPushDownConjuncts.add(rightRewritten);
        }

        if (leftRewritten == null && rightRewritten == null) {
          joinConjuncts.add(conjunct);
        }
      }

      // Add equalities from the inference back in
      leftPushDownConjuncts.addAll(
          allInferenceWithoutLeftInferred
              .generateEqualitiesPartitionedBy(in(leftSymbols))
              .getScopeEqualities());
      rightPushDownConjuncts.addAll(
          allInferenceWithoutRightInferred
              .generateEqualitiesPartitionedBy(not(in(leftSymbols)))
              .getScopeEqualities());
      joinConjuncts.addAll(
          allInference
              .generateEqualitiesPartitionedBy(in(leftSymbols))
              .getScopeStraddlingEqualities()); // scope straddling equalities get dropped in as
      // part of the join predicate

      // Since we only currently support equality in join conjuncts, factor out the non-equality
      // conjuncts to a post-join filter
      List<Expression> joinConjunctsList = joinConjuncts.build();
      List<Expression> postJoinConjuncts =
          ImmutableList.copyOf(filter(joinConjunctsList, not(joinEqualityExpression(leftSymbols))));
      joinConjunctsList =
          ImmutableList.copyOf(filter(joinConjunctsList, joinEqualityExpression(leftSymbols)));

      return new InnerJoinPushDownResult(
          combineConjuncts(leftPushDownConjuncts.build()),
          combineConjuncts(rightPushDownConjuncts.build()),
          combineConjuncts(joinConjunctsList),
          combineConjuncts(postJoinConjuncts));
    }

    private static class InnerJoinPushDownResult {
      private final Expression leftPredicate;
      private final Expression rightPredicate;
      private final Expression joinPredicate;
      private final Expression postJoinPredicate;

      private InnerJoinPushDownResult(
          Expression leftPredicate,
          Expression rightPredicate,
          Expression joinPredicate,
          Expression postJoinPredicate) {
        this.leftPredicate = leftPredicate;
        this.rightPredicate = rightPredicate;
        this.joinPredicate = joinPredicate;
        this.postJoinPredicate = postJoinPredicate;
      }

      private Expression getLeftPredicate() {
        return leftPredicate;
      }

      private Expression getRightPredicate() {
        return rightPredicate;
      }

      private Expression getJoinPredicate() {
        return joinPredicate;
      }

      private Expression getPostJoinPredicate() {
        return postJoinPredicate;
      }
    }

    private static Expression extractJoinPredicate(JoinNode joinNode) {
      ImmutableList.Builder<Expression> builder = ImmutableList.builder();
      for (JoinNode.EquiJoinClause equiJoinClause : joinNode.getCriteria()) {
        builder.add(equalsExpression(equiJoinClause.getLeft(), equiJoinClause.getRight()));
      }
      return combineConjuncts(builder.build());
    }

    private static Expression equalsExpression(Symbol symbol1, Symbol symbol2) {
      return new ComparisonExpression(
          ComparisonExpression.Type.EQUAL,
          new QualifiedNameReference(symbol1.toQualifiedName()),
          new QualifiedNameReference(symbol2.toQualifiedName()));
    }

    // TODO: temporary addition to infer result type from expression. fix this with the new planner
    // refactoring (martint)
    private Type extractType(Expression expression) {
      ExpressionAnalyzer expressionAnalyzer =
          new ExpressionAnalyzer(new Analysis(), session, metadata, experimentalSyntaxEnabled);
      List<Field> fields =
          IterableTransformer.<Symbol>on(DependencyExtractor.extractUnique(expression))
              .transform(
                  new Function<Symbol, Field>() {
                    @Override
                    public Field apply(Symbol symbol) {
                      return Field.newUnqualified(
                          symbol.getName(), symbolAllocator.getTypes().get(symbol));
                    }
                  })
              .list();
      return expressionAnalyzer.analyze(
          expression, new TupleDescriptor(fields), new AnalysisContext());
    }

    private JoinNode tryNormalizeToInnerJoin(JoinNode node, Expression inheritedPredicate) {
      Preconditions.checkArgument(
          EnumSet.of(INNER, RIGHT, LEFT, CROSS).contains(node.getType()),
          "Unsupported join type: %s",
          node.getType());

      if (node.getType() == JoinNode.Type.CROSS) {
        return new JoinNode(
            node.getId(), JoinNode.Type.INNER, node.getLeft(), node.getRight(), node.getCriteria());
      }

      if (node.getType() == JoinNode.Type.INNER
          || node.getType() == JoinNode.Type.LEFT
              && !canConvertOuterToInner(node.getRight().getOutputSymbols(), inheritedPredicate)
          || node.getType() == JoinNode.Type.RIGHT
              && !canConvertOuterToInner(node.getLeft().getOutputSymbols(), inheritedPredicate)) {
        return node;
      }
      return new JoinNode(
          node.getId(), JoinNode.Type.INNER, node.getLeft(), node.getRight(), node.getCriteria());
    }

    private boolean canConvertOuterToInner(
        List<Symbol> innerSymbolsForOuterJoin, Expression inheritedPredicate) {
      Set<Symbol> innerSymbols = ImmutableSet.copyOf(innerSymbolsForOuterJoin);
      for (Expression conjunct : extractConjuncts(inheritedPredicate)) {
        if (DeterminismEvaluator.isDeterministic(conjunct)) {
          // Ignore a conjunct for this test if we can not deterministically get responses from it
          Object response = nullInputEvaluator(innerSymbols, conjunct);
          if (response == null
              || response instanceof NullLiteral
              || Boolean.FALSE.equals(response)) {
            // If there is a single conjunct that returns FALSE or NULL given all NULL inputs for
            // the inner side symbols of an outer join
            // then this conjunct removes all effects of the outer join, and effectively turns this
            // into an equivalent of an inner join.
            // So, let's just rewrite this join as an INNER join
            return true;
          }
        }
      }
      return false;
    }

    // Temporary implementation for joins because the SimplifyExpressions optimizers can not run
    // properly on join clauses
    private Function<Expression, Expression> simplifyExpressions() {
      return new Function<Expression, Expression>() {
        @Override
        public Expression apply(Expression expression) {
          ExpressionInterpreter optimizer =
              ExpressionInterpreter.expressionOptimizer(expression, metadata, session);
          return LiteralInterpreter.toExpression(optimizer.optimize(NoOpSymbolResolver.INSTANCE));
        }
      };
    }

    /** Evaluates an expression's response to binding the specified input symbols to NULL */
    private Object nullInputEvaluator(final Collection<Symbol> nullSymbols, Expression expression) {
      return ExpressionInterpreter.expressionOptimizer(expression, metadata, session)
          .optimize(
              new SymbolResolver() {
                @Override
                public Object getValue(Symbol symbol) {
                  return nullSymbols.contains(symbol)
                      ? null
                      : new QualifiedNameReference(symbol.toQualifiedName());
                }
              });
    }

    private static Predicate<Expression> joinEqualityExpression(
        final Collection<Symbol> leftSymbols) {
      return new Predicate<Expression>() {
        @Override
        public boolean apply(Expression expression) {
          // At this point in time, our join predicates need to be deterministic
          if (isDeterministic(expression) && expression instanceof ComparisonExpression) {
            ComparisonExpression comparison = (ComparisonExpression) expression;
            if (comparison.getType() == ComparisonExpression.Type.EQUAL) {
              Set<Symbol> symbols1 = DependencyExtractor.extractUnique(comparison.getLeft());
              Set<Symbol> symbols2 = DependencyExtractor.extractUnique(comparison.getRight());
              return (Iterables.all(symbols1, in(leftSymbols))
                      && Iterables.all(symbols2, not(in(leftSymbols))))
                  || (Iterables.all(symbols2, in(leftSymbols))
                      && Iterables.all(symbols1, not(in(leftSymbols))));
            }
          }
          return false;
        }
      };
    }

    @Override
    public PlanNode rewriteSemiJoin(
        SemiJoinNode node, Expression inheritedPredicate, PlanRewriter<Expression> planRewriter) {
      Expression sourceEffectivePredicate = EffectivePredicateExtractor.extract(node.getSource());

      List<Expression> sourceConjuncts = new ArrayList<>();
      List<Expression> filteringSourceConjuncts = new ArrayList<>();
      List<Expression> postJoinConjuncts = new ArrayList<>();

      // TODO: see if there are predicates that can be inferred from the semi join output

      // Push inherited and source predicates to filtering source via a contrived join predicate
      // (but needs to avoid touching NULL values in the filtering source)
      Expression joinPredicate =
          equalsExpression(node.getSourceJoinSymbol(), node.getFilteringSourceJoinSymbol());
      EqualityInference joinInference =
          createEqualityInference(inheritedPredicate, sourceEffectivePredicate, joinPredicate);
      for (Expression conjunct :
          Iterables.concat(
              EqualityInference.nonInferrableConjuncts(inheritedPredicate),
              EqualityInference.nonInferrableConjuncts(sourceEffectivePredicate))) {
        Expression rewrittenConjunct =
            joinInference.rewriteExpression(conjunct, equalTo(node.getFilteringSourceJoinSymbol()));
        if (rewrittenConjunct != null && DeterminismEvaluator.isDeterministic(rewrittenConjunct)) {
          // Alter conjunct to include an OR filteringSourceJoinSymbol IS NULL disjunct
          Expression rewrittenConjunctOrNull =
              expressionOrNullSymbols(equalTo(node.getFilteringSourceJoinSymbol()))
                  .apply(rewrittenConjunct);
          filteringSourceConjuncts.add(rewrittenConjunctOrNull);
        }
      }
      EqualityInference.EqualityPartition joinInferenceEqualityPartition =
          joinInference.generateEqualitiesPartitionedBy(
              equalTo(node.getFilteringSourceJoinSymbol()));
      filteringSourceConjuncts.addAll(
          ImmutableList.copyOf(
              transform(
                  joinInferenceEqualityPartition.getScopeEqualities(),
                  expressionOrNullSymbols(equalTo(node.getFilteringSourceJoinSymbol())))));

      // Push inheritedPredicates down to the source if they don't involve the semi join output
      EqualityInference inheritedInference = createEqualityInference(inheritedPredicate);
      for (Expression conjunct : EqualityInference.nonInferrableConjuncts(inheritedPredicate)) {
        Expression rewrittenConjunct =
            inheritedInference.rewriteExpression(conjunct, in(node.getSource().getOutputSymbols()));
        // Since each source row is reflected exactly once in the output, ok to push
        // non-deterministic predicates down
        if (rewrittenConjunct != null) {
          sourceConjuncts.add(rewrittenConjunct);
        } else {
          postJoinConjuncts.add(conjunct);
        }
      }

      // Add the inherited equality predicates back in
      EqualityInference.EqualityPartition equalityPartition =
          inheritedInference.generateEqualitiesPartitionedBy(
              in(node.getSource().getOutputSymbols()));
      sourceConjuncts.addAll(equalityPartition.getScopeEqualities());
      postJoinConjuncts.addAll(equalityPartition.getScopeComplementEqualities());
      postJoinConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities());

      PlanNode rewrittenSource =
          planRewriter.rewrite(node.getSource(), combineConjuncts(sourceConjuncts));
      PlanNode rewrittenFilteringSource =
          planRewriter.rewrite(
              node.getFilteringSource(), combineConjuncts(filteringSourceConjuncts));

      PlanNode output = node;
      if (rewrittenSource != node.getSource()
          || rewrittenFilteringSource != node.getFilteringSource()) {
        output =
            new SemiJoinNode(
                node.getId(),
                rewrittenSource,
                rewrittenFilteringSource,
                node.getSourceJoinSymbol(),
                node.getFilteringSourceJoinSymbol(),
                node.getSemiJoinOutput());
      }
      if (!postJoinConjuncts.isEmpty()) {
        output =
            new FilterNode(idAllocator.getNextId(), output, combineConjuncts(postJoinConjuncts));
      }
      return output;
    }

    @Override
    public PlanNode rewriteAggregation(
        AggregationNode node,
        Expression inheritedPredicate,
        PlanRewriter<Expression> planRewriter) {
      EqualityInference equalityInference = createEqualityInference(inheritedPredicate);

      List<Expression> pushdownConjuncts = new ArrayList<>();
      List<Expression> postAggregationConjuncts = new ArrayList<>();

      // Strip out non-deterministic conjuncts
      postAggregationConjuncts.addAll(
          ImmutableList.copyOf(filter(extractConjuncts(inheritedPredicate), not(deterministic()))));
      inheritedPredicate = stripNonDeterministicConjuncts(inheritedPredicate);

      // Sort non-equality predicates by those that can be pushed down and those that cannot
      for (Expression conjunct : EqualityInference.nonInferrableConjuncts(inheritedPredicate)) {
        Expression rewrittenConjunct =
            equalityInference.rewriteExpression(conjunct, in(node.getGroupBy()));
        if (rewrittenConjunct != null) {
          pushdownConjuncts.add(rewrittenConjunct);
        } else {
          postAggregationConjuncts.add(conjunct);
        }
      }

      // Add the equality predicates back in
      EqualityInference.EqualityPartition equalityPartition =
          equalityInference.generateEqualitiesPartitionedBy(in(node.getGroupBy()));
      pushdownConjuncts.addAll(equalityPartition.getScopeEqualities());
      postAggregationConjuncts.addAll(equalityPartition.getScopeComplementEqualities());
      postAggregationConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities());

      PlanNode rewrittenSource =
          planRewriter.rewrite(node.getSource(), combineConjuncts(pushdownConjuncts));

      PlanNode output = node;
      if (rewrittenSource != node.getSource()) {
        output =
            new AggregationNode(
                node.getId(),
                rewrittenSource,
                node.getGroupBy(),
                node.getAggregations(),
                node.getFunctions(),
                node.getMasks(),
                node.getStep(),
                node.getSampleWeight(),
                node.getConfidence());
      }
      if (!postAggregationConjuncts.isEmpty()) {
        output =
            new FilterNode(
                idAllocator.getNextId(), output, combineConjuncts(postAggregationConjuncts));
      }
      return output;
    }

    @Override
    public PlanNode rewriteSample(
        SampleNode node, Expression inheritedPredicate, PlanRewriter<Expression> planRewriter) {
      return planRewriter.defaultRewrite(node, inheritedPredicate);
    }

    @Override
    public PlanNode rewriteTableScan(
        TableScanNode node, Expression inheritedPredicate, PlanRewriter<Expression> planRewriter) {
      DomainTranslator.ExtractionResult extractionResult =
          DomainTranslator.fromPredicate(
              inheritedPredicate, symbolAllocator.getTypes(), node.getAssignments());
      Expression extractionRemainingExpression = extractionResult.getRemainingExpression();
      TupleDomain tupleDomain = extractionResult.getTupleDomain();

      if (node.getGeneratedPartitions().isPresent()) {
        // Add back in the TupleDomain that was used to generate the previous set of Partitions if
        // present
        // And just for kicks, throw in the domain summary too (as that can only help prune down the
        // ranges)
        // The domains should never widen between each pass.
        tupleDomain =
            tupleDomain
                .intersect(node.getGeneratedPartitions().get().getTupleDomainInput())
                .intersect(node.getPartitionsDomainSummary());
      }

      PartitionResult matchingPartitions =
          splitManager.getPartitions(node.getTable(), Optional.of(tupleDomain));
      List<Partition> partitions = matchingPartitions.getPartitions();
      TupleDomain undeterminedTupleDomain = matchingPartitions.getUndeterminedTupleDomain();

      Expression unevaluatedDomainPredicate =
          DomainTranslator.toPredicate(
              undeterminedTupleDomain, ImmutableBiMap.copyOf(node.getAssignments()).inverse());

      // Construct the post scan predicate. Add the unevaluated TupleDomain back first since those
      // are generally cheaper to evaluate than anything we can't extract
      Expression postScanPredicate =
          combineConjuncts(unevaluatedDomainPredicate, extractionRemainingExpression);

      // Do some early partition pruning
      partitions =
          ImmutableList.copyOf(
              filter(
                  partitions, not(shouldPrunePartition(postScanPredicate, node.getAssignments()))));
      GeneratedPartitions generatedPartitions = new GeneratedPartitions(tupleDomain, partitions);

      PlanNode output = node;
      if (!node.getGeneratedPartitions().equals(Optional.of(generatedPartitions))) {
        // Only overwrite the originalConstraint if it was previously null
        Expression originalConstraint =
            node.getOriginalConstraint() == null
                ? inheritedPredicate
                : node.getOriginalConstraint();
        output =
            new TableScanNode(
                node.getId(),
                node.getTable(),
                node.getOutputSymbols(),
                node.getAssignments(),
                originalConstraint,
                Optional.of(generatedPartitions));
      }
      if (!postScanPredicate.equals(BooleanLiteral.TRUE_LITERAL)) {
        output = new FilterNode(idAllocator.getNextId(), output, postScanPredicate);
      }
      return output;
    }

    private Predicate<Partition> shouldPrunePartition(
        final Expression predicate, final Map<Symbol, ColumnHandle> symbolToColumn) {
      return new Predicate<Partition>() {
        @Override
        public boolean apply(Partition partition) {
          Map<ColumnHandle, Comparable<?>> columnFixedValueAssignments =
              partition.getTupleDomain().extractFixedValues();
          Map<ColumnHandle, Comparable<?>> translatableAssignments =
              Maps.filterKeys(columnFixedValueAssignments, in(symbolToColumn.values()));
          Map<Symbol, Comparable<?>> symbolFixedValueAssignments =
              DomainUtils.columnHandleToSymbol(translatableAssignments, symbolToColumn);

          LookupSymbolResolver inputs =
              new LookupSymbolResolver(
                  ImmutableMap.<Symbol, Object>copyOf(symbolFixedValueAssignments));

          // If any conjuncts evaluate to FALSE or null, then the whole predicate will never be true
          // and so the partition should be pruned
          for (Expression expression : extractConjuncts(predicate)) {
            ExpressionInterpreter optimizer =
                ExpressionInterpreter.expressionOptimizer(expression, metadata, session);
            Object optimized = optimizer.optimize(inputs);
            if (Boolean.FALSE.equals(optimized)
                || optimized == null
                || optimized instanceof NullLiteral) {
              return true;
            }
          }
          return false;
        }
      };
    }
  }
}
Example #13
0
//
// NOTE:  As a general strategy the methods should "stage" a change and only
// process the actual change before lock release (DriverLockResult.close()).
// The assures that only one thread will be working with the operators at a
// time and state changer threads are not blocked.
//
public class Driver {
  private static final Logger log = Logger.get(Driver.class);

  private final DriverContext driverContext;
  private final List<Operator> operators;
  private final Map<PlanNodeId, SourceOperator> sourceOperators;
  private final ConcurrentMap<PlanNodeId, TaskSource> newSources = new ConcurrentHashMap<>();

  private final AtomicReference<State> state = new AtomicReference<>(State.ALIVE);

  private final ReentrantLock exclusiveLock = new ReentrantLock();

  @GuardedBy("this")
  private Thread lockHolder;

  @GuardedBy("exclusiveLock")
  private Map<PlanNodeId, TaskSource> currentSources = new ConcurrentHashMap<>();

  private enum State {
    ALIVE,
    NEED_DESTRUCTION,
    DESTROYED
  }

  public Driver(DriverContext driverContext, Operator firstOperator, Operator... otherOperators) {
    this(
        checkNotNull(driverContext, "driverContext is null"),
        ImmutableList.<Operator>builder()
            .add(checkNotNull(firstOperator, "firstOperator is null"))
            .add(checkNotNull(otherOperators, "otherOperators is null"))
            .build());
  }

  public Driver(DriverContext driverContext, List<Operator> operators) {
    this.driverContext = checkNotNull(driverContext, "driverContext is null");
    this.operators = ImmutableList.copyOf(checkNotNull(operators, "operators is null"));
    checkArgument(!operators.isEmpty(), "There must be at least one operator");

    ImmutableMap.Builder<PlanNodeId, SourceOperator> sourceOperators = ImmutableMap.builder();
    for (Operator operator : operators) {
      if (operator instanceof SourceOperator) {
        SourceOperator sourceOperator = (SourceOperator) operator;
        sourceOperators.put(sourceOperator.getSourceId(), sourceOperator);
      }
    }
    this.sourceOperators = sourceOperators.build();
  }

  public DriverContext getDriverContext() {
    return driverContext;
  }

  public Set<PlanNodeId> getSourceIds() {
    return sourceOperators.keySet();
  }

  public void close() {
    // mark the service for destruction
    if (!state.compareAndSet(State.ALIVE, State.NEED_DESTRUCTION)) {
      return;
    }

    // if we can get the lock, attempt a clean shutdown; otherwise someone else will shutdown
    try (DriverLockResult lockResult =
        tryLockAndProcessPendingStateChanges(0, TimeUnit.MILLISECONDS)) {
      // if we did not get the lock, interrupt the lock holder
      if (!lockResult.wasAcquired()) {
        // there is a benign race condition here were the lock holder
        // can be change between attempting to get lock and grabbing
        // the synchronized lock here, but in either case we want to
        // interrupt the lock holder thread
        synchronized (this) {
          if (lockHolder != null) {
            lockHolder.interrupt();
          }
        }
      }

      // clean shutdown is automatically triggered during lock release
    }
  }

  public boolean isFinished() {
    checkLockNotHeld("Can not check finished status while holding the driver lock");

    // if we can get the lock, attempt a clean shutdown; otherwise someone else will shutdown
    try (DriverLockResult lockResult =
        tryLockAndProcessPendingStateChanges(0, TimeUnit.MILLISECONDS)) {
      if (lockResult.wasAcquired()) {
        boolean finished =
            state.get() != State.ALIVE
                || driverContext.isDone()
                || operators.get(operators.size() - 1).isFinished();
        if (finished) {
          state.compareAndSet(State.ALIVE, State.NEED_DESTRUCTION);
        }
        return finished;
      } else {
        // did not get the lock, so we can't check operators, or destroy
        return state.get() != State.ALIVE || driverContext.isDone();
      }
    }
  }

  public void updateSource(TaskSource source) {
    checkLockNotHeld("Can not update sources while holding the driver lock");

    // does this driver have an operator for the specified source?
    if (!sourceOperators.containsKey(source.getPlanNodeId())) {
      return;
    }

    // stage the new updates
    while (true) {
      // attempt to update directly to the new source
      TaskSource currentNewSource = newSources.putIfAbsent(source.getPlanNodeId(), source);

      // if update succeeded, just break
      if (currentNewSource == null) {
        break;
      }

      // merge source into the current new source
      TaskSource newSource = currentNewSource.update(source);

      // if this is not a new source, just return
      if (newSource == currentNewSource) {
        break;
      }

      // attempt to replace the currentNewSource with the new source
      if (newSources.replace(source.getPlanNodeId(), currentNewSource, newSource)) {
        break;
      }

      // someone else updated while we were processing
    }

    // attempt to get the lock and process the updates we staged above
    // updates will be processed in close if and only if we got the lock
    tryLockAndProcessPendingStateChanges(0, TimeUnit.MILLISECONDS).close();
  }

  private void processNewSources() {
    checkLockHeld("Lock must be held to call processNewSources");

    // only update if the driver is still alive
    if (state.get() != State.ALIVE) {
      return;
    }

    // copy the pending sources
    // it is ok to "miss" a source added during the copy as it will be
    // handled on the next call to this method
    Map<PlanNodeId, TaskSource> sources = new HashMap<>(newSources);
    for (Entry<PlanNodeId, TaskSource> entry : sources.entrySet()) {
      // Remove the entries we are going to process from the newSources map.
      // It is ok if someone already updated the entry; we will catch it on
      // the next iteration.
      newSources.remove(entry.getKey(), entry.getValue());

      processNewSource(entry.getValue());
    }
  }

  private void processNewSource(TaskSource source) {
    checkLockHeld("Lock must be held to call processNewSources");

    // create new source
    Set<ScheduledSplit> newSplits;
    TaskSource currentSource = currentSources.get(source.getPlanNodeId());
    if (currentSource == null) {
      newSplits = source.getSplits();
      currentSources.put(source.getPlanNodeId(), source);
    } else {
      // merge the current source and the specified source
      TaskSource newSource = currentSource.update(source);

      // if this is not a new source, just return
      if (newSource == currentSource) {
        return;
      }

      // find the new splits to add
      newSplits = Sets.difference(newSource.getSplits(), currentSource.getSplits());
      currentSources.put(source.getPlanNodeId(), newSource);
    }

    // add new splits
    for (ScheduledSplit newSplit : newSplits) {
      Split split = newSplit.getSplit();

      SourceOperator sourceOperator = sourceOperators.get(source.getPlanNodeId());
      if (sourceOperator != null) {
        sourceOperator.addSplit(split);
      }
    }

    // set no more splits
    if (source.isNoMoreSplits()) {
      sourceOperators.get(source.getPlanNodeId()).noMoreSplits();
    }
  }

  public ListenableFuture<?> processFor(Duration duration) {
    checkLockNotHeld("Can not process for a duration while holding the driver lock");

    checkNotNull(duration, "duration is null");

    long maxRuntime = duration.roundTo(TimeUnit.NANOSECONDS);

    long start = System.nanoTime();
    do {
      ListenableFuture<?> future = process();
      if (!future.isDone()) {
        return future;
      }
    } while (System.nanoTime() - start < maxRuntime && !isFinished());

    return NOT_BLOCKED;
  }

  public ListenableFuture<?> process() {
    checkLockNotHeld("Can not process while holding the driver lock");

    try (DriverLockResult lockResult =
        tryLockAndProcessPendingStateChanges(100, TimeUnit.MILLISECONDS)) {
      try {
        if (!lockResult.wasAcquired()) {
          // this is unlikely to happen unless the driver is being
          // destroyed and in that case the caller should notice notice
          // this state change by calling isFinished
          return NOT_BLOCKED;
        }

        driverContext.start();

        if (!newSources.isEmpty()) {
          processNewSources();
        }

        for (int i = 0; i < operators.size() - 1 && !driverContext.isDone(); i++) {
          // check if current operator is blocked
          Operator current = operators.get(i);
          ListenableFuture<?> blocked = current.isBlocked();
          if (!blocked.isDone()) {
            current.getOperatorContext().recordBlocked(blocked);
            return blocked;
          }

          // check if next operator is blocked
          Operator next = operators.get(i + 1);
          blocked = next.isBlocked();
          if (!blocked.isDone()) {
            next.getOperatorContext().recordBlocked(blocked);
            return blocked;
          }

          // if current operator is finished...
          if (current.isFinished()) {
            // let next operator know there will be no more data
            next.getOperatorContext().startIntervalTimer();
            next.finish();
            next.getOperatorContext().recordFinish();
          } else {
            // if next operator needs input...
            if (next.needsInput()) {
              // get an output page from current operator
              current.getOperatorContext().startIntervalTimer();
              Page page = current.getOutput();
              current.getOperatorContext().recordGetOutput(page);

              // if we got an output page, add it to the next operator
              if (page != null) {
                next.getOperatorContext().startIntervalTimer();
                next.addInput(page);
                next.getOperatorContext().recordAddInput(page);
              }
            }
          }
        }
        return NOT_BLOCKED;
      } catch (Throwable t) {
        driverContext.failed(t);
        throw t;
      }
    }
  }

  private void destroyIfNecessary() {
    checkLockHeld("Lock must be held to call destroyIfNecessary");

    if (!state.compareAndSet(State.NEED_DESTRUCTION, State.DESTROYED)) {
      return;
    }

    Throwable inFlightException = null;
    try {
      // call finish on every operator; if error occurs, just bail out; we will still call close
      for (Operator operator : operators) {
        operator.finish();
      }
    } catch (Throwable t) {
      // record in flight exception so we can add suppressed exceptions below
      inFlightException = t;
      throw t;
    } finally {
      // record the current interrupted status (and clear the flag); we'll reset it later
      boolean wasInterrupted = Thread.interrupted();

      // if we get an error while closing a driver, record it and we will throw it at the end
      try {
        for (Operator operator : operators) {
          if (operator instanceof AutoCloseable) {
            try {
              ((AutoCloseable) operator).close();
            } catch (InterruptedException t) {
              // don't record the stack
              wasInterrupted = true;
            } catch (Throwable t) {
              inFlightException =
                  addSuppressedException(
                      inFlightException,
                      t,
                      "Error closing operator %s for task %s",
                      operator.getOperatorContext().getOperatorId(),
                      driverContext.getTaskId());
            }
          }
        }
        driverContext.finished();
      } catch (Throwable t) {
        // this shouldn't happen but be safe
        inFlightException =
            addSuppressedException(
                inFlightException,
                t,
                "Error destroying driver for task %s",
                driverContext.getTaskId());
      } finally {
        // reset the interrupted flag
        if (wasInterrupted) {
          Thread.currentThread().interrupt();
        }
      }

      if (inFlightException != null) {
        // this will always be an Error or Runtime
        throw Throwables.propagate(inFlightException);
      }
    }
  }

  private Throwable addSuppressedException(
      Throwable inFlightException, Throwable newException, String message, Object... args) {
    if (newException instanceof Error) {
      if (inFlightException == null) {
        inFlightException = newException;
      } else {
        inFlightException.addSuppressed(newException);
      }
    } else {
      // log normal exceptions instead of rethrowing them
      log.error(newException, message, args);
    }
    return inFlightException;
  }

  private DriverLockResult tryLockAndProcessPendingStateChanges(int timeout, TimeUnit unit) {
    checkLockNotHeld("Can not acquire the driver lock while already holding the driver lock");

    return new DriverLockResult(timeout, unit);
  }

  private synchronized void checkLockNotHeld(String message) {
    checkState(Thread.currentThread() != lockHolder, message);
  }

  private synchronized void checkLockHeld(String message) {
    checkState(Thread.currentThread() == lockHolder, message);
  }

  private class DriverLockResult implements AutoCloseable {
    private final boolean acquired;

    private DriverLockResult(int timeout, TimeUnit unit) {
      boolean acquired = false;
      try {
        acquired = exclusiveLock.tryLock(timeout, unit);
      } catch (InterruptedException e) {
        Thread.currentThread().interrupt();
      }
      this.acquired = acquired;

      if (acquired) {
        synchronized (Driver.this) {
          lockHolder = Thread.currentThread();
        }
      }
    }

    public boolean wasAcquired() {
      return acquired;
    }

    @Override
    public void close() {
      if (!acquired) {
        return;
      }

      // before releasing the lock, process any new sources and/or destroy the driver
      try {
        try {
          processNewSources();
        } finally {
          destroyIfNecessary();
        }
      } finally {
        synchronized (Driver.this) {
          lockHolder = null;
        }
        exclusiveLock.unlock();
      }
    }
  }
}
Example #14
0
public class ClusterMemoryManager implements ClusterMemoryPoolManager {
  private static final Logger log = Logger.get(ClusterMemoryManager.class);
  private final ExecutorService listenerExecutor = Executors.newSingleThreadExecutor();
  private final NodeManager nodeManager;
  private final LocationFactory locationFactory;
  private final HttpClient httpClient;
  private final MBeanExporter exporter;
  private final JsonCodec<MemoryInfo> memoryInfoCodec;
  private final JsonCodec<MemoryPoolAssignmentsRequest> assignmentsRequestJsonCodec;
  private final DataSize maxQueryMemory;
  private final Duration maxQueryCpuTime;
  private final boolean enabled;
  private final boolean killOnOutOfMemory;
  private final Duration killOnOutOfMemoryDelay;
  private final String coordinatorId;
  private final AtomicLong memoryPoolAssignmentsVersion = new AtomicLong();
  private final AtomicLong clusterMemoryUsageBytes = new AtomicLong();
  private final AtomicLong clusterMemoryBytes = new AtomicLong();
  private final AtomicLong queriesKilledDueToOutOfMemory = new AtomicLong();
  private final Map<String, RemoteNodeMemory> nodes = new HashMap<>();

  @GuardedBy("this")
  private final Map<MemoryPoolId, List<Consumer<MemoryPoolInfo>>> changeListeners = new HashMap<>();

  @GuardedBy("this")
  private final Map<MemoryPoolId, ClusterMemoryPool> pools = new HashMap<>();

  @GuardedBy("this")
  private long lastTimeNotOutOfMemory = System.nanoTime();

  @GuardedBy("this")
  private QueryId lastKilledQuery;

  @Inject
  public ClusterMemoryManager(
      @ForMemoryManager HttpClient httpClient,
      NodeManager nodeManager,
      LocationFactory locationFactory,
      MBeanExporter exporter,
      JsonCodec<MemoryInfo> memoryInfoCodec,
      JsonCodec<MemoryPoolAssignmentsRequest> assignmentsRequestJsonCodec,
      QueryIdGenerator queryIdGenerator,
      ServerConfig serverConfig,
      MemoryManagerConfig config,
      QueryManagerConfig queryManagerConfig) {
    requireNonNull(config, "config is null");
    this.nodeManager = requireNonNull(nodeManager, "nodeManager is null");
    this.locationFactory = requireNonNull(locationFactory, "locationFactory is null");
    this.httpClient = requireNonNull(httpClient, "httpClient is null");
    this.exporter = requireNonNull(exporter, "exporter is null");
    this.memoryInfoCodec = requireNonNull(memoryInfoCodec, "memoryInfoCodec is null");
    this.assignmentsRequestJsonCodec =
        requireNonNull(assignmentsRequestJsonCodec, "assignmentsRequestJsonCodec is null");
    this.maxQueryMemory = config.getMaxQueryMemory();
    this.maxQueryCpuTime = queryManagerConfig.getQueryMaxCpuTime();
    this.coordinatorId = queryIdGenerator.getCoordinatorId();
    this.enabled = serverConfig.isCoordinator();
    this.killOnOutOfMemoryDelay = config.getKillOnOutOfMemoryDelay();
    this.killOnOutOfMemory = config.isKillOnOutOfMemory();
  }

  @Override
  public synchronized void addChangeListener(
      MemoryPoolId poolId, Consumer<MemoryPoolInfo> listener) {
    changeListeners.computeIfAbsent(poolId, id -> new ArrayList<>()).add(listener);
  }

  public synchronized void process(Iterable<QueryExecution> queries) {
    if (!enabled) {
      return;
    }

    boolean outOfMemory = isClusterOutOfMemory();
    if (!outOfMemory) {
      lastTimeNotOutOfMemory = System.nanoTime();
    }

    boolean queryKilled = false;
    long totalBytes = 0;
    for (QueryExecution query : queries) {
      long bytes = query.getTotalMemoryReservation();
      DataSize sessionMaxQueryMemory = getQueryMaxMemory(query.getSession());
      long queryMemoryLimit = Math.min(maxQueryMemory.toBytes(), sessionMaxQueryMemory.toBytes());
      totalBytes += bytes;
      if (resourceOvercommit(query.getSession()) && outOfMemory) {
        // If a query has requested resource overcommit, only kill it if the cluster has run out of
        // memory
        DataSize memory = succinctBytes(bytes);
        query.fail(
            new PrestoException(
                CLUSTER_OUT_OF_MEMORY,
                format(
                    "The cluster is out of memory and %s=true, so this query was killed. It was using %s of memory",
                    RESOURCE_OVERCOMMIT, memory)));
        queryKilled = true;
      }
      if (!resourceOvercommit(query.getSession()) && bytes > queryMemoryLimit) {
        DataSize maxMemory = succinctBytes(queryMemoryLimit);
        query.fail(exceededGlobalLimit(maxMemory));
        queryKilled = true;
      }
    }
    clusterMemoryUsageBytes.set(totalBytes);

    if (killOnOutOfMemory) {
      boolean shouldKillQuery =
          nanosSince(lastTimeNotOutOfMemory).compareTo(killOnOutOfMemoryDelay) > 0 && outOfMemory;
      boolean lastKilledQueryIsGone = (lastKilledQuery == null);

      if (!lastKilledQueryIsGone) {
        ClusterMemoryPool generalPool = pools.get(GENERAL_POOL);
        if (generalPool != null) {
          lastKilledQueryIsGone =
              generalPool.getQueryMemoryReservations().containsKey(lastKilledQuery);
        }
      }

      if (shouldKillQuery && lastKilledQueryIsGone && !queryKilled) {
        // Kill the biggest query in the general pool
        QueryExecution biggestQuery = null;
        long maxMemory = -1;
        for (QueryExecution query : queries) {
          long bytesUsed = query.getTotalMemoryReservation();
          if (bytesUsed > maxMemory && query.getMemoryPool().getId().equals(GENERAL_POOL)) {
            biggestQuery = query;
            maxMemory = bytesUsed;
          }
        }
        if (biggestQuery != null) {
          biggestQuery.fail(
              new PrestoException(
                  CLUSTER_OUT_OF_MEMORY,
                  "The cluster is out of memory, and your query was killed. Please try again in a few minutes."));
          queriesKilledDueToOutOfMemory.incrementAndGet();
          lastKilledQuery = biggestQuery.getQueryId();
        }
      }
    }

    Map<MemoryPoolId, Integer> countByPool = new HashMap<>();
    for (QueryExecution query : queries) {
      MemoryPoolId id = query.getMemoryPool().getId();
      countByPool.put(id, countByPool.getOrDefault(id, 0) + 1);
    }

    updatePools(countByPool);

    updateNodes(updateAssignments(queries));

    // check if CPU usage is over limit
    for (QueryExecution query : queries) {
      Duration cpuTime = query.getTotalCpuTime();
      Duration sessionLimit = getQueryMaxCpuTime(query.getSession());
      Duration limit = maxQueryCpuTime.compareTo(sessionLimit) < 0 ? maxQueryCpuTime : sessionLimit;
      if (cpuTime.compareTo(limit) > 0) {
        query.fail(new ExceededCpuLimitException(limit));
      }
    }
  }

  @VisibleForTesting
  synchronized Map<MemoryPoolId, ClusterMemoryPool> getPools() {
    return ImmutableMap.copyOf(pools);
  }

  private synchronized boolean isClusterOutOfMemory() {
    ClusterMemoryPool reservedPool = pools.get(RESERVED_POOL);
    ClusterMemoryPool generalPool = pools.get(GENERAL_POOL);
    return reservedPool != null
        && generalPool != null
        && reservedPool.getAssignedQueries() > 0
        && generalPool.getBlockedNodes() > 0;
  }

  private synchronized MemoryPoolAssignmentsRequest updateAssignments(
      Iterable<QueryExecution> queries) {
    ClusterMemoryPool reservedPool = pools.get(RESERVED_POOL);
    ClusterMemoryPool generalPool = pools.get(GENERAL_POOL);
    long version = memoryPoolAssignmentsVersion.incrementAndGet();
    // Check that all previous assignments have propagated to the visible nodes. This doesn't
    // account for temporary network issues,
    // and is more of a safety check than a guarantee
    if (reservedPool != null && generalPool != null && allAssignmentsHavePropagated(queries)) {
      if (reservedPool.getAssignedQueries() == 0 && generalPool.getBlockedNodes() > 0) {
        QueryExecution biggestQuery = null;
        long maxMemory = -1;
        for (QueryExecution queryExecution : queries) {
          if (resourceOvercommit(queryExecution.getSession())) {
            // Don't promote queries that requested resource overcommit to the reserved pool,
            // since their memory usage is unbounded.
            continue;
          }
          long bytesUsed = queryExecution.getTotalMemoryReservation();
          if (bytesUsed > maxMemory) {
            biggestQuery = queryExecution;
            maxMemory = bytesUsed;
          }
        }
        if (biggestQuery != null) {
          biggestQuery.setMemoryPool(new VersionedMemoryPoolId(RESERVED_POOL, version));
        }
      }
    }

    ImmutableList.Builder<MemoryPoolAssignment> assignments = ImmutableList.builder();
    for (QueryExecution queryExecution : queries) {
      assignments.add(
          new MemoryPoolAssignment(
              queryExecution.getQueryId(), queryExecution.getMemoryPool().getId()));
    }
    return new MemoryPoolAssignmentsRequest(coordinatorId, version, assignments.build());
  }

  private boolean allAssignmentsHavePropagated(Iterable<QueryExecution> queries) {
    if (nodes.isEmpty()) {
      // Assignments can't have propagated, if there are no visible nodes.
      return false;
    }
    long newestAssignment =
        ImmutableList.copyOf(queries)
            .stream()
            .map(QueryExecution::getMemoryPool)
            .mapToLong(VersionedMemoryPoolId::getVersion)
            .min()
            .orElse(-1);

    long mostOutOfDateNode =
        nodes
            .values()
            .stream()
            .mapToLong(RemoteNodeMemory::getCurrentAssignmentVersion)
            .min()
            .orElse(Long.MAX_VALUE);

    return newestAssignment <= mostOutOfDateNode;
  }

  private void updateNodes(MemoryPoolAssignmentsRequest assignments) {
    ImmutableSet.Builder<Node> builder = ImmutableSet.builder();
    Set<Node> aliveNodes =
        builder
            .addAll(nodeManager.getNodes(ACTIVE))
            .addAll(nodeManager.getNodes(SHUTTING_DOWN))
            .build();

    ImmutableSet<String> aliveNodeIds =
        aliveNodes.stream().map(Node::getNodeIdentifier).collect(toImmutableSet());

    // Remove nodes that don't exist anymore
    // Make a copy to materialize the set difference
    Set<String> deadNodes = ImmutableSet.copyOf(difference(nodes.keySet(), aliveNodeIds));
    nodes.keySet().removeAll(deadNodes);

    // Add new nodes
    for (Node node : aliveNodes) {
      if (!nodes.containsKey(node.getNodeIdentifier())) {
        nodes.put(
            node.getNodeIdentifier(),
            new RemoteNodeMemory(
                httpClient,
                memoryInfoCodec,
                assignmentsRequestJsonCodec,
                locationFactory.createMemoryInfoLocation(node)));
      }
    }

    // Schedule refresh
    for (RemoteNodeMemory node : nodes.values()) {
      node.asyncRefresh(assignments);
    }
  }

  private synchronized void updatePools(Map<MemoryPoolId, Integer> queryCounts) {
    // Update view of cluster memory and pools
    List<MemoryInfo> nodeMemoryInfos =
        nodes
            .values()
            .stream()
            .map(RemoteNodeMemory::getInfo)
            .filter(Optional::isPresent)
            .map(Optional::get)
            .collect(toImmutableList());

    long totalClusterMemory =
        nodeMemoryInfos
            .stream()
            .map(MemoryInfo::getTotalNodeMemory)
            .mapToLong(DataSize::toBytes)
            .sum();
    clusterMemoryBytes.set(totalClusterMemory);

    Set<MemoryPoolId> activePoolIds =
        nodeMemoryInfos
            .stream()
            .flatMap(info -> info.getPools().keySet().stream())
            .collect(toImmutableSet());

    // Make a copy to materialize the set difference
    Set<MemoryPoolId> removedPools = ImmutableSet.copyOf(difference(pools.keySet(), activePoolIds));
    for (MemoryPoolId removed : removedPools) {
      unexport(pools.get(removed));
      pools.remove(removed);
      if (changeListeners.containsKey(removed)) {
        for (Consumer<MemoryPoolInfo> listener : changeListeners.get(removed)) {
          listenerExecutor.execute(
              () -> listener.accept(new MemoryPoolInfo(0, 0, ImmutableMap.of())));
        }
      }
    }
    for (MemoryPoolId id : activePoolIds) {
      ClusterMemoryPool pool =
          pools.computeIfAbsent(
              id,
              poolId -> {
                ClusterMemoryPool newPool = new ClusterMemoryPool(poolId);
                String objectName =
                    ObjectNames.builder(ClusterMemoryPool.class, newPool.getId().toString())
                        .build();
                try {
                  exporter.export(objectName, newPool);
                } catch (JmxException e) {
                  log.error(e, "Error exporting memory pool %s", poolId);
                }
                return newPool;
              });
      pool.update(nodeMemoryInfos, queryCounts.getOrDefault(pool.getId(), 0));
      if (changeListeners.containsKey(id)) {
        MemoryPoolInfo info = pool.getInfo();
        for (Consumer<MemoryPoolInfo> listener : changeListeners.get(id)) {
          listenerExecutor.execute(() -> listener.accept(info));
        }
      }
    }
  }

  @PreDestroy
  public synchronized void destroy() {
    try {
      for (ClusterMemoryPool pool : pools.values()) {
        unexport(pool);
      }
      pools.clear();
    } finally {
      listenerExecutor.shutdownNow();
    }
  }

  private void unexport(ClusterMemoryPool pool) {
    try {
      String objectName =
          ObjectNames.builder(ClusterMemoryPool.class, pool.getId().toString()).build();
      exporter.unexport(objectName);
    } catch (JmxException e) {
      log.error(e, "Failed to unexport pool %s", pool.getId());
    }
  }

  @Managed
  public long getClusterMemoryUsageBytes() {
    return clusterMemoryUsageBytes.get();
  }

  @Managed
  public long getClusterMemoryBytes() {
    return clusterMemoryBytes.get();
  }

  @Managed
  public long getQueriesKilledDueToOutOfMemory() {
    return queriesKilledDueToOutOfMemory.get();
  }
}
public class ClusterMemoryManager {
  private static final Logger log = Logger.get(ClusterMemoryManager.class);
  private final NodeManager nodeManager;
  private final LocationFactory locationFactory;
  private final HttpClient httpClient;
  private final MBeanExporter exporter;
  private final JsonCodec<MemoryInfo> memoryInfoCodec;
  private final JsonCodec<MemoryPoolAssignmentsRequest> assignmentsRequestJsonCodec;
  private final DataSize maxQueryMemory;
  private final boolean enabled;
  private final String coordinatorId;
  private final AtomicLong memoryPoolAssignmentsVersion = new AtomicLong();
  private final AtomicLong clusterMemoryUsageBytes = new AtomicLong();
  private final AtomicLong clusterMemoryBytes = new AtomicLong();
  private final Map<String, RemoteNodeMemory> nodes = new HashMap<>();

  @GuardedBy("this")
  private final Map<MemoryPoolId, ClusterMemoryPool> pools = new HashMap<>();

  @Inject
  public ClusterMemoryManager(
      @ForMemoryManager HttpClient httpClient,
      NodeManager nodeManager,
      LocationFactory locationFactory,
      MBeanExporter exporter,
      JsonCodec<MemoryInfo> memoryInfoCodec,
      JsonCodec<MemoryPoolAssignmentsRequest> assignmentsRequestJsonCodec,
      QueryIdGenerator queryIdGenerator,
      ServerConfig serverConfig,
      MemoryManagerConfig config) {
    requireNonNull(config, "config is null");
    this.nodeManager = requireNonNull(nodeManager, "nodeManager is null");
    this.locationFactory = requireNonNull(locationFactory, "locationFactory is null");
    this.httpClient = requireNonNull(httpClient, "httpClient is null");
    this.exporter = requireNonNull(exporter, "exporter is null");
    this.memoryInfoCodec = requireNonNull(memoryInfoCodec, "memoryInfoCodec is null");
    this.assignmentsRequestJsonCodec =
        requireNonNull(assignmentsRequestJsonCodec, "assignmentsRequestJsonCodec is null");
    this.maxQueryMemory = config.getMaxQueryMemory();
    this.coordinatorId = queryIdGenerator.getCoordinatorId();
    this.enabled = config.isClusterMemoryManagerEnabled() && serverConfig.isCoordinator();
  }

  public void process(Iterable<QueryExecution> queries) {
    if (!enabled) {
      return;
    }
    long totalBytes = 0;
    for (QueryExecution query : queries) {
      long bytes = query.getTotalMemoryReservation();
      DataSize sessionMaxQueryMemory = getQueryMaxMemory(query.getSession());
      long queryMemoryLimit = Math.min(maxQueryMemory.toBytes(), sessionMaxQueryMemory.toBytes());
      totalBytes += bytes;
      if (bytes > queryMemoryLimit) {
        query.fail(
            new ExceededMemoryLimitException(
                "Query", DataSize.succinctDataSize(queryMemoryLimit, Unit.BYTE)));
      }
    }
    clusterMemoryUsageBytes.set(totalBytes);

    Map<MemoryPoolId, Integer> countByPool = new HashMap<>();
    for (QueryExecution query : queries) {
      MemoryPoolId id = query.getMemoryPool().getId();
      countByPool.put(id, countByPool.getOrDefault(id, 0) + 1);
    }

    updatePools(countByPool);

    updateNodes(updateAssignments(queries));
  }

  @VisibleForTesting
  synchronized Map<MemoryPoolId, ClusterMemoryPool> getPools() {
    return ImmutableMap.copyOf(pools);
  }

  private MemoryPoolAssignmentsRequest updateAssignments(Iterable<QueryExecution> queries) {
    ClusterMemoryPool reservedPool = pools.get(RESERVED_POOL);
    ClusterMemoryPool generalPool = pools.get(GENERAL_POOL);
    long version = memoryPoolAssignmentsVersion.incrementAndGet();
    // Check that all previous assignments have propagated to the visible nodes. This doesn't
    // account for temporary network issues,
    // and is more of a safety check than a guarantee
    if (reservedPool != null && generalPool != null && allAssignmentsHavePropagated(queries)) {
      if (reservedPool.getQueries() == 0 && generalPool.getBlockedNodes() > 0) {
        QueryExecution biggestQuery = null;
        long maxMemory = -1;
        for (QueryExecution queryExecution : queries) {
          long bytesUsed = queryExecution.getTotalMemoryReservation();
          if (bytesUsed > maxMemory) {
            biggestQuery = queryExecution;
            maxMemory = bytesUsed;
          }
        }
        for (QueryExecution queryExecution : queries) {
          if (queryExecution.getQueryId().equals(biggestQuery.getQueryId())) {
            queryExecution.setMemoryPool(new VersionedMemoryPoolId(RESERVED_POOL, version));
          }
        }
      }
    }

    ImmutableList.Builder<MemoryPoolAssignment> assignments = ImmutableList.builder();
    for (QueryExecution queryExecution : queries) {
      assignments.add(
          new MemoryPoolAssignment(
              queryExecution.getQueryId(), queryExecution.getMemoryPool().getId()));
    }
    return new MemoryPoolAssignmentsRequest(coordinatorId, version, assignments.build());
  }

  private boolean allAssignmentsHavePropagated(Iterable<QueryExecution> queries) {
    if (nodes.isEmpty()) {
      // Assignments can't have propagated, if there are no visible nodes.
      return false;
    }
    long newestAssignment =
        ImmutableList.copyOf(queries)
            .stream()
            .map(QueryExecution::getMemoryPool)
            .mapToLong(VersionedMemoryPoolId::getVersion)
            .min()
            .orElse(-1);

    long mostOutOfDateNode =
        nodes
            .values()
            .stream()
            .mapToLong(RemoteNodeMemory::getCurrentAssignmentVersion)
            .min()
            .orElse(Long.MAX_VALUE);

    return newestAssignment <= mostOutOfDateNode;
  }

  private void updateNodes(MemoryPoolAssignmentsRequest assignments) {
    Set<Node> activeNodes = nodeManager.getActiveNodes();
    ImmutableSet<String> activeNodeIds =
        activeNodes.stream().map(Node::getNodeIdentifier).collect(toImmutableSet());

    // Remove nodes that don't exist anymore
    // Make a copy to materialize the set difference
    Set<String> deadNodes = ImmutableSet.copyOf(difference(nodes.keySet(), activeNodeIds));
    nodes.keySet().removeAll(deadNodes);

    // Add new nodes
    for (Node node : activeNodes) {
      if (!nodes.containsKey(node.getNodeIdentifier())) {
        nodes.put(
            node.getNodeIdentifier(),
            new RemoteNodeMemory(
                httpClient,
                memoryInfoCodec,
                assignmentsRequestJsonCodec,
                locationFactory.createMemoryInfoLocation(node)));
      }
    }

    // Schedule refresh
    for (RemoteNodeMemory node : nodes.values()) {
      node.asyncRefresh(assignments);
    }
  }

  private synchronized void updatePools(Map<MemoryPoolId, Integer> queryCounts) {
    // Update view of cluster memory and pools
    List<MemoryInfo> nodeMemoryInfos =
        nodes
            .values()
            .stream()
            .map(RemoteNodeMemory::getInfo)
            .filter(Optional::isPresent)
            .map(Optional::get)
            .collect(toImmutableList());

    long totalClusterMemory =
        nodeMemoryInfos
            .stream()
            .map(MemoryInfo::getTotalNodeMemory)
            .mapToLong(DataSize::toBytes)
            .sum();
    clusterMemoryBytes.set(totalClusterMemory);

    Set<MemoryPoolId> activePoolIds =
        nodeMemoryInfos
            .stream()
            .flatMap(info -> info.getPools().keySet().stream())
            .collect(toImmutableSet());

    // Make a copy to materialize the set difference
    Set<MemoryPoolId> removedPools = ImmutableSet.copyOf(difference(pools.keySet(), activePoolIds));
    for (MemoryPoolId removed : removedPools) {
      unexport(pools.get(removed));
      pools.remove(removed);
    }
    for (MemoryPoolId id : activePoolIds) {
      ClusterMemoryPool pool =
          pools.computeIfAbsent(
              id,
              poolId -> {
                ClusterMemoryPool newPool = new ClusterMemoryPool(poolId);
                String objectName =
                    ObjectNames.builder(ClusterMemoryPool.class, newPool.getId().toString())
                        .build();
                try {
                  exporter.export(objectName, newPool);
                } catch (JmxException e) {
                  log.error(e, "Error exporting memory pool %s", poolId);
                }
                return newPool;
              });
      pool.update(nodeMemoryInfos, queryCounts.getOrDefault(pool.getId(), 0));
    }
  }

  @PreDestroy
  public synchronized void destroy() {
    for (ClusterMemoryPool pool : pools.values()) {
      unexport(pool);
    }
    pools.clear();
  }

  private void unexport(ClusterMemoryPool pool) {
    try {
      String objectName =
          ObjectNames.builder(ClusterMemoryPool.class, pool.getId().toString()).build();
      exporter.unexport(objectName);
    } catch (JmxException e) {
      log.error(e, "Failed to unexport pool %s", pool.getId());
    }
  }

  @Managed
  public long getClusterMemoryUsageBytes() {
    return clusterMemoryUsageBytes.get();
  }

  @Managed
  public long getClusterMemoryBytes() {
    return clusterMemoryBytes.get();
  }
}
Example #16
0
public class ShutdownUtil {
  private static final Logger log = Logger.get(ShutdownUtil.class);

  public static void shutdownChannelFactory(
      ChannelFactory channelFactory,
      ExecutorService bossExecutor,
      ExecutorService workerExecutor,
      ChannelGroup allChannels) {
    // Close all channels
    if (allChannels != null) {
      closeChannels(allChannels);
    }

    // Shutdown the channel factory
    if (channelFactory != null) {
      channelFactory.shutdown();
    }

    // Stop boss threads
    if (bossExecutor != null) {
      shutdownExecutor(bossExecutor, "bossExecutor");
    }

    // Finally stop I/O workers
    if (workerExecutor != null) {
      shutdownExecutor(workerExecutor, "workerExecutor");
    }

    // Release any other resources netty might be holding onto via this channelFactory
    if (channelFactory != null) {
      channelFactory.releaseExternalResources();
    }
  }

  public static void closeChannels(ChannelGroup allChannels) {
    if (allChannels.size() > 0) {
      // TODO : allow an option here to control if we need to drain connections and wait instead of
      // killing them all
      try {
        log.info("Closing %s open client connections", allChannels.size());
        if (!allChannels.close().await(5, TimeUnit.SECONDS)) {
          log.warn("Failed to close all open client connections");
        }
      } catch (InterruptedException e) {
        log.warn("Interrupted while closing client connections");
        Thread.currentThread().interrupt();
      }
    }
  }

  // TODO : make wait time configurable ?
  public static void shutdownExecutor(ExecutorService executor, final String name) {
    executor.shutdown();
    try {
      log.info("Waiting for %s to shutdown", name);
      if (!executor.awaitTermination(5, TimeUnit.SECONDS)) {
        log.warn("%s did not shutdown properly", name);
      }
    } catch (InterruptedException e) {
      log.warn("Interrupted while waiting for %s to shutdown", name);
      Thread.currentThread().interrupt();
    }
  }
}
final class CodeCacheGcTrigger {
  private static final Logger log = Logger.get(CodeCacheGcTrigger.class);
  private static final AtomicBoolean installed = new AtomicBoolean();

  private final Duration interval;
  private final int collectionThreshold;

  @Inject
  public CodeCacheGcTrigger(CodeCacheGcConfig config) {
    this.interval = config.getCodeCacheCheckInterval();
    this.collectionThreshold = config.getCodeCacheCollectionThreshold();
  }

  @PostConstruct
  public void start() {
    installCodeCacheGcTrigger();
  }

  public void installCodeCacheGcTrigger() {
    if (installed.getAndSet(true)) {
      return;
    }

    // Hack to work around bugs in java 8 (8u45+) related to code cache management.
    // See
    // http://openjdk.5641.n7.nabble.com/JIT-stops-compiling-after-a-while-java-8u45-td259603.html
    // for more info.
    MemoryPoolMXBean codeCacheMbean = findCodeCacheMBean();

    Thread gcThread =
        new Thread(
            () -> {
              while (!Thread.currentThread().isInterrupted()) {
                long used = codeCacheMbean.getUsage().getUsed();
                long max = codeCacheMbean.getUsage().getMax();

                if (used > 0.95 * max) {
                  log.error("Code Cache is more than 95% full. JIT may stop working.");
                }
                if (used > (max * collectionThreshold) / 100) {
                  // Due to some obscure bug in hotspot (java 8), once the code cache fills up the
                  // JIT stops compiling
                  // By forcing a GC, we let the code cache evictor make room before the cache fills
                  // up.
                  log.info("Triggering GC to avoid Code Cache eviction bugs");
                  System.gc();
                }

                try {
                  TimeUnit.MILLISECONDS.sleep(interval.toMillis());
                } catch (InterruptedException e) {
                  Thread.currentThread().interrupt();
                }
              }
            });
    gcThread.setDaemon(true);
    gcThread.setName("Code-Cache-GC-Trigger");
    gcThread.start();
  }

  private static MemoryPoolMXBean findCodeCacheMBean() {
    for (MemoryPoolMXBean bean : ManagementFactory.getMemoryPoolMXBeans()) {
      if (bean.getName().equals("Code Cache")) {
        return bean;
      }
    }
    throw new RuntimeException("Could not obtain a reference to the 'Code Cache' MemoryPoolMXBean");
  }
}
Example #18
0
public class StatusPrinter {
  private static final Logger log = Logger.get(StatusPrinter.class);

  private static final int CTRL_C = 3;
  private static final int CTRL_P = 16;

  private final long start = System.nanoTime();
  private final StatementClient client;
  private final PrintStream out;
  private final ConsolePrinter console;

  private boolean debug;

  public StatusPrinter(StatementClient client, PrintStream out) {
    this.client = client;
    this.out = out;
    this.console = new ConsolePrinter(out);
    this.debug = client.isDebug();
  }

  /*

  Query 16, RUNNING, 1 node, 855 splits
  http://my.server:8080/v1/query/16?pretty
  Splits:   646 queued, 34 running, 175 done
  CPU Time: 33.7s total,  191K rows/s, 16.6MB/s, 22% active
  Per Node: 2.5 parallelism,  473K rows/s, 41.1MB/s
  Parallelism: 2.5
  0:13 [6.45M rows,  560MB] [ 473K rows/s, 41.1MB/s] [=========>>           ] 20%

       STAGES   ROWS  ROWS/s  BYTES  BYTES/s   PEND    RUN   DONE
  0.........R  13.8M    336K  1.99G    49.5M      0      1    706
    1.......R   666K   41.5K  82.1M    5.12M    563     65     79
      2.....R  4.58M    234K   620M    31.6M    406     65    236

   */

  public void printInitialStatusUpdates() {
    long lastPrint = System.nanoTime();
    try {
      while (client.isValid()) {
        try {
          // exit status loop if there is there is pending output
          if (client.current().getData() != null) {
            return;
          }

          // check if time to update screen
          boolean update = nanosSince(lastPrint).getValue(SECONDS) >= 0.5;

          // check for keyboard input
          int key = readKey();
          if (key == CTRL_P) {
            partialCancel();
          } else if (key == CTRL_C) {
            updateScreen();
            update = false;
            client.close();
          } else if (toUpperCase(key) == 'D') {
            debug = !debug;
            console.resetScreen();
            update = true;
          }

          // update screen
          if (update) {
            updateScreen();
            lastPrint = System.nanoTime();
          }

          // fetch next results (server will wait for a while if no data)
          client.advance();
        } catch (RuntimeException e) {
          log.debug(e, "error printing status");
          if (debug) {
            e.printStackTrace(out);
          }
        }
      }
    } finally {
      console.resetScreen();
    }
  }

  private void updateScreen() {
    console.repositionCursor();
    printQueryInfo(client.current());
  }

  public void printFinalInfo() {
    Duration wallTime = nanosSince(start);

    QueryResults results = client.finalResults();
    StatementStats stats = results.getStats();

    int nodes = stats.getNodes();
    if ((nodes == 0) || (stats.getTotalSplits() == 0)) {
      return;
    }

    // blank line
    out.println();

    // Query 12, FINISHED, 1 node
    String querySummary =
        String.format(
            "Query %s, %s, %,d %s",
            results.getId(), stats.getState(), nodes, pluralize("node", nodes));
    out.println(querySummary);

    if (debug) {
      out.println(results.getInfoUri().toString());
    }

    // Splits: 1000 total, 842 done (84.20%)
    String splitsSummary =
        String.format(
            "Splits: %,d total, %,d done (%.2f%%)",
            stats.getTotalSplits(),
            stats.getCompletedSplits(),
            percentage(stats.getCompletedSplits(), stats.getTotalSplits()));
    out.println(splitsSummary);

    if (debug) {
      // CPU Time: 565.2s total,   26K rows/s, 3.85MB/s
      Duration cpuTime = millis(stats.getCpuTimeMillis());
      String cpuTimeSummary =
          String.format(
              "CPU Time: %.1fs total, %5s rows/s, %8s, %d%% active",
              cpuTime.getValue(SECONDS),
              formatCountRate(stats.getProcessedRows(), cpuTime, false),
              formatDataRate(bytes(stats.getProcessedBytes()), cpuTime, true),
              (int) percentage(stats.getCpuTimeMillis(), stats.getWallTimeMillis()));
      out.println(cpuTimeSummary);

      double parallelism = cpuTime.getValue(MILLISECONDS) / wallTime.getValue(MILLISECONDS);

      // Per Node: 3.5 parallelism, 83.3K rows/s, 0.7 MB/s
      String perNodeSummary =
          String.format(
              "Per Node: %.1f parallelism, %5s rows/s, %8s",
              parallelism / nodes,
              formatCountRate((double) stats.getProcessedRows() / nodes, wallTime, false),
              formatDataRate(bytes(stats.getProcessedBytes() / nodes), wallTime, true));
      reprintLine(perNodeSummary);

      out.println(String.format("Parallelism: %.1f", parallelism));
    }

    // 0:32 [2.12GB, 15M rows] [67MB/s, 463K rows/s]
    String statsLine =
        String.format(
            "%s [%s rows, %s] [%s rows/s, %s]",
            formatTime(wallTime),
            formatCount(stats.getProcessedRows()),
            formatDataSize(bytes(stats.getProcessedBytes()), true),
            formatCountRate(stats.getProcessedRows(), wallTime, false),
            formatDataRate(bytes(stats.getProcessedBytes()), wallTime, true));

    out.println(statsLine);

    // blank line
    out.println();
  }

  private void printQueryInfo(QueryResults results) {
    StatementStats stats = results.getStats();
    Duration wallTime = nanosSince(start);

    // cap progress at 99%, otherwise it looks weird when the query is still running and it says
    // 100%
    int progressPercentage =
        (int) min(99, percentage(stats.getCompletedSplits(), stats.getTotalSplits()));

    if (console.isRealTerminal()) {
      // blank line
      reprintLine("");

      int terminalWidth = console.getWidth();

      if (terminalWidth < 75) {
        reprintLine("WARNING: Terminal");
        reprintLine("must be at least");
        reprintLine("80 characters wide");
        reprintLine("");
        reprintLine(stats.getState());
        reprintLine(String.format("%s %d%%", formatTime(wallTime), progressPercentage));
        return;
      }

      int nodes = stats.getNodes();

      // Query 10, RUNNING, 1 node, 778 splits
      String querySummary =
          String.format(
              "Query %s, %s, %,d %s, %,d splits",
              results.getId(),
              stats.getState(),
              nodes,
              pluralize("node", nodes),
              stats.getTotalSplits());
      reprintLine(querySummary);

      String url = results.getInfoUri().toString();
      if (debug && (url.length() < terminalWidth)) {
        reprintLine(url);
      }

      if ((nodes == 0) || (stats.getTotalSplits() == 0)) {
        return;
      }

      if (debug) {
        // Splits:   620 queued, 34 running, 124 done
        String splitsSummary =
            String.format(
                "Splits:   %,d queued, %,d running, %,d done",
                stats.getQueuedSplits(), stats.getRunningSplits(), stats.getCompletedSplits());
        reprintLine(splitsSummary);

        // CPU Time: 56.5s total, 36.4K rows/s, 4.44MB/s, 60% active
        Duration cpuTime = millis(stats.getCpuTimeMillis());
        String cpuTimeSummary =
            String.format(
                "CPU Time: %.1fs total, %5s rows/s, %8s, %d%% active",
                cpuTime.getValue(SECONDS),
                formatCountRate(stats.getProcessedRows(), cpuTime, false),
                formatDataRate(bytes(stats.getProcessedBytes()), cpuTime, true),
                (int) percentage(stats.getCpuTimeMillis(), stats.getWallTimeMillis()));
        reprintLine(cpuTimeSummary);

        double parallelism = cpuTime.getValue(MILLISECONDS) / wallTime.getValue(MILLISECONDS);

        // Per Node: 3.5 parallelism, 83.3K rows/s, 0.7 MB/s
        String perNodeSummary =
            String.format(
                "Per Node: %.1f parallelism, %5s rows/s, %8s",
                parallelism / nodes,
                formatCountRate((double) stats.getProcessedRows() / nodes, wallTime, false),
                formatDataRate(bytes(stats.getProcessedBytes() / nodes), wallTime, true));
        reprintLine(perNodeSummary);

        reprintLine(String.format("Parallelism: %.1f", parallelism));
      }

      assert terminalWidth >= 75;
      int progressWidth =
          (min(terminalWidth, 100) - 75) + 17; // progress bar is 17-42 characters wide

      if (stats.isScheduled()) {
        String progressBar =
            formatProgressBar(
                progressWidth,
                stats.getCompletedSplits(),
                max(0, stats.getRunningSplits()),
                stats.getTotalSplits());

        // 0:17 [ 103MB,  802K rows] [5.74MB/s, 44.9K rows/s] [=====>>
        //     ] 10%
        String progressLine =
            String.format(
                "%s [%5s rows, %6s] [%5s rows/s, %8s] [%s] %d%%",
                formatTime(wallTime),
                formatCount(stats.getProcessedRows()),
                formatDataSize(bytes(stats.getProcessedBytes()), true),
                formatCountRate(stats.getProcessedRows(), wallTime, false),
                formatDataRate(bytes(stats.getProcessedBytes()), wallTime, true),
                progressBar,
                progressPercentage);

        reprintLine(progressLine);
      } else {
        String progressBar =
            formatProgressBar(
                progressWidth, Ints.saturatedCast(nanosSince(start).roundTo(SECONDS)));

        // 0:17 [ 103MB,  802K rows] [5.74MB/s, 44.9K rows/s] [    <=>
        //    ]
        String progressLine =
            String.format(
                "%s [%5s rows, %6s] [%5s rows/s, %8s] [%s]",
                formatTime(wallTime),
                formatCount(stats.getProcessedRows()),
                formatDataSize(bytes(stats.getProcessedBytes()), true),
                formatCountRate(stats.getProcessedRows(), wallTime, false),
                formatDataRate(bytes(stats.getProcessedBytes()), wallTime, true),
                progressBar);

        reprintLine(progressLine);
      }

      // todo Mem: 1949M shared, 7594M private

      // blank line
      reprintLine("");

      // STAGE  S    ROWS    RPS  BYTES    BPS   QUEUED    RUN   DONE
      String stagesHeader =
          String.format(
              "%10s%1s  %5s  %6s  %5s  %7s  %6s  %5s  %5s",
              "STAGE", "S", "ROWS", "ROWS/s", "BYTES", "BYTES/s", "QUEUED", "RUN", "DONE");
      reprintLine(stagesHeader);

      printStageTree(stats.getRootStage(), "", new AtomicInteger());
    } else {
      // Query 31 [S] i[2.7M 67.3MB 62.7MBps] o[35 6.1KB 1KBps] splits[252/16/380]
      String querySummary =
          String.format(
              "Query %s [%s] i[%s %s %s] o[%s %s %s] splits[%,d/%,d/%,d]",
              results.getId(),
              stats.getState(),
              formatCount(stats.getProcessedRows()),
              formatDataSize(bytes(stats.getProcessedBytes()), false),
              formatDataRate(bytes(stats.getProcessedBytes()), wallTime, false),
              formatCount(stats.getProcessedRows()),
              formatDataSize(bytes(stats.getProcessedBytes()), false),
              formatDataRate(bytes(stats.getProcessedBytes()), wallTime, false),
              stats.getQueuedSplits(),
              stats.getRunningSplits(),
              stats.getCompletedSplits());
      reprintLine(querySummary);
    }
  }

  private void printStageTree(StageStats stage, String indent, AtomicInteger stageNumberCounter) {
    Duration elapsedTime = nanosSince(start);

    // STAGE  S    ROWS  ROWS/s  BYTES  BYTES/s  QUEUED    RUN   DONE
    // 0......Q     26M   9077M  9993G    9077M   9077M  9077M  9077M
    //   2....R     17K    627M   673M     627M    627M   627M   627M
    //     3..C     999    627M   673M     627M    627M   627M   627M
    //   4....R     26M    627M   673T     627M    627M   627M   627M
    //     5..F     29T    627M   673M     627M    627M   627M   627M

    String id = String.valueOf(stageNumberCounter.getAndIncrement());
    String name = indent + id;
    name += Strings.repeat(".", max(0, 10 - name.length()));

    String bytesPerSecond;
    String rowsPerSecond;
    if (stage.isDone()) {
      bytesPerSecond = formatDataRate(new DataSize(0, BYTE), new Duration(0, SECONDS), false);
      rowsPerSecond = formatCountRate(0, new Duration(0, SECONDS), false);
    } else {
      bytesPerSecond = formatDataRate(bytes(stage.getProcessedBytes()), elapsedTime, false);
      rowsPerSecond = formatCountRate(stage.getProcessedRows(), elapsedTime, false);
    }

    String stageSummary =
        String.format(
            "%10s%1s  %5s  %6s  %5s  %7s  %6s  %5s  %5s",
            name,
            stageStateCharacter(stage.getState()),
            formatCount(stage.getProcessedRows()),
            rowsPerSecond,
            formatDataSize(bytes(stage.getProcessedBytes()), false),
            bytesPerSecond,
            stage.getQueuedSplits(),
            stage.getRunningSplits(),
            stage.getCompletedSplits());
    reprintLine(stageSummary);

    for (StageStats subStage : stage.getSubStages()) {
      printStageTree(subStage, indent + "  ", stageNumberCounter);
    }
  }

  private void partialCancel() {
    try {
      client.cancelLeafStage(new Duration(1, SECONDS));
    } catch (RuntimeException e) {
      log.debug(e, "error canceling leaf stage");
    }
  }

  private void reprintLine(String line) {
    console.reprintLine(line);
  }

  private static char stageStateCharacter(String state) {
    return "FAILED".equals(state) ? 'X' : state.charAt(0);
  }

  private static Duration millis(long millis) {
    return new Duration(millis, MILLISECONDS);
  }

  private static DataSize bytes(long bytes) {
    return new DataSize(bytes, BYTE);
  }

  private static double percentage(double count, double total) {
    if (total == 0) {
      return 0;
    }
    return min(100, (count * 100.0) / total);
  }
}
Example #19
0
public class QueryMonitor {
  private static final Logger log = Logger.get(QueryMonitor.class);

  private final ObjectMapper objectMapper;
  private final EventClient eventClient;
  private final String environment;

  @Inject
  public QueryMonitor(ObjectMapper objectMapper, EventClient eventClient, NodeInfo nodeInfo) {
    this.objectMapper = checkNotNull(objectMapper, "objectMapper is null");
    this.eventClient = checkNotNull(eventClient, "eventClient is null");
    this.environment = checkNotNull(nodeInfo, "nodeInfo is null").getEnvironment();
  }

  public void createdEvent(QueryInfo queryInfo) {
    eventClient.post(
        new QueryCreatedEvent(
            queryInfo.getQueryId(),
            queryInfo.getSession().getUser(),
            queryInfo.getSession().getSource(),
            environment,
            queryInfo.getSession().getCatalog(),
            queryInfo.getSession().getSchema(),
            queryInfo.getSession().getRemoteUserAddress(),
            queryInfo.getSession().getUserAgent(),
            queryInfo.getSelf(),
            queryInfo.getQuery(),
            queryInfo.getQueryStats().getCreateTime()));
  }

  public void completionEvent(QueryInfo queryInfo) {
    try {
      QueryStats queryStats = queryInfo.getQueryStats();
      FailureInfo failureInfo = queryInfo.getFailureInfo();

      String failureType = failureInfo == null ? null : failureInfo.getType();
      String failureMessage = failureInfo == null ? null : failureInfo.getMessage();

      eventClient.post(
          new QueryCompletionEvent(
              queryInfo.getQueryId(),
              queryInfo.getSession().getUser(),
              queryInfo.getSession().getSource(),
              environment,
              queryInfo.getSession().getCatalog(),
              queryInfo.getSession().getSchema(),
              queryInfo.getSession().getRemoteUserAddress(),
              queryInfo.getSession().getUserAgent(),
              queryInfo.getState(),
              queryInfo.getSelf(),
              queryInfo.getFieldNames(),
              queryInfo.getQuery(),
              queryStats.getCreateTime(),
              queryStats.getExecutionStartTime(),
              queryStats.getEndTime(),
              queryStats.getQueuedTime(),
              queryStats.getAnalysisTime(),
              queryStats.getDistributedPlanningTime(),
              queryStats.getTotalScheduledTime(),
              queryStats.getTotalCpuTime(),
              queryStats.getRawInputDataSize(),
              queryStats.getRawInputPositions(),
              queryStats.getTotalDrivers(),
              queryInfo.getErrorCode(),
              failureType,
              failureMessage,
              objectMapper.writeValueAsString(queryInfo.getOutputStage()),
              objectMapper.writeValueAsString(queryInfo.getFailureInfo()),
              objectMapper.writeValueAsString(queryInfo.getInputs())));

      logQueryTimeline(queryInfo);
    } catch (JsonProcessingException e) {
      throw Throwables.propagate(e);
    }
  }

  private void logQueryTimeline(QueryInfo queryInfo) {
    try {
      QueryStats queryStats = queryInfo.getQueryStats();
      DateTime queryStartTime = queryStats.getCreateTime();
      DateTime queryEndTime = queryStats.getEndTime();

      // query didn't finish cleanly
      if (queryStartTime == null || queryEndTime == null) {
        return;
      }

      // planning duration -- start to end of planning
      Duration planning = queryStats.getTotalPlanningTime();
      if (planning == null) {
        planning = new Duration(0, MILLISECONDS);
      }

      List<StageInfo> stages = StageInfo.getAllStages(queryInfo.getOutputStage());
      // long lastSchedulingCompletion = 0;
      long firstTaskStartTime = queryEndTime.getMillis();
      long lastTaskStartTime = queryStartTime.getMillis() + planning.toMillis();
      long lastTaskEndTime = queryStartTime.getMillis() + planning.toMillis();
      for (StageInfo stage : stages) {
        // only consider leaf stages
        if (!stage.getSubStages().isEmpty()) {
          continue;
        }

        for (TaskInfo taskInfo : stage.getTasks()) {
          TaskStats taskStats = taskInfo.getStats();

          DateTime firstStartTime = taskStats.getFirstStartTime();
          if (firstStartTime != null) {
            firstTaskStartTime = Math.min(firstStartTime.getMillis(), firstTaskStartTime);
          }

          DateTime lastStartTime = taskStats.getLastStartTime();
          if (lastStartTime != null) {
            lastTaskStartTime = Math.max(lastStartTime.getMillis(), lastTaskStartTime);
          }

          DateTime endTime = taskStats.getEndTime();
          if (endTime != null) {
            lastTaskEndTime = Math.max(endTime.getMillis(), lastTaskEndTime);
          }
        }
      }

      Duration elapsed = millis(queryEndTime.getMillis() - queryStartTime.getMillis());

      Duration scheduling =
          millis(firstTaskStartTime - queryStartTime.getMillis() - planning.toMillis());

      Duration running = millis(lastTaskEndTime - firstTaskStartTime);

      Duration finishing = millis(queryEndTime.getMillis() - lastTaskEndTime);

      log.info(
          "TIMELINE: Query %s :: elapsed %s :: planning %s :: scheduling %s :: running %s :: finishing %s :: begin %s :: end %s",
          queryInfo.getQueryId(),
          elapsed,
          planning,
          scheduling,
          running,
          finishing,
          queryStartTime,
          queryEndTime);
    } catch (Exception e) {
      log.error(e, "Error logging query timeline");
    }
  }

  public void splitCompletionEvent(TaskId taskId, DriverStats driverStats) {
    splitCompletionEvent(taskId, driverStats, null, null);
  }

  public void splitFailedEvent(TaskId taskId, DriverStats driverStats, Throwable cause) {
    splitCompletionEvent(taskId, driverStats, cause.getClass().getName(), cause.getMessage());
  }

  private void splitCompletionEvent(
      TaskId taskId,
      DriverStats driverStats,
      @Nullable String failureType,
      @Nullable String failureMessage) {
    Duration timeToStart = null;
    if (driverStats.getStartTime() != null) {
      timeToStart =
          millis(driverStats.getStartTime().getMillis() - driverStats.getCreateTime().getMillis());
    }
    Duration timeToEnd = null;
    if (driverStats.getEndTime() != null) {
      timeToEnd =
          millis(driverStats.getEndTime().getMillis() - driverStats.getCreateTime().getMillis());
    }

    try {
      eventClient.post(
          new SplitCompletionEvent(
              taskId.getQueryId(),
              taskId.getStageId(),
              taskId,
              environment,
              driverStats.getQueuedTime(),
              driverStats.getStartTime(),
              timeToStart,
              timeToEnd,
              driverStats.getRawInputDataSize(),
              driverStats.getRawInputPositions(),
              driverStats.getRawInputReadTime(),
              driverStats.getElapsedTime(),
              driverStats.getTotalCpuTime(),
              driverStats.getTotalUserTime(),
              failureType,
              failureMessage,
              objectMapper.writeValueAsString(driverStats)));
    } catch (JsonProcessingException e) {
      log.error(e, "Error posting split completion event for task %s", taskId);
    }
  }

  private static Duration millis(long millis) {
    if (millis < 0) {
      millis = 0;
    }
    return new Duration(millis, MILLISECONDS);
  }
}
@ThreadSafe
class RequestErrorTracker {
  private static final Logger log = Logger.get(RequestErrorTracker.class);

  private final TaskId taskId;
  private final URI taskUri;
  private final ScheduledExecutorService scheduledExecutor;
  private final String jobDescription;
  private final Backoff backoff;

  private final Queue<Throwable> errorsSinceLastSuccess = new ConcurrentLinkedQueue<>();

  public RequestErrorTracker(
      TaskId taskId,
      URI taskUri,
      Duration minErrorDuration,
      ScheduledExecutorService scheduledExecutor,
      String jobDescription) {
    this.taskId = taskId;
    this.taskUri = taskUri;
    this.scheduledExecutor = scheduledExecutor;
    this.backoff = new Backoff(minErrorDuration);
    this.jobDescription = jobDescription;
  }

  public ListenableFuture<?> acquireRequestPermit() {
    long delayNanos = backoff.getBackoffDelayNanos();

    if (delayNanos == 0) {
      return Futures.immediateFuture(null);
    }

    ListenableFutureTask<Object> futureTask = ListenableFutureTask.create(() -> null);
    scheduledExecutor.schedule(futureTask, delayNanos, NANOSECONDS);
    return futureTask;
  }

  public void startRequest() {
    // before scheduling a new request clear the error timer
    // we consider a request to be "new" if there are no current failures
    if (backoff.getFailureCount() == 0) {
      requestSucceeded();
    }
  }

  public void requestSucceeded() {
    backoff.success();
    errorsSinceLastSuccess.clear();
  }

  public void requestFailed(Throwable reason) throws PrestoException {
    // cancellation is not a failure
    if (reason instanceof CancellationException) {
      return;
    }

    if (reason instanceof RejectedExecutionException) {
      throw new PrestoException(REMOTE_TASK_ERROR, reason);
    }

    // log failure message
    if (isExpectedError(reason)) {
      // don't print a stack for a known errors
      log.warn("Error " + jobDescription + " %s: %s: %s", taskId, reason.getMessage(), taskUri);
    } else {
      log.warn(reason, "Error " + jobDescription + " %s: %s", taskId, taskUri);
    }

    // remember the first 10 errors
    if (errorsSinceLastSuccess.size() < 10) {
      errorsSinceLastSuccess.add(reason);
    }

    // fail the task, if we have more than X failures in a row and more than Y seconds have passed
    // since the last request
    if (backoff.failure()) {
      // it is weird to mark the task failed locally and then cancel the remote task, but there is
      // no way to tell a remote task that it is failed
      PrestoException exception =
          new PrestoException(
              TOO_MANY_REQUESTS_FAILED,
              format(
                  "%s (%s %s - %s failures, time since last success %s)",
                  WORKER_NODE_ERROR,
                  jobDescription,
                  taskUri,
                  backoff.getFailureCount(),
                  backoff.getTimeSinceLastSuccess().convertTo(SECONDS)));
      errorsSinceLastSuccess.forEach(exception::addSuppressed);
      throw exception;
    }
  }

  static void logError(Throwable t, String format, Object... args) {
    if (isExpectedError(t)) {
      log.error(format + ": %s", ObjectArrays.concat(args, t));
    } else {
      log.error(t, format, args);
    }
  }

  private static boolean isExpectedError(Throwable t) {
    while (t != null) {
      if ((t instanceof SocketException)
          || (t instanceof SocketTimeoutException)
          || (t instanceof EOFException)
          || (t instanceof TimeoutException)
          || (t instanceof ServiceUnavailableException)) {
        return true;
      }
      t = t.getCause();
    }
    return false;
  }
}
public class HiveMetadata implements ConnectorMetadata {
  private static final Logger log = Logger.get(HiveMetadata.class);

  private final String connectorId;
  private final boolean allowDropTable;
  private final boolean allowRenameTable;
  private final boolean allowCorruptWritesForTesting;
  private final HiveMetastore metastore;
  private final HdfsEnvironment hdfsEnvironment;
  private final DateTimeZone timeZone;
  private final HiveStorageFormat hiveStorageFormat;
  private final TypeManager typeManager;

  @Inject
  @SuppressWarnings("deprecation")
  public HiveMetadata(
      HiveConnectorId connectorId,
      HiveClientConfig hiveClientConfig,
      HiveMetastore metastore,
      HdfsEnvironment hdfsEnvironment,
      @ForHiveClient ExecutorService executorService,
      TypeManager typeManager) {
    this(
        connectorId,
        metastore,
        hdfsEnvironment,
        DateTimeZone.forTimeZone(hiveClientConfig.getTimeZone()),
        hiveClientConfig.getAllowDropTable(),
        hiveClientConfig.getAllowRenameTable(),
        hiveClientConfig.getAllowCorruptWritesForTesting(),
        hiveClientConfig.getHiveStorageFormat(),
        typeManager);
  }

  public HiveMetadata(
      HiveConnectorId connectorId,
      HiveMetastore metastore,
      HdfsEnvironment hdfsEnvironment,
      DateTimeZone timeZone,
      boolean allowDropTable,
      boolean allowRenameTable,
      boolean allowCorruptWritesForTesting,
      HiveStorageFormat hiveStorageFormat,
      TypeManager typeManager) {
    this.connectorId = checkNotNull(connectorId, "connectorId is null").toString();

    this.allowDropTable = allowDropTable;
    this.allowRenameTable = allowRenameTable;
    this.allowCorruptWritesForTesting = allowCorruptWritesForTesting;

    this.metastore = checkNotNull(metastore, "metastore is null");
    this.hdfsEnvironment = checkNotNull(hdfsEnvironment, "hdfsEnvironment is null");
    this.timeZone = checkNotNull(timeZone, "timeZone is null");
    this.hiveStorageFormat = hiveStorageFormat;
    this.typeManager = checkNotNull(typeManager, "typeManager is null");

    if (!allowCorruptWritesForTesting && !timeZone.equals(DateTimeZone.getDefault())) {
      log.warn(
          "Hive writes are disabled. "
              + "To write data to Hive, your JVM timezone must match the Hive storage timezone. "
              + "Add -Duser.timezone=%s to your JVM arguments",
          timeZone.getID());
    }
  }

  public HiveMetastore getMetastore() {
    return metastore;
  }

  @Override
  public List<String> listSchemaNames(ConnectorSession session) {
    return metastore.getAllDatabases();
  }

  @Override
  public HiveTableHandle getTableHandle(ConnectorSession session, SchemaTableName tableName) {
    checkNotNull(tableName, "tableName is null");
    try {
      metastore.getTable(tableName.getSchemaName(), tableName.getTableName());
      return new HiveTableHandle(
          connectorId, tableName.getSchemaName(), tableName.getTableName(), session);
    } catch (NoSuchObjectException e) {
      // table was not found
      return null;
    }
  }

  @Override
  public ConnectorTableMetadata getTableMetadata(ConnectorTableHandle tableHandle) {
    checkNotNull(tableHandle, "tableHandle is null");
    SchemaTableName tableName = schemaTableName(tableHandle);
    return getTableMetadata(tableName);
  }

  private ConnectorTableMetadata getTableMetadata(SchemaTableName tableName) {
    try {
      Table table = metastore.getTable(tableName.getSchemaName(), tableName.getTableName());
      if (table.getTableType().equals(TableType.VIRTUAL_VIEW.name())) {
        throw new TableNotFoundException(tableName);
      }
      List<HiveColumnHandle> handles = hiveColumnHandles(typeManager, connectorId, table, false);
      List<ColumnMetadata> columns =
          ImmutableList.copyOf(transform(handles, columnMetadataGetter(table, typeManager)));
      return new ConnectorTableMetadata(tableName, columns, table.getOwner());
    } catch (NoSuchObjectException e) {
      throw new TableNotFoundException(tableName);
    }
  }

  @Override
  public List<SchemaTableName> listTables(ConnectorSession session, String schemaNameOrNull) {
    ImmutableList.Builder<SchemaTableName> tableNames = ImmutableList.builder();
    for (String schemaName : listSchemas(session, schemaNameOrNull)) {
      try {
        for (String tableName : metastore.getAllTables(schemaName)) {
          tableNames.add(new SchemaTableName(schemaName, tableName));
        }
      } catch (NoSuchObjectException e) {
        // schema disappeared during listing operation
      }
    }
    return tableNames.build();
  }

  private List<String> listSchemas(ConnectorSession session, String schemaNameOrNull) {
    if (schemaNameOrNull == null) {
      return listSchemaNames(session);
    }
    return ImmutableList.of(schemaNameOrNull);
  }

  @Override
  public ColumnHandle getSampleWeightColumnHandle(ConnectorTableHandle tableHandle) {
    SchemaTableName tableName = schemaTableName(tableHandle);
    try {
      Table table = metastore.getTable(tableName.getSchemaName(), tableName.getTableName());
      for (HiveColumnHandle columnHandle :
          hiveColumnHandles(typeManager, connectorId, table, true)) {
        if (columnHandle.getName().equals(SAMPLE_WEIGHT_COLUMN_NAME)) {
          return columnHandle;
        }
      }
      return null;
    } catch (NoSuchObjectException e) {
      throw new TableNotFoundException(tableName);
    }
  }

  @Override
  public boolean canCreateSampledTables(ConnectorSession session) {
    return true;
  }

  @Override
  public Map<String, ColumnHandle> getColumnHandles(ConnectorTableHandle tableHandle) {
    SchemaTableName tableName = schemaTableName(tableHandle);
    try {
      Table table = metastore.getTable(tableName.getSchemaName(), tableName.getTableName());
      ImmutableMap.Builder<String, ColumnHandle> columnHandles = ImmutableMap.builder();
      for (HiveColumnHandle columnHandle :
          hiveColumnHandles(typeManager, connectorId, table, false)) {
        columnHandles.put(columnHandle.getName(), columnHandle);
      }
      return columnHandles.build();
    } catch (NoSuchObjectException e) {
      throw new TableNotFoundException(tableName);
    }
  }

  @Override
  public Map<SchemaTableName, List<ColumnMetadata>> listTableColumns(
      ConnectorSession session, SchemaTablePrefix prefix) {
    checkNotNull(prefix, "prefix is null");
    ImmutableMap.Builder<SchemaTableName, List<ColumnMetadata>> columns = ImmutableMap.builder();
    for (SchemaTableName tableName : listTables(session, prefix)) {
      try {
        columns.put(tableName, getTableMetadata(tableName).getColumns());
      } catch (HiveViewNotSupportedException e) {
        // view is not supported
      } catch (TableNotFoundException e) {
        // table disappeared during listing operation
      }
    }
    return columns.build();
  }

  private List<SchemaTableName> listTables(ConnectorSession session, SchemaTablePrefix prefix) {
    if (prefix.getSchemaName() == null || prefix.getTableName() == null) {
      return listTables(session, prefix.getSchemaName());
    }
    return ImmutableList.of(new SchemaTableName(prefix.getSchemaName(), prefix.getTableName()));
  }

  /** NOTE: This method does not return column comment */
  @Override
  public ColumnMetadata getColumnMetadata(
      ConnectorTableHandle tableHandle, ColumnHandle columnHandle) {
    checkType(tableHandle, HiveTableHandle.class, "tableHandle");
    return checkType(columnHandle, HiveColumnHandle.class, "columnHandle")
        .getColumnMetadata(typeManager);
  }

  @Override
  public void createTable(ConnectorSession session, ConnectorTableMetadata tableMetadata) {
    checkArgument(!isNullOrEmpty(tableMetadata.getOwner()), "Table owner is null or empty");

    SchemaTableName schemaTableName = tableMetadata.getTable();
    String schemaName = schemaTableName.getSchemaName();
    String tableName = schemaTableName.getTableName();

    ImmutableList.Builder<String> columnNames = ImmutableList.builder();
    ImmutableList.Builder<Type> columnTypes = ImmutableList.builder();

    buildColumnInfo(tableMetadata, columnNames, columnTypes);

    ImmutableList.Builder<FieldSchema> partitionKeys = ImmutableList.builder();
    ImmutableList.Builder<FieldSchema> columns = ImmutableList.builder();

    List<String> names = columnNames.build();
    List<String> typeNames =
        columnTypes
            .build()
            .stream()
            .map(HiveType::toHiveType)
            .map(HiveType::getHiveTypeName)
            .collect(toList());

    for (int i = 0; i < names.size(); i++) {
      if (tableMetadata.getColumns().get(i).isPartitionKey()) {
        partitionKeys.add(new FieldSchema(names.get(i), typeNames.get(i), null));
      } else {
        columns.add(new FieldSchema(names.get(i), typeNames.get(i), null));
      }
    }

    Path targetPath = getTargetPath(schemaName, tableName, schemaTableName);

    HiveStorageFormat hiveStorageFormat = getHiveStorageFormat(session, this.hiveStorageFormat);
    SerDeInfo serdeInfo = new SerDeInfo();
    serdeInfo.setName(tableName);
    serdeInfo.setSerializationLib(hiveStorageFormat.getSerDe());

    StorageDescriptor sd = new StorageDescriptor();
    sd.setLocation(targetPath.toString());

    sd.setCols(columns.build());
    sd.setSerdeInfo(serdeInfo);
    sd.setInputFormat(hiveStorageFormat.getInputFormat());
    sd.setOutputFormat(hiveStorageFormat.getOutputFormat());

    Table table = new Table();
    table.setDbName(schemaName);
    table.setTableName(tableName);
    table.setOwner(tableMetadata.getOwner());
    table.setTableType(TableType.MANAGED_TABLE.toString());
    String tableComment = "Created by Presto";
    table.setParameters(ImmutableMap.of("comment", tableComment));
    table.setPartitionKeys(partitionKeys.build());
    table.setSd(sd);

    metastore.createTable(table);
  }

  @Override
  public void renameTable(ConnectorTableHandle tableHandle, SchemaTableName newTableName) {
    if (!allowRenameTable) {
      throw new PrestoException(
          PERMISSION_DENIED, "Renaming tables is disabled in this Hive catalog");
    }

    HiveTableHandle handle = checkType(tableHandle, HiveTableHandle.class, "tableHandle");
    metastore.renameTable(
        handle.getSchemaName(),
        handle.getTableName(),
        newTableName.getSchemaName(),
        newTableName.getTableName());
  }

  @Override
  public void dropTable(ConnectorTableHandle tableHandle) {
    HiveTableHandle handle = checkType(tableHandle, HiveTableHandle.class, "tableHandle");
    SchemaTableName tableName = schemaTableName(tableHandle);

    if (!allowDropTable) {
      throw new PrestoException(PERMISSION_DENIED, "DROP TABLE is disabled in this Hive catalog");
    }

    try {
      Table table = metastore.getTable(handle.getSchemaName(), handle.getTableName());
      if (!handle.getSession().getUser().equals(table.getOwner())) {
        throw new PrestoException(
            PERMISSION_DENIED,
            format(
                "Unable to drop table '%s': owner of the table is different from session user",
                table));
      }
      metastore.dropTable(handle.getSchemaName(), handle.getTableName());
    } catch (NoSuchObjectException e) {
      throw new TableNotFoundException(tableName);
    }
  }

  @Override
  public HiveOutputTableHandle beginCreateTable(
      ConnectorSession session, ConnectorTableMetadata tableMetadata) {
    verifyJvmTimeZone();

    checkArgument(!isNullOrEmpty(tableMetadata.getOwner()), "Table owner is null or empty");

    HiveStorageFormat hiveStorageFormat = getHiveStorageFormat(session, this.hiveStorageFormat);

    ImmutableList.Builder<String> columnNames = ImmutableList.builder();
    ImmutableList.Builder<Type> columnTypes = ImmutableList.builder();

    // get the root directory for the database
    SchemaTableName schemaTableName = tableMetadata.getTable();
    String schemaName = schemaTableName.getSchemaName();
    String tableName = schemaTableName.getTableName();

    buildColumnInfo(tableMetadata, columnNames, columnTypes);

    Path targetPath = getTargetPath(schemaName, tableName, schemaTableName);

    if (!useTemporaryDirectory(targetPath)) {
      return new HiveOutputTableHandle(
          connectorId,
          schemaName,
          tableName,
          columnNames.build(),
          columnTypes.build(),
          tableMetadata.getOwner(),
          targetPath.toString(),
          targetPath.toString(),
          session,
          hiveStorageFormat);
    }

    // use a per-user temporary directory to avoid permission problems
    // TODO: this should use Hadoop UserGroupInformation
    String temporaryPrefix = "/tmp/presto-" + StandardSystemProperty.USER_NAME.value();

    // create a temporary directory on the same filesystem
    Path temporaryRoot = new Path(targetPath, temporaryPrefix);
    Path temporaryPath = new Path(temporaryRoot, randomUUID().toString());
    createDirectories(temporaryPath);

    return new HiveOutputTableHandle(
        connectorId,
        schemaName,
        tableName,
        columnNames.build(),
        columnTypes.build(),
        tableMetadata.getOwner(),
        targetPath.toString(),
        temporaryPath.toString(),
        session,
        hiveStorageFormat);
  }

  @Override
  public void commitCreateTable(
      ConnectorOutputTableHandle tableHandle, Collection<Slice> fragments) {
    HiveOutputTableHandle handle =
        checkType(tableHandle, HiveOutputTableHandle.class, "tableHandle");

    // verify no one raced us to create the target directory
    Path targetPath = new Path(handle.getTargetPath());

    // rename if using a temporary directory
    if (handle.hasTemporaryPath()) {
      if (pathExists(targetPath)) {
        SchemaTableName table = new SchemaTableName(handle.getSchemaName(), handle.getTableName());
        throw new PrestoException(
            HIVE_PATH_ALREADY_EXISTS,
            format(
                "Unable to commit creation of table '%s': target directory already exists: %s",
                table, targetPath));
      }
      // rename the temporary directory to the target
      rename(new Path(handle.getTemporaryPath()), targetPath);
    }

    // create the table in the metastore
    List<String> types =
        handle
            .getColumnTypes()
            .stream()
            .map(HiveType::toHiveType)
            .map(HiveType::getHiveTypeName)
            .collect(toList());

    boolean sampled = false;
    ImmutableList.Builder<FieldSchema> columns = ImmutableList.builder();
    for (int i = 0; i < handle.getColumnNames().size(); i++) {
      String name = handle.getColumnNames().get(i);
      String type = types.get(i);
      if (name.equals(SAMPLE_WEIGHT_COLUMN_NAME)) {
        columns.add(new FieldSchema(name, type, "Presto sample weight column"));
        sampled = true;
      } else {
        columns.add(new FieldSchema(name, type, null));
      }
    }

    HiveStorageFormat hiveStorageFormat = handle.getHiveStorageFormat();

    SerDeInfo serdeInfo = new SerDeInfo();
    serdeInfo.setName(handle.getTableName());
    serdeInfo.setSerializationLib(hiveStorageFormat.getSerDe());
    serdeInfo.setParameters(ImmutableMap.<String, String>of());

    StorageDescriptor sd = new StorageDescriptor();
    sd.setLocation(targetPath.toString());
    sd.setCols(columns.build());
    sd.setSerdeInfo(serdeInfo);
    sd.setInputFormat(hiveStorageFormat.getInputFormat());
    sd.setOutputFormat(hiveStorageFormat.getOutputFormat());
    sd.setParameters(ImmutableMap.<String, String>of());

    Table table = new Table();
    table.setDbName(handle.getSchemaName());
    table.setTableName(handle.getTableName());
    table.setOwner(handle.getTableOwner());
    table.setTableType(TableType.MANAGED_TABLE.toString());
    String tableComment = "Created by Presto";
    if (sampled) {
      tableComment =
          "Sampled table created by Presto. Only query this table from Hive if you understand how Presto implements sampling.";
    }
    table.setParameters(ImmutableMap.of("comment", tableComment));
    table.setPartitionKeys(ImmutableList.<FieldSchema>of());
    table.setSd(sd);

    metastore.createTable(table);
  }

  private Path getTargetPath(String schemaName, String tableName, SchemaTableName schemaTableName) {
    String location = getDatabase(schemaName).getLocationUri();
    if (isNullOrEmpty(location)) {
      throw new PrestoException(
          HIVE_DATABASE_LOCATION_ERROR, format("Database '%s' location is not set", schemaName));
    }

    Path databasePath = new Path(location);
    if (!pathExists(databasePath)) {
      throw new PrestoException(
          HIVE_DATABASE_LOCATION_ERROR,
          format("Database '%s' location does not exist: %s", schemaName, databasePath));
    }
    if (!isDirectory(databasePath)) {
      throw new PrestoException(
          HIVE_DATABASE_LOCATION_ERROR,
          format("Database '%s' location is not a directory: %s", schemaName, databasePath));
    }

    // verify the target directory for the table
    Path targetPath = new Path(databasePath, tableName);
    if (pathExists(targetPath)) {
      throw new PrestoException(
          HIVE_PATH_ALREADY_EXISTS,
          format(
              "Target directory for table '%s' already exists: %s", schemaTableName, targetPath));
    }
    return targetPath;
  }

  private Database getDatabase(String database) {
    try {
      return metastore.getDatabase(database);
    } catch (NoSuchObjectException e) {
      throw new SchemaNotFoundException(database);
    }
  }

  private boolean useTemporaryDirectory(Path path) {
    try {
      // skip using temporary directory for S3
      return !(hdfsEnvironment.getFileSystem(path) instanceof PrestoS3FileSystem);
    } catch (IOException e) {
      throw new PrestoException(HIVE_FILESYSTEM_ERROR, "Failed checking path: " + path, e);
    }
  }

  private boolean pathExists(Path path) {
    try {
      return hdfsEnvironment.getFileSystem(path).exists(path);
    } catch (IOException e) {
      throw new PrestoException(HIVE_FILESYSTEM_ERROR, "Failed checking path: " + path, e);
    }
  }

  private boolean isDirectory(Path path) {
    try {
      return hdfsEnvironment.getFileSystem(path).isDirectory(path);
    } catch (IOException e) {
      throw new PrestoException(HIVE_FILESYSTEM_ERROR, "Failed checking path: " + path, e);
    }
  }

  private void createDirectories(Path path) {
    try {
      if (!hdfsEnvironment.getFileSystem(path).mkdirs(path)) {
        throw new IOException("mkdirs returned false");
      }
    } catch (IOException e) {
      throw new PrestoException(HIVE_FILESYSTEM_ERROR, "Failed to create directory: " + path, e);
    }
  }

  private void rename(Path source, Path target) {
    try {
      if (!hdfsEnvironment.getFileSystem(source).rename(source, target)) {
        throw new IOException("rename returned false");
      }
    } catch (IOException e) {
      throw new PrestoException(
          HIVE_FILESYSTEM_ERROR, format("Failed to rename %s to %s", source, target), e);
    }
  }

  @Override
  public void createView(
      ConnectorSession session, SchemaTableName viewName, String viewData, boolean replace) {
    if (replace) {
      try {
        dropView(session, viewName);
      } catch (ViewNotFoundException ignored) {
      }
    }

    Map<String, String> properties =
        ImmutableMap.<String, String>builder()
            .put("comment", "Presto View")
            .put(PRESTO_VIEW_FLAG, "true")
            .build();

    FieldSchema dummyColumn = new FieldSchema("dummy", STRING_TYPE_NAME, null);

    StorageDescriptor sd = new StorageDescriptor();
    sd.setCols(ImmutableList.of(dummyColumn));
    sd.setSerdeInfo(new SerDeInfo());

    Table table = new Table();
    table.setDbName(viewName.getSchemaName());
    table.setTableName(viewName.getTableName());
    table.setOwner(session.getUser());
    table.setTableType(TableType.VIRTUAL_VIEW.name());
    table.setParameters(properties);
    table.setViewOriginalText(encodeViewData(viewData));
    table.setViewExpandedText("/* Presto View */");
    table.setSd(sd);

    try {
      metastore.createTable(table);
    } catch (TableAlreadyExistsException e) {
      throw new ViewAlreadyExistsException(e.getTableName());
    }
  }

  @Override
  public void dropView(ConnectorSession session, SchemaTableName viewName) {
    String view = getViews(session, viewName.toSchemaTablePrefix()).get(viewName);
    if (view == null) {
      throw new ViewNotFoundException(viewName);
    }

    try {
      metastore.dropTable(viewName.getSchemaName(), viewName.getTableName());
    } catch (TableNotFoundException e) {
      throw new ViewNotFoundException(e.getTableName());
    }
  }

  @Override
  public List<SchemaTableName> listViews(ConnectorSession session, String schemaNameOrNull) {
    ImmutableList.Builder<SchemaTableName> tableNames = ImmutableList.builder();
    for (String schemaName : listSchemas(session, schemaNameOrNull)) {
      try {
        for (String tableName : metastore.getAllViews(schemaName)) {
          tableNames.add(new SchemaTableName(schemaName, tableName));
        }
      } catch (NoSuchObjectException e) {
        // schema disappeared during listing operation
      }
    }
    return tableNames.build();
  }

  @Override
  public Map<SchemaTableName, String> getViews(ConnectorSession session, SchemaTablePrefix prefix) {
    ImmutableMap.Builder<SchemaTableName, String> views = ImmutableMap.builder();
    List<SchemaTableName> tableNames;
    if (prefix.getTableName() != null) {
      tableNames =
          ImmutableList.of(new SchemaTableName(prefix.getSchemaName(), prefix.getTableName()));
    } else {
      tableNames = listViews(session, prefix.getSchemaName());
    }

    for (SchemaTableName schemaTableName : tableNames) {
      try {
        Table table =
            metastore.getTable(schemaTableName.getSchemaName(), schemaTableName.getTableName());
        if (HiveUtil.isPrestoView(table)) {
          views.put(schemaTableName, decodeViewData(table.getViewOriginalText()));
        }
      } catch (NoSuchObjectException ignored) {
      }
    }

    return views.build();
  }

  @Override
  public ConnectorInsertTableHandle beginInsert(
      ConnectorSession session, ConnectorTableHandle tableHandle) {
    verifyJvmTimeZone();

    throw new PrestoException(NOT_SUPPORTED, "INSERT not yet supported for Hive");
  }

  @Override
  public void commitInsert(ConnectorInsertTableHandle insertHandle, Collection<Slice> fragments) {
    throw new PrestoException(NOT_SUPPORTED, "INSERT not yet supported for Hive");
  }

  @Override
  public String toString() {
    return toStringHelper(this).add("clientId", connectorId).toString();
  }

  private void verifyJvmTimeZone() {
    if (!allowCorruptWritesForTesting && !timeZone.equals(DateTimeZone.getDefault())) {
      throw new PrestoException(
          HIVE_TIMEZONE_MISMATCH,
          format(
              "To write Hive data, your JVM timezone must match the Hive storage timezone. Add -Duser.timezone=%s to your JVM arguments.",
              timeZone.getID()));
    }
  }

  private static void buildColumnInfo(
      ConnectorTableMetadata tableMetadata,
      ImmutableList.Builder<String> names,
      ImmutableList.Builder<Type> types) {
    for (ColumnMetadata column : tableMetadata.getColumns()) {
      // TODO: also verify that the OutputFormat supports the type
      if (!HiveRecordSink.isTypeSupported(column.getType())) {
        throw new PrestoException(
            NOT_SUPPORTED,
            format(
                "Cannot create table with unsupported type: %s",
                column.getType().getDisplayName()));
      }
      names.add(column.getName());
      types.add(column.getType());
    }

    if (tableMetadata.isSampled()) {
      names.add(SAMPLE_WEIGHT_COLUMN_NAME);
      types.add(BIGINT);
    }
  }

  private static Function<HiveColumnHandle, ColumnMetadata> columnMetadataGetter(
      Table table, final TypeManager typeManager) {
    ImmutableMap.Builder<String, String> builder = ImmutableMap.builder();
    for (FieldSchema field : concat(table.getSd().getCols(), table.getPartitionKeys())) {
      if (field.getComment() != null) {
        builder.put(field.getName(), field.getComment());
      }
    }
    final Map<String, String> columnComment = builder.build();

    return input ->
        new ColumnMetadata(
            input.getName(),
            typeManager.getType(input.getTypeSignature()),
            input.isPartitionKey(),
            columnComment.get(input.getName()),
            false);
  }
}
Example #22
0
public class Query implements Closeable {
  private static final Logger log = Logger.get(Query.class);

  private static final Signal SIGINT = new Signal("INT");

  private final AtomicBoolean ignoreUserInterrupt = new AtomicBoolean();
  private final AtomicBoolean userAbortedQuery = new AtomicBoolean();
  private final StatementClient client;

  public Query(StatementClient client) {
    this.client = requireNonNull(client, "client is null");
  }

  public Map<String, String> getSetSessionProperties() {
    return client.getSetSessionProperties();
  }

  public Set<String> getResetSessionProperties() {
    return client.getResetSessionProperties();
  }

  public String getStartedTransactionId() {
    return client.getStartedtransactionId();
  }

  public boolean isClearTransactionId() {
    return client.isClearTransactionId();
  }

  public void renderOutput(PrintStream out, OutputFormat outputFormat, boolean interactive) {
    Thread clientThread = Thread.currentThread();
    SignalHandler oldHandler =
        Signal.handle(
            SIGINT,
            signal -> {
              if (ignoreUserInterrupt.get() || client.isClosed()) {
                return;
              }
              userAbortedQuery.set(true);
              client.close();
              clientThread.interrupt();
            });
    try {
      renderQueryOutput(out, outputFormat, interactive);
    } finally {
      Signal.handle(SIGINT, oldHandler);
      Thread.interrupted(); // clear interrupt status
    }
  }

  private void renderQueryOutput(PrintStream out, OutputFormat outputFormat, boolean interactive) {
    StatusPrinter statusPrinter = null;
    @SuppressWarnings("resource")
    PrintStream errorChannel = interactive ? out : System.err;

    if (interactive) {
      statusPrinter = new StatusPrinter(client, out);
      statusPrinter.printInitialStatusUpdates();
    } else {
      waitForData();
    }

    if ((!client.isFailed()) && (!client.isGone()) && (!client.isClosed())) {
      QueryResults results = client.isValid() ? client.current() : client.finalResults();
      if (results.getUpdateType() != null) {
        renderUpdate(out, results);
      } else if (results.getColumns() == null) {
        errorChannel.printf("Query %s has no columns\n", results.getId());
        return;
      } else {
        renderResults(out, outputFormat, interactive, results.getColumns());
      }
    }

    if (statusPrinter != null) {
      statusPrinter.printFinalInfo();
    }

    if (client.isClosed()) {
      errorChannel.println("Query aborted by user");
    } else if (client.isGone()) {
      errorChannel.println("Query is gone (server restarted?)");
    } else if (client.isFailed()) {
      renderFailure(errorChannel);
    }
  }

  private void waitForData() {
    while (client.isValid() && (client.current().getData() == null)) {
      client.advance();
    }
  }

  private void renderUpdate(PrintStream out, QueryResults results) {
    String status = results.getUpdateType();
    if (results.getUpdateCount() != null) {
      long count = results.getUpdateCount();
      status += format(": %s row%s", count, (count != 1) ? "s" : "");
    }
    out.println(status);
    discardResults();
  }

  private void discardResults() {
    try (OutputHandler handler = new OutputHandler(new NullPrinter())) {
      handler.processRows(client);
    } catch (IOException e) {
      throw Throwables.propagate(e);
    }
  }

  private void renderResults(
      PrintStream out, OutputFormat outputFormat, boolean interactive, List<Column> columns) {
    try {
      doRenderResults(out, outputFormat, interactive, columns);
    } catch (QueryAbortedException e) {
      System.out.println("(query aborted by user)");
      client.close();
    } catch (IOException e) {
      throw Throwables.propagate(e);
    }
  }

  private void doRenderResults(
      PrintStream out, OutputFormat format, boolean interactive, List<Column> columns)
      throws IOException {
    List<String> fieldNames = Lists.transform(columns, Column::getName);
    if (interactive) {
      pageOutput(format, fieldNames);
    } else {
      sendOutput(out, format, fieldNames);
    }
  }

  private void pageOutput(OutputFormat format, List<String> fieldNames) throws IOException {
    try (Pager pager = Pager.create();
        Writer writer = createWriter(pager);
        OutputHandler handler = createOutputHandler(format, writer, fieldNames)) {
      if (!pager.isNullPager()) {
        // ignore the user pressing ctrl-C while in the pager
        ignoreUserInterrupt.set(true);
        Thread clientThread = Thread.currentThread();
        pager
            .getFinishFuture()
            .thenRun(
                () -> {
                  userAbortedQuery.set(true);
                  ignoreUserInterrupt.set(false);
                  clientThread.interrupt();
                });
      }
      handler.processRows(client);
    } catch (RuntimeException | IOException e) {
      if (userAbortedQuery.get() && !(e instanceof QueryAbortedException)) {
        throw new QueryAbortedException(e);
      }
      throw e;
    }
  }

  private void sendOutput(PrintStream out, OutputFormat format, List<String> fieldNames)
      throws IOException {
    try (OutputHandler handler = createOutputHandler(format, createWriter(out), fieldNames)) {
      handler.processRows(client);
    }
  }

  private static OutputHandler createOutputHandler(
      OutputFormat format, Writer writer, List<String> fieldNames) {
    return new OutputHandler(createOutputPrinter(format, writer, fieldNames));
  }

  private static OutputPrinter createOutputPrinter(
      OutputFormat format, Writer writer, List<String> fieldNames) {
    switch (format) {
      case ALIGNED:
        return new AlignedTablePrinter(fieldNames, writer);
      case VERTICAL:
        return new VerticalRecordPrinter(fieldNames, writer);
      case CSV:
        return new CsvPrinter(fieldNames, writer, false);
      case CSV_HEADER:
        return new CsvPrinter(fieldNames, writer, true);
      case TSV:
        return new TsvPrinter(fieldNames, writer, false);
      case TSV_HEADER:
        return new TsvPrinter(fieldNames, writer, true);
      case NULL:
        return new NullPrinter();
    }
    throw new RuntimeException(format + " not supported");
  }

  private static Writer createWriter(OutputStream out) {
    return new OutputStreamWriter(out, UTF_8);
  }

  @Override
  public void close() {
    client.close();
  }

  public void renderFailure(PrintStream out) {
    QueryResults results = client.finalResults();
    QueryError error = results.getError();
    checkState(error != null);

    out.printf("Query %s failed: %s%n", results.getId(), error.getMessage());
    if (client.isDebug() && (error.getFailureInfo() != null)) {
      error.getFailureInfo().toException().printStackTrace(out);
    }
    if (error.getErrorLocation() != null) {
      renderErrorLocation(client.getQuery(), error.getErrorLocation(), out);
    }
    out.println();
  }

  private static void renderErrorLocation(String query, ErrorLocation location, PrintStream out) {
    List<String> lines = ImmutableList.copyOf(Splitter.on('\n').split(query).iterator());

    String errorLine = lines.get(location.getLineNumber() - 1);
    String good = errorLine.substring(0, location.getColumnNumber() - 1);
    String bad = errorLine.substring(location.getColumnNumber() - 1);

    if ((location.getLineNumber() == lines.size()) && bad.trim().isEmpty()) {
      bad = " <EOF>";
    }

    if (REAL_TERMINAL) {
      Ansi ansi = Ansi.ansi();

      ansi.fg(Ansi.Color.CYAN);
      for (int i = 1; i < location.getLineNumber(); i++) {
        ansi.a(lines.get(i - 1)).newline();
      }
      ansi.a(good);

      ansi.fg(Ansi.Color.RED);
      ansi.a(bad).newline();
      for (int i = location.getLineNumber(); i < lines.size(); i++) {
        ansi.a(lines.get(i)).newline();
      }

      ansi.reset();
      out.print(ansi);
    } else {
      String prefix = format("LINE %s: ", location.getLineNumber());
      String padding = Strings.repeat(" ", prefix.length() + (location.getColumnNumber() - 1));
      out.println(prefix + errorLine);
      out.println(padding + "^");
    }
  }
}
Example #23
0
public class ExpressionCompiler {
  private static final Logger log = Logger.get(ExpressionCompiler.class);

  private static final AtomicLong CLASS_ID = new AtomicLong();

  private static final boolean DUMP_BYTE_CODE_TREE = false;
  private static final boolean DUMP_BYTE_CODE_RAW = false;
  private static final boolean RUN_ASM_VERIFIER = false; // verifier doesn't work right now
  private static final AtomicReference<String> DUMP_CLASS_FILES_TO = new AtomicReference<>();

  private final Metadata metadata;

  private final LoadingCache<OperatorCacheKey, FilterAndProjectOperatorFactoryFactory>
      operatorFactories =
          CacheBuilder.newBuilder()
              .maximumSize(1000)
              .build(
                  new CacheLoader<OperatorCacheKey, FilterAndProjectOperatorFactoryFactory>() {
                    @Override
                    public FilterAndProjectOperatorFactoryFactory load(OperatorCacheKey key)
                        throws Exception {
                      return internalCompileFilterAndProjectOperator(
                          key.getFilter(), key.getProjections());
                    }
                  });

  private final LoadingCache<OperatorCacheKey, ScanFilterAndProjectOperatorFactoryFactory>
      sourceOperatorFactories =
          CacheBuilder.newBuilder()
              .maximumSize(1000)
              .build(
                  new CacheLoader<OperatorCacheKey, ScanFilterAndProjectOperatorFactoryFactory>() {
                    @Override
                    public ScanFilterAndProjectOperatorFactoryFactory load(OperatorCacheKey key)
                        throws Exception {
                      return internalCompileScanFilterAndProjectOperator(
                          key.getSourceId(), key.getFilter(), key.getProjections());
                    }
                  });

  private final AtomicLong generatedClasses = new AtomicLong();

  @Inject
  public ExpressionCompiler(Metadata metadata) {
    this.metadata = metadata;
  }

  @Managed
  public long getGeneratedClasses() {
    return generatedClasses.get();
  }

  @Managed
  public long getCachedFilterAndProjectOperators() {
    return operatorFactories.size();
  }

  @Managed
  public long getCachedScanFilterAndProjectOperators() {
    return sourceOperatorFactories.size();
  }

  public OperatorFactory compileFilterAndProjectOperator(
      int operatorId, RowExpression filter, List<RowExpression> projections) {
    return operatorFactories
        .getUnchecked(new OperatorCacheKey(filter, projections, null))
        .create(operatorId);
  }

  private DynamicClassLoader createClassLoader() {
    return new DynamicClassLoader(getClass().getClassLoader());
  }

  @VisibleForTesting
  public FilterAndProjectOperatorFactoryFactory internalCompileFilterAndProjectOperator(
      RowExpression filter, List<RowExpression> projections) {
    DynamicClassLoader classLoader = createClassLoader();

    // create filter and project page iterator class
    TypedOperatorClass typedOperatorClass =
        compileFilterAndProjectOperator(filter, projections, classLoader);

    Constructor<? extends Operator> constructor;
    try {
      constructor =
          typedOperatorClass
              .getOperatorClass()
              .getConstructor(OperatorContext.class, Iterable.class);
    } catch (NoSuchMethodException e) {
      throw Throwables.propagate(e);
    }
    FilterAndProjectOperatorFactoryFactory operatorFactoryFactory =
        new FilterAndProjectOperatorFactoryFactory(constructor, typedOperatorClass.getTypes());

    return operatorFactoryFactory;
  }

  private TypedOperatorClass compileFilterAndProjectOperator(
      RowExpression filter, List<RowExpression> projections, DynamicClassLoader classLoader) {
    CallSiteBinder callSiteBinder = new CallSiteBinder();

    ClassDefinition classDefinition =
        new ClassDefinition(
            new CompilerContext(BOOTSTRAP_METHOD),
            a(PUBLIC, FINAL),
            typeFromPathName("FilterAndProjectOperator_" + CLASS_ID.incrementAndGet()),
            type(AbstractFilterAndProjectOperator.class));

    // declare fields
    FieldDefinition sessionField =
        classDefinition.declareField(a(PRIVATE, FINAL), "session", ConnectorSession.class);
    classDefinition.declareField(a(PRIVATE, VOLATILE, STATIC), "callSites", Map.class);

    // constructor
    classDefinition
        .declareConstructor(
            new CompilerContext(BOOTSTRAP_METHOD),
            a(PUBLIC),
            arg("operatorContext", OperatorContext.class),
            arg("types", type(Iterable.class, Type.class)))
        .getBody()
        .comment("super(operatorContext, types);")
        .pushThis()
        .getVariable("operatorContext")
        .getVariable("types")
        .invokeConstructor(
            AbstractFilterAndProjectOperator.class, OperatorContext.class, Iterable.class)
        .comment("this.session = operatorContext.getSession();")
        .pushThis()
        .getVariable("operatorContext")
        .invokeVirtual(OperatorContext.class, "getSession", ConnectorSession.class)
        .putField(sessionField)
        .ret();

    generateFilterAndProjectRowOriented(classDefinition, filter, projections);

    //
    // filter method
    //
    generateFilterMethod(callSiteBinder, classDefinition, filter, true);
    generateFilterMethod(callSiteBinder, classDefinition, filter, false);

    //
    // project methods
    //
    List<Type> types = new ArrayList<>();
    int projectionIndex = 0;
    for (RowExpression projection : projections) {
      generateProjectMethod(
          callSiteBinder, classDefinition, "project_" + projectionIndex, projection, true);
      generateProjectMethod(
          callSiteBinder, classDefinition, "project_" + projectionIndex, projection, false);
      types.add(projection.getType());
      projectionIndex++;
    }

    //
    // toString method
    //
    generateToString(
        classDefinition,
        toStringHelper(classDefinition.getType().getJavaClassName())
            .add("filter", filter)
            .add("projections", projections)
            .toString());

    Class<? extends Operator> filterAndProjectClass =
        defineClass(classDefinition, Operator.class, classLoader);
    setCallSitesField(filterAndProjectClass, callSiteBinder.getBindings());

    return new TypedOperatorClass(filterAndProjectClass, types);
  }

  public SourceOperatorFactory compileScanFilterAndProjectOperator(
      int operatorId,
      PlanNodeId sourceId,
      DataStreamProvider dataStreamProvider,
      List<ColumnHandle> columns,
      RowExpression filter,
      List<RowExpression> projections) {
    OperatorCacheKey cacheKey = new OperatorCacheKey(filter, projections, sourceId);
    return sourceOperatorFactories
        .getUnchecked(cacheKey)
        .create(operatorId, dataStreamProvider, columns);
  }

  @VisibleForTesting
  public ScanFilterAndProjectOperatorFactoryFactory internalCompileScanFilterAndProjectOperator(
      PlanNodeId sourceId, RowExpression filter, List<RowExpression> projections) {
    DynamicClassLoader classLoader = createClassLoader();

    // create filter and project page iterator class
    TypedOperatorClass typedOperatorClass =
        compileScanFilterAndProjectOperator(filter, projections, classLoader);

    Constructor<? extends SourceOperator> constructor;
    try {
      constructor =
          typedOperatorClass
              .getOperatorClass()
              .asSubclass(SourceOperator.class)
              .getConstructor(
                  OperatorContext.class,
                  PlanNodeId.class,
                  DataStreamProvider.class,
                  Iterable.class,
                  Iterable.class);
    } catch (NoSuchMethodException e) {
      throw Throwables.propagate(e);
    }

    ScanFilterAndProjectOperatorFactoryFactory operatorFactoryFactory =
        new ScanFilterAndProjectOperatorFactoryFactory(
            constructor, sourceId, typedOperatorClass.getTypes());

    return operatorFactoryFactory;
  }

  private TypedOperatorClass compileScanFilterAndProjectOperator(
      RowExpression filter, List<RowExpression> projections, DynamicClassLoader classLoader) {
    CallSiteBinder callSiteBinder = new CallSiteBinder();

    ClassDefinition classDefinition =
        new ClassDefinition(
            new CompilerContext(BOOTSTRAP_METHOD),
            a(PUBLIC, FINAL),
            typeFromPathName("ScanFilterAndProjectOperator_" + CLASS_ID.incrementAndGet()),
            type(AbstractScanFilterAndProjectOperator.class));

    // declare fields
    FieldDefinition sessionField =
        classDefinition.declareField(a(PRIVATE, FINAL), "session", ConnectorSession.class);
    classDefinition.declareField(a(PRIVATE, VOLATILE, STATIC), "callSites", Map.class);

    // constructor
    classDefinition
        .declareConstructor(
            new CompilerContext(BOOTSTRAP_METHOD),
            a(PUBLIC),
            arg("operatorContext", OperatorContext.class),
            arg("sourceId", PlanNodeId.class),
            arg("dataStreamProvider", DataStreamProvider.class),
            arg("columns", type(Iterable.class, ColumnHandle.class)),
            arg("types", type(Iterable.class, Type.class)))
        .getBody()
        .comment("super(operatorContext, sourceId, dataStreamProvider, columns, types);")
        .pushThis()
        .getVariable("operatorContext")
        .getVariable("sourceId")
        .getVariable("dataStreamProvider")
        .getVariable("columns")
        .getVariable("types")
        .invokeConstructor(
            AbstractScanFilterAndProjectOperator.class,
            OperatorContext.class,
            PlanNodeId.class,
            DataStreamProvider.class,
            Iterable.class,
            Iterable.class)
        .comment("this.session = operatorContext.getSession();")
        .pushThis()
        .getVariable("operatorContext")
        .invokeVirtual(OperatorContext.class, "getSession", ConnectorSession.class)
        .putField(sessionField)
        .ret();

    generateFilterAndProjectRowOriented(classDefinition, filter, projections);
    generateFilterAndProjectCursorMethod(classDefinition, projections);

    //
    // filter method
    //
    generateFilterMethod(callSiteBinder, classDefinition, filter, true);
    generateFilterMethod(callSiteBinder, classDefinition, filter, false);

    //
    // project methods
    //
    List<Type> types = new ArrayList<>();
    int projectionIndex = 0;
    for (RowExpression projection : projections) {
      generateProjectMethod(
          callSiteBinder, classDefinition, "project_" + projectionIndex, projection, true);
      generateProjectMethod(
          callSiteBinder, classDefinition, "project_" + projectionIndex, projection, false);
      types.add(projection.getType());
      projectionIndex++;
    }

    //
    // toString method
    //
    generateToString(
        classDefinition,
        toStringHelper(classDefinition.getType().getJavaClassName())
            .add("filter", filter)
            .add("projections", projections)
            .toString());

    Class<? extends SourceOperator> filterAndProjectClass =
        defineClass(classDefinition, SourceOperator.class, classLoader);
    setCallSitesField(filterAndProjectClass, callSiteBinder.getBindings());

    return new TypedOperatorClass(filterAndProjectClass, types);
  }

  private void generateToString(ClassDefinition classDefinition, String string) {
    // Constant strings can't be too large or the bytecode becomes invalid
    if (string.length() > 100) {
      string = string.substring(0, 100) + "...";
    }

    classDefinition
        .declareMethod(
            new CompilerContext(BOOTSTRAP_METHOD), a(PUBLIC), "toString", type(String.class))
        .getBody()
        .push(string)
        .retObject();
  }

  private void generateFilterAndProjectRowOriented(
      ClassDefinition classDefinition, RowExpression filter, List<RowExpression> projections) {
    MethodDefinition filterAndProjectMethod =
        classDefinition.declareMethod(
            new CompilerContext(BOOTSTRAP_METHOD),
            a(PUBLIC),
            "filterAndProjectRowOriented",
            type(void.class),
            arg("page", com.facebook.presto.operator.Page.class),
            arg("pageBuilder", PageBuilder.class));

    CompilerContext compilerContext = filterAndProjectMethod.getCompilerContext();

    LocalVariableDefinition positionVariable =
        compilerContext.declareVariable(int.class, "position");

    LocalVariableDefinition rowsVariable = compilerContext.declareVariable(int.class, "rows");
    filterAndProjectMethod
        .getBody()
        .comment("int rows = page.getPositionCount();")
        .getVariable("page")
        .invokeVirtual(com.facebook.presto.operator.Page.class, "getPositionCount", int.class)
        .putVariable(rowsVariable);

    List<Integer> allInputChannels =
        getInputChannels(Iterables.concat(projections, ImmutableList.of(filter)));
    for (int channel : allInputChannels) {
      LocalVariableDefinition blockVariable =
          compilerContext.declareVariable(
              com.facebook.presto.spi.block.Block.class, "block_" + channel);
      filterAndProjectMethod
          .getBody()
          .comment("Block %s = page.getBlock(%s);", blockVariable.getName(), channel)
          .getVariable("page")
          .push(channel)
          .invokeVirtual(
              com.facebook.presto.operator.Page.class,
              "getBlock",
              com.facebook.presto.spi.block.Block.class,
              int.class)
          .putVariable(blockVariable);
    }

    //
    // for loop body
    //

    // for (position = 0; position < rows; position++)
    ForLoopBuilder forLoop =
        forLoopBuilder(compilerContext)
            .comment("for (position = 0; position < rows; position++)")
            .initialize(new Block(compilerContext).putVariable(positionVariable, 0))
            .condition(
                new Block(compilerContext)
                    .getVariable(positionVariable)
                    .getVariable(rowsVariable)
                    .invokeStatic(
                        CompilerOperations.class, "lessThan", boolean.class, int.class, int.class))
            .update(new Block(compilerContext).incrementVariable(positionVariable, (byte) 1));

    Block forLoopBody = new Block(compilerContext);

    IfStatementBuilder ifStatement =
        new IfStatementBuilder(compilerContext).comment("if (filter(position, blocks...)");
    Block condition = new Block(compilerContext);
    condition.pushThis();
    condition.getVariable(positionVariable);
    List<Integer> filterInputChannels = getInputChannels(filter);
    for (int channel : filterInputChannels) {
      condition.getVariable("block_" + channel);
    }
    condition.invokeVirtual(
        classDefinition.getType(),
        "filter",
        type(boolean.class),
        ImmutableList.<ParameterizedType>builder()
            .add(type(int.class))
            .addAll(
                nCopies(
                    filterInputChannels.size(), type(com.facebook.presto.spi.block.Block.class)))
            .build());
    ifStatement.condition(condition);

    Block trueBlock = new Block(compilerContext);
    if (projections.isEmpty()) {
      trueBlock
          .comment("pageBuilder.declarePosition()")
          .getVariable("pageBuilder")
          .invokeVirtual(PageBuilder.class, "declarePosition", void.class);
    } else {
      // pageBuilder.getBlockBuilder(0).append(block.getDouble(0);
      for (int projectionIndex = 0; projectionIndex < projections.size(); projectionIndex++) {
        trueBlock.comment(
            "project_%s(position, blocks..., pageBuilder.getBlockBuilder(%s))",
            projectionIndex, projectionIndex);
        trueBlock.pushThis();
        List<Integer> projectionInputs = getInputChannels(projections.get(projectionIndex));
        trueBlock.getVariable(positionVariable);
        for (int channel : projectionInputs) {
          trueBlock.getVariable("block_" + channel);
        }

        // pageBuilder.getBlockBuilder(0)
        trueBlock
            .getVariable("pageBuilder")
            .push(projectionIndex)
            .invokeVirtual(PageBuilder.class, "getBlockBuilder", BlockBuilder.class, int.class);

        // project(position, block_0, block_1, blockBuilder)
        trueBlock.invokeVirtual(
            classDefinition.getType(),
            "project_" + projectionIndex,
            type(void.class),
            ImmutableList.<ParameterizedType>builder()
                .add(type(int.class))
                .addAll(
                    nCopies(
                        projectionInputs.size(), type(com.facebook.presto.spi.block.Block.class)))
                .add(type(BlockBuilder.class))
                .build());
      }
    }
    ifStatement.ifTrue(trueBlock);

    forLoopBody.append(ifStatement.build());
    filterAndProjectMethod.getBody().append(forLoop.body(forLoopBody).build());

    filterAndProjectMethod.getBody().ret();
  }

  private void generateFilterAndProjectCursorMethod(
      ClassDefinition classDefinition, List<RowExpression> projections) {
    MethodDefinition filterAndProjectMethod =
        classDefinition.declareMethod(
            new CompilerContext(BOOTSTRAP_METHOD),
            a(PUBLIC),
            "filterAndProjectRowOriented",
            type(int.class),
            arg("cursor", RecordCursor.class),
            arg("pageBuilder", PageBuilder.class));

    CompilerContext compilerContext = filterAndProjectMethod.getCompilerContext();

    LocalVariableDefinition completedPositionsVariable =
        compilerContext.declareVariable(int.class, "completedPositions");
    filterAndProjectMethod
        .getBody()
        .comment("int completedPositions = 0;")
        .putVariable(completedPositionsVariable, 0);

    //
    // for loop loop body
    //
    LabelNode done = new LabelNode("done");
    ForLoopBuilder forLoop =
        ForLoop.forLoopBuilder(compilerContext)
            .initialize(NOP)
            .condition(
                new Block(compilerContext)
                    .comment("completedPositions < 16384")
                    .getVariable(completedPositionsVariable)
                    .push(16384)
                    .invokeStatic(
                        CompilerOperations.class, "lessThan", boolean.class, int.class, int.class))
            .update(
                new Block(compilerContext)
                    .comment("completedPositions++")
                    .incrementVariable(completedPositionsVariable, (byte) 1));

    Block forLoopBody = new Block(compilerContext);
    forLoop.body(forLoopBody);

    forLoopBody
        .comment("if (pageBuilder.isFull()) break;")
        .append(
            new Block(compilerContext)
                .getVariable("pageBuilder")
                .invokeVirtual(PageBuilder.class, "isFull", boolean.class)
                .ifTrueGoto(done));

    forLoopBody
        .comment("if (!cursor.advanceNextPosition()) break;")
        .append(
            new Block(compilerContext)
                .getVariable("cursor")
                .invokeInterface(RecordCursor.class, "advanceNextPosition", boolean.class)
                .ifFalseGoto(done));

    // if (filter(cursor))
    IfStatementBuilder ifStatement = new IfStatementBuilder(compilerContext);
    ifStatement.condition(
        new Block(compilerContext)
            .pushThis()
            .getVariable("cursor")
            .invokeVirtual(
                classDefinition.getType(),
                "filter",
                type(boolean.class),
                type(RecordCursor.class)));

    Block trueBlock = new Block(compilerContext);
    ifStatement.ifTrue(trueBlock);
    if (projections.isEmpty()) {
      // pageBuilder.declarePosition();
      trueBlock
          .getVariable("pageBuilder")
          .invokeVirtual(PageBuilder.class, "declarePosition", void.class);
    } else {
      // project_43(block..., pageBuilder.getBlockBuilder(42)));
      for (int projectionIndex = 0; projectionIndex < projections.size(); projectionIndex++) {
        trueBlock.pushThis();
        trueBlock.getVariable("cursor");

        // pageBuilder.getBlockBuilder(0)
        trueBlock
            .getVariable("pageBuilder")
            .push(projectionIndex)
            .invokeVirtual(PageBuilder.class, "getBlockBuilder", BlockBuilder.class, int.class);

        // project(block..., blockBuilder)
        trueBlock.invokeVirtual(
            classDefinition.getType(),
            "project_" + projectionIndex,
            type(void.class),
            type(RecordCursor.class),
            type(BlockBuilder.class));
      }
    }
    forLoopBody.append(ifStatement.build());

    filterAndProjectMethod
        .getBody()
        .append(forLoop.build())
        .visitLabel(done)
        .comment("return completedPositions;")
        .getVariable("completedPositions")
        .retInt();
  }

  private void generateFilterMethod(
      CallSiteBinder callSiteBinder,
      ClassDefinition classDefinition,
      RowExpression filter,
      boolean sourceIsCursor) {
    MethodDefinition filterMethod;
    if (sourceIsCursor) {
      filterMethod =
          classDefinition.declareMethod(
              new CompilerContext(BOOTSTRAP_METHOD),
              a(PUBLIC),
              "filter",
              type(boolean.class),
              arg("cursor", RecordCursor.class));
    } else {
      filterMethod =
          classDefinition.declareMethod(
              new CompilerContext(BOOTSTRAP_METHOD),
              a(PUBLIC),
              "filter",
              type(boolean.class),
              ImmutableList.<NamedParameterDefinition>builder()
                  .add(arg("position", int.class))
                  .addAll(toBlockParameters(getInputChannels(filter)))
                  .build());
    }

    filterMethod.comment("Filter: %s", filter.toString());

    filterMethod.getCompilerContext().declareVariable(type(boolean.class), "wasNull");
    Block getSessionByteCode =
        new Block(filterMethod.getCompilerContext())
            .pushThis()
            .getField(classDefinition.getType(), "session", type(ConnectorSession.class));
    ByteCodeNode body =
        compileExpression(
            callSiteBinder,
            filter,
            sourceIsCursor,
            filterMethod.getCompilerContext(),
            getSessionByteCode);

    LabelNode end = new LabelNode("end");
    filterMethod
        .getBody()
        .comment("boolean wasNull = false;")
        .putVariable("wasNull", false)
        .append(body)
        .getVariable("wasNull")
        .ifFalseGoto(end)
        .pop(boolean.class)
        .push(false)
        .visitLabel(end)
        .retBoolean();
  }

  private ByteCodeNode compileExpression(
      CallSiteBinder callSiteBinder,
      RowExpression expression,
      boolean sourceIsCursor,
      CompilerContext context,
      Block getSessionByteCode) {
    ByteCodeExpressionVisitor visitor =
        new ByteCodeExpressionVisitor(
            callSiteBinder, getSessionByteCode, metadata.getFunctionRegistry(), sourceIsCursor);
    return expression.accept(visitor, context);
  }

  private Class<?> generateProjectMethod(
      CallSiteBinder callSiteBinder,
      ClassDefinition classDefinition,
      String methodName,
      RowExpression projection,
      boolean sourceIsCursor) {
    MethodDefinition projectionMethod;
    if (sourceIsCursor) {
      projectionMethod =
          classDefinition.declareMethod(
              new CompilerContext(BOOTSTRAP_METHOD),
              a(PUBLIC),
              methodName,
              type(void.class),
              arg("cursor", RecordCursor.class),
              arg("output", BlockBuilder.class));
    } else {
      ImmutableList.Builder<NamedParameterDefinition> parameters = ImmutableList.builder();
      parameters.add(arg("position", int.class));
      parameters.addAll(toBlockParameters(getInputChannels(projection)));
      parameters.add(arg("output", BlockBuilder.class));

      projectionMethod =
          classDefinition.declareMethod(
              new CompilerContext(BOOTSTRAP_METHOD),
              a(PUBLIC),
              methodName,
              type(void.class),
              parameters.build());
    }

    projectionMethod.comment("Projection: %s", projection.toString());

    // generate body code
    CompilerContext context = projectionMethod.getCompilerContext();
    context.declareVariable(type(boolean.class), "wasNull");
    Block getSessionByteCode =
        new Block(context)
            .pushThis()
            .getField(classDefinition.getType(), "session", type(ConnectorSession.class));

    ByteCodeNode body =
        compileExpression(callSiteBinder, projection, sourceIsCursor, context, getSessionByteCode);

    projectionMethod
        .getBody()
        .comment("boolean wasNull = false;")
        .putVariable("wasNull", false)
        .getVariable("output")
        .append(body);

    Type projectionType = projection.getType();
    Block notNullBlock = new Block(context);
    if (projectionType.getJavaType() == boolean.class) {
      notNullBlock
          .comment("output.append(<booleanStackValue>);")
          .invokeInterface(BlockBuilder.class, "appendBoolean", BlockBuilder.class, boolean.class)
          .pop();
    } else if (projectionType.getJavaType() == long.class) {
      notNullBlock
          .comment("output.append(<longStackValue>);")
          .invokeInterface(BlockBuilder.class, "appendLong", BlockBuilder.class, long.class)
          .pop();
    } else if (projectionType.getJavaType() == double.class) {
      notNullBlock
          .comment("output.append(<doubleStackValue>);")
          .invokeInterface(BlockBuilder.class, "appendDouble", BlockBuilder.class, double.class)
          .pop();
    } else if (projectionType.getJavaType() == Slice.class) {
      notNullBlock
          .comment("output.append(<sliceStackValue>);")
          .invokeInterface(BlockBuilder.class, "appendSlice", BlockBuilder.class, Slice.class)
          .pop();
    } else {
      throw new UnsupportedOperationException("Type " + projectionType + " can not be output yet");
    }

    Block nullBlock =
        new Block(context)
            .comment("output.appendNull();")
            .pop(projectionType.getJavaType())
            .invokeInterface(BlockBuilder.class, "appendNull", BlockBuilder.class)
            .pop();

    projectionMethod
        .getBody()
        .comment("if the result was null, appendNull; otherwise append the value")
        .append(
            new IfStatement(
                context, new Block(context).getVariable("wasNull"), nullBlock, notNullBlock))
        .ret();

    return projectionType.getJavaType();
  }

  private static List<Integer> getInputChannels(RowExpression expression) {
    return getInputChannels(ImmutableList.of(expression));
  }

  private static List<Integer> getInputChannels(Iterable<RowExpression> expressions) {
    TreeSet<Integer> channels = new TreeSet<>();
    for (RowExpression expression : Expressions.subExpressions(expressions)) {
      if (expression instanceof InputReferenceExpression) {
        channels.add(((InputReferenceExpression) expression).getField());
      }
    }
    return ImmutableList.copyOf(channels);
  }

  private static class TypedOperatorClass {
    private final Class<? extends Operator> operatorClass;
    private final List<Type> types;

    private TypedOperatorClass(Class<? extends Operator> operatorClass, List<Type> types) {
      this.operatorClass = operatorClass;
      this.types = types;
    }

    private Class<? extends Operator> getOperatorClass() {
      return operatorClass;
    }

    private List<Type> getTypes() {
      return types;
    }
  }

  private static List<NamedParameterDefinition> toBlockParameters(List<Integer> inputChannels) {
    ImmutableList.Builder<NamedParameterDefinition> parameters = ImmutableList.builder();
    for (int channel : inputChannels) {
      parameters.add(arg("block_" + channel, com.facebook.presto.spi.block.Block.class));
    }
    return parameters.build();
  }

  private <T> Class<? extends T> defineClass(
      ClassDefinition classDefinition, Class<T> superType, DynamicClassLoader classLoader) {
    Class<?> clazz =
        defineClasses(ImmutableList.of(classDefinition), classLoader).values().iterator().next();
    return clazz.asSubclass(superType);
  }

  private Map<String, Class<?>> defineClasses(
      List<ClassDefinition> classDefinitions, DynamicClassLoader classLoader) {
    ClassInfoLoader classInfoLoader =
        ClassInfoLoader.createClassInfoLoader(classDefinitions, classLoader);

    if (DUMP_BYTE_CODE_TREE) {
      DumpByteCodeVisitor dumpByteCode = new DumpByteCodeVisitor(System.out);
      for (ClassDefinition classDefinition : classDefinitions) {
        dumpByteCode.visitClass(classDefinition);
      }
    }

    Map<String, byte[]> byteCodes = new LinkedHashMap<>();
    for (ClassDefinition classDefinition : classDefinitions) {
      ClassWriter cw = new SmartClassWriter(classInfoLoader);
      classDefinition.visit(cw);
      byte[] byteCode = cw.toByteArray();
      if (RUN_ASM_VERIFIER) {
        ClassReader reader = new ClassReader(byteCode);
        CheckClassAdapter.verify(reader, classLoader, true, new PrintWriter(System.out));
      }
      byteCodes.put(classDefinition.getType().getJavaClassName(), byteCode);
    }

    String dumpClassPath = DUMP_CLASS_FILES_TO.get();
    if (dumpClassPath != null) {
      for (Entry<String, byte[]> entry : byteCodes.entrySet()) {
        File file =
            new File(
                dumpClassPath,
                ParameterizedType.typeFromJavaClassName(entry.getKey()).getClassName() + ".class");
        try {
          log.debug("ClassFile: " + file.getAbsolutePath());
          Files.createParentDirs(file);
          Files.write(entry.getValue(), file);
        } catch (IOException e) {
          log.error(e, "Failed to write generated class file to: %s" + file.getAbsolutePath());
        }
      }
    }
    if (DUMP_BYTE_CODE_RAW) {
      for (byte[] byteCode : byteCodes.values()) {
        ClassReader classReader = new ClassReader(byteCode);
        classReader.accept(
            new TraceClassVisitor(new PrintWriter(System.err)), ClassReader.SKIP_FRAMES);
      }
    }
    Map<String, Class<?>> classes = classLoader.defineClasses(byteCodes);
    generatedClasses.addAndGet(classes.size());
    return classes;
  }

  private static void setCallSitesField(Class<?> clazz, Map<Long, MethodHandle> callSites) {
    try {
      Field field = clazz.getDeclaredField("callSites");
      field.setAccessible(true);
      field.set(null, callSites);
    } catch (IllegalAccessException | NoSuchFieldException e) {
      throw Throwables.propagate(e);
    }
  }

  private static final class OperatorCacheKey {
    private final RowExpression filter;
    private final List<RowExpression> projections;
    private final PlanNodeId sourceId;

    private OperatorCacheKey(
        RowExpression filter, List<RowExpression> projections, PlanNodeId sourceId) {
      this.filter = filter;
      this.projections = ImmutableList.copyOf(projections);
      this.sourceId = sourceId;
    }

    private RowExpression getFilter() {
      return filter;
    }

    private List<RowExpression> getProjections() {
      return projections;
    }

    private PlanNodeId getSourceId() {
      return sourceId;
    }

    @Override
    public int hashCode() {
      return Objects.hashCode(filter, projections, sourceId);
    }

    @Override
    public boolean equals(Object obj) {
      if (this == obj) {
        return true;
      }
      if (obj == null || getClass() != obj.getClass()) {
        return false;
      }
      OperatorCacheKey other = (OperatorCacheKey) obj;
      return Objects.equal(this.filter, other.filter)
          && Objects.equal(this.sourceId, other.sourceId)
          && Objects.equal(this.projections, other.projections);
    }

    @Override
    public String toString() {
      return toStringHelper(this)
          .add("filter", filter)
          .add("projections", projections)
          .add("sourceId", sourceId)
          .toString();
    }
  }

  private static class FilterAndProjectOperatorFactoryFactory {
    private final Constructor<? extends Operator> constructor;
    private final List<Type> types;

    public FilterAndProjectOperatorFactoryFactory(
        Constructor<? extends Operator> constructor, List<Type> types) {
      this.constructor = checkNotNull(constructor, "constructor is null");
      this.types = ImmutableList.copyOf(checkNotNull(types, "types is null"));
    }

    public OperatorFactory create(int operatorId) {
      return new FilterAndProjectOperatorFactory(constructor, operatorId, types);
    }
  }

  private static class FilterAndProjectOperatorFactory implements OperatorFactory {
    private final Constructor<? extends Operator> constructor;
    private final int operatorId;
    private final List<Type> types;
    private boolean closed;

    public FilterAndProjectOperatorFactory(
        Constructor<? extends Operator> constructor, int operatorId, List<Type> types) {
      this.constructor = checkNotNull(constructor, "constructor is null");
      this.operatorId = operatorId;
      this.types = ImmutableList.copyOf(checkNotNull(types, "types is null"));
    }

    @Override
    public List<Type> getTypes() {
      return types;
    }

    @Override
    public Operator createOperator(DriverContext driverContext) {
      checkState(!closed, "Factory is already closed");
      OperatorContext operatorContext =
          driverContext.addOperatorContext(
              operatorId, constructor.getDeclaringClass().getSimpleName());
      try {
        return constructor.newInstance(operatorContext, types);
      } catch (InvocationTargetException e) {
        throw Throwables.propagate(e.getCause());
      } catch (ReflectiveOperationException e) {
        throw Throwables.propagate(e);
      }
    }

    @Override
    public void close() {
      closed = true;
    }
  }

  private static class ScanFilterAndProjectOperatorFactoryFactory {
    private final Constructor<? extends SourceOperator> constructor;
    private final PlanNodeId sourceId;
    private final List<Type> types;

    public ScanFilterAndProjectOperatorFactoryFactory(
        Constructor<? extends SourceOperator> constructor, PlanNodeId sourceId, List<Type> types) {
      this.sourceId = checkNotNull(sourceId, "sourceId is null");
      this.constructor = checkNotNull(constructor, "constructor is null");
      this.types = ImmutableList.copyOf(checkNotNull(types, "types is null"));
    }

    public SourceOperatorFactory create(
        int operatorId, DataStreamProvider dataStreamProvider, List<ColumnHandle> columns) {
      return new ScanFilterAndProjectOperatorFactory(
          constructor, operatorId, sourceId, dataStreamProvider, columns, types);
    }
  }

  private static class ScanFilterAndProjectOperatorFactory implements SourceOperatorFactory {
    private final Constructor<? extends SourceOperator> constructor;
    private final int operatorId;
    private final PlanNodeId sourceId;
    private final DataStreamProvider dataStreamProvider;
    private final List<ColumnHandle> columns;
    private final List<Type> types;
    private boolean closed;

    public ScanFilterAndProjectOperatorFactory(
        Constructor<? extends SourceOperator> constructor,
        int operatorId,
        PlanNodeId sourceId,
        DataStreamProvider dataStreamProvider,
        List<ColumnHandle> columns,
        List<Type> types) {
      this.constructor = checkNotNull(constructor, "constructor is null");
      this.operatorId = operatorId;
      this.sourceId = checkNotNull(sourceId, "sourceId is null");
      this.dataStreamProvider = checkNotNull(dataStreamProvider, "dataStreamProvider is null");
      this.columns = ImmutableList.copyOf(checkNotNull(columns, "columns is null"));
      this.types = ImmutableList.copyOf(checkNotNull(types, "types is null"));
    }

    @Override
    public PlanNodeId getSourceId() {
      return sourceId;
    }

    @Override
    public List<Type> getTypes() {
      return types;
    }

    @Override
    public SourceOperator createOperator(DriverContext driverContext) {
      checkState(!closed, "Factory is already closed");
      OperatorContext operatorContext =
          driverContext.addOperatorContext(
              operatorId, constructor.getDeclaringClass().getSimpleName());
      try {
        return constructor.newInstance(
            operatorContext, sourceId, dataStreamProvider, columns, types);
      } catch (InvocationTargetException e) {
        throw Throwables.propagate(e.getCause());
      } catch (ReflectiveOperationException e) {
        throw Throwables.propagate(e);
      }
    }

    @Override
    public void close() {
      closed = true;
    }
  }
}
public class HttpRemoteTask implements RemoteTask {
  private static final Logger log = Logger.get(HttpRemoteTask.class);
  private static final Duration MAX_CLEANUP_RETRY_TIME = new Duration(2, TimeUnit.MINUTES);

  private final TaskId taskId;

  private final Session session;
  private final String nodeId;
  private final PlanFragment planFragment;

  private final AtomicLong nextSplitId = new AtomicLong();

  private final StateMachine<TaskInfo> taskInfo;

  @GuardedBy("this")
  private Future<?> currentRequest;

  @GuardedBy("this")
  private long currentRequestStartNanos;

  @GuardedBy("this")
  private final SetMultimap<PlanNodeId, ScheduledSplit> pendingSplits = HashMultimap.create();

  @GuardedBy("this")
  private volatile int pendingSourceSplitCount;

  @GuardedBy("this")
  private final Set<PlanNodeId> noMoreSplits = new HashSet<>();

  @GuardedBy("this")
  private final AtomicReference<OutputBuffers> outputBuffers = new AtomicReference<>();

  private final ContinuousTaskInfoFetcher continuousTaskInfoFetcher;

  private final HttpClient httpClient;
  private final Executor executor;
  private final ScheduledExecutorService errorScheduledExecutor;
  private final JsonCodec<TaskInfo> taskInfoCodec;
  private final JsonCodec<TaskUpdateRequest> taskUpdateRequestCodec;

  private final RequestErrorTracker updateErrorTracker;
  private final RequestErrorTracker getErrorTracker;

  private final AtomicBoolean needsUpdate = new AtomicBoolean(true);

  private final SplitCountChangeListener splitCountChangeListener;

  public HttpRemoteTask(
      Session session,
      TaskId taskId,
      String nodeId,
      URI location,
      PlanFragment planFragment,
      Multimap<PlanNodeId, Split> initialSplits,
      OutputBuffers outputBuffers,
      HttpClient httpClient,
      Executor executor,
      ScheduledExecutorService errorScheduledExecutor,
      Duration minErrorDuration,
      Duration refreshMaxWait,
      JsonCodec<TaskInfo> taskInfoCodec,
      JsonCodec<TaskUpdateRequest> taskUpdateRequestCodec,
      SplitCountChangeListener splitCountChangeListener) {
    requireNonNull(session, "session is null");
    requireNonNull(taskId, "taskId is null");
    requireNonNull(nodeId, "nodeId is null");
    requireNonNull(location, "location is null");
    requireNonNull(planFragment, "planFragment1 is null");
    requireNonNull(outputBuffers, "outputBuffers is null");
    requireNonNull(httpClient, "httpClient is null");
    requireNonNull(executor, "executor is null");
    requireNonNull(taskInfoCodec, "taskInfoCodec is null");
    requireNonNull(taskUpdateRequestCodec, "taskUpdateRequestCodec is null");
    requireNonNull(splitCountChangeListener, "splitCountChangeListener is null");

    try (SetThreadName ignored = new SetThreadName("HttpRemoteTask-%s", taskId)) {
      this.taskId = taskId;
      this.session = session;
      this.nodeId = nodeId;
      this.planFragment = planFragment;
      this.outputBuffers.set(outputBuffers);
      this.httpClient = httpClient;
      this.executor = executor;
      this.errorScheduledExecutor = errorScheduledExecutor;
      this.taskInfoCodec = taskInfoCodec;
      this.taskUpdateRequestCodec = taskUpdateRequestCodec;
      this.updateErrorTracker =
          new RequestErrorTracker(
              taskId, location, minErrorDuration, errorScheduledExecutor, "updating task");
      this.getErrorTracker =
          new RequestErrorTracker(
              taskId, location, minErrorDuration, errorScheduledExecutor, "getting info for task");
      this.splitCountChangeListener = splitCountChangeListener;

      for (Entry<PlanNodeId, Split> entry :
          requireNonNull(initialSplits, "initialSplits is null").entries()) {
        ScheduledSplit scheduledSplit =
            new ScheduledSplit(nextSplitId.getAndIncrement(), entry.getValue());
        pendingSplits.put(entry.getKey(), scheduledSplit);
      }
      if (initialSplits.containsKey(planFragment.getPartitionedSource())) {
        pendingSourceSplitCount = initialSplits.get(planFragment.getPartitionedSource()).size();
        fireSplitCountChanged(pendingSourceSplitCount);
      }

      List<BufferInfo> bufferStates =
          outputBuffers
              .getBuffers()
              .keySet()
              .stream()
              .map(outputId -> new BufferInfo(outputId, false, 0, 0, PageBufferInfo.empty()))
              .collect(toImmutableList());

      TaskStats taskStats = new TaskStats(DateTime.now(), null);

      taskInfo =
          new StateMachine<>(
              "task " + taskId,
              executor,
              new TaskInfo(
                  taskId,
                  Optional.empty(),
                  TaskInfo.MIN_VERSION,
                  TaskState.PLANNED,
                  location,
                  DateTime.now(),
                  new SharedBufferInfo(BufferState.OPEN, true, true, 0, 0, 0, 0, bufferStates),
                  ImmutableSet.<PlanNodeId>of(),
                  taskStats,
                  ImmutableList.<ExecutionFailureInfo>of()));

      continuousTaskInfoFetcher = new ContinuousTaskInfoFetcher(refreshMaxWait);
    }
  }

  @Override
  public TaskId getTaskId() {
    return taskId;
  }

  @Override
  public String getNodeId() {
    return nodeId;
  }

  @Override
  public TaskInfo getTaskInfo() {
    return taskInfo.get();
  }

  @Override
  public void start() {
    try (SetThreadName ignored = new SetThreadName("HttpRemoteTask-%s", taskId)) {
      // to start we just need to trigger an update
      scheduleUpdate();

      // begin the info fetcher
      continuousTaskInfoFetcher.start();
    }
  }

  @Override
  public synchronized void addSplits(PlanNodeId sourceId, Iterable<Split> splits) {
    try (SetThreadName ignored = new SetThreadName("HttpRemoteTask-%s", taskId)) {
      requireNonNull(sourceId, "sourceId is null");
      requireNonNull(splits, "splits is null");
      checkState(
          !noMoreSplits.contains(sourceId), "noMoreSplits has already been set for %s", sourceId);

      // only add pending split if not done
      if (!getTaskInfo().getState().isDone()) {
        int added = 0;
        for (Split split : splits) {
          if (pendingSplits.put(
              sourceId, new ScheduledSplit(nextSplitId.getAndIncrement(), split))) {
            added++;
          }
        }
        if (sourceId.equals(planFragment.getPartitionedSource())) {
          pendingSourceSplitCount += added;
          fireSplitCountChanged(added);
        }
        needsUpdate.set(true);
      }

      scheduleUpdate();
    }
  }

  @Override
  public synchronized void noMoreSplits(PlanNodeId sourceId) {
    if (noMoreSplits.add(sourceId)) {
      needsUpdate.set(true);
      scheduleUpdate();
    }
  }

  @Override
  public synchronized void setOutputBuffers(OutputBuffers newOutputBuffers) {
    if (getTaskInfo().getState().isDone()) {
      return;
    }

    if (newOutputBuffers.getVersion() > outputBuffers.get().getVersion()) {
      outputBuffers.set(newOutputBuffers);
      needsUpdate.set(true);
      scheduleUpdate();
    }
  }

  @Override
  public int getPartitionedSplitCount() {
    return pendingSourceSplitCount
        + taskInfo.get().getStats().getQueuedPartitionedDrivers()
        + taskInfo.get().getStats().getRunningPartitionedDrivers();
  }

  @Override
  public int getQueuedPartitionedSplitCount() {
    return pendingSourceSplitCount + taskInfo.get().getStats().getQueuedPartitionedDrivers();
  }

  @Override
  public void addStateChangeListener(StateChangeListener<TaskInfo> stateChangeListener) {
    try (SetThreadName ignored = new SetThreadName("HttpRemoteTask-%s", taskId)) {
      taskInfo.addStateChangeListener(stateChangeListener);
    }
  }

  @Override
  public CompletableFuture<TaskInfo> getStateChange(TaskInfo taskInfo) {
    return this.taskInfo.getStateChange(taskInfo);
  }

  private synchronized void updateTaskInfo(TaskInfo newValue) {
    updateTaskInfo(newValue, ImmutableList.of());
  }

  private synchronized void updateTaskInfo(TaskInfo newValue, List<TaskSource> sources) {
    if (newValue.getState().isDone()) {
      // splits can be huge so clear the list
      pendingSplits.clear();
      fireSplitCountChanged(-pendingSourceSplitCount);
      pendingSourceSplitCount = 0;
    }

    int oldPartitionedSplitCount = getPartitionedSplitCount();

    // change to new value if old value is not changed and new value has a newer version
    AtomicBoolean workerRestarted = new AtomicBoolean();
    boolean updated =
        taskInfo.setIf(
            newValue,
            oldValue -> {
              // did the worker restart
              if (oldValue.getNodeInstanceId().isPresent()
                  && !oldValue.getNodeInstanceId().equals(newValue.getNodeInstanceId())) {
                workerRestarted.set(true);
                return false;
              }

              if (oldValue.getState().isDone()) {
                // never update if the task has reached a terminal state
                return false;
              }
              if (newValue.getVersion() < oldValue.getVersion()) {
                // don't update to an older version (same version is ok)
                return false;
              }
              return true;
            });

    if (workerRestarted.get()) {
      PrestoException exception =
          new PrestoException(
              WORKER_RESTARTED, format("%s (%s)", WORKER_RESTARTED_ERROR, newValue.getSelf()));
      failTask(exception);
      abort();
    }

    // remove acknowledged splits, which frees memory
    for (TaskSource source : sources) {
      PlanNodeId planNodeId = source.getPlanNodeId();
      int removed = 0;
      for (ScheduledSplit split : source.getSplits()) {
        if (pendingSplits.remove(planNodeId, split)) {
          removed++;
        }
      }
      if (planNodeId.equals(planFragment.getPartitionedSource())) {
        pendingSourceSplitCount -= removed;
      }
    }

    if (updated) {
      if (getTaskInfo().getState().isDone()) {
        fireSplitCountChanged(-oldPartitionedSplitCount);
      } else {
        fireSplitCountChanged(getPartitionedSplitCount() - oldPartitionedSplitCount);
      }
    }
  }

  private void fireSplitCountChanged(int delta) {
    if (delta != 0) {
      splitCountChangeListener.splitCountChanged(delta);
    }
  }

  private synchronized void scheduleUpdate() {
    // don't update if the task hasn't been started yet or if it is already finished
    if (!needsUpdate.get() || taskInfo.get().getState().isDone()) {
      return;
    }

    // if we have an old request outstanding, cancel it
    if (currentRequest != null
        && Duration.nanosSince(currentRequestStartNanos).compareTo(new Duration(2, SECONDS)) >= 0) {
      needsUpdate.set(true);
      currentRequest.cancel(true);
      currentRequest = null;
      currentRequestStartNanos = 0;
    }

    // if there is a request already running, wait for it to complete
    if (this.currentRequest != null && !this.currentRequest.isDone()) {
      return;
    }

    // if throttled due to error, asynchronously wait for timeout and try again
    ListenableFuture<?> errorRateLimit = updateErrorTracker.acquireRequestPermit();
    if (!errorRateLimit.isDone()) {
      errorRateLimit.addListener(this::scheduleUpdate, executor);
      return;
    }

    List<TaskSource> sources = getSources();
    TaskUpdateRequest updateRequest =
        new TaskUpdateRequest(
            session.toSessionRepresentation(), planFragment, sources, outputBuffers.get());

    Request request =
        preparePost()
            .setUri(uriBuilderFrom(taskInfo.get().getSelf()).addParameter("summarize").build())
            .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.JSON_UTF_8.toString())
            .setBodyGenerator(jsonBodyGenerator(taskUpdateRequestCodec, updateRequest))
            .build();

    updateErrorTracker.startRequest();

    ListenableFuture<JsonResponse<TaskInfo>> future =
        httpClient.executeAsync(request, createFullJsonResponseHandler(taskInfoCodec));
    currentRequest = future;
    currentRequestStartNanos = System.nanoTime();

    // The needsUpdate flag needs to be set to false BEFORE adding the Future callback since
    // callback might change the flag value
    // and does so without grabbing the instance lock.
    needsUpdate.set(false);

    Futures.addCallback(
        future,
        new SimpleHttpResponseHandler<>(new UpdateResponseHandler(sources), request.getUri()),
        executor);
  }

  private synchronized List<TaskSource> getSources() {
    return Stream.concat(
            Stream.of(planFragment.getPartitionedSourceNode()),
            planFragment.getRemoteSourceNodes().stream())
        .filter(Objects::nonNull)
        .map(PlanNode::getId)
        .map(this::getSource)
        .filter(Objects::nonNull)
        .collect(toImmutableList());
  }

  private TaskSource getSource(PlanNodeId planNodeId) {
    Set<ScheduledSplit> splits = pendingSplits.get(planNodeId);
    boolean noMoreSplits = this.noMoreSplits.contains(planNodeId);
    TaskSource element = null;
    if (!splits.isEmpty() || noMoreSplits) {
      element = new TaskSource(planNodeId, splits, noMoreSplits);
    }
    return element;
  }

  @Override
  public synchronized void cancel() {
    try (SetThreadName ignored = new SetThreadName("HttpRemoteTask-%s", taskId)) {
      if (getTaskInfo().getState().isDone()) {
        return;
      }

      URI uri = getTaskInfo().getSelf();
      if (uri == null) {
        return;
      }

      // send cancel to task and ignore response
      Request request =
          prepareDelete()
              .setUri(
                  uriBuilderFrom(uri)
                      .addParameter("abort", "false")
                      .addParameter("summarize")
                      .build())
              .build();
      scheduleAsyncCleanupRequest(new Backoff(MAX_CLEANUP_RETRY_TIME), request, "cancel");
    }
  }

  @Override
  public synchronized void abort() {
    try (SetThreadName ignored = new SetThreadName("HttpRemoteTask-%s", taskId)) {
      // clear pending splits to free memory
      fireSplitCountChanged(-pendingSourceSplitCount);
      pendingSplits.clear();
      pendingSourceSplitCount = 0;

      // cancel pending request
      if (currentRequest != null) {
        currentRequest.cancel(true);
        currentRequest = null;
        currentRequestStartNanos = 0;
      }

      // mark task as canceled (if not already done)
      TaskInfo taskInfo = getTaskInfo();
      URI uri = taskInfo.getSelf();

      updateTaskInfo(
          new TaskInfo(
              taskInfo.getTaskId(),
              taskInfo.getNodeInstanceId(),
              TaskInfo.MAX_VERSION,
              TaskState.ABORTED,
              uri,
              taskInfo.getLastHeartbeat(),
              taskInfo.getOutputBuffers(),
              taskInfo.getNoMoreSplits(),
              taskInfo.getStats(),
              ImmutableList.<ExecutionFailureInfo>of()));

      // send abort to task and ignore response
      Request request =
          prepareDelete().setUri(uriBuilderFrom(uri).addParameter("summarize").build()).build();
      scheduleAsyncCleanupRequest(new Backoff(MAX_CLEANUP_RETRY_TIME), request, "abort");
    }
  }

  private void scheduleAsyncCleanupRequest(Backoff cleanupBackoff, Request request, String action) {
    Futures.addCallback(
        httpClient.executeAsync(request, createStatusResponseHandler()),
        new FutureCallback<StatusResponse>() {
          @Override
          public void onSuccess(StatusResponse result) {
            // assume any response is good enough
          }

          @Override
          public void onFailure(Throwable t) {
            if (t instanceof RejectedExecutionException) {
              // client has been shutdown
              return;
            }

            // record failure
            if (cleanupBackoff.failure()) {
              logError(t, "Unable to %s task at %s", action, request.getUri());
              return;
            }

            // reschedule
            long delayNanos = cleanupBackoff.getBackoffDelayNanos();
            if (delayNanos == 0) {
              scheduleAsyncCleanupRequest(cleanupBackoff, request, action);
            } else {
              errorScheduledExecutor.schedule(
                  () -> scheduleAsyncCleanupRequest(cleanupBackoff, request, action),
                  delayNanos,
                  NANOSECONDS);
            }
          }
        },
        executor);
  }

  /** Move the task directly to the failed state */
  private void failTask(Throwable cause) {
    TaskInfo taskInfo = getTaskInfo();
    if (!taskInfo.getState().isDone()) {
      log.debug(cause, "Remote task failed: %s", taskInfo.getSelf());
    }
    updateTaskInfo(
        new TaskInfo(
            taskInfo.getTaskId(),
            taskInfo.getNodeInstanceId(),
            TaskInfo.MAX_VERSION,
            TaskState.FAILED,
            taskInfo.getSelf(),
            taskInfo.getLastHeartbeat(),
            taskInfo.getOutputBuffers(),
            taskInfo.getNoMoreSplits(),
            taskInfo.getStats(),
            ImmutableList.of(toFailure(cause))));
  }

  @Override
  public String toString() {
    return toStringHelper(this).addValue(getTaskInfo()).toString();
  }

  private class UpdateResponseHandler implements SimpleHttpResponseCallback<TaskInfo> {
    private final List<TaskSource> sources;

    private UpdateResponseHandler(List<TaskSource> sources) {
      this.sources = ImmutableList.copyOf(requireNonNull(sources, "sources is null"));
    }

    @Override
    public void success(TaskInfo value) {
      try (SetThreadName ignored = new SetThreadName("UpdateResponseHandler-%s", taskId)) {
        try {
          synchronized (HttpRemoteTask.this) {
            currentRequest = null;
          }
          updateTaskInfo(value, sources);
          updateErrorTracker.requestSucceeded();
        } finally {
          scheduleUpdate();
        }
      }
    }

    @Override
    public void failed(Throwable cause) {
      try (SetThreadName ignored = new SetThreadName("UpdateResponseHandler-%s", taskId)) {
        try {
          synchronized (HttpRemoteTask.this) {
            currentRequest = null;
          }

          // on failure assume we need to update again
          needsUpdate.set(true);

          // if task not already done, record error
          TaskInfo taskInfo = getTaskInfo();
          if (!taskInfo.getState().isDone()) {
            updateErrorTracker.requestFailed(cause);
          }
        } catch (Error e) {
          failTask(e);
          abort();
          throw e;
        } catch (RuntimeException e) {
          failTask(e);
          abort();
        } finally {
          scheduleUpdate();
        }
      }
    }

    @Override
    public void fatal(Throwable cause) {
      try (SetThreadName ignored = new SetThreadName("UpdateResponseHandler-%s", taskId)) {
        failTask(cause);
      }
    }
  }

  /**
   * Continuous update loop for task info. Wait for a short period for task state to change, and if
   * it does not, return the current state of the task. This will cause stats to be updated at a
   * regular interval, and state changes will be immediately recorded.
   */
  private class ContinuousTaskInfoFetcher implements SimpleHttpResponseCallback<TaskInfo> {
    private final Duration refreshMaxWait;

    @GuardedBy("this")
    private boolean running;

    @GuardedBy("this")
    private ListenableFuture<JsonResponse<TaskInfo>> future;

    public ContinuousTaskInfoFetcher(Duration refreshMaxWait) {
      this.refreshMaxWait = refreshMaxWait;
    }

    public synchronized void start() {
      if (running) {
        // already running
        return;
      }
      running = true;
      scheduleNextRequest();
    }

    public synchronized void stop() {
      running = false;
      if (future != null) {
        future.cancel(true);
        future = null;
      }
    }

    private synchronized void scheduleNextRequest() {
      // stopped or done?
      TaskInfo taskInfo = HttpRemoteTask.this.taskInfo.get();
      if (!running || taskInfo.getState().isDone()) {
        return;
      }

      // outstanding request?
      if (future != null && !future.isDone()) {
        // this should never happen
        log.error("Can not reschedule update because an update is already running");
        return;
      }

      // if throttled due to error, asynchronously wait for timeout and try again
      ListenableFuture<?> errorRateLimit = getErrorTracker.acquireRequestPermit();
      if (!errorRateLimit.isDone()) {
        errorRateLimit.addListener(this::scheduleNextRequest, executor);
        return;
      }

      Request request =
          prepareGet()
              .setUri(uriBuilderFrom(taskInfo.getSelf()).addParameter("summarize").build())
              .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.JSON_UTF_8.toString())
              .setHeader(PrestoHeaders.PRESTO_CURRENT_STATE, taskInfo.getState().toString())
              .setHeader(PrestoHeaders.PRESTO_MAX_WAIT, refreshMaxWait.toString())
              .build();

      getErrorTracker.startRequest();

      future = httpClient.executeAsync(request, createFullJsonResponseHandler(taskInfoCodec));
      Futures.addCallback(
          future, new SimpleHttpResponseHandler<>(this, request.getUri()), executor);
    }

    @Override
    public void success(TaskInfo value) {
      try (SetThreadName ignored = new SetThreadName("ContinuousTaskInfoFetcher-%s", taskId)) {
        synchronized (this) {
          future = null;
        }

        try {
          updateTaskInfo(value, ImmutableList.<TaskSource>of());
          getErrorTracker.requestSucceeded();
        } finally {
          scheduleNextRequest();
        }
      }
    }

    @Override
    public void failed(Throwable cause) {
      try (SetThreadName ignored = new SetThreadName("ContinuousTaskInfoFetcher-%s", taskId)) {
        synchronized (this) {
          future = null;
        }

        try {
          // if task not already done, record error
          TaskInfo taskInfo = getTaskInfo();
          if (!taskInfo.getState().isDone()) {
            getErrorTracker.requestFailed(cause);
          }
        } catch (Error e) {
          failTask(e);
          abort();
          throw e;
        } catch (RuntimeException e) {
          failTask(e);
          abort();
        } finally {
          // there is no back off here so we can get a lot of error messages when a server spins
          // down, but it typically goes away quickly because the queries get canceled
          scheduleNextRequest();
        }
      }
    }

    @Override
    public void fatal(Throwable cause) {
      try (SetThreadName ignored = new SetThreadName("ContinuousTaskInfoFetcher-%s", taskId)) {
        synchronized (this) {
          future = null;
        }

        failTask(cause);
      }
    }
  }

  public static class SimpleHttpResponseHandler<T> implements FutureCallback<JsonResponse<T>> {
    private final SimpleHttpResponseCallback<T> callback;

    private final URI uri;

    public SimpleHttpResponseHandler(SimpleHttpResponseCallback<T> callback, URI uri) {
      this.callback = callback;
      this.uri = uri;
    }

    @Override
    public void onSuccess(JsonResponse<T> response) {
      try {
        if (response.getStatusCode() == HttpStatus.OK.code() && response.hasValue()) {
          callback.success(response.getValue());
        } else if (response.getStatusCode() == HttpStatus.SERVICE_UNAVAILABLE.code()) {
          callback.failed(new ServiceUnavailableException(uri));
        } else {
          // Something is broken in the server or the client, so fail the task immediately (includes
          // 500 errors)
          Exception cause = response.getException();
          if (cause == null) {
            if (response.getStatusCode() == HttpStatus.OK.code()) {
              cause =
                  new PrestoException(
                      REMOTE_TASK_ERROR, format("Expected response from %s is empty", uri));
            } else {
              cause =
                  new PrestoException(
                      REMOTE_TASK_ERROR,
                      format(
                          "Expected response code from %s to be %s, but was %s: %s%n%s",
                          uri,
                          HttpStatus.OK.code(),
                          response.getStatusCode(),
                          response.getStatusMessage(),
                          response.getResponseBody()));
            }
          }
          callback.fatal(cause);
        }
      } catch (Throwable t) {
        // this should never happen
        callback.fatal(t);
      }
    }

    @Override
    public void onFailure(Throwable t) {
      callback.failed(t);
    }
  }

  @ThreadSafe
  private static class RequestErrorTracker {
    private final TaskId taskId;
    private final URI taskUri;
    private final ScheduledExecutorService scheduledExecutor;
    private final String jobDescription;
    private final Backoff backoff;

    private final Queue<Throwable> errorsSinceLastSuccess = new ConcurrentLinkedQueue<>();

    public RequestErrorTracker(
        TaskId taskId,
        URI taskUri,
        Duration minErrorDuration,
        ScheduledExecutorService scheduledExecutor,
        String jobDescription) {
      this.taskId = taskId;
      this.taskUri = taskUri;
      this.scheduledExecutor = scheduledExecutor;
      this.backoff = new Backoff(minErrorDuration);
      this.jobDescription = jobDescription;
    }

    public ListenableFuture<?> acquireRequestPermit() {
      long delayNanos = backoff.getBackoffDelayNanos();

      if (delayNanos == 0) {
        return Futures.immediateFuture(null);
      }

      ListenableFutureTask<Object> futureTask = ListenableFutureTask.create(() -> null);
      scheduledExecutor.schedule(futureTask, delayNanos, NANOSECONDS);
      return futureTask;
    }

    public void startRequest() {
      // before scheduling a new request clear the error timer
      // we consider a request to be "new" if there are no current failures
      if (backoff.getFailureCount() == 0) {
        requestSucceeded();
      }
    }

    public void requestSucceeded() {
      backoff.success();
      errorsSinceLastSuccess.clear();
    }

    public void requestFailed(Throwable reason) throws PrestoException {
      // cancellation is not a failure
      if (reason instanceof CancellationException) {
        return;
      }

      if (reason instanceof RejectedExecutionException) {
        throw new PrestoException(REMOTE_TASK_ERROR, reason);
      }

      // log failure message
      if (isExpectedError(reason)) {
        // don't print a stack for a known errors
        log.warn("Error " + jobDescription + " %s: %s: %s", taskId, reason.getMessage(), taskUri);
      } else {
        log.warn(reason, "Error " + jobDescription + " %s: %s", taskId, taskUri);
      }

      // remember the first 10 errors
      if (errorsSinceLastSuccess.size() < 10) {
        errorsSinceLastSuccess.add(reason);
      }

      // fail the task, if we have more than X failures in a row and more than Y seconds have passed
      // since the last request
      if (backoff.failure()) {
        // it is weird to mark the task failed locally and then cancel the remote task, but there is
        // no way to tell a remote task that it is failed
        PrestoException exception =
            new PrestoException(
                TOO_MANY_REQUESTS_FAILED,
                format(
                    "%s (%s - %s failures, time since last success %s)",
                    WORKER_NODE_ERROR,
                    taskUri,
                    backoff.getFailureCount(),
                    backoff.getTimeSinceLastSuccess().convertTo(SECONDS)));
        errorsSinceLastSuccess.forEach(exception::addSuppressed);
        throw exception;
      }
    }
  }

  public interface SimpleHttpResponseCallback<T> {
    void success(T value);

    void failed(Throwable cause);

    void fatal(Throwable cause);
  }

  private static void logError(Throwable t, String format, Object... args) {
    if (isExpectedError(t)) {
      log.error(format + ": %s", ObjectArrays.concat(args, t));
    } else {
      log.error(t, format, args);
    }
  }

  private static boolean isExpectedError(Throwable t) {
    while (t != null) {
      if ((t instanceof SocketException)
          || (t instanceof SocketTimeoutException)
          || (t instanceof EOFException)
          || (t instanceof TimeoutException)
          || (t instanceof ServiceUnavailableException)) {
        return true;
      }
      t = t.getCause();
    }
    return false;
  }

  private static class ServiceUnavailableException extends RuntimeException {
    public ServiceUnavailableException(URI uri) {
      super("Server returned SERVICE_UNAVAILABLE: " + uri);
    }
  }
}
public class DirectoryDeploymentManager implements DeploymentManager {
  private static final Logger log = Logger.get(DirectoryDeploymentManager.class);
  private final JsonCodec<DeploymentRepresentation> jsonCodec =
      jsonCodec(DeploymentRepresentation.class);

  private final UUID slotId;
  private final String location;
  private final Duration tarTimeout;

  private final File baseDir;
  private final File deploymentFile;
  private Deployment deployment;

  public DirectoryDeploymentManager(File baseDir, String location, Duration tarTimeout) {
    Preconditions.checkNotNull(location, "location is null");
    Preconditions.checkArgument(location.startsWith("/"), "location must start with /");
    this.location = location;
    this.tarTimeout = tarTimeout;

    Preconditions.checkNotNull(baseDir, "baseDir is null");
    baseDir.mkdirs();
    Preconditions.checkArgument(
        baseDir.isDirectory(), "baseDir is not a directory: " + baseDir.getAbsolutePath());
    this.baseDir = baseDir;

    // verify deployment file is readable and writable
    deploymentFile = new File(baseDir, "airship-deployment.json");
    if (deploymentFile.exists()) {
      Preconditions.checkArgument(
          deploymentFile.canRead(),
          "Can not read slot-id file %s",
          deploymentFile.getAbsolutePath());
      Preconditions.checkArgument(
          deploymentFile.canWrite(),
          "Can not write slot-id file %s",
          deploymentFile.getAbsolutePath());
    }

    // load deployments
    if (deploymentFile.exists()) {
      try {
        Deployment deployment = load(deploymentFile);
        Preconditions.checkArgument(
            deployment.getDeploymentDir().isDirectory(),
            "Deployment directory is not a directory: %s",
            deployment.getDeploymentDir());
        this.deployment = deployment;
      } catch (IOException e) {
        throw new IllegalArgumentException(
            "Invalid deployment file: " + deploymentFile.getAbsolutePath(), e);
      }
    }

    // load slot-id
    File slotIdFile = new File(baseDir, "airship-slot-id.txt");
    UUID uuid = null;
    if (slotIdFile.exists()) {
      Preconditions.checkArgument(
          slotIdFile.canRead(), "can not read " + slotIdFile.getAbsolutePath());
      try {
        String slotIdString = Files.toString(slotIdFile, UTF_8).trim();
        try {
          uuid = UUID.fromString(slotIdString);
        } catch (IllegalArgumentException e) {

        }
        if (uuid == null) {
          log.warn(
              "Invalid slot id ["
                  + slotIdString
                  + "]: attempting to delete airship-slot-id.txt file and recreating a new one");
          slotIdFile.delete();
        }
      } catch (IOException e) {
        Preconditions.checkArgument(
            slotIdFile.canRead(), "can not read " + slotIdFile.getAbsolutePath());
      }
    }

    if (uuid == null) {
      uuid = UUID.randomUUID();
      try {
        Files.write(uuid.toString(), slotIdFile, UTF_8);
      } catch (IOException e) {
        Preconditions.checkArgument(
            slotIdFile.canRead(), "can not write " + slotIdFile.getAbsolutePath());
      }
    }
    slotId = uuid;
  }

  @Override
  public UUID getSlotId() {
    return slotId;
  }

  public String getLocation() {
    return location;
  }

  @Override
  public Deployment install(Installation installation) {
    Preconditions.checkNotNull(installation, "installation is null");

    File deploymentDir = new File(baseDir, "installation");

    Assignment assignment = installation.getAssignment();

    Deployment newDeployment =
        new Deployment(
            slotId, location, deploymentDir, getDataDir(), assignment, installation.getResources());
    File tempDir = createTempDir(baseDir, "tmp-install");
    try {
      // download the binary
      File binary = new File(tempDir, "airship-binary.tar.gz");
      try {
        Files.copy(Resources.newInputStreamSupplier(installation.getBinaryFile().toURL()), binary);
      } catch (IOException e) {
        throw new RuntimeException(
            "Unable to download binary "
                + assignment.getBinary()
                + " from "
                + installation.getBinaryFile(),
            e);
      }

      // unpack the binary into a temp unpack dir
      File unpackDir = new File(tempDir, "unpack");
      unpackDir.mkdirs();
      try {
        extractTar(binary, unpackDir, tarTimeout);
      } catch (CommandFailedException e) {
        throw new RuntimeException(
            "Unable to extract tar file " + assignment.getBinary() + ": " + e.getMessage());
      }

      // find the archive root dir (it should be the only file in the temp unpack dir)
      List<File> files = listFiles(unpackDir);
      if (files.size() != 1) {
        throw new RuntimeException(
            "Invalid tar file: file does not have a root directory " + assignment.getBinary());
      }
      File binaryRootDir = files.get(0);

      // unpack config bundle
      try {
        URL url = installation.getConfigFile().toURL();
        ConfigUtils.unpackConfig(Resources.newInputStreamSupplier(url), binaryRootDir);
      } catch (Exception e) {
        throw new RuntimeException(
            "Unable to extract config bundle " + assignment.getConfig() + ": " + e.getMessage());
      }

      // installation is good, clear the current deployment
      if (this.deployment != null) {
        this.deploymentFile.delete();
        deleteRecursively(this.deployment.getDeploymentDir());
        this.deployment = null;
      }

      // save deployment versions file
      try {
        save(newDeployment);
      } catch (IOException e) {
        throw new RuntimeException("Unable to save deployment file", e);
      }

      // move the binary root directory to the final target
      try {
        Files.move(binaryRootDir, deploymentDir);
      } catch (IOException e) {
        throw new RuntimeException("Unable to move deployment to final location", e);
      }
    } finally {
      if (!deleteRecursively(tempDir)) {
        log.warn("Unable to delete temp directory: %s", tempDir.getAbsolutePath());
      }
    }

    this.deployment = newDeployment;
    return newDeployment;
  }

  @Override
  public Deployment getDeployment() {
    return deployment;
  }

  @Override
  public void terminate() {
    deleteRecursively(baseDir);
    deployment = null;
  }

  @Override
  public File hackGetDataDir() {
    return getDataDir();
  }

  public void save(Deployment deployment) throws IOException {
    String json = jsonCodec.toJson(DeploymentRepresentation.from(deployment));
    Files.write(json, deploymentFile, UTF_8);
  }

  public Deployment load(File deploymentFile) throws IOException {
    String json = Files.toString(deploymentFile, UTF_8);
    DeploymentRepresentation data = jsonCodec.fromJson(json);
    File deploymentDir = new File(baseDir, "installation");
    if (!deploymentDir.isDirectory()) {
      deploymentDir = new File(baseDir, "deployment");
    }
    Deployment deployment = data.toDeployment(deploymentDir, getDataDir(), location);
    return deployment;
  }

  private File getDataDir() {
    File dataDir = new File(baseDir, "data");
    dataDir.mkdirs();
    if (!dataDir.isDirectory()) {
      throw new RuntimeException(
          String.format("Unable to create data dir %s", dataDir.getAbsolutePath()));
    }
    return dataDir;
  }
}
Example #26
0
public class PrestoS3FileSystem extends FileSystem {
  public static final String S3_SSL_ENABLED = "presto.s3.ssl.enabled";
  public static final String S3_MAX_ERROR_RETRIES = "presto.s3.max-error-retries";
  public static final String S3_MAX_CLIENT_RETRIES = "presto.s3.max-client-retries";
  public static final String S3_MAX_BACKOFF_TIME = "presto.s3.max-backoff-time";
  public static final String S3_MAX_RETRY_TIME = "presto.s3.max-retry-time";
  public static final String S3_CONNECT_TIMEOUT = "presto.s3.connect-timeout";
  public static final String S3_SOCKET_TIMEOUT = "presto.s3.socket-timeout";
  public static final String S3_MAX_CONNECTIONS = "presto.s3.max-connections";
  public static final String S3_STAGING_DIRECTORY = "presto.s3.staging-directory";
  public static final String S3_MULTIPART_MIN_FILE_SIZE = "presto.s3.multipart.min-file-size";
  public static final String S3_MULTIPART_MIN_PART_SIZE = "presto.s3.multipart.min-part-size";

  private static final Logger log = Logger.get(PrestoS3FileSystem.class);

  private static final DataSize BLOCK_SIZE = new DataSize(32, MEGABYTE);
  private static final DataSize MAX_SKIP_SIZE = new DataSize(1, MEGABYTE);

  private final TransferManagerConfiguration transferConfig = new TransferManagerConfiguration();
  private URI uri;
  private Path workingDirectory;
  private AmazonS3 s3;
  private File stagingDirectory;
  private int maxClientRetries;
  private Duration maxBackoffTime;
  private Duration maxRetryTime;

  @Override
  public void initialize(URI uri, Configuration conf) throws IOException {
    checkNotNull(uri, "uri is null");
    checkNotNull(conf, "conf is null");
    super.initialize(uri, conf);
    setConf(conf);

    this.uri = URI.create(uri.getScheme() + "://" + uri.getAuthority());
    this.workingDirectory = new Path("/").makeQualified(this.uri, new Path("/"));

    HiveClientConfig defaults = new HiveClientConfig();
    this.stagingDirectory =
        new File(conf.get(S3_STAGING_DIRECTORY, defaults.getS3StagingDirectory().toString()));
    this.maxClientRetries = conf.getInt(S3_MAX_CLIENT_RETRIES, defaults.getS3MaxClientRetries());
    this.maxBackoffTime =
        Duration.valueOf(conf.get(S3_MAX_BACKOFF_TIME, defaults.getS3MaxBackoffTime().toString()));
    this.maxRetryTime =
        Duration.valueOf(conf.get(S3_MAX_RETRY_TIME, defaults.getS3MaxRetryTime().toString()));
    int maxErrorRetries = conf.getInt(S3_MAX_ERROR_RETRIES, defaults.getS3MaxErrorRetries());
    boolean sslEnabled = conf.getBoolean(S3_SSL_ENABLED, defaults.isS3SslEnabled());
    Duration connectTimeout =
        Duration.valueOf(conf.get(S3_CONNECT_TIMEOUT, defaults.getS3ConnectTimeout().toString()));
    Duration socketTimeout =
        Duration.valueOf(conf.get(S3_SOCKET_TIMEOUT, defaults.getS3SocketTimeout().toString()));
    int maxConnections = conf.getInt(S3_MAX_CONNECTIONS, defaults.getS3MaxConnections());
    long minFileSize =
        conf.getLong(S3_MULTIPART_MIN_FILE_SIZE, defaults.getS3MultipartMinFileSize().toBytes());
    long minPartSize =
        conf.getLong(S3_MULTIPART_MIN_PART_SIZE, defaults.getS3MultipartMinPartSize().toBytes());

    ClientConfiguration configuration = new ClientConfiguration();
    configuration.setMaxErrorRetry(maxErrorRetries);
    configuration.setProtocol(sslEnabled ? Protocol.HTTPS : Protocol.HTTP);
    configuration.setConnectionTimeout(Ints.checkedCast(connectTimeout.toMillis()));
    configuration.setSocketTimeout(Ints.checkedCast(socketTimeout.toMillis()));
    configuration.setMaxConnections(maxConnections);

    this.s3 = new AmazonS3Client(getAwsCredentials(uri, conf), configuration);

    transferConfig.setMultipartUploadThreshold(minFileSize);
    transferConfig.setMinimumUploadPartSize(minPartSize);
  }

  @Override
  public URI getUri() {
    return uri;
  }

  @Override
  public Path getWorkingDirectory() {
    return workingDirectory;
  }

  @Override
  public void setWorkingDirectory(Path path) {
    workingDirectory = path;
  }

  @Override
  public FileStatus[] listStatus(Path path) throws IOException {
    List<LocatedFileStatus> list = new ArrayList<>();
    RemoteIterator<LocatedFileStatus> iterator = listLocatedStatus(path);
    while (iterator.hasNext()) {
      list.add(iterator.next());
    }
    return toArray(list, LocatedFileStatus.class);
  }

  @Override
  public RemoteIterator<LocatedFileStatus> listLocatedStatus(final Path path) throws IOException {
    return new RemoteIterator<LocatedFileStatus>() {
      private final Iterator<LocatedFileStatus> iterator = listPrefix(path);

      @Override
      public boolean hasNext() throws IOException {
        try {
          return iterator.hasNext();
        } catch (AmazonClientException e) {
          throw new IOException(e);
        }
      }

      @Override
      public LocatedFileStatus next() throws IOException {
        try {
          return iterator.next();
        } catch (AmazonClientException e) {
          throw new IOException(e);
        }
      }
    };
  }

  @Override
  public FileStatus getFileStatus(Path path) throws IOException {
    if (path.getName().isEmpty()) {
      // the bucket root requires special handling
      if (getS3ObjectMetadata(path) != null) {
        return new FileStatus(0, true, 1, 0, 0, qualifiedPath(path));
      }
      throw new FileNotFoundException("File does not exist: " + path);
    }

    ObjectMetadata metadata = getS3ObjectMetadata(path);

    if (metadata == null) {
      // check if this path is a directory
      Iterator<LocatedFileStatus> iterator = listPrefix(path);
      if ((iterator != null) && iterator.hasNext()) {
        return new FileStatus(0, true, 1, 0, 0, qualifiedPath(path));
      }
      throw new FileNotFoundException("File does not exist: " + path);
    }

    return new FileStatus(
        metadata.getContentLength(),
        false,
        1,
        BLOCK_SIZE.toBytes(),
        lastModifiedTime(metadata),
        qualifiedPath(path));
  }

  @Override
  public FSDataInputStream open(Path path, int bufferSize) throws IOException {
    return new FSDataInputStream(
        new BufferedFSInputStream(
            new PrestoS3InputStream(
                s3, uri.getHost(), path, maxClientRetries, maxBackoffTime, maxRetryTime),
            bufferSize));
  }

  @Override
  public FSDataOutputStream create(
      Path path,
      FsPermission permission,
      boolean overwrite,
      int bufferSize,
      short replication,
      long blockSize,
      Progressable progress)
      throws IOException {
    if ((!overwrite) && exists(path)) {
      throw new IOException("File already exists:" + path);
    }

    createDirectories(stagingDirectory.toPath());
    File tempFile = createTempFile(stagingDirectory.toPath(), "presto-s3-", ".tmp").toFile();

    String key = keyFromPath(qualifiedPath(path));
    return new FSDataOutputStream(
        new PrestoS3OutputStream(s3, transferConfig, uri.getHost(), key, tempFile), statistics);
  }

  @Override
  public FSDataOutputStream append(Path f, int bufferSize, Progressable progress)
      throws IOException {
    throw new UnsupportedOperationException("append");
  }

  @Override
  public boolean rename(Path src, Path dst) throws IOException {
    throw new UnsupportedOperationException("rename");
  }

  @Override
  public boolean delete(Path f, boolean recursive) throws IOException {
    throw new UnsupportedOperationException("delete");
  }

  @Override
  public boolean mkdirs(Path f, FsPermission permission) throws IOException {
    // no need to do anything for S3
    return true;
  }

  private Iterator<LocatedFileStatus> listPrefix(Path path) {
    String key = keyFromPath(path);
    if (!key.isEmpty()) {
      key += "/";
    }

    ListObjectsRequest request =
        new ListObjectsRequest().withBucketName(uri.getHost()).withPrefix(key).withDelimiter("/");

    Iterator<ObjectListing> listings =
        new AbstractSequentialIterator<ObjectListing>(s3.listObjects(request)) {
          @Override
          protected ObjectListing computeNext(ObjectListing previous) {
            if (!previous.isTruncated()) {
              return null;
            }
            return s3.listNextBatchOfObjects(previous);
          }
        };

    return Iterators.concat(Iterators.transform(listings, this::statusFromListing));
  }

  private Iterator<LocatedFileStatus> statusFromListing(ObjectListing listing) {
    return Iterators.concat(
        statusFromPrefixes(listing.getCommonPrefixes()),
        statusFromObjects(listing.getObjectSummaries()));
  }

  private Iterator<LocatedFileStatus> statusFromPrefixes(List<String> prefixes) {
    List<LocatedFileStatus> list = new ArrayList<>();
    for (String prefix : prefixes) {
      Path path = qualifiedPath(new Path("/" + prefix));
      FileStatus status = new FileStatus(0, true, 1, 0, 0, path);
      list.add(createLocatedFileStatus(status));
    }
    return list.iterator();
  }

  private Iterator<LocatedFileStatus> statusFromObjects(List<S3ObjectSummary> objects) {
    List<LocatedFileStatus> list = new ArrayList<>();
    for (S3ObjectSummary object : objects) {
      if (!object.getKey().endsWith("/")) {
        FileStatus status =
            new FileStatus(
                object.getSize(),
                false,
                1,
                BLOCK_SIZE.toBytes(),
                object.getLastModified().getTime(),
                qualifiedPath(new Path("/" + object.getKey())));
        list.add(createLocatedFileStatus(status));
      }
    }
    return list.iterator();
  }

  /**
   * This exception is for stopping retries for S3 calls that shouldn't be retried. For example,
   * "Caused by: com.amazonaws.services.s3.model.AmazonS3Exception: Forbidden (Service: Amazon S3;
   * Status Code: 403 ..."
   */
  private static class UnrecoverableS3OperationException extends Exception {
    public UnrecoverableS3OperationException(Throwable cause) {
      super(cause);
    }
  }

  private ObjectMetadata getS3ObjectMetadata(final Path path) throws IOException {
    try {
      return retry()
          .maxAttempts(maxClientRetries)
          .exponentialBackoff(new Duration(1, TimeUnit.SECONDS), maxBackoffTime, maxRetryTime, 2.0)
          .stopOn(InterruptedException.class, UnrecoverableS3OperationException.class)
          .run(
              "getS3ObjectMetadata",
              () -> {
                try {
                  return s3.getObjectMetadata(uri.getHost(), keyFromPath(path));
                } catch (AmazonS3Exception e) {
                  if (e.getStatusCode() == SC_NOT_FOUND) {
                    return null;
                  } else if (e.getStatusCode() == SC_FORBIDDEN) {
                    throw new UnrecoverableS3OperationException(e);
                  }
                  throw Throwables.propagate(e);
                }
              });
    } catch (InterruptedException e) {
      Thread.currentThread().interrupt();
      throw Throwables.propagate(e);
    } catch (Exception e) {
      Throwables.propagateIfInstanceOf(e, IOException.class);
      throw Throwables.propagate(e);
    }
  }

  private Path qualifiedPath(Path path) {
    return path.makeQualified(this.uri, getWorkingDirectory());
  }

  private LocatedFileStatus createLocatedFileStatus(FileStatus status) {
    try {
      BlockLocation[] fakeLocation = getFileBlockLocations(status, 0, status.getLen());
      return new LocatedFileStatus(status, fakeLocation);
    } catch (IOException e) {
      throw Throwables.propagate(e);
    }
  }

  private static long lastModifiedTime(ObjectMetadata metadata) {
    Date date = metadata.getLastModified();
    return (date != null) ? date.getTime() : 0;
  }

  private static String keyFromPath(Path path) {
    checkArgument(path.isAbsolute(), "Path is not absolute: %s", path);
    String key = nullToEmpty(path.toUri().getPath());
    if (key.startsWith("/")) {
      key = key.substring(1);
    }
    if (key.endsWith("/")) {
      key = key.substring(0, key.length() - 1);
    }
    return key;
  }

  private static AWSCredentials getAwsCredentials(URI uri, Configuration conf) {
    S3Credentials credentials = new S3Credentials();
    credentials.initialize(uri, conf);
    return new BasicAWSCredentials(credentials.getAccessKey(), credentials.getSecretAccessKey());
  }

  private static class PrestoS3InputStream extends FSInputStream {
    private final AmazonS3 s3;
    private final String host;
    private final Path path;
    private final int maxClientRetry;
    private final Duration maxBackoffTime;
    private final Duration maxRetryTime;

    private boolean closed;
    private S3ObjectInputStream in;
    private long position;

    public PrestoS3InputStream(
        AmazonS3 s3,
        String host,
        Path path,
        int maxClientRetry,
        Duration maxBackoffTime,
        Duration maxRetryTime) {
      this.s3 = checkNotNull(s3, "s3 is null");
      this.host = checkNotNull(host, "host is null");
      this.path = checkNotNull(path, "path is null");

      checkArgument(maxClientRetry >= 0, "maxClientRetries cannot be negative");
      this.maxClientRetry = maxClientRetry;
      this.maxBackoffTime = checkNotNull(maxBackoffTime, "maxBackoffTime is null");
      this.maxRetryTime = checkNotNull(maxRetryTime, "maxRetryTime is null");
    }

    @Override
    public void close() throws IOException {
      closed = true;
      closeStream();
    }

    @Override
    public void seek(long pos) throws IOException {
      checkState(!closed, "already closed");
      checkArgument(pos >= 0, "position is negative: %s", pos);

      if ((in != null) && (pos == position)) {
        // already at specified position
        return;
      }

      if ((in != null) && (pos > position)) {
        // seeking forwards
        long skip = pos - position;
        if (skip <= max(in.available(), MAX_SKIP_SIZE.toBytes())) {
          // already buffered or seek is small enough
          if (in.skip(skip) == skip) {
            position = pos;
            return;
          }
        }
      }

      // close the stream and open at desired position
      position = pos;
      closeStream();
      openStream();
    }

    @Override
    public long getPos() throws IOException {
      return position;
    }

    @Override
    public int read() throws IOException {
      // This stream is wrapped with BufferedInputStream, so this method should never be called
      throw new UnsupportedOperationException();
    }

    @Override
    public int read(final byte[] buffer, final int offset, final int length) throws IOException {
      try {
        int bytesRead =
            retry()
                .maxAttempts(maxClientRetry)
                .exponentialBackoff(
                    new Duration(1, TimeUnit.SECONDS), maxBackoffTime, maxRetryTime, 2.0)
                .stopOn(InterruptedException.class)
                .run(
                    "readStream",
                    () -> {
                      openStream();
                      try {
                        return in.read(buffer, offset, length);
                      } catch (Exception e) {
                        closeStream();
                        throw e;
                      }
                    });

        if (bytesRead != -1) {
          position += bytesRead;
        }
        return bytesRead;
      } catch (InterruptedException e) {
        Thread.currentThread().interrupt();
        throw Throwables.propagate(e);
      } catch (Exception e) {
        Throwables.propagateIfInstanceOf(e, IOException.class);
        throw Throwables.propagate(e);
      }
    }

    @Override
    public boolean seekToNewSource(long targetPos) throws IOException {
      return false;
    }

    private S3Object getS3Object(final Path path, final long start) throws IOException {
      try {
        return retry()
            .maxAttempts(maxClientRetry)
            .exponentialBackoff(
                new Duration(1, TimeUnit.SECONDS), maxBackoffTime, maxRetryTime, 2.0)
            .stopOn(InterruptedException.class, UnrecoverableS3OperationException.class)
            .run(
                "getS3Object",
                () -> {
                  try {
                    return s3.getObject(
                        new GetObjectRequest(host, keyFromPath(path))
                            .withRange(start, Long.MAX_VALUE));
                  } catch (AmazonServiceException e) {
                    if (e.getStatusCode() == SC_FORBIDDEN) {
                      throw new UnrecoverableS3OperationException(e);
                    }
                    throw Throwables.propagate(e);
                  }
                });
      } catch (InterruptedException e) {
        Thread.currentThread().interrupt();
        throw Throwables.propagate(e);
      } catch (Exception e) {
        Throwables.propagateIfInstanceOf(e, IOException.class);
        throw Throwables.propagate(e);
      }
    }

    private void openStream() throws IOException {
      if (in == null) {
        in = getS3Object(path, position).getObjectContent();
      }
    }

    private void closeStream() throws IOException {
      if (in != null) {
        try {
          in.abort();
        } catch (AbortedException ignored) {
          // thrown if the current thread is in the interrupted state
        }
        in = null;
      }
    }
  }

  private static class PrestoS3OutputStream extends FilterOutputStream {
    private final TransferManager transferManager;
    private final String host;
    private final String key;
    private final File tempFile;

    private boolean closed;

    public PrestoS3OutputStream(
        AmazonS3 s3, TransferManagerConfiguration config, String host, String key, File tempFile)
        throws IOException {
      super(
          new BufferedOutputStream(
              new FileOutputStream(checkNotNull(tempFile, "tempFile is null"))));

      transferManager = new TransferManager(checkNotNull(s3, "s3 is null"));
      transferManager.setConfiguration(checkNotNull(config, "config is null"));

      this.host = checkNotNull(host, "host is null");
      this.key = checkNotNull(key, "key is null");
      this.tempFile = tempFile;

      log.debug("OutputStream for key '%s' using file: %s", key, tempFile);
    }

    @Override
    public void close() throws IOException {
      if (closed) {
        return;
      }
      closed = true;

      try {
        super.close();
        uploadObject();
      } finally {
        if (!tempFile.delete()) {
          log.warn("Could not delete temporary file: %s", tempFile);
        }
        // close transfer manager but keep underlying S3 client open
        transferManager.shutdownNow(false);
      }
    }

    private void uploadObject() throws IOException {
      try {
        log.debug(
            "Starting upload for host: %s, key: %s, file: %s, size: %s",
            host, key, tempFile, tempFile.length());
        Upload upload = transferManager.upload(host, key, tempFile);

        if (log.isDebugEnabled()) {
          upload.addProgressListener(createProgressListener(upload));
        }

        upload.waitForCompletion();
        log.debug("Completed upload for host: %s, key: %s", host, key);
      } catch (AmazonClientException e) {
        throw new IOException(e);
      } catch (InterruptedException e) {
        Thread.currentThread().interrupt();
        throw new InterruptedIOException();
      }
    }

    private ProgressListener createProgressListener(final Transfer transfer) {
      return new ProgressListener() {
        private ProgressEventType previousType;
        private double previousTransferred;

        @Override
        public synchronized void progressChanged(ProgressEvent progressEvent) {
          ProgressEventType eventType = progressEvent.getEventType();
          if (previousType != eventType) {
            log.debug("Upload progress event (%s/%s): %s", host, key, eventType);
            previousType = eventType;
          }

          double transferred = transfer.getProgress().getPercentTransferred();
          if (transferred >= (previousTransferred + 10.0)) {
            log.debug("Upload percentage (%s/%s): %.0f%%", host, key, transferred);
            previousTransferred = transferred;
          }
        }
      };
    }
  }
}
Example #27
0
@ThreadSafe
public final class HttpPageBufferClient implements Closeable {
  private static final int INITIAL_DELAY_MILLIS = 1;
  private static final int MAX_DELAY_MILLIS = 100;

  private static final Logger log = Logger.get(HttpPageBufferClient.class);

  /**
   * For each request, the addPage method will be called zero or more times, followed by either
   * requestComplete or clientFinished (if buffer complete). If the client is closed,
   * requestComplete or bufferFinished may never be called.
   *
   * <p><b>NOTE:</b> Implementations of this interface are not allowed to perform blocking
   * operations.
   */
  public interface ClientCallback {
    boolean addPages(HttpPageBufferClient client, List<Page> pages);

    void requestComplete(HttpPageBufferClient client);

    void clientFinished(HttpPageBufferClient client);

    void clientFailed(HttpPageBufferClient client, Throwable cause);
  }

  private final HttpClient httpClient;
  private final DataSize maxResponseSize;
  private final Duration minErrorDuration;
  private final URI location;
  private final ClientCallback clientCallback;
  private final BlockEncodingSerde blockEncodingSerde;
  private final ScheduledExecutorService executor;

  @GuardedBy("this")
  private final Stopwatch errorStopwatch;

  @GuardedBy("this")
  private boolean closed;

  @GuardedBy("this")
  private HttpResponseFuture<?> future;

  @GuardedBy("this")
  private DateTime lastUpdate = DateTime.now();

  @GuardedBy("this")
  private long token;

  @GuardedBy("this")
  private boolean scheduled;

  @GuardedBy("this")
  private boolean completed;

  @GuardedBy("this")
  private long errorDelayMillis;

  @GuardedBy("this")
  private String taskInstanceId;

  private final AtomicLong rowsReceived = new AtomicLong();
  private final AtomicInteger pagesReceived = new AtomicInteger();

  private final AtomicLong rowsRejected = new AtomicLong();
  private final AtomicInteger pagesRejected = new AtomicInteger();

  private final AtomicInteger requestsScheduled = new AtomicInteger();
  private final AtomicInteger requestsCompleted = new AtomicInteger();
  private final AtomicInteger requestsFailed = new AtomicInteger();

  public HttpPageBufferClient(
      HttpClient httpClient,
      DataSize maxResponseSize,
      Duration minErrorDuration,
      URI location,
      ClientCallback clientCallback,
      BlockEncodingSerde blockEncodingSerde,
      ScheduledExecutorService executor) {
    this(
        httpClient,
        maxResponseSize,
        minErrorDuration,
        location,
        clientCallback,
        blockEncodingSerde,
        executor,
        Stopwatch.createUnstarted());
  }

  public HttpPageBufferClient(
      HttpClient httpClient,
      DataSize maxResponseSize,
      Duration minErrorDuration,
      URI location,
      ClientCallback clientCallback,
      BlockEncodingSerde blockEncodingSerde,
      ScheduledExecutorService executor,
      Stopwatch errorStopwatch) {
    this.httpClient = requireNonNull(httpClient, "httpClient is null");
    this.maxResponseSize = requireNonNull(maxResponseSize, "maxResponseSize is null");
    this.minErrorDuration = requireNonNull(minErrorDuration, "minErrorDuration is null");
    this.location = requireNonNull(location, "location is null");
    this.clientCallback = requireNonNull(clientCallback, "clientCallback is null");
    this.blockEncodingSerde = requireNonNull(blockEncodingSerde, "blockEncodingManager is null");
    this.executor = requireNonNull(executor, "executor is null");
    this.errorStopwatch = requireNonNull(errorStopwatch, "errorStopwatch is null").reset();
  }

  public synchronized PageBufferClientStatus getStatus() {
    String state;
    if (closed) {
      state = "closed";
    } else if (future != null) {
      state = "running";
    } else if (scheduled) {
      state = "scheduled";
    } else if (completed) {
      state = "completed";
    } else {
      state = "queued";
    }
    String httpRequestState = "not scheduled";
    if (future != null) {
      httpRequestState = future.getState();
    }

    long rejectedRows = rowsRejected.get();
    int rejectedPages = pagesRejected.get();

    return new PageBufferClientStatus(
        location,
        state,
        lastUpdate,
        rowsReceived.get(),
        pagesReceived.get(),
        rejectedRows == 0 ? OptionalLong.empty() : OptionalLong.of(rejectedRows),
        rejectedPages == 0 ? OptionalInt.empty() : OptionalInt.of(rejectedPages),
        requestsScheduled.get(),
        requestsCompleted.get(),
        requestsFailed.get(),
        httpRequestState);
  }

  public synchronized boolean isRunning() {
    return future != null;
  }

  @Override
  public void close() {
    boolean shouldSendDelete;
    Future<?> future;
    synchronized (this) {
      shouldSendDelete = !closed;

      closed = true;

      future = this.future;

      this.future = null;

      lastUpdate = DateTime.now();
    }

    if (future != null && !future.isDone()) {
      future.cancel(true);
    }

    // abort the output buffer on the remote node; response of delete is ignored
    if (shouldSendDelete) {
      sendDelete();
    }
  }

  public synchronized void scheduleRequest() {
    if (closed || (future != null) || scheduled) {
      return;
    }
    scheduled = true;

    // start before scheduling to include error delay
    errorStopwatch.start();

    executor.schedule(
        () -> {
          try {
            initiateRequest();
          } catch (Throwable t) {
            // should not happen, but be safe and fail the operator
            clientCallback.clientFailed(HttpPageBufferClient.this, t);
          }
        },
        errorDelayMillis,
        TimeUnit.MILLISECONDS);

    lastUpdate = DateTime.now();
    requestsScheduled.incrementAndGet();
  }

  private synchronized void initiateRequest() {
    scheduled = false;
    if (closed || (future != null)) {
      return;
    }

    if (completed) {
      sendDelete();
    } else {
      sendGetResults();
    }

    lastUpdate = DateTime.now();
  }

  private synchronized void sendGetResults() {
    URI uri = HttpUriBuilder.uriBuilderFrom(location).appendPath(String.valueOf(token)).build();
    HttpResponseFuture<PagesResponse> resultFuture =
        httpClient.executeAsync(
            prepareGet().setHeader(PRESTO_MAX_SIZE, maxResponseSize.toString()).setUri(uri).build(),
            new PageResponseHandler(blockEncodingSerde));

    future = resultFuture;
    Futures.addCallback(
        resultFuture,
        new FutureCallback<PagesResponse>() {
          @Override
          public void onSuccess(PagesResponse result) {
            checkNotHoldsLock();

            resetErrors();

            List<Page> pages;
            try {
              synchronized (HttpPageBufferClient.this) {
                if (taskInstanceId == null) {
                  taskInstanceId = result.getTaskInstanceId();
                }

                if (!isNullOrEmpty(taskInstanceId)
                    && !result.getTaskInstanceId().equals(taskInstanceId)) {
                  // TODO: update error message
                  throw new PrestoException(REMOTE_TASK_MISMATCH, REMOTE_TASK_MISMATCH_ERROR);
                }

                if (result.getToken() == token) {
                  pages = result.getPages();
                  token = result.getNextToken();
                } else {
                  pages = ImmutableList.of();
                }
              }
            } catch (PrestoException e) {
              handleFailure(e, resultFuture);
              return;
            }

            // add pages
            if (clientCallback.addPages(HttpPageBufferClient.this, pages)) {
              pagesReceived.addAndGet(pages.size());
              rowsReceived.addAndGet(pages.stream().mapToLong(Page::getPositionCount).sum());
            } else {
              pagesRejected.addAndGet(pages.size());
              rowsRejected.addAndGet(pages.stream().mapToLong(Page::getPositionCount).sum());
            }

            synchronized (HttpPageBufferClient.this) {
              // client is complete, acknowledge it by sending it a delete in the next request
              if (result.isClientComplete()) {
                completed = true;
              }
              if (future == resultFuture) {
                future = null;
                errorDelayMillis = 0;
              }
              lastUpdate = DateTime.now();
            }
            requestsCompleted.incrementAndGet();
            clientCallback.requestComplete(HttpPageBufferClient.this);
          }

          @Override
          public void onFailure(Throwable t) {
            log.debug("Request to %s failed %s", uri, t);
            checkNotHoldsLock();

            Duration errorDuration = elapsedErrorDuration();

            t = rewriteException(t);
            if (!(t instanceof PrestoException) && errorDuration.compareTo(minErrorDuration) > 0) {
              String message =
                  format("%s (%s - requests failed for %s)", WORKER_NODE_ERROR, uri, errorDuration);
              t = new PageTransportTimeoutException(message, t);
            }
            handleFailure(t, resultFuture);
          }
        },
        executor);
  }

  private synchronized void sendDelete() {
    HttpResponseFuture<StatusResponse> resultFuture =
        httpClient.executeAsync(
            prepareDelete().setUri(location).build(), createStatusResponseHandler());
    future = resultFuture;
    Futures.addCallback(
        resultFuture,
        new FutureCallback<StatusResponse>() {
          @Override
          public void onSuccess(@Nullable StatusResponse result) {
            checkNotHoldsLock();
            synchronized (HttpPageBufferClient.this) {
              closed = true;
              if (future == resultFuture) {
                future = null;
                errorDelayMillis = 0;
              }
              lastUpdate = DateTime.now();
            }
            requestsCompleted.incrementAndGet();
            clientCallback.clientFinished(HttpPageBufferClient.this);
          }

          @Override
          public void onFailure(Throwable t) {
            checkNotHoldsLock();

            log.error("Request to delete %s failed %s", location, t);
            Duration errorDuration = elapsedErrorDuration();
            if (!(t instanceof PrestoException) && errorDuration.compareTo(minErrorDuration) > 0) {
              String message =
                  format(
                      "Error closing remote buffer (%s - requests failed for %s)",
                      location, errorDuration);
              t = new PrestoException(REMOTE_BUFFER_CLOSE_FAILED, message, t);
            }
            handleFailure(t, resultFuture);
          }
        },
        executor);
  }

  private void checkNotHoldsLock() {
    if (Thread.holdsLock(HttpPageBufferClient.this)) {
      log.error("Can not handle callback while holding a lock on this");
    }
  }

  private void handleFailure(Throwable t, HttpResponseFuture<?> expectedFuture) {
    // Can not delegate to other callback while holding a lock on this
    checkNotHoldsLock();

    requestsFailed.incrementAndGet();
    requestsCompleted.incrementAndGet();

    if (t instanceof PrestoException) {
      clientCallback.clientFailed(HttpPageBufferClient.this, t);
    }

    synchronized (HttpPageBufferClient.this) {
      increaseErrorDelay();
      if (future == expectedFuture) {
        future = null;
      }
      lastUpdate = DateTime.now();
    }
    clientCallback.requestComplete(HttpPageBufferClient.this);
  }

  @Override
  public boolean equals(Object o) {
    if (this == o) {
      return true;
    }
    if (o == null || getClass() != o.getClass()) {
      return false;
    }

    HttpPageBufferClient that = (HttpPageBufferClient) o;

    if (!location.equals(that.location)) {
      return false;
    }

    return true;
  }

  @Override
  public int hashCode() {
    return location.hashCode();
  }

  @Override
  public String toString() {
    String state;
    synchronized (this) {
      if (closed) {
        state = "CLOSED";
      } else if (future != null) {
        state = "RUNNING";
      } else {
        state = "QUEUED";
      }
    }
    return toStringHelper(this).add("location", location).addValue(state).toString();
  }

  private static Throwable rewriteException(Throwable t) {
    if (t instanceof ResponseTooLargeException) {
      return new PageTooLargeException();
    }
    return t;
  }

  private synchronized Duration elapsedErrorDuration() {
    if (errorStopwatch.isRunning()) {
      errorStopwatch.stop();
    }
    long nanos = errorStopwatch.elapsed(TimeUnit.NANOSECONDS);
    return new Duration(nanos, TimeUnit.NANOSECONDS).convertTo(TimeUnit.SECONDS);
  }

  private synchronized void increaseErrorDelay() {
    if (errorDelayMillis == 0) {
      errorDelayMillis = INITIAL_DELAY_MILLIS;
    } else {
      errorDelayMillis = min(errorDelayMillis * 2, MAX_DELAY_MILLIS);
    }
  }

  private synchronized void resetErrors() {
    errorStopwatch.reset();
  }

  public static class PageResponseHandler
      implements ResponseHandler<PagesResponse, RuntimeException> {
    private final BlockEncodingSerde blockEncodingSerde;

    public PageResponseHandler(BlockEncodingSerde blockEncodingSerde) {
      this.blockEncodingSerde = blockEncodingSerde;
    }

    @Override
    public PagesResponse handleException(Request request, Exception exception) {
      throw propagate(request, exception);
    }

    @Override
    public PagesResponse handle(Request request, Response response) {
      // no content means no content was created within the wait period, but query is still ok
      // if job is finished, complete is set in the response
      if (response.getStatusCode() == HttpStatus.NO_CONTENT.code()) {
        return createEmptyPagesResponse(
            getTaskInstanceId(response),
            getToken(response),
            getNextToken(response),
            getComplete(response));
      }

      // otherwise we must have gotten an OK response, everything else is considered fatal
      if (response.getStatusCode() != HttpStatus.OK.code()) {
        StringBuilder body = new StringBuilder();
        try (BufferedReader reader =
            new BufferedReader(new InputStreamReader(response.getInputStream()))) {
          // Get up to 1000 lines for debugging
          for (int i = 0; i < 1000; i++) {
            String line = reader.readLine();
            // Don't output more than 100KB
            if (line == null || body.length() + line.length() > 100 * 1024) {
              break;
            }
            body.append(line + "\n");
          }
        } catch (RuntimeException | IOException e) {
          // Ignored. Just return whatever message we were able to decode
        }
        throw new PageTransportErrorException(
            format(
                "Expected response code to be 200, but was %s %s: %s%n%s",
                response.getStatusCode(),
                response.getStatusMessage(),
                request.getUri(),
                body.toString()));
      }

      // invalid content type can happen when an error page is returned, but is unlikely given the
      // above 200
      String contentType = response.getHeader(CONTENT_TYPE);
      if (contentType == null) {
        throw new PageTransportErrorException(
            format("%s header is not set: %s %s", CONTENT_TYPE, request.getUri(), response));
      }
      if (!mediaTypeMatches(contentType, PRESTO_PAGES_TYPE)) {
        throw new PageTransportErrorException(
            format(
                "Expected %s response from server but got %s: %s",
                PRESTO_PAGES_TYPE, contentType, request.getUri()));
      }

      String taskInstanceId = getTaskInstanceId(response);
      long token = getToken(response);
      long nextToken = getNextToken(response);
      boolean complete = getComplete(response);

      try (SliceInput input = new InputStreamSliceInput(response.getInputStream())) {
        List<Page> pages = ImmutableList.copyOf(readPages(blockEncodingSerde, input));
        return createPagesResponse(taskInstanceId, token, nextToken, pages, complete);
      } catch (IOException e) {
        throw Throwables.propagate(e);
      }
    }

    private static String getTaskInstanceId(Response response) {
      String taskInstanceId = response.getHeader(PRESTO_TASK_INSTANCE_ID);
      if (taskInstanceId == null) {
        throw new PageTransportErrorException(
            format("Expected %s header", PRESTO_TASK_INSTANCE_ID));
      }
      return taskInstanceId;
    }

    private static long getToken(Response response) {
      String tokenHeader = response.getHeader(PRESTO_PAGE_TOKEN);
      if (tokenHeader == null) {
        throw new PageTransportErrorException(format("Expected %s header", PRESTO_PAGE_TOKEN));
      }
      return Long.parseLong(tokenHeader);
    }

    private static long getNextToken(Response response) {
      String nextTokenHeader = response.getHeader(PRESTO_PAGE_NEXT_TOKEN);
      if (nextTokenHeader == null) {
        throw new PageTransportErrorException(format("Expected %s header", PRESTO_PAGE_NEXT_TOKEN));
      }
      return Long.parseLong(nextTokenHeader);
    }

    private static boolean getComplete(Response response) {
      String bufferComplete = response.getHeader(PRESTO_BUFFER_COMPLETE);
      if (bufferComplete == null) {
        throw new PageTransportErrorException(format("Expected %s header", PRESTO_BUFFER_COMPLETE));
      }
      return Boolean.parseBoolean(bufferComplete);
    }

    private static boolean mediaTypeMatches(String value, MediaType range) {
      try {
        return MediaType.parse(value).is(range);
      } catch (IllegalArgumentException | IllegalStateException e) {
        return false;
      }
    }
  }

  public static class PagesResponse {
    public static PagesResponse createPagesResponse(
        String taskInstanceId, long token, long nextToken, Iterable<Page> pages, boolean complete) {
      return new PagesResponse(taskInstanceId, token, nextToken, pages, complete);
    }

    public static PagesResponse createEmptyPagesResponse(
        String taskInstanceId, long token, long nextToken, boolean complete) {
      return new PagesResponse(
          taskInstanceId, token, nextToken, ImmutableList.<Page>of(), complete);
    }

    private final String taskInstanceId;
    private final long token;
    private final long nextToken;
    private final List<Page> pages;
    private final boolean clientComplete;

    private PagesResponse(
        String taskInstanceId,
        long token,
        long nextToken,
        Iterable<Page> pages,
        boolean clientComplete) {
      this.taskInstanceId = taskInstanceId;
      this.token = token;
      this.nextToken = nextToken;
      this.pages = ImmutableList.copyOf(pages);
      this.clientComplete = clientComplete;
    }

    public long getToken() {
      return token;
    }

    public long getNextToken() {
      return nextToken;
    }

    public List<Page> getPages() {
      return pages;
    }

    public boolean isClientComplete() {
      return clientComplete;
    }

    public String getTaskInstanceId() {
      return taskInstanceId;
    }

    @Override
    public String toString() {
      return toStringHelper(this)
          .add("token", token)
          .add("nextToken", nextToken)
          .add("pagesSize", pages.size())
          .add("clientComplete", clientComplete)
          .toString();
    }
  }
}
Example #28
0
@ThreadSafe
public class SqlQueryManager implements QueryManager {
  private static final Logger log = Logger.get(SqlQueryManager.class);

  private final ExecutorService queryExecutor;
  private final ThreadPoolExecutorMBean queryExecutorMBean;

  private final int maxQueryHistory;
  private final Duration maxQueryAge;

  private final ConcurrentMap<QueryId, QueryExecution> queries = new ConcurrentHashMap<>();

  private final Duration clientTimeout;

  private final ScheduledExecutorService queryManagementExecutor;
  private final ThreadPoolExecutorMBean queryManagementExecutorMBean;

  private final QueryMonitor queryMonitor;
  private final LocationFactory locationFactory;
  private final QueryIdGenerator queryIdGenerator;

  private final Map<Class<? extends Statement>, QueryExecutionFactory<?>> executionFactories;

  private final SqlQueryManagerStats stats = new SqlQueryManagerStats();

  @Inject
  public SqlQueryManager(
      QueryManagerConfig config,
      QueryMonitor queryMonitor,
      QueryIdGenerator queryIdGenerator,
      LocationFactory locationFactory,
      Map<Class<? extends Statement>, QueryExecutionFactory<?>> executionFactories) {
    checkNotNull(config, "config is null");

    this.executionFactories = checkNotNull(executionFactories, "executionFactories is null");

    this.queryExecutor = Executors.newCachedThreadPool(threadsNamed("query-scheduler-%d"));
    this.queryExecutorMBean = new ThreadPoolExecutorMBean((ThreadPoolExecutor) queryExecutor);

    this.queryMonitor = checkNotNull(queryMonitor, "queryMonitor is null");
    this.locationFactory = checkNotNull(locationFactory, "locationFactory is null");
    this.queryIdGenerator = checkNotNull(queryIdGenerator, "queryIdGenerator is null");

    this.maxQueryAge = config.getMaxQueryAge();
    this.maxQueryHistory = config.getMaxQueryHistory();
    this.clientTimeout = config.getClientTimeout();

    queryManagementExecutor =
        Executors.newScheduledThreadPool(
            config.getQueryManagerExecutorPoolSize(), threadsNamed("query-management-%d"));
    queryManagementExecutorMBean =
        new ThreadPoolExecutorMBean((ThreadPoolExecutor) queryManagementExecutor);
    queryManagementExecutor.scheduleAtFixedRate(
        new Runnable() {
          @Override
          public void run() {
            try {
              removeExpiredQueries();
            } catch (Throwable e) {
              log.warn(e, "Error removing old queries");
            }
            try {
              failAbandonedQueries();
            } catch (Throwable e) {
              log.warn(e, "Error removing old queries");
            }
          }
        },
        200,
        200,
        TimeUnit.MILLISECONDS);
  }

  @PreDestroy
  public void stop() {
    queryManagementExecutor.shutdownNow();
    queryExecutor.shutdownNow();
  }

  @Override
  public List<QueryInfo> getAllQueryInfo() {
    return ImmutableList.copyOf(
        filter(
            transform(
                queries.values(),
                new Function<QueryExecution, QueryInfo>() {
                  @Override
                  public QueryInfo apply(QueryExecution queryExecution) {
                    try {
                      return queryExecution.getQueryInfo();
                    } catch (RuntimeException ignored) {
                      return null;
                    }
                  }
                }),
            Predicates.notNull()));
  }

  @Override
  public Duration waitForStateChange(QueryId queryId, QueryState currentState, Duration maxWait)
      throws InterruptedException {
    Preconditions.checkNotNull(queryId, "queryId is null");
    Preconditions.checkNotNull(maxWait, "maxWait is null");

    QueryExecution query = queries.get(queryId);
    if (query == null) {
      return maxWait;
    }

    query.recordHeartbeat();
    return query.waitForStateChange(currentState, maxWait);
  }

  @Override
  public QueryInfo getQueryInfo(QueryId queryId) {
    checkNotNull(queryId, "queryId is null");

    QueryExecution query = queries.get(queryId);
    if (query == null) {
      throw new NoSuchElementException();
    }

    query.recordHeartbeat();
    return query.getQueryInfo();
  }

  @Override
  public QueryInfo createQuery(Session session, String query) {
    checkNotNull(query, "query is null");
    Preconditions.checkArgument(!query.isEmpty(), "query must not be empty string");

    QueryId queryId = queryIdGenerator.createNextQueryId();

    Statement statement;
    try {
      statement = SqlParser.createStatement(query);
    } catch (ParsingException e) {
      return createFailedQuery(session, query, queryId, e);
    }

    QueryExecutionFactory<?> queryExecutionFactory = executionFactories.get(statement.getClass());
    Preconditions.checkState(
        queryExecutionFactory != null,
        "Unsupported statement type %s",
        statement.getClass().getName());
    final QueryExecution queryExecution =
        queryExecutionFactory.createQueryExecution(queryId, query, session, statement);
    queryMonitor.createdEvent(queryExecution.getQueryInfo());

    queryExecution.addStateChangeListener(
        new StateChangeListener<QueryState>() {
          @Override
          public void stateChanged(QueryState newValue) {
            if (newValue.isDone()) {
              QueryInfo info = queryExecution.getQueryInfo();

              stats.queryFinished(info);
              queryMonitor.completionEvent(info);
            }
          }
        });

    queries.put(queryId, queryExecution);

    // start the query in the background
    queryExecutor.submit(new QueryStarter(queryExecution, stats));

    return queryExecution.getQueryInfo();
  }

  @Override
  public void cancelQuery(QueryId queryId) {
    checkNotNull(queryId, "queryId is null");

    log.debug("Cancel query %s", queryId);

    QueryExecution query = queries.get(queryId);
    if (query != null) {
      query.cancel();
    }
  }

  @Override
  public void cancelStage(StageId stageId) {
    Preconditions.checkNotNull(stageId, "stageId is null");

    log.debug("Cancel stage %s", stageId);

    QueryExecution query = queries.get(stageId.getQueryId());
    if (query != null) {
      query.cancelStage(stageId);
    }
  }

  @Managed
  @Flatten
  public SqlQueryManagerStats getStats() {
    return stats;
  }

  @Managed(description = "Query scheduler executor")
  @Nested
  public ThreadPoolExecutorMBean getExecutor() {
    return queryExecutorMBean;
  }

  @Managed(description = "Query garbage collector executor")
  @Nested
  public ThreadPoolExecutorMBean getManagementExecutor() {
    return queryManagementExecutorMBean;
  }

  public void removeQuery(QueryId queryId) {
    Preconditions.checkNotNull(queryId, "queryId is null");

    log.debug("Remove query %s", queryId);

    QueryExecution query = queries.remove(queryId);
    if (query != null) {
      query.cancel();
    }
  }

  /** Remove completed queries after a waiting period */
  public void removeExpiredQueries() {
    List<QueryExecution> sortedQueries =
        IterableTransformer.on(queries.values())
            .select(compose(not(isNull()), endTimeGetter()))
            .orderBy(Ordering.natural().onResultOf(endTimeGetter()))
            .list();

    int toRemove = Math.max(sortedQueries.size() - maxQueryHistory, 0);
    DateTime oldestAllowedQuery = DateTime.now().minus(maxQueryAge.toMillis());

    for (QueryExecution queryExecution : sortedQueries) {
      try {
        DateTime endTime = queryExecution.getQueryInfo().getQueryStats().getEndTime();
        if ((endTime.isBefore(oldestAllowedQuery) || toRemove > 0) && isAbandoned(queryExecution)) {
          removeQuery(queryExecution.getQueryInfo().getQueryId());
          --toRemove;
        }
      } catch (RuntimeException e) {
        log.warn(
            e,
            "Error while inspecting age of query %s",
            queryExecution.getQueryInfo().getQueryId());
      }
    }
  }

  public void failAbandonedQueries() {
    for (QueryExecution queryExecution : queries.values()) {
      try {
        QueryInfo queryInfo = queryExecution.getQueryInfo();
        if (queryInfo.getState().isDone()) {
          continue;
        }

        if (isAbandoned(queryExecution)) {
          log.info("Failing abandoned query %s", queryExecution.getQueryInfo().getQueryId());
          queryExecution.fail(
              new AbandonedException(
                  "Query " + queryInfo.getQueryId(),
                  queryInfo.getQueryStats().getLastHeartbeat(),
                  DateTime.now()));
        }
      } catch (RuntimeException e) {
        log.warn(
            e,
            "Error while inspecting age of query %s",
            queryExecution.getQueryInfo().getQueryId());
      }
    }
  }

  private boolean isAbandoned(QueryExecution query) {
    DateTime oldestAllowedHeartbeat = DateTime.now().minus(clientTimeout.toMillis());
    DateTime lastHeartbeat = query.getQueryInfo().getQueryStats().getLastHeartbeat();

    return lastHeartbeat != null && lastHeartbeat.isBefore(oldestAllowedHeartbeat);
  }

  private QueryInfo createFailedQuery(
      Session session, String query, QueryId queryId, Throwable cause) {
    URI self = locationFactory.createQueryLocation(queryId);
    QueryExecution execution =
        new FailedQueryExecution(queryId, query, session, self, queryExecutor, cause);

    queries.put(queryId, execution);
    queryMonitor.createdEvent(execution.getQueryInfo());
    queryMonitor.completionEvent(execution.getQueryInfo());
    stats.queryFinished(execution.getQueryInfo());

    return execution.getQueryInfo();
  }

  private static Function<QueryExecution, DateTime> endTimeGetter() {
    return new Function<QueryExecution, DateTime>() {
      @Nullable
      @Override
      public DateTime apply(QueryExecution input) {
        return input.getQueryInfo().getQueryStats().getEndTime();
      }
    };
  }

  private static class QueryStarter implements Runnable {
    private final QueryExecution queryExecution;
    private final SqlQueryManagerStats stats;

    public QueryStarter(QueryExecution queryExecution, SqlQueryManagerStats stats) {
      this.queryExecution = queryExecution;
      this.stats = stats;
    }

    @Override
    public void run() {
      try (SetThreadName setThreadName =
          new SetThreadName("Query-%s", queryExecution.getQueryInfo().getQueryId())) {
        stats.queryStarted();
        queryExecution.start();
      }
    }
  }
}
Example #29
0
public final class CompilerUtils {
  private static final Logger log = Logger.get(CompilerUtils.class);

  private static final boolean DUMP_BYTE_CODE_TREE = false;
  private static final boolean DUMP_BYTE_CODE_RAW = false;
  private static final boolean RUN_ASM_VERIFIER = false; // verifier doesn't work right now
  private static final AtomicReference<String> DUMP_CLASS_FILES_TO = new AtomicReference<>();
  private static final AtomicLong CLASS_ID = new AtomicLong();

  private CompilerUtils() {}

  public static ParameterizedType makeClassName(String baseName) {
    String className = "com.facebook.presto.$gen." + baseName + "_" + CLASS_ID.incrementAndGet();
    String javaClassName = toJavaIdentifierString(className);
    return ParameterizedType.typeFromJavaClassName(javaClassName);
  }

  public static String toJavaIdentifierString(String className) {
    // replace invalid characters with '_'
    int[] codePoints =
        className.codePoints().map(c -> Character.isJavaIdentifierPart(c) ? c : '_').toArray();
    return new String(codePoints, 0, codePoints.length);
  }

  public static <T> Class<? extends T> defineClass(
      ClassDefinition classDefinition, Class<T> superType, DynamicClassLoader classLoader) {
    Class<?> clazz =
        defineClasses(ImmutableList.of(classDefinition), classLoader).values().iterator().next();
    return clazz.asSubclass(superType);
  }

  public static <T> Class<? extends T> defineClass(
      ClassDefinition classDefinition,
      Class<T> superType,
      Map<Long, MethodHandle> callSiteBindings,
      ClassLoader parentClassLoader) {
    Class<?> clazz =
        defineClass(
            classDefinition,
            superType,
            new DynamicClassLoader(parentClassLoader, callSiteBindings));
    return clazz.asSubclass(superType);
  }

  private static Map<String, Class<?>> defineClasses(
      List<ClassDefinition> classDefinitions, DynamicClassLoader classLoader) {
    ClassInfoLoader classInfoLoader =
        ClassInfoLoader.createClassInfoLoader(classDefinitions, classLoader);

    if (DUMP_BYTE_CODE_TREE) {
      ByteArrayOutputStream out = new ByteArrayOutputStream();
      DumpByteCodeVisitor dumpByteCode = new DumpByteCodeVisitor(new PrintStream(out));
      for (ClassDefinition classDefinition : classDefinitions) {
        dumpByteCode.visitClass(classDefinition);
      }
      System.out.println(new String(out.toByteArray(), StandardCharsets.UTF_8));
    }

    Map<String, byte[]> byteCodes = new LinkedHashMap<>();
    for (ClassDefinition classDefinition : classDefinitions) {
      ClassWriter cw = new SmartClassWriter(classInfoLoader);
      classDefinition.visit(cw);
      byte[] byteCode = cw.toByteArray();
      if (RUN_ASM_VERIFIER) {
        ClassReader reader = new ClassReader(byteCode);
        CheckClassAdapter.verify(reader, classLoader, true, new PrintWriter(System.out));
      }
      byteCodes.put(classDefinition.getType().getJavaClassName(), byteCode);
    }

    String dumpClassPath = DUMP_CLASS_FILES_TO.get();
    if (dumpClassPath != null) {
      for (Map.Entry<String, byte[]> entry : byteCodes.entrySet()) {
        File file =
            new File(
                dumpClassPath,
                ParameterizedType.typeFromJavaClassName(entry.getKey()).getClassName() + ".class");
        try {
          log.debug("ClassFile: " + file.getAbsolutePath());
          Files.createParentDirs(file);
          Files.write(entry.getValue(), file);
        } catch (IOException e) {
          log.error(e, "Failed to write generated class file to: %s" + file.getAbsolutePath());
        }
      }
    }
    if (DUMP_BYTE_CODE_RAW) {
      for (byte[] byteCode : byteCodes.values()) {
        ClassReader classReader = new ClassReader(byteCode);
        classReader.accept(
            new TraceClassVisitor(new PrintWriter(System.err)), ClassReader.SKIP_FRAMES);
      }
    }
    Map<String, Class<?>> classes = classLoader.defineClasses(byteCodes);
    try {
      for (Class<?> clazz : classes.values()) {
        Reflection.initialize(clazz);
      }
    } catch (VerifyError e) {
      throw new RuntimeException(e);
    }
    return classes;
  }
}
Example #30
0
public class BenchmarkSuite {
  private static final Logger LOGGER = Logger.get(BenchmarkSuite.class);

  public static List<AbstractBenchmark> createBenchmarks(ExecutorService executor) {
    LocalQueryRunner localQueryRunner = createLocalQueryRunner(executor);
    LocalQueryRunner localSampledQueryRunner = createLocalSampledQueryRunner(executor);

    return ImmutableList.<AbstractBenchmark>of(
        // hand built benchmarks
        new CountAggregationBenchmark(localQueryRunner),
        new DoubleSumAggregationBenchmark(localQueryRunner),
        new HashAggregationBenchmark(localQueryRunner),
        new PredicateFilterBenchmark(localQueryRunner),
        new RawStreamingBenchmark(localQueryRunner),
        new Top100Benchmark(localQueryRunner),
        new OrderByBenchmark(localQueryRunner),
        new HashBuildBenchmark(localQueryRunner),
        new HashJoinBenchmark(localQueryRunner),
        new HashBuildAndJoinBenchmark(localQueryRunner),
        new HandTpchQuery1(localQueryRunner),
        new HandTpchQuery6(localQueryRunner),

        // sql benchmarks
        new GroupBySumWithArithmeticSqlBenchmark(localQueryRunner),
        new CountAggregationSqlBenchmark(localQueryRunner),
        new SqlDoubleSumAggregationBenchmark(localQueryRunner),
        new CountWithFilterSqlBenchmark(localQueryRunner),
        new GroupByAggregationSqlBenchmark(localQueryRunner),
        new PredicateFilterSqlBenchmark(localQueryRunner),
        new RawStreamingSqlBenchmark(localQueryRunner),
        new Top100SqlBenchmark(localQueryRunner),
        new SqlHashJoinBenchmark(localQueryRunner),
        new SqlJoinWithPredicateBenchmark(localQueryRunner),
        new VarBinaryMaxAggregationSqlBenchmark(localQueryRunner),
        new SqlDistinctMultipleFields(localQueryRunner),
        new SqlDistinctSingleField(localQueryRunner),
        new SqlTpchQuery1(localQueryRunner),
        new SqlTpchQuery6(localQueryRunner),
        new SqlLikeBenchmark(localQueryRunner),
        new SqlInBenchmark(localQueryRunner),
        new SqlSemiJoinInPredicateBenchmark(localQueryRunner),
        new SqlRegexpLikeBenchmark(localQueryRunner),
        new SqlApproximatePercentileBenchmark(localQueryRunner),
        new SqlBetweenBenchmark(localQueryRunner),

        // Sampled sql benchmarks
        new RenamingBenchmark(
            "sampled_", new GroupBySumWithArithmeticSqlBenchmark(localSampledQueryRunner)),
        new RenamingBenchmark(
            "sampled_", new CountAggregationSqlBenchmark(localSampledQueryRunner)),
        new RenamingBenchmark(
            "sampled_", new SqlJoinWithPredicateBenchmark(localSampledQueryRunner)),
        new RenamingBenchmark(
            "sampled_", new SqlDoubleSumAggregationBenchmark(localSampledQueryRunner)),

        // statistics benchmarks
        new StatisticsBenchmark.LongVarianceBenchmark(localQueryRunner),
        new StatisticsBenchmark.LongVariancePopBenchmark(localQueryRunner),
        new StatisticsBenchmark.DoubleVarianceBenchmark(localQueryRunner),
        new StatisticsBenchmark.DoubleVariancePopBenchmark(localQueryRunner),
        new StatisticsBenchmark.LongStdDevBenchmark(localQueryRunner),
        new StatisticsBenchmark.LongStdDevPopBenchmark(localQueryRunner),
        new StatisticsBenchmark.DoubleStdDevBenchmark(localQueryRunner),
        new StatisticsBenchmark.DoubleStdDevPopBenchmark(localQueryRunner),
        new SqlApproximateCountDistinctLongBenchmark(localQueryRunner),
        new SqlApproximateCountDistinctDoubleBenchmark(localQueryRunner),
        new SqlApproximateCountDistinctVarBinaryBenchmark(localQueryRunner));
  }

  private final String outputDirectory;

  public BenchmarkSuite(String outputDirectory) {
    this.outputDirectory = checkNotNull(outputDirectory, "outputDirectory is null");
  }

  private File createOutputFile(String fileName) throws IOException {
    File outputFile = new File(fileName);
    Files.createParentDirs(outputFile);
    return outputFile;
  }

  public void runAllBenchmarks() throws IOException {
    ExecutorService executor = newCachedThreadPool(daemonThreadsNamed("test"));
    try {
      List<AbstractBenchmark> benchmarks = createBenchmarks(executor);

      LOGGER.info("=== Pre-running all benchmarks for JVM warmup ===");
      for (AbstractBenchmark benchmark : benchmarks) {
        benchmark.runBenchmark();
      }

      LOGGER.info("=== Actually running benchmarks for metrics ===");
      for (AbstractBenchmark benchmark : benchmarks) {
        try (OutputStream jsonOut =
                new FileOutputStream(
                    createOutputFile(
                        String.format(
                            "%s/json/%s.json", outputDirectory, benchmark.getBenchmarkName())));
            OutputStream jsonAvgOut =
                new FileOutputStream(
                    createOutputFile(
                        String.format(
                            "%s/json-avg/%s.json",
                            outputDirectory, benchmark.getBenchmarkName())));
            OutputStream csvOut =
                new FileOutputStream(
                    createOutputFile(
                        String.format(
                            "%s/csv/%s.csv", outputDirectory, benchmark.getBenchmarkName())));
            OutputStream odsOut =
                new FileOutputStream(
                    createOutputFile(
                        String.format(
                            "%s/ods/%s.json", outputDirectory, benchmark.getBenchmarkName())))) {
          benchmark.runBenchmark(
              new ForwardingBenchmarkResultWriter(
                  ImmutableList.of(
                      new JsonBenchmarkResultWriter(jsonOut),
                      new JsonAvgBenchmarkResultWriter(jsonAvgOut),
                      new SimpleLineBenchmarkResultWriter(csvOut),
                      new OdsBenchmarkResultWriter(
                          "presto.benchmark." + benchmark.getBenchmarkName(), odsOut))));
        }
      }
    } finally {
      executor.shutdownNow();
    }
  }

  private static class ForwardingBenchmarkResultWriter implements BenchmarkResultHook {
    private final List<BenchmarkResultHook> benchmarkResultHooks;

    private ForwardingBenchmarkResultWriter(List<BenchmarkResultHook> benchmarkResultHooks) {
      checkNotNull(benchmarkResultHooks, "benchmarkResultWriters is null");
      this.benchmarkResultHooks = ImmutableList.copyOf(benchmarkResultHooks);
    }

    @Override
    public BenchmarkResultHook addResults(Map<String, Long> results) {
      checkNotNull(results, "results is null");
      for (BenchmarkResultHook benchmarkResultHook : benchmarkResultHooks) {
        benchmarkResultHook.addResults(results);
      }
      return this;
    }

    @Override
    public void finished() {
      for (BenchmarkResultHook benchmarkResultHook : benchmarkResultHooks) {
        benchmarkResultHook.finished();
      }
    }
  }

  public static void main(String[] args) throws IOException {
    String outputDirectory =
        checkNotNull(System.getProperty("outputDirectory"), "Must specify -DoutputDirectory=...");
    new BenchmarkSuite(outputDirectory).runAllBenchmarks();
  }
}