public final class ResumableTasks { private static final Logger log = Logger.get(ResumableTasks.class); private ResumableTasks() {} public static void submit(Executor executor, ResumableTask task) { AtomicReference<Runnable> runnableReference = new AtomicReference<>(); Runnable runnable = () -> { ResumableTask.TaskStatus status = safeProcessTask(task); if (!status.isFinished()) { status.getContinuationFuture().thenRun(() -> executor.execute(runnableReference.get())); } }; runnableReference.set(runnable); executor.execute(runnable); } private static ResumableTask.TaskStatus safeProcessTask(ResumableTask task) { try { return task.process(); } catch (Throwable t) { log.warn(t, "ResumableTask completed exceptionally"); return ResumableTask.TaskStatus.finished(); } } }
private void dropTable(SchemaTableName table) { try { metastoreClient.dropTable(table.getSchemaName(), table.getTableName()); } catch (RuntimeException e) { Logger.get(getClass()).warn(e, "Failed to drop table: %s", table); } }
public class JmxAgent { private final int registryPort; private final int serverPort; private static final Logger log = Logger.get(JmxAgent.class); private final JMXServiceURL url; @Inject public JmxAgent(JmxConfig config) throws IOException { if (config.getRmiRegistryPort() == null) { registryPort = NetUtils.findUnusedPort(); } else { registryPort = config.getRmiRegistryPort(); } if (config.getRmiServerPort() == null) { serverPort = NetUtils.findUnusedPort(); } else { serverPort = config.getRmiServerPort(); } try { // This is how the jdk jmx agent constructs its url url = new JMXServiceURL("rmi", null, registryPort); } catch (MalformedURLException e) { // should not happen... throw new AssertionError(e); } } public JMXServiceURL getURL() { return url; } @PostConstruct public void start() throws IOException { // This is somewhat of a hack, but the jmx agent in Oracle/OpenJDK doesn't // have a programmatic API for starting it and controlling its parameters System.setProperty("com.sun.management.jmxremote", "true"); System.setProperty("com.sun.management.jmxremote.port", Integer.toString(registryPort)); System.setProperty("com.sun.management.jmxremote.rmi.port", Integer.toString(serverPort)); System.setProperty("com.sun.management.jmxremote.authenticate", "false"); System.setProperty("com.sun.management.jmxremote.ssl", "false"); try { Agent.startAgent(); } catch (Exception e) { throw Throwables.propagate(e); } log.info("JMX Agent listening on %s:%s", url.getHost(), url.getPort()); } }
@Test public void testViewCreation() { try { verifyViewCreation(); } finally { try { metadata.dropView(SESSION, temporaryCreateView); } catch (RuntimeException e) { Logger.get(getClass()).warn(e, "Failed to drop view: %s", temporaryCreateView); } } }
public static final class Utils { public static final Logger log = Logger.get(AliasDao.class); public static void createTables(AliasDao dao) { dao.createAliasTable(); } public static void createTablesWithRetry(AliasDao dao) throws InterruptedException { Duration delay = new Duration(10, TimeUnit.SECONDS); while (true) { try { createTables(dao); return; } catch (UnableToObtainConnectionException e) { log.warn( "Failed to connect to database. Will retry again in %s. Exception: %s", delay, e.getMessage()); Thread.sleep(delay.toMillis()); } } } }
@ThreadSafe public class TaskExecutor { private static final Logger log = Logger.get(TaskExecutor.class); // each task is guaranteed a minimum number of tasks private static final int GUARANTEED_SPLITS_PER_TASK = 3; // each time we run a split, run it for this length before returning to the pool private static final Duration SPLIT_RUN_QUANTA = new Duration(1, TimeUnit.SECONDS); private static final AtomicLong NEXT_RUNNER_ID = new AtomicLong(); private static final AtomicLong NEXT_WORKER_ID = new AtomicLong(); private final ExecutorService executor; private final ThreadPoolExecutorMBean executorMBean; private final int runnerThreads; private final int minimumNumberOfTasks; private final Ticker ticker; @GuardedBy("this") private final List<TaskHandle> tasks; private final Set<PrioritizedSplitRunner> allSplits = new HashSet<>(); private final PriorityBlockingQueue<PrioritizedSplitRunner> pendingSplits; private final Set<PrioritizedSplitRunner> runningSplits = Sets.newSetFromMap(new ConcurrentHashMap<PrioritizedSplitRunner, Boolean>()); private final Set<PrioritizedSplitRunner> blockedSplits = Sets.newSetFromMap(new ConcurrentHashMap<PrioritizedSplitRunner, Boolean>()); private final AtomicLongArray completedTasksPerLevel = new AtomicLongArray(5); private final DistributionStat queuedTime = new DistributionStat(); private final DistributionStat wallTime = new DistributionStat(); private volatile boolean closed; @Inject public TaskExecutor(TaskManagerConfig config) { this(checkNotNull(config, "config is null").getMaxShardProcessorThreads()); } public TaskExecutor(int runnerThreads) { this(runnerThreads, Ticker.systemTicker()); } @VisibleForTesting public TaskExecutor(int runnerThreads, Ticker ticker) { checkArgument(runnerThreads > 0, "runnerThreads must be at least 1"); // we manages thread pool size directly, so create an unlimited pool this.executor = Executors.newCachedThreadPool(threadsNamed("task-processor-%d")); this.executorMBean = new ThreadPoolExecutorMBean((ThreadPoolExecutor) executor); this.runnerThreads = runnerThreads; this.ticker = checkNotNull(ticker, "ticker is null"); // we assume we need at least two tasks per runner thread to keep the system busy this.minimumNumberOfTasks = 2 * this.runnerThreads; this.pendingSplits = new PriorityBlockingQueue<>(Runtime.getRuntime().availableProcessors() * 10); this.tasks = new LinkedList<>(); } @PostConstruct public synchronized void start() { checkState(!closed, "TaskExecutor is closed"); for (int i = 0; i < runnerThreads; i++) { addRunnerThread(); } } @PreDestroy public synchronized void stop() { closed = true; executor.shutdownNow(); } @Override public synchronized String toString() { return Objects.toStringHelper(this) .add("runnerThreads", runnerThreads) .add("allSplits", allSplits.size()) .add("pendingSplits", pendingSplits.size()) .add("runningSplits", runningSplits.size()) .add("blockedSplits", blockedSplits.size()) .toString(); } private synchronized void addRunnerThread() { try { executor.execute(new Runner()); } catch (RejectedExecutionException ignored) { } } public synchronized TaskHandle addTask(TaskId taskId) { TaskHandle taskHandle = new TaskHandle(checkNotNull(taskId, "taskId is null")); tasks.add(taskHandle); return taskHandle; } public synchronized void removeTask(TaskHandle taskHandle) { taskHandle.destroy(); tasks.remove(taskHandle); // record completed stats long threadUsageNanos = taskHandle.getThreadUsageNanos(); int priorityLevel = calculatePriorityLevel(threadUsageNanos); completedTasksPerLevel.incrementAndGet(priorityLevel); } public synchronized ListenableFuture<?> enqueueSplit( TaskHandle taskHandle, SplitRunner taskSplit) { PrioritizedSplitRunner prioritizedSplitRunner = new PrioritizedSplitRunner(taskHandle, taskSplit, ticker); taskHandle.addSplit(prioritizedSplitRunner); scheduleTaskIfNecessary(taskHandle); addNewEntrants(); return prioritizedSplitRunner.getFinishedFuture(); } public synchronized ListenableFuture<?> forceRunSplit( TaskHandle taskHandle, SplitRunner taskSplit) { PrioritizedSplitRunner prioritizedSplitRunner = new PrioritizedSplitRunner(taskHandle, taskSplit, ticker); // Note: we do not record queued time for forced splits startSplit(prioritizedSplitRunner); return prioritizedSplitRunner.getFinishedFuture(); } private synchronized void splitFinished(PrioritizedSplitRunner split) { allSplits.remove(split); pendingSplits.remove(split); TaskHandle taskHandle = split.getTaskHandle(); taskHandle.splitComplete(split); wallTime.add(System.nanoTime() - split.createdNanos); scheduleTaskIfNecessary(taskHandle); addNewEntrants(); } private synchronized void scheduleTaskIfNecessary(TaskHandle taskHandle) { // if task has less than the minimum guaranteed splits running, // immediately schedule a new split for this task. This assures // that a task gets its fair amount of consideration (you have to // have splits to be considered for running on a thread). if (taskHandle.getRunningSplits() < GUARANTEED_SPLITS_PER_TASK) { PrioritizedSplitRunner split = taskHandle.pollNextSplit(); if (split != null) { startSplit(split); queuedTime.add(System.nanoTime() - split.createdNanos); } } } private synchronized void addNewEntrants() { int running = allSplits.size(); for (int i = 0; i < minimumNumberOfTasks - running; i++) { PrioritizedSplitRunner split = pollNextSplitWorker(); if (split == null) { break; } queuedTime.add(System.nanoTime() - split.createdNanos); startSplit(split); } } private synchronized void startSplit(PrioritizedSplitRunner split) { allSplits.add(split); pendingSplits.put(split); } private synchronized PrioritizedSplitRunner pollNextSplitWorker() { // todo find a better algorithm for this // find the first task that produces a split, then move that task to the // end of the task list, so we get round robin for (Iterator<TaskHandle> iterator = tasks.iterator(); iterator.hasNext(); ) { TaskHandle task = iterator.next(); PrioritizedSplitRunner split = task.pollNextSplit(); if (split != null) { // move task to end of list iterator.remove(); // CAUTION: we are modifying the list in the loop which would normally // cause a ConcurrentModificationException but we exit immediately tasks.add(task); return split; } } return null; } @NotThreadSafe public static class TaskHandle { private final TaskId taskId; private final Queue<PrioritizedSplitRunner> queuedSplits = new ArrayDeque<>(10); private final List<PrioritizedSplitRunner> runningSplits = new ArrayList<>(10); private final AtomicLong taskThreadUsageNanos = new AtomicLong(); private TaskHandle(TaskId taskId) { this.taskId = taskId; } private long addThreadUsageNanos(long durationNanos) { return taskThreadUsageNanos.addAndGet(durationNanos); } private TaskId getTaskId() { return taskId; } private void destroy() { for (PrioritizedSplitRunner runningSplit : runningSplits) { runningSplit.destroy(); } runningSplits.clear(); for (PrioritizedSplitRunner queuedSplit : queuedSplits) { queuedSplit.destroy(); } queuedSplits.clear(); } private void addSplit(PrioritizedSplitRunner split) { queuedSplits.add(split); } private int getRunningSplits() { return runningSplits.size(); } private long getThreadUsageNanos() { return taskThreadUsageNanos.get(); } private PrioritizedSplitRunner pollNextSplit() { PrioritizedSplitRunner split = queuedSplits.poll(); if (split != null) { runningSplits.add(split); } return split; } private void splitComplete(PrioritizedSplitRunner split) { runningSplits.remove(split); split.destroy(); } @Override public String toString() { return Objects.toStringHelper(this).add("taskId", taskId).toString(); } } private static class PrioritizedSplitRunner implements Comparable<PrioritizedSplitRunner> { private final long createdNanos = System.nanoTime(); private final TaskHandle taskHandle; private final long workerId; private final SplitRunner split; private final Ticker ticker; private final SettableFuture<?> finishedFuture = SettableFuture.create(); private final AtomicBoolean initialized = new AtomicBoolean(); private final AtomicBoolean destroyed = new AtomicBoolean(); private final AtomicInteger priorityLevel = new AtomicInteger(); private final AtomicLong threadUsageNanos = new AtomicLong(); private final AtomicLong lastRun = new AtomicLong(); private PrioritizedSplitRunner(TaskHandle taskHandle, SplitRunner split, Ticker ticker) { this.taskHandle = taskHandle; this.split = split; this.ticker = ticker; this.workerId = NEXT_WORKER_ID.getAndIncrement(); } private TaskHandle getTaskHandle() { return taskHandle; } private SettableFuture<?> getFinishedFuture() { return finishedFuture; } public void initializeIfNecessary() { if (initialized.compareAndSet(false, true)) { split.initialize(); } } public void destroy() { try { split.close(); } catch (RuntimeException e) { log.error(e, "Error closing split for task %s", taskHandle.getTaskId()); } destroyed.set(true); } public boolean isFinished() { boolean finished = split.isFinished(); if (finished) { finishedFuture.set(null); } return finished || destroyed.get(); } public ListenableFuture<?> process() throws Exception { try { long start = ticker.read(); ListenableFuture<?> blocked = split.processFor(SPLIT_RUN_QUANTA); long endTime = ticker.read(); // update priority level base on total thread usage of task long durationNanos = endTime - start; long threadUsageNanos = taskHandle.addThreadUsageNanos(durationNanos); this.threadUsageNanos.set(threadUsageNanos); priorityLevel.set(calculatePriorityLevel(threadUsageNanos)); // record last run for prioritization within a level lastRun.set(endTime); return blocked; } catch (Throwable e) { finishedFuture.setException(e); throw e; } } public boolean updatePriorityLevel() { int newPriority = calculatePriorityLevel(taskHandle.getThreadUsageNanos()); if (newPriority == priorityLevel.getAndSet(newPriority)) { return false; } // update thread usage while if level changed threadUsageNanos.set(taskHandle.getThreadUsageNanos()); return true; } @Override public int compareTo(PrioritizedSplitRunner o) { int level = priorityLevel.get(); int result = Ints.compare(level, o.priorityLevel.get()); if (result != 0) { return result; } if (level < 4) { result = Long.compare(threadUsageNanos.get(), threadUsageNanos.get()); } else { result = Long.compare(lastRun.get(), o.lastRun.get()); } if (result != 0) { return result; } return Longs.compare(workerId, o.workerId); } @Override public String toString() { return String.format( "Split %-15s %s %s", taskHandle.getTaskId(), priorityLevel, new Duration(threadUsageNanos.get(), TimeUnit.NANOSECONDS) .convertToMostSuccinctTimeUnit()); } } private static int calculatePriorityLevel(long threadUsageNanos) { long millis = TimeUnit.NANOSECONDS.toMillis(threadUsageNanos); int priorityLevel; if (millis < 1000) { priorityLevel = 0; } else if (millis < 10_000) { priorityLevel = 1; } else if (millis < 60_000) { priorityLevel = 2; } else if (millis < 300_000) { priorityLevel = 3; } else { priorityLevel = 4; } return priorityLevel; } private class Runner implements Runnable { private final long runnerId = NEXT_RUNNER_ID.getAndIncrement(); @Override public void run() { try (SetThreadName runnerName = new SetThreadName("SplitRunner-%s", runnerId)) { while (!closed && !Thread.currentThread().isInterrupted()) { // select next worker final PrioritizedSplitRunner split; try { split = pendingSplits.take(); if (split.updatePriorityLevel()) { // priority level changed, return split to queue for re-prioritization pendingSplits.put(split); continue; } } catch (InterruptedException e) { Thread.currentThread().interrupt(); return; } try (SetThreadName splitName = new SetThreadName(split.toString())) { runningSplits.add(split); boolean finished; ListenableFuture<?> blocked; try { split.initializeIfNecessary(); blocked = split.process(); finished = split.isFinished(); } finally { runningSplits.remove(split); } if (finished) { log.debug("%s is finished", split); splitFinished(split); } else { if (blocked.isDone()) { pendingSplits.put(split); } else { blockedSplits.add(split); blocked.addListener( new Runnable() { @Override public void run() { blockedSplits.remove(split); split.updatePriorityLevel(); pendingSplits.put(split); } }, executor); } } } catch (Throwable t) { log.error(t, "Error processing %s", split); splitFinished(split); } } } finally { // unless we have been closed, we need to replace this thread if (!closed) { addRunnerThread(); } } } } // // STATS // @Managed public int getTasks() { return tasks.size(); } @Managed public int getRunnerThreads() { return runnerThreads; } @Managed public int getMinimumNumberOfTasks() { return minimumNumberOfTasks; } @Managed public int getTotalSplits() { return allSplits.size(); } @Managed public int getPendingSplits() { return pendingSplits.size(); } @Managed public int getRunningSplits() { return runningSplits.size(); } @Managed public int getBlockedSplits() { return blockedSplits.size(); } @Managed public long getCompletedTasksLevel0() { return completedTasksPerLevel.get(0); } @Managed public long getCompletedTasksLevel1() { return completedTasksPerLevel.get(1); } @Managed public long getCompletedTasksLevel2() { return completedTasksPerLevel.get(2); } @Managed public long getCompletedTasksLevel3() { return completedTasksPerLevel.get(3); } @Managed public long getCompletedTasksLevel4() { return completedTasksPerLevel.get(4); } @Managed public long getRunningTasksLevel0() { return calculateRunningTasksForLevel(0); } @Managed public long getRunningTasksLevel1() { return calculateRunningTasksForLevel(1); } @Managed public long getRunningTasksLevel2() { return calculateRunningTasksForLevel(2); } @Managed public long getRunningTasksLevel3() { return calculateRunningTasksForLevel(3); } @Managed public long getRunningTasksLevel4() { return calculateRunningTasksForLevel(4); } @Managed @Nested public DistributionStat getQueuedTime() { return queuedTime; } @Managed @Nested public DistributionStat getWallTime() { return wallTime; } private synchronized int calculateRunningTasksForLevel(int level) { int count = 0; for (TaskHandle task : tasks) { if (calculatePriorityLevel(task.getThreadUsageNanos()) == level) { count++; } } return count; } @Managed(description = "Task processor executor") @Nested public ThreadPoolExecutorMBean getProcessorExecutor() { return executorMBean; } }
public class OrganizationJobFactory implements JobFactory { private static final Logger log = Logger.get(OrganizationJobFactory.class); private final MetadataDao metadataDao; private final ShardManager shardManager; private final ShardCompactor compactor; @Inject public OrganizationJobFactory( @ForMetadata IDBI dbi, ShardManager shardManager, ShardCompactor compactor) { this.metadataDao = onDemandDao(dbi, MetadataDao.class); this.shardManager = requireNonNull(shardManager, "shardManager is null"); this.compactor = requireNonNull(compactor, "compactor is null"); } @Override public Runnable create(OrganizationSet organizationSet) { return new OrganizationJob(organizationSet); } private class OrganizationJob implements Runnable { private final OrganizationSet organizationSet; public OrganizationJob(OrganizationSet organizationSet) { this.organizationSet = requireNonNull(organizationSet, "organizationSet is null"); } @Override public void run() { try { runJob( organizationSet.getTableId(), organizationSet.getBucketNumber(), organizationSet.getShards()); } catch (IOException e) { throw Throwables.propagate(e); } } private void runJob(long tableId, OptionalInt bucketNumber, Set<UUID> shardUuids) throws IOException { long transactionId = shardManager.beginTransaction(); try { runJob(transactionId, bucketNumber, tableId, shardUuids); } catch (Throwable e) { shardManager.rollbackTransaction(transactionId); throw e; } } private void runJob( long transactionId, OptionalInt bucketNumber, long tableId, Set<UUID> shardUuids) throws IOException { TableMetadata metadata = getTableMetadata(tableId); // This job could be in the queue for quite some time, so before doing any expensive // operations, // filter out shards that no longer exist, reducing the possibility of failure shardUuids = shardManager.getExistingShardUuids(tableId, shardUuids); if (shardUuids.size() <= 1) { return; } List<ShardInfo> newShards = performCompaction(transactionId, bucketNumber, shardUuids, metadata); log.info( "Compacted shards %s into %s", shardUuids, newShards.stream().map(ShardInfo::getShardUuid).collect(toList())); shardManager.replaceShardUuids( transactionId, tableId, metadata.getColumns(), shardUuids, newShards, OptionalLong.empty()); } private TableMetadata getTableMetadata(long tableId) { List<TableColumn> sortColumns = metadataDao.listSortColumns(tableId); List<Long> sortColumnIds = sortColumns.stream().map(TableColumn::getColumnId).collect(toList()); List<ColumnInfo> columns = metadataDao .listTableColumns(tableId) .stream() .map(TableColumn::toColumnInfo) .collect(toList()); return new TableMetadata(tableId, columns, sortColumnIds); } private List<ShardInfo> performCompaction( long transactionId, OptionalInt bucketNumber, Set<UUID> shardUuids, TableMetadata tableMetadata) throws IOException { if (tableMetadata.getSortColumnIds().isEmpty()) { return compactor.compact( transactionId, bucketNumber, shardUuids, tableMetadata.getColumns()); } return compactor.compactSorted( transactionId, bucketNumber, shardUuids, tableMetadata.getColumns(), tableMetadata.getSortColumnIds(), nCopies(tableMetadata.getSortColumnIds().size(), ASC_NULLS_FIRST)); } } }
public class BatchProcessor<T> { private static final Logger log = Logger.get(BatchProcessor.class); private final BatchHandler<T> handler; private final int maxBatchSize; private final BlockingQueue<T> queue; private final String name; private ExecutorService executor; private volatile Future<?> future; private final AtomicLong processedEntries = new AtomicLong(); private final AtomicLong droppedEntries = new AtomicLong(); private final AtomicLong errors = new AtomicLong(); public BatchProcessor(String name, BatchHandler<T> handler, int maxBatchSize, int queueSize) { Preconditions.checkNotNull(name, "name is null"); Preconditions.checkNotNull(handler, "handler is null"); Preconditions.checkArgument(queueSize > 0, "queue size needs to be a positive integer"); Preconditions.checkArgument(maxBatchSize > 0, "max batch size needs to be a positive integer"); this.name = name; this.handler = handler; this.maxBatchSize = maxBatchSize; this.queue = new ArrayBlockingQueue<T>(queueSize); } @PostConstruct public synchronized void start() { if (future == null) { executor = newSingleThreadExecutor(threadsNamed("batch-processor-" + name + "-%d")); future = executor.submit( new Runnable() { public void run() { while (!Thread.interrupted()) { final List<T> entries = new ArrayList<T>(maxBatchSize); try { T first = queue.take(); entries.add(first); queue.drainTo(entries, maxBatchSize - 1); handler.processBatch(Collections.unmodifiableList(entries)); processedEntries.addAndGet(entries.size()); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } catch (Throwable t) { errors.incrementAndGet(); log.warn(t, "Error handling batch"); } // TODO: expose timestamp of last execution via jmx } } }); } } @Managed public long getProcessedEntries() { return processedEntries.get(); } @Managed public long getDroppedEntries() { return droppedEntries.get(); } @Managed public long getErrors() { return errors.get(); } @Managed public long getQueueSize() { return queue.size(); } @PreDestroy public synchronized void stop() { if (future != null) { future.cancel(true); executor.shutdownNow(); future = null; } } public void put(T entry) { Preconditions.checkState(!future.isCancelled(), "Processor is not running"); Preconditions.checkNotNull(entry, "entry is null"); while (!queue.offer(entry)) { // throw away oldest and try again if (queue.poll() != null) { droppedEntries.incrementAndGet(); } } } public static interface BatchHandler<T> { void processBatch(Collection<T> entries); } }
public class RiakSplitManager implements ConnectorSplitManager { private static final Logger log = Logger.get(RiakSplitManager.class); private final String connectorId; private final RiakClient riakClient; private final RiakConfig riakConfig; private final DirectConnection directConnection; @Inject public RiakSplitManager( RiakConnectorId connectorId, RiakClient riakClient, RiakConfig config, DirectConnection directConnection) { this.connectorId = checkNotNull(connectorId, "connectorId is null").toString(); this.riakClient = checkNotNull(riakClient, "client is null"); this.riakConfig = checkNotNull(config); this.directConnection = checkNotNull(directConnection); } // TODO: get the right partitions right here @Override public ConnectorPartitionResult getPartitions( ConnectorTableHandle tableHandle, TupleDomain<ColumnHandle> tupleDomain) { checkArgument( tableHandle instanceof RiakTableHandle, "tableHandle is not an instance of RiakTableHandle"); RiakTableHandle riakTableHandle = (RiakTableHandle) tableHandle; log.info("==========================tupleDomain============================="); log.info(tupleDomain.toString()); try { String parentTable = PRSubTable.parentTableName(riakTableHandle.getTableName()); SchemaTableName parentSchemaTable = new SchemaTableName(riakTableHandle.getSchemaName(), parentTable); PRTable table = riakClient.getTable(parentSchemaTable); List<String> indexedColumns = new LinkedList<String>(); for (RiakColumn riakColumn : table.getColumns()) { if (riakColumn.getIndex()) { indexedColumns.add(riakColumn.getName()); } } // Riak connector has only one partition List<ConnectorPartition> partitions = ImmutableList.<ConnectorPartition>of( new RiakPartition( riakTableHandle.getSchemaName(), riakTableHandle.getTableName(), tupleDomain, indexedColumns)); // Riak connector does not do any additional processing/filtering with the TupleDomain, so // just return the whole TupleDomain return new ConnectorPartitionResult(partitions, tupleDomain); } catch (Exception e) { log.error("interrupted: %s", e.toString()); throw new TableNotFoundException(riakTableHandle.toSchemaTableName()); } } // TODO: return correct splits from partitions @Override public ConnectorSplitSource getPartitionSplits( ConnectorTableHandle tableHandle, List<ConnectorPartition> partitions) { checkNotNull(partitions, "partitions is null"); checkArgument(partitions.size() == 1, "Expected one partition but got %s", partitions.size()); ConnectorPartition partition = partitions.get(0); checkArgument( partition instanceof RiakPartition, "partition is not an instance of RiakPartition"); // RiakPartition riakPartition = (RiakPartition) partition; RiakTableHandle riakTableHandle = (RiakTableHandle) tableHandle; try { String parentTable = PRSubTable.parentTableName(riakTableHandle.getTableName()); SchemaTableName parentSchemaTable = new SchemaTableName(riakTableHandle.getSchemaName(), parentTable); PRTable table = riakClient.getTable(parentSchemaTable); log.debug("> %s", table.getColumns().toString()); // add all nodes at the cluster here List<ConnectorSplit> splits = Lists.newArrayList(); String hosts = riakClient.getHosts(); log.debug(hosts); if (riakConfig.getLocalNode() != null) { // TODO: make coverageSplits here // try { DirectConnection conn = directConnection; // conn.connect(riak); // conn.ping(); Coverage coverage = new Coverage(conn); coverage.plan(); List<SplitTask> splitTasks = coverage.getSplits(); log.debug("print coverage plan=============="); log.debug(coverage.toString()); for (SplitTask split : splitTasks) { log.info("============printing split data at " + split.getHost() + "==============="); // log.debug(((OtpErlangObject)split.getTask()).toString()); log.info(split.toString()); CoverageSplit coverageSplit = new CoverageSplit( riakTableHandle, // maybe toplevel or subtable table, // toplevel PRTable split.getHost(), split.toString(), partition.getTupleDomain()); // log.info(new JsonCodecFactory().jsonCodec(CoverageSplit.class).toJson(coverageSplit)); splits.add(coverageSplit); } } else { // TODO: in Riak connector, you only need single access point for each presto worker??? log.error("localNode must be set and working"); log.debug(hosts); // splits.add(new CoverageSplit(connectorId, riakTableHandle.getSchemaName(), // riakTableHandle.getTableName(), hosts, // partition.getTupleDomain(), // ((RiakPartition) partition).getIndexedColumns())); } log.debug( "table %s.%s has %d splits.", riakTableHandle.getSchemaName(), riakTableHandle.getTableName(), splits.size()); Collections.shuffle(splits); return new FixedSplitSource(connectorId, splits); } catch (Exception e) { throw new TableNotFoundException(riakTableHandle.toSchemaTableName()); } // this can happen if table is removed during a query } }
final class ShardIterator extends AbstractIterator<BucketShards> implements ResultIterator<BucketShards> { private static final Logger log = Logger.get(ShardIterator.class); private final Map<Integer, String> nodeMap = new HashMap<>(); private final boolean merged; private final Map<Integer, String> bucketToNode; private final ShardDao dao; private final Connection connection; private final PreparedStatement statement; private final ResultSet resultSet; private boolean first = true; public ShardIterator( long tableId, boolean merged, Optional<Map<Integer, String>> bucketToNode, TupleDomain<RaptorColumnHandle> effectivePredicate, IDBI dbi) { this.merged = merged; this.bucketToNode = bucketToNode.orElse(null); ShardPredicate predicate = ShardPredicate.create(effectivePredicate, bucketToNode.isPresent()); String sql; if (bucketToNode.isPresent()) { sql = "SELECT shard_uuid, bucket_number FROM %s WHERE %s ORDER BY bucket_number"; } else { sql = "SELECT shard_uuid, node_ids FROM %s WHERE %s"; } sql = format(sql, shardIndexTable(tableId), predicate.getPredicate()); dao = onDemandDao(dbi, ShardDao.class); fetchNodes(); try { connection = dbi.open().getConnection(); statement = connection.prepareStatement(sql); enableStreamingResults(statement); predicate.bind(statement); log.debug("Running query: %s", statement); resultSet = statement.executeQuery(); } catch (SQLException e) { close(); throw metadataError(e); } } @Override protected BucketShards computeNext() { try { return merged ? computeMerged() : compute(); } catch (SQLException e) { throw metadataError(e); } } @SuppressWarnings({"UnusedDeclaration", "EmptyTryBlock"}) @Override public void close() { // use try-with-resources to close everything properly try (Connection connection = this.connection; Statement statement = this.statement; ResultSet resultSet = this.resultSet) { // do nothing } catch (SQLException ignored) { } } /** Compute split-per-shard (separate split for each shard). */ private BucketShards compute() throws SQLException { if (!resultSet.next()) { return endOfData(); } UUID shardUuid = uuidFromBytes(resultSet.getBytes("shard_uuid")); Set<String> nodeIdentifiers; OptionalInt bucketNumber = OptionalInt.empty(); if (bucketToNode != null) { int bucket = resultSet.getInt("bucket_number"); bucketNumber = OptionalInt.of(bucket); nodeIdentifiers = ImmutableSet.of(getBucketNode(bucket)); } else { List<Integer> nodeIds = intArrayFromBytes(resultSet.getBytes("node_ids")); nodeIdentifiers = getNodeIdentifiers(nodeIds, shardUuid); } ShardNodes shard = new ShardNodes(shardUuid, nodeIdentifiers); return new BucketShards(bucketNumber, ImmutableSet.of(shard)); } /** Compute split-per-bucket (single split for all shards in a bucket). */ private BucketShards computeMerged() throws SQLException { if (resultSet.isAfterLast()) { return endOfData(); } if (first) { first = false; if (!resultSet.next()) { return endOfData(); } } int bucketNumber = resultSet.getInt("bucket_number"); ImmutableSet.Builder<ShardNodes> shards = ImmutableSet.builder(); do { UUID shardUuid = uuidFromBytes(resultSet.getBytes("shard_uuid")); int bucket = resultSet.getInt("bucket_number"); Set<String> nodeIdentifiers = ImmutableSet.of(getBucketNode(bucket)); shards.add(new ShardNodes(shardUuid, nodeIdentifiers)); } while (resultSet.next() && resultSet.getInt("bucket_number") == bucketNumber); return new BucketShards(OptionalInt.of(bucketNumber), shards.build()); } private String getBucketNode(int bucket) { String node = bucketToNode.get(bucket); if (node == null) { throw new PrestoException(RAPTOR_ERROR, "No node mapping for bucket: " + bucket); } return node; } private Set<String> getNodeIdentifiers(List<Integer> nodeIds, UUID shardUuid) { Function<Integer, String> fetchNode = id -> fetchNode(id, shardUuid); return nodeIds.stream().map(id -> nodeMap.computeIfAbsent(id, fetchNode)).collect(toSet()); } private String fetchNode(int id, UUID shardUuid) { String node = dao.getNodeIdentifier(id); if (node == null) { throw new PrestoException( RAPTOR_ERROR, format("Missing node ID [%s] for shard: %s", id, shardUuid)); } return node; } private void fetchNodes() { for (RaptorNode node : dao.getNodes()) { nodeMap.put(node.getNodeId(), node.getNodeIdentifier()); } } private static void enableStreamingResults(Statement statement) throws SQLException { if (statement.isWrapperFor(com.mysql.jdbc.Statement.class)) { statement.unwrap(com.mysql.jdbc.Statement.class).enableStreamingResults(); } } }
public class HttpServiceInventory implements ServiceInventory { private static final Logger log = Logger.get(HttpServiceInventory.class); private final Repository repository; private final JsonCodec<List<ServiceDescriptor>> descriptorsJsonCodec; private final Set<String> invalidServiceInventory = Collections.newSetFromMap(new ConcurrentHashMap<String, Boolean>()); private final File cacheDir; @Inject public HttpServiceInventory( Repository repository, JsonCodec<List<ServiceDescriptor>> descriptorsJsonCodec, CoordinatorConfig config) { this(repository, descriptorsJsonCodec, new File(config.getServiceInventoryCacheDir())); } public HttpServiceInventory( Repository repository, JsonCodec<List<ServiceDescriptor>> descriptorsJsonCodec, File cacheDir) { Preconditions.checkNotNull(repository, "repository is null"); Preconditions.checkNotNull(descriptorsJsonCodec, "descriptorsJsonCodec is null"); this.repository = repository; this.descriptorsJsonCodec = descriptorsJsonCodec; this.cacheDir = cacheDir; } @Override public ImmutableList<ServiceDescriptor> getServiceInventory(Iterable<SlotStatus> allSlotStatus) { ImmutableList.Builder<ServiceDescriptor> newDescriptors = ImmutableList.builder(); for (SlotStatus slotStatus : allSlotStatus) { // if the self reference is null, the slot is totally offline so skip for now if (slotStatus.getSelf() == null) { continue; } List<ServiceDescriptor> serviceDescriptors = getServiceInventory(slotStatus); if (serviceDescriptors == null) { continue; } for (ServiceDescriptor serviceDescriptor : serviceDescriptors) { newDescriptors.add( new ServiceDescriptor( null, slotStatus.getId().toString(), serviceDescriptor.getType(), serviceDescriptor.getPool(), slotStatus.getLocation(), slotStatus.getState() == SlotLifecycleState.RUNNING ? ServiceState.RUNNING : ServiceState.STOPPED, interpolateProperties(serviceDescriptor.getProperties(), slotStatus))); } } return newDescriptors.build(); } private Map<String, String> interpolateProperties( Map<String, String> properties, SlotStatus slotStatus) { ImmutableMap.Builder<String, String> builder = ImmutableMap.builder(); for (Entry<String, String> entry : properties.entrySet()) { String key = entry.getKey(); String value = entry.getValue(); value = value.replaceAll(Pattern.quote("${airship.host}"), slotStatus.getSelf().getHost()); builder.put(key, value); } return builder.build(); } private List<ServiceDescriptor> getServiceInventory(SlotStatus slotStatus) { Assignment assignment = slotStatus.getAssignment(); if (assignment == null) { return null; } String config = assignment.getConfig(); File cacheFile = getCacheFile(config); if (cacheFile.canRead()) { try { String json = CharStreams.toString(Files.newReaderSupplier(cacheFile, Charsets.UTF_8)); List<ServiceDescriptor> descriptors = descriptorsJsonCodec.fromJson(json); invalidServiceInventory.remove(config); return descriptors; } catch (Exception ignored) { // delete the bad cache file cacheFile.delete(); } } InputSupplier<? extends InputStream> configFile = ConfigUtils.newConfigEntrySupplier(repository, config, "airship-service-inventory.json"); if (configFile == null) { return null; } try { String json; try { json = CharStreams.toString(CharStreams.newReaderSupplier(configFile, Charsets.UTF_8)); } catch (FileNotFoundException e) { // no service inventory in the config, so replace with json null so caching works json = "null"; } invalidServiceInventory.remove(config); // cache json cacheFile.getParentFile().mkdirs(); Files.write(json, cacheFile, Charsets.UTF_8); List<ServiceDescriptor> descriptors = descriptorsJsonCodec.fromJson(json); return descriptors; } catch (Exception e) { if (invalidServiceInventory.add(config)) { log.error(e, "Unable to read service inventory for %s" + config); } } return null; } private File getCacheFile(String config) { String cacheName = config; if (cacheName.startsWith("@")) { cacheName = cacheName.substring(1); } cacheName = cacheName.replaceAll("[^a-zA-Z0-9_.-]", "_"); cacheName = cacheName + "_" + DigestUtils.md5Hex(cacheName); return new File(cacheDir, cacheName).getAbsoluteFile(); } }
public class PredicatePushDown extends PlanOptimizer { private static final Logger log = Logger.get(PredicatePushDown.class); private final Metadata metadata; private final SplitManager splitManager; private final boolean experimentalSyntaxEnabled; public PredicatePushDown( Metadata metadata, SplitManager splitManager, boolean experimentalSyntaxEnabled) { this.metadata = checkNotNull(metadata, "metadata is null"); this.splitManager = checkNotNull(splitManager, "splitManager is null"); this.experimentalSyntaxEnabled = experimentalSyntaxEnabled; } @Override public PlanNode optimize( PlanNode plan, Session session, Map<Symbol, Type> types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) { checkNotNull(plan, "plan is null"); checkNotNull(session, "session is null"); checkNotNull(types, "types is null"); checkNotNull(idAllocator, "idAllocator is null"); return PlanRewriter.rewriteWith( new Rewriter( symbolAllocator, idAllocator, metadata, splitManager, session, experimentalSyntaxEnabled), plan, BooleanLiteral.TRUE_LITERAL); } private static class Rewriter extends PlanNodeRewriter<Expression> { private final SymbolAllocator symbolAllocator; private final PlanNodeIdAllocator idAllocator; private final Metadata metadata; private final SplitManager splitManager; private final Session session; private final boolean experimentalSyntaxEnabled; private Rewriter( SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, Metadata metadata, SplitManager splitManager, Session session, boolean experimentalSyntaxEnabled) { this.symbolAllocator = checkNotNull(symbolAllocator, "symbolAllocator is null"); this.idAllocator = checkNotNull(idAllocator, "idAllocator is null"); this.metadata = checkNotNull(metadata, "metadata is null"); this.splitManager = checkNotNull(splitManager, "splitManager is null"); this.session = checkNotNull(session, "session is null"); this.experimentalSyntaxEnabled = experimentalSyntaxEnabled; } @Override public PlanNode rewriteNode( PlanNode node, Expression inheritedPredicate, PlanRewriter<Expression> planRewriter) { PlanNode rewrittenNode = planRewriter.defaultRewrite(node, BooleanLiteral.TRUE_LITERAL); if (!inheritedPredicate.equals(BooleanLiteral.TRUE_LITERAL)) { // Drop in a FilterNode b/c we cannot push our predicate down any further rewrittenNode = new FilterNode(idAllocator.getNextId(), rewrittenNode, inheritedPredicate); } return rewrittenNode; } @Override public PlanNode rewriteProject( ProjectNode node, Expression inheritedPredicate, PlanRewriter<Expression> planRewriter) { Expression inlinedPredicate = ExpressionTreeRewriter.rewriteWith( new ExpressionSymbolInliner(node.getOutputMap()), inheritedPredicate); return planRewriter.defaultRewrite(node, inlinedPredicate); } @Override public PlanNode rewriteMarkDistinct( MarkDistinctNode node, Expression inheritedPredicate, PlanRewriter<Expression> planRewriter) { checkState( !DependencyExtractor.extractUnique(inheritedPredicate).contains(node.getMarkerSymbol()), "predicate depends on marker symbol"); return planRewriter.defaultRewrite(node, inheritedPredicate); } @Override public PlanNode rewriteSort( SortNode node, Expression inheritedPredicate, PlanRewriter<Expression> planRewriter) { return planRewriter.defaultRewrite(node, inheritedPredicate); } @Override public PlanNode rewriteUnion( UnionNode node, Expression inheritedPredicate, PlanRewriter<Expression> planRewriter) { boolean modified = false; ImmutableList.Builder<PlanNode> builder = ImmutableList.builder(); for (int i = 0; i < node.getSources().size(); i++) { Expression sourcePredicate = ExpressionTreeRewriter.rewriteWith( new ExpressionSymbolInliner(node.sourceSymbolMap(i)), inheritedPredicate); PlanNode source = node.getSources().get(i); PlanNode rewrittenSource = planRewriter.rewrite(source, sourcePredicate); if (rewrittenSource != source) { modified = true; } builder.add(rewrittenSource); } if (modified) { return new UnionNode(node.getId(), builder.build(), node.getSymbolMapping()); } return node; } @Override public PlanNode rewriteFilter( FilterNode node, Expression inheritedPredicate, PlanRewriter<Expression> planRewriter) { return planRewriter.rewrite( node.getSource(), combineConjuncts(node.getPredicate(), inheritedPredicate)); } @Override public PlanNode rewriteJoin( JoinNode node, Expression inheritedPredicate, PlanRewriter<Expression> planRewriter) { boolean isCrossJoin = (node.getType() == JoinNode.Type.CROSS); // See if we can rewrite outer joins in terms of a plain inner join node = tryNormalizeToInnerJoin(node, inheritedPredicate); Expression leftEffectivePredicate = EffectivePredicateExtractor.extract(node.getLeft()); Expression rightEffectivePredicate = EffectivePredicateExtractor.extract(node.getRight()); Expression joinPredicate = extractJoinPredicate(node); Expression leftPredicate; Expression rightPredicate; Expression postJoinPredicate; Expression newJoinPredicate; switch (node.getType()) { case INNER: InnerJoinPushDownResult innerJoinPushDownResult = processInnerJoin( inheritedPredicate, leftEffectivePredicate, rightEffectivePredicate, joinPredicate, node.getLeft().getOutputSymbols()); leftPredicate = innerJoinPushDownResult.getLeftPredicate(); rightPredicate = innerJoinPushDownResult.getRightPredicate(); postJoinPredicate = innerJoinPushDownResult.getPostJoinPredicate(); newJoinPredicate = innerJoinPushDownResult.getJoinPredicate(); break; case LEFT: OuterJoinPushDownResult leftOuterJoinPushDownResult = processOuterJoin( inheritedPredicate, leftEffectivePredicate, rightEffectivePredicate, joinPredicate, node.getLeft().getOutputSymbols()); leftPredicate = leftOuterJoinPushDownResult.getOuterJoinPredicate(); rightPredicate = leftOuterJoinPushDownResult.getInnerJoinPredicate(); postJoinPredicate = leftOuterJoinPushDownResult.getPostJoinPredicate(); newJoinPredicate = joinPredicate; // Use the same as the original break; case RIGHT: OuterJoinPushDownResult rightOuterJoinPushDownResult = processOuterJoin( inheritedPredicate, rightEffectivePredicate, leftEffectivePredicate, joinPredicate, node.getRight().getOutputSymbols()); leftPredicate = rightOuterJoinPushDownResult.getInnerJoinPredicate(); rightPredicate = rightOuterJoinPushDownResult.getOuterJoinPredicate(); postJoinPredicate = rightOuterJoinPushDownResult.getPostJoinPredicate(); newJoinPredicate = joinPredicate; // Use the same as the original break; default: throw new UnsupportedOperationException("Unsupported join type: " + node.getType()); } PlanNode leftSource = planRewriter.rewrite(node.getLeft(), leftPredicate); PlanNode rightSource = planRewriter.rewrite(node.getRight(), rightPredicate); PlanNode output = node; if (leftSource != node.getLeft() || rightSource != node.getRight() || !newJoinPredicate.equals(joinPredicate)) { List<JoinNode.EquiJoinClause> criteria = node.getCriteria(); // Rewrite criteria and add projections if there is a new join predicate if (!newJoinPredicate.equals(joinPredicate) || isCrossJoin) { // Create identity projections for all existing symbols ImmutableMap.Builder<Symbol, Expression> leftProjections = ImmutableMap.builder(); leftProjections.putAll( IterableTransformer.<Symbol>on(node.getLeft().getOutputSymbols()) .toMap(symbolToQualifiedNameReference()) .map()); ImmutableMap.Builder<Symbol, Expression> rightProjections = ImmutableMap.builder(); rightProjections.putAll( IterableTransformer.<Symbol>on(node.getRight().getOutputSymbols()) .toMap(symbolToQualifiedNameReference()) .map()); // HACK! we don't support cross joins right now, so put in a simple fake join predicate // instead if all of the join clauses got simplified out // TODO: remove this code when cross join support is added Iterable<Expression> simplifiedJoinConjuncts = transform(extractConjuncts(newJoinPredicate), simplifyExpressions()); simplifiedJoinConjuncts = filter( simplifiedJoinConjuncts, not(Predicates.<Expression>equalTo(BooleanLiteral.TRUE_LITERAL))); if (Iterables.isEmpty(simplifiedJoinConjuncts)) { simplifiedJoinConjuncts = ImmutableList.<Expression>of( new ComparisonExpression( ComparisonExpression.Type.EQUAL, new LongLiteral("0"), new LongLiteral("0"))); } // Create new projections for the new join clauses ImmutableList.Builder<JoinNode.EquiJoinClause> builder = ImmutableList.builder(); for (Expression conjunct : simplifiedJoinConjuncts) { checkState( joinEqualityExpression(node.getLeft().getOutputSymbols()).apply(conjunct), "Expected join predicate to be a valid join equality"); ComparisonExpression equality = (ComparisonExpression) conjunct; boolean alignedComparison = Iterables.all( DependencyExtractor.extractUnique(equality.getLeft()), in(node.getLeft().getOutputSymbols())); Expression leftExpression = (alignedComparison) ? equality.getLeft() : equality.getRight(); Expression rightExpression = (alignedComparison) ? equality.getRight() : equality.getLeft(); Symbol leftSymbol = symbolAllocator.newSymbol(leftExpression, extractType(leftExpression)); leftProjections.put(leftSymbol, leftExpression); Symbol rightSymbol = symbolAllocator.newSymbol(rightExpression, extractType(rightExpression)); rightProjections.put(rightSymbol, rightExpression); builder.add(new JoinNode.EquiJoinClause(leftSymbol, rightSymbol)); } leftSource = new ProjectNode(idAllocator.getNextId(), leftSource, leftProjections.build()); rightSource = new ProjectNode(idAllocator.getNextId(), rightSource, rightProjections.build()); criteria = builder.build(); } output = new JoinNode(node.getId(), node.getType(), leftSource, rightSource, criteria); } if (!postJoinPredicate.equals(BooleanLiteral.TRUE_LITERAL)) { output = new FilterNode(idAllocator.getNextId(), output, postJoinPredicate); } return output; } private OuterJoinPushDownResult processOuterJoin( Expression inheritedPredicate, Expression outerEffectivePredicate, Expression innerEffectivePredicate, Expression joinPredicate, Collection<Symbol> outerSymbols) { checkArgument( Iterables.all( DependencyExtractor.extractUnique(outerEffectivePredicate), in(outerSymbols)), "outerEffectivePredicate must only contain symbols from outerSymbols"); checkArgument( Iterables.all( DependencyExtractor.extractUnique(innerEffectivePredicate), not(in(outerSymbols))), "innerEffectivePredicate must not contain symbols from outerSymbols"); ImmutableList.Builder<Expression> outerPushdownConjuncts = ImmutableList.builder(); ImmutableList.Builder<Expression> innerPushdownConjuncts = ImmutableList.builder(); ImmutableList.Builder<Expression> postJoinConjuncts = ImmutableList.builder(); // Strip out non-deterministic conjuncts postJoinConjuncts.addAll(filter(extractConjuncts(inheritedPredicate), not(deterministic()))); inheritedPredicate = stripNonDeterministicConjuncts(inheritedPredicate); outerEffectivePredicate = stripNonDeterministicConjuncts(outerEffectivePredicate); innerEffectivePredicate = stripNonDeterministicConjuncts(innerEffectivePredicate); joinPredicate = stripNonDeterministicConjuncts(joinPredicate); // Generate equality inferences EqualityInference inheritedInference = createEqualityInference(inheritedPredicate); EqualityInference outerInference = createEqualityInference(inheritedPredicate, outerEffectivePredicate); EqualityInference.EqualityPartition equalityPartition = inheritedInference.generateEqualitiesPartitionedBy(in(outerSymbols)); Expression outerOnlyInheritedEqualities = combineConjuncts(equalityPartition.getScopeEqualities()); EqualityInference potentialNullSymbolInference = createEqualityInference( outerOnlyInheritedEqualities, outerEffectivePredicate, innerEffectivePredicate, joinPredicate); EqualityInference potentialNullSymbolInferenceWithoutInnerInferred = createEqualityInference( outerOnlyInheritedEqualities, outerEffectivePredicate, joinPredicate); // Sort through conjuncts in inheritedPredicate that were not used for inference for (Expression conjunct : EqualityInference.nonInferrableConjuncts(inheritedPredicate)) { Expression outerRewritten = outerInference.rewriteExpression(conjunct, in(outerSymbols)); if (outerRewritten != null) { outerPushdownConjuncts.add(outerRewritten); // A conjunct can only be pushed down into an inner side if it can be rewritten in terms // of the outer side Expression innerRewritten = potentialNullSymbolInference.rewriteExpression(outerRewritten, not(in(outerSymbols))); if (innerRewritten != null) { innerPushdownConjuncts.add(innerRewritten); } } else { postJoinConjuncts.add(conjunct); } } // See if we can push down any outer or join predicates to the inner side for (Expression conjunct : EqualityInference.nonInferrableConjuncts(and(outerEffectivePredicate, joinPredicate))) { Expression rewritten = potentialNullSymbolInference.rewriteExpression(conjunct, not(in(outerSymbols))); if (rewritten != null) { innerPushdownConjuncts.add(rewritten); } } // TODO: consider adding join predicate optimizations to outer joins // Add the equalities from the inferences back in outerPushdownConjuncts.addAll(equalityPartition.getScopeEqualities()); postJoinConjuncts.addAll(equalityPartition.getScopeComplementEqualities()); postJoinConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities()); innerPushdownConjuncts.addAll( potentialNullSymbolInferenceWithoutInnerInferred .generateEqualitiesPartitionedBy(not(in(outerSymbols))) .getScopeEqualities()); return new OuterJoinPushDownResult( combineConjuncts(outerPushdownConjuncts.build()), combineConjuncts(innerPushdownConjuncts.build()), combineConjuncts(postJoinConjuncts.build())); } private static class OuterJoinPushDownResult { private final Expression outerJoinPredicate; private final Expression innerJoinPredicate; private final Expression postJoinPredicate; private OuterJoinPushDownResult( Expression outerJoinPredicate, Expression innerJoinPredicate, Expression postJoinPredicate) { this.outerJoinPredicate = outerJoinPredicate; this.innerJoinPredicate = innerJoinPredicate; this.postJoinPredicate = postJoinPredicate; } private Expression getOuterJoinPredicate() { return outerJoinPredicate; } private Expression getInnerJoinPredicate() { return innerJoinPredicate; } private Expression getPostJoinPredicate() { return postJoinPredicate; } } private InnerJoinPushDownResult processInnerJoin( Expression inheritedPredicate, Expression leftEffectivePredicate, Expression rightEffectivePredicate, Expression joinPredicate, Collection<Symbol> leftSymbols) { checkArgument( Iterables.all(DependencyExtractor.extractUnique(leftEffectivePredicate), in(leftSymbols)), "leftEffectivePredicate must only contain symbols from leftSymbols"); checkArgument( Iterables.all( DependencyExtractor.extractUnique(rightEffectivePredicate), not(in(leftSymbols))), "rightEffectivePredicate must not contain symbols from leftSymbols"); ImmutableList.Builder<Expression> leftPushDownConjuncts = ImmutableList.builder(); ImmutableList.Builder<Expression> rightPushDownConjuncts = ImmutableList.builder(); ImmutableList.Builder<Expression> joinConjuncts = ImmutableList.builder(); // Strip out non-deterministic conjuncts joinConjuncts.addAll(filter(extractConjuncts(inheritedPredicate), not(deterministic()))); inheritedPredicate = stripNonDeterministicConjuncts(inheritedPredicate); joinConjuncts.addAll(filter(extractConjuncts(joinPredicate), not(deterministic()))); joinPredicate = stripNonDeterministicConjuncts(joinPredicate); leftEffectivePredicate = stripNonDeterministicConjuncts(leftEffectivePredicate); rightEffectivePredicate = stripNonDeterministicConjuncts(rightEffectivePredicate); // Generate equality inferences EqualityInference allInference = createEqualityInference( inheritedPredicate, leftEffectivePredicate, rightEffectivePredicate, joinPredicate); EqualityInference allInferenceWithoutLeftInferred = createEqualityInference(inheritedPredicate, rightEffectivePredicate, joinPredicate); EqualityInference allInferenceWithoutRightInferred = createEqualityInference(inheritedPredicate, leftEffectivePredicate, joinPredicate); // Sort through conjuncts in inheritedPredicate that were not used for inference for (Expression conjunct : EqualityInference.nonInferrableConjuncts(inheritedPredicate)) { Expression leftRewrittenConjunct = allInference.rewriteExpression(conjunct, in(leftSymbols)); if (leftRewrittenConjunct != null) { leftPushDownConjuncts.add(leftRewrittenConjunct); } Expression rightRewrittenConjunct = allInference.rewriteExpression(conjunct, not(in(leftSymbols))); if (rightRewrittenConjunct != null) { rightPushDownConjuncts.add(rightRewrittenConjunct); } // Drop predicate after join only if unable to push down to either side if (leftRewrittenConjunct == null && rightRewrittenConjunct == null) { joinConjuncts.add(conjunct); } } // See if we can push the right effective predicate to the left side for (Expression conjunct : EqualityInference.nonInferrableConjuncts(rightEffectivePredicate)) { Expression rewritten = allInference.rewriteExpression(conjunct, in(leftSymbols)); if (rewritten != null) { leftPushDownConjuncts.add(rewritten); } } // See if we can push the left effective predicate to the right side for (Expression conjunct : EqualityInference.nonInferrableConjuncts(leftEffectivePredicate)) { Expression rewritten = allInference.rewriteExpression(conjunct, not(in(leftSymbols))); if (rewritten != null) { rightPushDownConjuncts.add(rewritten); } } // See if we can push any parts of the join predicates to either side for (Expression conjunct : EqualityInference.nonInferrableConjuncts(joinPredicate)) { Expression leftRewritten = allInference.rewriteExpression(conjunct, in(leftSymbols)); if (leftRewritten != null) { leftPushDownConjuncts.add(leftRewritten); } Expression rightRewritten = allInference.rewriteExpression(conjunct, not(in(leftSymbols))); if (rightRewritten != null) { rightPushDownConjuncts.add(rightRewritten); } if (leftRewritten == null && rightRewritten == null) { joinConjuncts.add(conjunct); } } // Add equalities from the inference back in leftPushDownConjuncts.addAll( allInferenceWithoutLeftInferred .generateEqualitiesPartitionedBy(in(leftSymbols)) .getScopeEqualities()); rightPushDownConjuncts.addAll( allInferenceWithoutRightInferred .generateEqualitiesPartitionedBy(not(in(leftSymbols))) .getScopeEqualities()); joinConjuncts.addAll( allInference .generateEqualitiesPartitionedBy(in(leftSymbols)) .getScopeStraddlingEqualities()); // scope straddling equalities get dropped in as // part of the join predicate // Since we only currently support equality in join conjuncts, factor out the non-equality // conjuncts to a post-join filter List<Expression> joinConjunctsList = joinConjuncts.build(); List<Expression> postJoinConjuncts = ImmutableList.copyOf(filter(joinConjunctsList, not(joinEqualityExpression(leftSymbols)))); joinConjunctsList = ImmutableList.copyOf(filter(joinConjunctsList, joinEqualityExpression(leftSymbols))); return new InnerJoinPushDownResult( combineConjuncts(leftPushDownConjuncts.build()), combineConjuncts(rightPushDownConjuncts.build()), combineConjuncts(joinConjunctsList), combineConjuncts(postJoinConjuncts)); } private static class InnerJoinPushDownResult { private final Expression leftPredicate; private final Expression rightPredicate; private final Expression joinPredicate; private final Expression postJoinPredicate; private InnerJoinPushDownResult( Expression leftPredicate, Expression rightPredicate, Expression joinPredicate, Expression postJoinPredicate) { this.leftPredicate = leftPredicate; this.rightPredicate = rightPredicate; this.joinPredicate = joinPredicate; this.postJoinPredicate = postJoinPredicate; } private Expression getLeftPredicate() { return leftPredicate; } private Expression getRightPredicate() { return rightPredicate; } private Expression getJoinPredicate() { return joinPredicate; } private Expression getPostJoinPredicate() { return postJoinPredicate; } } private static Expression extractJoinPredicate(JoinNode joinNode) { ImmutableList.Builder<Expression> builder = ImmutableList.builder(); for (JoinNode.EquiJoinClause equiJoinClause : joinNode.getCriteria()) { builder.add(equalsExpression(equiJoinClause.getLeft(), equiJoinClause.getRight())); } return combineConjuncts(builder.build()); } private static Expression equalsExpression(Symbol symbol1, Symbol symbol2) { return new ComparisonExpression( ComparisonExpression.Type.EQUAL, new QualifiedNameReference(symbol1.toQualifiedName()), new QualifiedNameReference(symbol2.toQualifiedName())); } // TODO: temporary addition to infer result type from expression. fix this with the new planner // refactoring (martint) private Type extractType(Expression expression) { ExpressionAnalyzer expressionAnalyzer = new ExpressionAnalyzer(new Analysis(), session, metadata, experimentalSyntaxEnabled); List<Field> fields = IterableTransformer.<Symbol>on(DependencyExtractor.extractUnique(expression)) .transform( new Function<Symbol, Field>() { @Override public Field apply(Symbol symbol) { return Field.newUnqualified( symbol.getName(), symbolAllocator.getTypes().get(symbol)); } }) .list(); return expressionAnalyzer.analyze( expression, new TupleDescriptor(fields), new AnalysisContext()); } private JoinNode tryNormalizeToInnerJoin(JoinNode node, Expression inheritedPredicate) { Preconditions.checkArgument( EnumSet.of(INNER, RIGHT, LEFT, CROSS).contains(node.getType()), "Unsupported join type: %s", node.getType()); if (node.getType() == JoinNode.Type.CROSS) { return new JoinNode( node.getId(), JoinNode.Type.INNER, node.getLeft(), node.getRight(), node.getCriteria()); } if (node.getType() == JoinNode.Type.INNER || node.getType() == JoinNode.Type.LEFT && !canConvertOuterToInner(node.getRight().getOutputSymbols(), inheritedPredicate) || node.getType() == JoinNode.Type.RIGHT && !canConvertOuterToInner(node.getLeft().getOutputSymbols(), inheritedPredicate)) { return node; } return new JoinNode( node.getId(), JoinNode.Type.INNER, node.getLeft(), node.getRight(), node.getCriteria()); } private boolean canConvertOuterToInner( List<Symbol> innerSymbolsForOuterJoin, Expression inheritedPredicate) { Set<Symbol> innerSymbols = ImmutableSet.copyOf(innerSymbolsForOuterJoin); for (Expression conjunct : extractConjuncts(inheritedPredicate)) { if (DeterminismEvaluator.isDeterministic(conjunct)) { // Ignore a conjunct for this test if we can not deterministically get responses from it Object response = nullInputEvaluator(innerSymbols, conjunct); if (response == null || response instanceof NullLiteral || Boolean.FALSE.equals(response)) { // If there is a single conjunct that returns FALSE or NULL given all NULL inputs for // the inner side symbols of an outer join // then this conjunct removes all effects of the outer join, and effectively turns this // into an equivalent of an inner join. // So, let's just rewrite this join as an INNER join return true; } } } return false; } // Temporary implementation for joins because the SimplifyExpressions optimizers can not run // properly on join clauses private Function<Expression, Expression> simplifyExpressions() { return new Function<Expression, Expression>() { @Override public Expression apply(Expression expression) { ExpressionInterpreter optimizer = ExpressionInterpreter.expressionOptimizer(expression, metadata, session); return LiteralInterpreter.toExpression(optimizer.optimize(NoOpSymbolResolver.INSTANCE)); } }; } /** Evaluates an expression's response to binding the specified input symbols to NULL */ private Object nullInputEvaluator(final Collection<Symbol> nullSymbols, Expression expression) { return ExpressionInterpreter.expressionOptimizer(expression, metadata, session) .optimize( new SymbolResolver() { @Override public Object getValue(Symbol symbol) { return nullSymbols.contains(symbol) ? null : new QualifiedNameReference(symbol.toQualifiedName()); } }); } private static Predicate<Expression> joinEqualityExpression( final Collection<Symbol> leftSymbols) { return new Predicate<Expression>() { @Override public boolean apply(Expression expression) { // At this point in time, our join predicates need to be deterministic if (isDeterministic(expression) && expression instanceof ComparisonExpression) { ComparisonExpression comparison = (ComparisonExpression) expression; if (comparison.getType() == ComparisonExpression.Type.EQUAL) { Set<Symbol> symbols1 = DependencyExtractor.extractUnique(comparison.getLeft()); Set<Symbol> symbols2 = DependencyExtractor.extractUnique(comparison.getRight()); return (Iterables.all(symbols1, in(leftSymbols)) && Iterables.all(symbols2, not(in(leftSymbols)))) || (Iterables.all(symbols2, in(leftSymbols)) && Iterables.all(symbols1, not(in(leftSymbols)))); } } return false; } }; } @Override public PlanNode rewriteSemiJoin( SemiJoinNode node, Expression inheritedPredicate, PlanRewriter<Expression> planRewriter) { Expression sourceEffectivePredicate = EffectivePredicateExtractor.extract(node.getSource()); List<Expression> sourceConjuncts = new ArrayList<>(); List<Expression> filteringSourceConjuncts = new ArrayList<>(); List<Expression> postJoinConjuncts = new ArrayList<>(); // TODO: see if there are predicates that can be inferred from the semi join output // Push inherited and source predicates to filtering source via a contrived join predicate // (but needs to avoid touching NULL values in the filtering source) Expression joinPredicate = equalsExpression(node.getSourceJoinSymbol(), node.getFilteringSourceJoinSymbol()); EqualityInference joinInference = createEqualityInference(inheritedPredicate, sourceEffectivePredicate, joinPredicate); for (Expression conjunct : Iterables.concat( EqualityInference.nonInferrableConjuncts(inheritedPredicate), EqualityInference.nonInferrableConjuncts(sourceEffectivePredicate))) { Expression rewrittenConjunct = joinInference.rewriteExpression(conjunct, equalTo(node.getFilteringSourceJoinSymbol())); if (rewrittenConjunct != null && DeterminismEvaluator.isDeterministic(rewrittenConjunct)) { // Alter conjunct to include an OR filteringSourceJoinSymbol IS NULL disjunct Expression rewrittenConjunctOrNull = expressionOrNullSymbols(equalTo(node.getFilteringSourceJoinSymbol())) .apply(rewrittenConjunct); filteringSourceConjuncts.add(rewrittenConjunctOrNull); } } EqualityInference.EqualityPartition joinInferenceEqualityPartition = joinInference.generateEqualitiesPartitionedBy( equalTo(node.getFilteringSourceJoinSymbol())); filteringSourceConjuncts.addAll( ImmutableList.copyOf( transform( joinInferenceEqualityPartition.getScopeEqualities(), expressionOrNullSymbols(equalTo(node.getFilteringSourceJoinSymbol()))))); // Push inheritedPredicates down to the source if they don't involve the semi join output EqualityInference inheritedInference = createEqualityInference(inheritedPredicate); for (Expression conjunct : EqualityInference.nonInferrableConjuncts(inheritedPredicate)) { Expression rewrittenConjunct = inheritedInference.rewriteExpression(conjunct, in(node.getSource().getOutputSymbols())); // Since each source row is reflected exactly once in the output, ok to push // non-deterministic predicates down if (rewrittenConjunct != null) { sourceConjuncts.add(rewrittenConjunct); } else { postJoinConjuncts.add(conjunct); } } // Add the inherited equality predicates back in EqualityInference.EqualityPartition equalityPartition = inheritedInference.generateEqualitiesPartitionedBy( in(node.getSource().getOutputSymbols())); sourceConjuncts.addAll(equalityPartition.getScopeEqualities()); postJoinConjuncts.addAll(equalityPartition.getScopeComplementEqualities()); postJoinConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities()); PlanNode rewrittenSource = planRewriter.rewrite(node.getSource(), combineConjuncts(sourceConjuncts)); PlanNode rewrittenFilteringSource = planRewriter.rewrite( node.getFilteringSource(), combineConjuncts(filteringSourceConjuncts)); PlanNode output = node; if (rewrittenSource != node.getSource() || rewrittenFilteringSource != node.getFilteringSource()) { output = new SemiJoinNode( node.getId(), rewrittenSource, rewrittenFilteringSource, node.getSourceJoinSymbol(), node.getFilteringSourceJoinSymbol(), node.getSemiJoinOutput()); } if (!postJoinConjuncts.isEmpty()) { output = new FilterNode(idAllocator.getNextId(), output, combineConjuncts(postJoinConjuncts)); } return output; } @Override public PlanNode rewriteAggregation( AggregationNode node, Expression inheritedPredicate, PlanRewriter<Expression> planRewriter) { EqualityInference equalityInference = createEqualityInference(inheritedPredicate); List<Expression> pushdownConjuncts = new ArrayList<>(); List<Expression> postAggregationConjuncts = new ArrayList<>(); // Strip out non-deterministic conjuncts postAggregationConjuncts.addAll( ImmutableList.copyOf(filter(extractConjuncts(inheritedPredicate), not(deterministic())))); inheritedPredicate = stripNonDeterministicConjuncts(inheritedPredicate); // Sort non-equality predicates by those that can be pushed down and those that cannot for (Expression conjunct : EqualityInference.nonInferrableConjuncts(inheritedPredicate)) { Expression rewrittenConjunct = equalityInference.rewriteExpression(conjunct, in(node.getGroupBy())); if (rewrittenConjunct != null) { pushdownConjuncts.add(rewrittenConjunct); } else { postAggregationConjuncts.add(conjunct); } } // Add the equality predicates back in EqualityInference.EqualityPartition equalityPartition = equalityInference.generateEqualitiesPartitionedBy(in(node.getGroupBy())); pushdownConjuncts.addAll(equalityPartition.getScopeEqualities()); postAggregationConjuncts.addAll(equalityPartition.getScopeComplementEqualities()); postAggregationConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities()); PlanNode rewrittenSource = planRewriter.rewrite(node.getSource(), combineConjuncts(pushdownConjuncts)); PlanNode output = node; if (rewrittenSource != node.getSource()) { output = new AggregationNode( node.getId(), rewrittenSource, node.getGroupBy(), node.getAggregations(), node.getFunctions(), node.getMasks(), node.getStep(), node.getSampleWeight(), node.getConfidence()); } if (!postAggregationConjuncts.isEmpty()) { output = new FilterNode( idAllocator.getNextId(), output, combineConjuncts(postAggregationConjuncts)); } return output; } @Override public PlanNode rewriteSample( SampleNode node, Expression inheritedPredicate, PlanRewriter<Expression> planRewriter) { return planRewriter.defaultRewrite(node, inheritedPredicate); } @Override public PlanNode rewriteTableScan( TableScanNode node, Expression inheritedPredicate, PlanRewriter<Expression> planRewriter) { DomainTranslator.ExtractionResult extractionResult = DomainTranslator.fromPredicate( inheritedPredicate, symbolAllocator.getTypes(), node.getAssignments()); Expression extractionRemainingExpression = extractionResult.getRemainingExpression(); TupleDomain tupleDomain = extractionResult.getTupleDomain(); if (node.getGeneratedPartitions().isPresent()) { // Add back in the TupleDomain that was used to generate the previous set of Partitions if // present // And just for kicks, throw in the domain summary too (as that can only help prune down the // ranges) // The domains should never widen between each pass. tupleDomain = tupleDomain .intersect(node.getGeneratedPartitions().get().getTupleDomainInput()) .intersect(node.getPartitionsDomainSummary()); } PartitionResult matchingPartitions = splitManager.getPartitions(node.getTable(), Optional.of(tupleDomain)); List<Partition> partitions = matchingPartitions.getPartitions(); TupleDomain undeterminedTupleDomain = matchingPartitions.getUndeterminedTupleDomain(); Expression unevaluatedDomainPredicate = DomainTranslator.toPredicate( undeterminedTupleDomain, ImmutableBiMap.copyOf(node.getAssignments()).inverse()); // Construct the post scan predicate. Add the unevaluated TupleDomain back first since those // are generally cheaper to evaluate than anything we can't extract Expression postScanPredicate = combineConjuncts(unevaluatedDomainPredicate, extractionRemainingExpression); // Do some early partition pruning partitions = ImmutableList.copyOf( filter( partitions, not(shouldPrunePartition(postScanPredicate, node.getAssignments())))); GeneratedPartitions generatedPartitions = new GeneratedPartitions(tupleDomain, partitions); PlanNode output = node; if (!node.getGeneratedPartitions().equals(Optional.of(generatedPartitions))) { // Only overwrite the originalConstraint if it was previously null Expression originalConstraint = node.getOriginalConstraint() == null ? inheritedPredicate : node.getOriginalConstraint(); output = new TableScanNode( node.getId(), node.getTable(), node.getOutputSymbols(), node.getAssignments(), originalConstraint, Optional.of(generatedPartitions)); } if (!postScanPredicate.equals(BooleanLiteral.TRUE_LITERAL)) { output = new FilterNode(idAllocator.getNextId(), output, postScanPredicate); } return output; } private Predicate<Partition> shouldPrunePartition( final Expression predicate, final Map<Symbol, ColumnHandle> symbolToColumn) { return new Predicate<Partition>() { @Override public boolean apply(Partition partition) { Map<ColumnHandle, Comparable<?>> columnFixedValueAssignments = partition.getTupleDomain().extractFixedValues(); Map<ColumnHandle, Comparable<?>> translatableAssignments = Maps.filterKeys(columnFixedValueAssignments, in(symbolToColumn.values())); Map<Symbol, Comparable<?>> symbolFixedValueAssignments = DomainUtils.columnHandleToSymbol(translatableAssignments, symbolToColumn); LookupSymbolResolver inputs = new LookupSymbolResolver( ImmutableMap.<Symbol, Object>copyOf(symbolFixedValueAssignments)); // If any conjuncts evaluate to FALSE or null, then the whole predicate will never be true // and so the partition should be pruned for (Expression expression : extractConjuncts(predicate)) { ExpressionInterpreter optimizer = ExpressionInterpreter.expressionOptimizer(expression, metadata, session); Object optimized = optimizer.optimize(inputs); if (Boolean.FALSE.equals(optimized) || optimized == null || optimized instanceof NullLiteral) { return true; } } return false; } }; } } }
// // NOTE: As a general strategy the methods should "stage" a change and only // process the actual change before lock release (DriverLockResult.close()). // The assures that only one thread will be working with the operators at a // time and state changer threads are not blocked. // public class Driver { private static final Logger log = Logger.get(Driver.class); private final DriverContext driverContext; private final List<Operator> operators; private final Map<PlanNodeId, SourceOperator> sourceOperators; private final ConcurrentMap<PlanNodeId, TaskSource> newSources = new ConcurrentHashMap<>(); private final AtomicReference<State> state = new AtomicReference<>(State.ALIVE); private final ReentrantLock exclusiveLock = new ReentrantLock(); @GuardedBy("this") private Thread lockHolder; @GuardedBy("exclusiveLock") private Map<PlanNodeId, TaskSource> currentSources = new ConcurrentHashMap<>(); private enum State { ALIVE, NEED_DESTRUCTION, DESTROYED } public Driver(DriverContext driverContext, Operator firstOperator, Operator... otherOperators) { this( checkNotNull(driverContext, "driverContext is null"), ImmutableList.<Operator>builder() .add(checkNotNull(firstOperator, "firstOperator is null")) .add(checkNotNull(otherOperators, "otherOperators is null")) .build()); } public Driver(DriverContext driverContext, List<Operator> operators) { this.driverContext = checkNotNull(driverContext, "driverContext is null"); this.operators = ImmutableList.copyOf(checkNotNull(operators, "operators is null")); checkArgument(!operators.isEmpty(), "There must be at least one operator"); ImmutableMap.Builder<PlanNodeId, SourceOperator> sourceOperators = ImmutableMap.builder(); for (Operator operator : operators) { if (operator instanceof SourceOperator) { SourceOperator sourceOperator = (SourceOperator) operator; sourceOperators.put(sourceOperator.getSourceId(), sourceOperator); } } this.sourceOperators = sourceOperators.build(); } public DriverContext getDriverContext() { return driverContext; } public Set<PlanNodeId> getSourceIds() { return sourceOperators.keySet(); } public void close() { // mark the service for destruction if (!state.compareAndSet(State.ALIVE, State.NEED_DESTRUCTION)) { return; } // if we can get the lock, attempt a clean shutdown; otherwise someone else will shutdown try (DriverLockResult lockResult = tryLockAndProcessPendingStateChanges(0, TimeUnit.MILLISECONDS)) { // if we did not get the lock, interrupt the lock holder if (!lockResult.wasAcquired()) { // there is a benign race condition here were the lock holder // can be change between attempting to get lock and grabbing // the synchronized lock here, but in either case we want to // interrupt the lock holder thread synchronized (this) { if (lockHolder != null) { lockHolder.interrupt(); } } } // clean shutdown is automatically triggered during lock release } } public boolean isFinished() { checkLockNotHeld("Can not check finished status while holding the driver lock"); // if we can get the lock, attempt a clean shutdown; otherwise someone else will shutdown try (DriverLockResult lockResult = tryLockAndProcessPendingStateChanges(0, TimeUnit.MILLISECONDS)) { if (lockResult.wasAcquired()) { boolean finished = state.get() != State.ALIVE || driverContext.isDone() || operators.get(operators.size() - 1).isFinished(); if (finished) { state.compareAndSet(State.ALIVE, State.NEED_DESTRUCTION); } return finished; } else { // did not get the lock, so we can't check operators, or destroy return state.get() != State.ALIVE || driverContext.isDone(); } } } public void updateSource(TaskSource source) { checkLockNotHeld("Can not update sources while holding the driver lock"); // does this driver have an operator for the specified source? if (!sourceOperators.containsKey(source.getPlanNodeId())) { return; } // stage the new updates while (true) { // attempt to update directly to the new source TaskSource currentNewSource = newSources.putIfAbsent(source.getPlanNodeId(), source); // if update succeeded, just break if (currentNewSource == null) { break; } // merge source into the current new source TaskSource newSource = currentNewSource.update(source); // if this is not a new source, just return if (newSource == currentNewSource) { break; } // attempt to replace the currentNewSource with the new source if (newSources.replace(source.getPlanNodeId(), currentNewSource, newSource)) { break; } // someone else updated while we were processing } // attempt to get the lock and process the updates we staged above // updates will be processed in close if and only if we got the lock tryLockAndProcessPendingStateChanges(0, TimeUnit.MILLISECONDS).close(); } private void processNewSources() { checkLockHeld("Lock must be held to call processNewSources"); // only update if the driver is still alive if (state.get() != State.ALIVE) { return; } // copy the pending sources // it is ok to "miss" a source added during the copy as it will be // handled on the next call to this method Map<PlanNodeId, TaskSource> sources = new HashMap<>(newSources); for (Entry<PlanNodeId, TaskSource> entry : sources.entrySet()) { // Remove the entries we are going to process from the newSources map. // It is ok if someone already updated the entry; we will catch it on // the next iteration. newSources.remove(entry.getKey(), entry.getValue()); processNewSource(entry.getValue()); } } private void processNewSource(TaskSource source) { checkLockHeld("Lock must be held to call processNewSources"); // create new source Set<ScheduledSplit> newSplits; TaskSource currentSource = currentSources.get(source.getPlanNodeId()); if (currentSource == null) { newSplits = source.getSplits(); currentSources.put(source.getPlanNodeId(), source); } else { // merge the current source and the specified source TaskSource newSource = currentSource.update(source); // if this is not a new source, just return if (newSource == currentSource) { return; } // find the new splits to add newSplits = Sets.difference(newSource.getSplits(), currentSource.getSplits()); currentSources.put(source.getPlanNodeId(), newSource); } // add new splits for (ScheduledSplit newSplit : newSplits) { Split split = newSplit.getSplit(); SourceOperator sourceOperator = sourceOperators.get(source.getPlanNodeId()); if (sourceOperator != null) { sourceOperator.addSplit(split); } } // set no more splits if (source.isNoMoreSplits()) { sourceOperators.get(source.getPlanNodeId()).noMoreSplits(); } } public ListenableFuture<?> processFor(Duration duration) { checkLockNotHeld("Can not process for a duration while holding the driver lock"); checkNotNull(duration, "duration is null"); long maxRuntime = duration.roundTo(TimeUnit.NANOSECONDS); long start = System.nanoTime(); do { ListenableFuture<?> future = process(); if (!future.isDone()) { return future; } } while (System.nanoTime() - start < maxRuntime && !isFinished()); return NOT_BLOCKED; } public ListenableFuture<?> process() { checkLockNotHeld("Can not process while holding the driver lock"); try (DriverLockResult lockResult = tryLockAndProcessPendingStateChanges(100, TimeUnit.MILLISECONDS)) { try { if (!lockResult.wasAcquired()) { // this is unlikely to happen unless the driver is being // destroyed and in that case the caller should notice notice // this state change by calling isFinished return NOT_BLOCKED; } driverContext.start(); if (!newSources.isEmpty()) { processNewSources(); } for (int i = 0; i < operators.size() - 1 && !driverContext.isDone(); i++) { // check if current operator is blocked Operator current = operators.get(i); ListenableFuture<?> blocked = current.isBlocked(); if (!blocked.isDone()) { current.getOperatorContext().recordBlocked(blocked); return blocked; } // check if next operator is blocked Operator next = operators.get(i + 1); blocked = next.isBlocked(); if (!blocked.isDone()) { next.getOperatorContext().recordBlocked(blocked); return blocked; } // if current operator is finished... if (current.isFinished()) { // let next operator know there will be no more data next.getOperatorContext().startIntervalTimer(); next.finish(); next.getOperatorContext().recordFinish(); } else { // if next operator needs input... if (next.needsInput()) { // get an output page from current operator current.getOperatorContext().startIntervalTimer(); Page page = current.getOutput(); current.getOperatorContext().recordGetOutput(page); // if we got an output page, add it to the next operator if (page != null) { next.getOperatorContext().startIntervalTimer(); next.addInput(page); next.getOperatorContext().recordAddInput(page); } } } } return NOT_BLOCKED; } catch (Throwable t) { driverContext.failed(t); throw t; } } } private void destroyIfNecessary() { checkLockHeld("Lock must be held to call destroyIfNecessary"); if (!state.compareAndSet(State.NEED_DESTRUCTION, State.DESTROYED)) { return; } Throwable inFlightException = null; try { // call finish on every operator; if error occurs, just bail out; we will still call close for (Operator operator : operators) { operator.finish(); } } catch (Throwable t) { // record in flight exception so we can add suppressed exceptions below inFlightException = t; throw t; } finally { // record the current interrupted status (and clear the flag); we'll reset it later boolean wasInterrupted = Thread.interrupted(); // if we get an error while closing a driver, record it and we will throw it at the end try { for (Operator operator : operators) { if (operator instanceof AutoCloseable) { try { ((AutoCloseable) operator).close(); } catch (InterruptedException t) { // don't record the stack wasInterrupted = true; } catch (Throwable t) { inFlightException = addSuppressedException( inFlightException, t, "Error closing operator %s for task %s", operator.getOperatorContext().getOperatorId(), driverContext.getTaskId()); } } } driverContext.finished(); } catch (Throwable t) { // this shouldn't happen but be safe inFlightException = addSuppressedException( inFlightException, t, "Error destroying driver for task %s", driverContext.getTaskId()); } finally { // reset the interrupted flag if (wasInterrupted) { Thread.currentThread().interrupt(); } } if (inFlightException != null) { // this will always be an Error or Runtime throw Throwables.propagate(inFlightException); } } } private Throwable addSuppressedException( Throwable inFlightException, Throwable newException, String message, Object... args) { if (newException instanceof Error) { if (inFlightException == null) { inFlightException = newException; } else { inFlightException.addSuppressed(newException); } } else { // log normal exceptions instead of rethrowing them log.error(newException, message, args); } return inFlightException; } private DriverLockResult tryLockAndProcessPendingStateChanges(int timeout, TimeUnit unit) { checkLockNotHeld("Can not acquire the driver lock while already holding the driver lock"); return new DriverLockResult(timeout, unit); } private synchronized void checkLockNotHeld(String message) { checkState(Thread.currentThread() != lockHolder, message); } private synchronized void checkLockHeld(String message) { checkState(Thread.currentThread() == lockHolder, message); } private class DriverLockResult implements AutoCloseable { private final boolean acquired; private DriverLockResult(int timeout, TimeUnit unit) { boolean acquired = false; try { acquired = exclusiveLock.tryLock(timeout, unit); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } this.acquired = acquired; if (acquired) { synchronized (Driver.this) { lockHolder = Thread.currentThread(); } } } public boolean wasAcquired() { return acquired; } @Override public void close() { if (!acquired) { return; } // before releasing the lock, process any new sources and/or destroy the driver try { try { processNewSources(); } finally { destroyIfNecessary(); } } finally { synchronized (Driver.this) { lockHolder = null; } exclusiveLock.unlock(); } } } }
public class ClusterMemoryManager implements ClusterMemoryPoolManager { private static final Logger log = Logger.get(ClusterMemoryManager.class); private final ExecutorService listenerExecutor = Executors.newSingleThreadExecutor(); private final NodeManager nodeManager; private final LocationFactory locationFactory; private final HttpClient httpClient; private final MBeanExporter exporter; private final JsonCodec<MemoryInfo> memoryInfoCodec; private final JsonCodec<MemoryPoolAssignmentsRequest> assignmentsRequestJsonCodec; private final DataSize maxQueryMemory; private final Duration maxQueryCpuTime; private final boolean enabled; private final boolean killOnOutOfMemory; private final Duration killOnOutOfMemoryDelay; private final String coordinatorId; private final AtomicLong memoryPoolAssignmentsVersion = new AtomicLong(); private final AtomicLong clusterMemoryUsageBytes = new AtomicLong(); private final AtomicLong clusterMemoryBytes = new AtomicLong(); private final AtomicLong queriesKilledDueToOutOfMemory = new AtomicLong(); private final Map<String, RemoteNodeMemory> nodes = new HashMap<>(); @GuardedBy("this") private final Map<MemoryPoolId, List<Consumer<MemoryPoolInfo>>> changeListeners = new HashMap<>(); @GuardedBy("this") private final Map<MemoryPoolId, ClusterMemoryPool> pools = new HashMap<>(); @GuardedBy("this") private long lastTimeNotOutOfMemory = System.nanoTime(); @GuardedBy("this") private QueryId lastKilledQuery; @Inject public ClusterMemoryManager( @ForMemoryManager HttpClient httpClient, NodeManager nodeManager, LocationFactory locationFactory, MBeanExporter exporter, JsonCodec<MemoryInfo> memoryInfoCodec, JsonCodec<MemoryPoolAssignmentsRequest> assignmentsRequestJsonCodec, QueryIdGenerator queryIdGenerator, ServerConfig serverConfig, MemoryManagerConfig config, QueryManagerConfig queryManagerConfig) { requireNonNull(config, "config is null"); this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); this.locationFactory = requireNonNull(locationFactory, "locationFactory is null"); this.httpClient = requireNonNull(httpClient, "httpClient is null"); this.exporter = requireNonNull(exporter, "exporter is null"); this.memoryInfoCodec = requireNonNull(memoryInfoCodec, "memoryInfoCodec is null"); this.assignmentsRequestJsonCodec = requireNonNull(assignmentsRequestJsonCodec, "assignmentsRequestJsonCodec is null"); this.maxQueryMemory = config.getMaxQueryMemory(); this.maxQueryCpuTime = queryManagerConfig.getQueryMaxCpuTime(); this.coordinatorId = queryIdGenerator.getCoordinatorId(); this.enabled = serverConfig.isCoordinator(); this.killOnOutOfMemoryDelay = config.getKillOnOutOfMemoryDelay(); this.killOnOutOfMemory = config.isKillOnOutOfMemory(); } @Override public synchronized void addChangeListener( MemoryPoolId poolId, Consumer<MemoryPoolInfo> listener) { changeListeners.computeIfAbsent(poolId, id -> new ArrayList<>()).add(listener); } public synchronized void process(Iterable<QueryExecution> queries) { if (!enabled) { return; } boolean outOfMemory = isClusterOutOfMemory(); if (!outOfMemory) { lastTimeNotOutOfMemory = System.nanoTime(); } boolean queryKilled = false; long totalBytes = 0; for (QueryExecution query : queries) { long bytes = query.getTotalMemoryReservation(); DataSize sessionMaxQueryMemory = getQueryMaxMemory(query.getSession()); long queryMemoryLimit = Math.min(maxQueryMemory.toBytes(), sessionMaxQueryMemory.toBytes()); totalBytes += bytes; if (resourceOvercommit(query.getSession()) && outOfMemory) { // If a query has requested resource overcommit, only kill it if the cluster has run out of // memory DataSize memory = succinctBytes(bytes); query.fail( new PrestoException( CLUSTER_OUT_OF_MEMORY, format( "The cluster is out of memory and %s=true, so this query was killed. It was using %s of memory", RESOURCE_OVERCOMMIT, memory))); queryKilled = true; } if (!resourceOvercommit(query.getSession()) && bytes > queryMemoryLimit) { DataSize maxMemory = succinctBytes(queryMemoryLimit); query.fail(exceededGlobalLimit(maxMemory)); queryKilled = true; } } clusterMemoryUsageBytes.set(totalBytes); if (killOnOutOfMemory) { boolean shouldKillQuery = nanosSince(lastTimeNotOutOfMemory).compareTo(killOnOutOfMemoryDelay) > 0 && outOfMemory; boolean lastKilledQueryIsGone = (lastKilledQuery == null); if (!lastKilledQueryIsGone) { ClusterMemoryPool generalPool = pools.get(GENERAL_POOL); if (generalPool != null) { lastKilledQueryIsGone = generalPool.getQueryMemoryReservations().containsKey(lastKilledQuery); } } if (shouldKillQuery && lastKilledQueryIsGone && !queryKilled) { // Kill the biggest query in the general pool QueryExecution biggestQuery = null; long maxMemory = -1; for (QueryExecution query : queries) { long bytesUsed = query.getTotalMemoryReservation(); if (bytesUsed > maxMemory && query.getMemoryPool().getId().equals(GENERAL_POOL)) { biggestQuery = query; maxMemory = bytesUsed; } } if (biggestQuery != null) { biggestQuery.fail( new PrestoException( CLUSTER_OUT_OF_MEMORY, "The cluster is out of memory, and your query was killed. Please try again in a few minutes.")); queriesKilledDueToOutOfMemory.incrementAndGet(); lastKilledQuery = biggestQuery.getQueryId(); } } } Map<MemoryPoolId, Integer> countByPool = new HashMap<>(); for (QueryExecution query : queries) { MemoryPoolId id = query.getMemoryPool().getId(); countByPool.put(id, countByPool.getOrDefault(id, 0) + 1); } updatePools(countByPool); updateNodes(updateAssignments(queries)); // check if CPU usage is over limit for (QueryExecution query : queries) { Duration cpuTime = query.getTotalCpuTime(); Duration sessionLimit = getQueryMaxCpuTime(query.getSession()); Duration limit = maxQueryCpuTime.compareTo(sessionLimit) < 0 ? maxQueryCpuTime : sessionLimit; if (cpuTime.compareTo(limit) > 0) { query.fail(new ExceededCpuLimitException(limit)); } } } @VisibleForTesting synchronized Map<MemoryPoolId, ClusterMemoryPool> getPools() { return ImmutableMap.copyOf(pools); } private synchronized boolean isClusterOutOfMemory() { ClusterMemoryPool reservedPool = pools.get(RESERVED_POOL); ClusterMemoryPool generalPool = pools.get(GENERAL_POOL); return reservedPool != null && generalPool != null && reservedPool.getAssignedQueries() > 0 && generalPool.getBlockedNodes() > 0; } private synchronized MemoryPoolAssignmentsRequest updateAssignments( Iterable<QueryExecution> queries) { ClusterMemoryPool reservedPool = pools.get(RESERVED_POOL); ClusterMemoryPool generalPool = pools.get(GENERAL_POOL); long version = memoryPoolAssignmentsVersion.incrementAndGet(); // Check that all previous assignments have propagated to the visible nodes. This doesn't // account for temporary network issues, // and is more of a safety check than a guarantee if (reservedPool != null && generalPool != null && allAssignmentsHavePropagated(queries)) { if (reservedPool.getAssignedQueries() == 0 && generalPool.getBlockedNodes() > 0) { QueryExecution biggestQuery = null; long maxMemory = -1; for (QueryExecution queryExecution : queries) { if (resourceOvercommit(queryExecution.getSession())) { // Don't promote queries that requested resource overcommit to the reserved pool, // since their memory usage is unbounded. continue; } long bytesUsed = queryExecution.getTotalMemoryReservation(); if (bytesUsed > maxMemory) { biggestQuery = queryExecution; maxMemory = bytesUsed; } } if (biggestQuery != null) { biggestQuery.setMemoryPool(new VersionedMemoryPoolId(RESERVED_POOL, version)); } } } ImmutableList.Builder<MemoryPoolAssignment> assignments = ImmutableList.builder(); for (QueryExecution queryExecution : queries) { assignments.add( new MemoryPoolAssignment( queryExecution.getQueryId(), queryExecution.getMemoryPool().getId())); } return new MemoryPoolAssignmentsRequest(coordinatorId, version, assignments.build()); } private boolean allAssignmentsHavePropagated(Iterable<QueryExecution> queries) { if (nodes.isEmpty()) { // Assignments can't have propagated, if there are no visible nodes. return false; } long newestAssignment = ImmutableList.copyOf(queries) .stream() .map(QueryExecution::getMemoryPool) .mapToLong(VersionedMemoryPoolId::getVersion) .min() .orElse(-1); long mostOutOfDateNode = nodes .values() .stream() .mapToLong(RemoteNodeMemory::getCurrentAssignmentVersion) .min() .orElse(Long.MAX_VALUE); return newestAssignment <= mostOutOfDateNode; } private void updateNodes(MemoryPoolAssignmentsRequest assignments) { ImmutableSet.Builder<Node> builder = ImmutableSet.builder(); Set<Node> aliveNodes = builder .addAll(nodeManager.getNodes(ACTIVE)) .addAll(nodeManager.getNodes(SHUTTING_DOWN)) .build(); ImmutableSet<String> aliveNodeIds = aliveNodes.stream().map(Node::getNodeIdentifier).collect(toImmutableSet()); // Remove nodes that don't exist anymore // Make a copy to materialize the set difference Set<String> deadNodes = ImmutableSet.copyOf(difference(nodes.keySet(), aliveNodeIds)); nodes.keySet().removeAll(deadNodes); // Add new nodes for (Node node : aliveNodes) { if (!nodes.containsKey(node.getNodeIdentifier())) { nodes.put( node.getNodeIdentifier(), new RemoteNodeMemory( httpClient, memoryInfoCodec, assignmentsRequestJsonCodec, locationFactory.createMemoryInfoLocation(node))); } } // Schedule refresh for (RemoteNodeMemory node : nodes.values()) { node.asyncRefresh(assignments); } } private synchronized void updatePools(Map<MemoryPoolId, Integer> queryCounts) { // Update view of cluster memory and pools List<MemoryInfo> nodeMemoryInfos = nodes .values() .stream() .map(RemoteNodeMemory::getInfo) .filter(Optional::isPresent) .map(Optional::get) .collect(toImmutableList()); long totalClusterMemory = nodeMemoryInfos .stream() .map(MemoryInfo::getTotalNodeMemory) .mapToLong(DataSize::toBytes) .sum(); clusterMemoryBytes.set(totalClusterMemory); Set<MemoryPoolId> activePoolIds = nodeMemoryInfos .stream() .flatMap(info -> info.getPools().keySet().stream()) .collect(toImmutableSet()); // Make a copy to materialize the set difference Set<MemoryPoolId> removedPools = ImmutableSet.copyOf(difference(pools.keySet(), activePoolIds)); for (MemoryPoolId removed : removedPools) { unexport(pools.get(removed)); pools.remove(removed); if (changeListeners.containsKey(removed)) { for (Consumer<MemoryPoolInfo> listener : changeListeners.get(removed)) { listenerExecutor.execute( () -> listener.accept(new MemoryPoolInfo(0, 0, ImmutableMap.of()))); } } } for (MemoryPoolId id : activePoolIds) { ClusterMemoryPool pool = pools.computeIfAbsent( id, poolId -> { ClusterMemoryPool newPool = new ClusterMemoryPool(poolId); String objectName = ObjectNames.builder(ClusterMemoryPool.class, newPool.getId().toString()) .build(); try { exporter.export(objectName, newPool); } catch (JmxException e) { log.error(e, "Error exporting memory pool %s", poolId); } return newPool; }); pool.update(nodeMemoryInfos, queryCounts.getOrDefault(pool.getId(), 0)); if (changeListeners.containsKey(id)) { MemoryPoolInfo info = pool.getInfo(); for (Consumer<MemoryPoolInfo> listener : changeListeners.get(id)) { listenerExecutor.execute(() -> listener.accept(info)); } } } } @PreDestroy public synchronized void destroy() { try { for (ClusterMemoryPool pool : pools.values()) { unexport(pool); } pools.clear(); } finally { listenerExecutor.shutdownNow(); } } private void unexport(ClusterMemoryPool pool) { try { String objectName = ObjectNames.builder(ClusterMemoryPool.class, pool.getId().toString()).build(); exporter.unexport(objectName); } catch (JmxException e) { log.error(e, "Failed to unexport pool %s", pool.getId()); } } @Managed public long getClusterMemoryUsageBytes() { return clusterMemoryUsageBytes.get(); } @Managed public long getClusterMemoryBytes() { return clusterMemoryBytes.get(); } @Managed public long getQueriesKilledDueToOutOfMemory() { return queriesKilledDueToOutOfMemory.get(); } }
public class ClusterMemoryManager { private static final Logger log = Logger.get(ClusterMemoryManager.class); private final NodeManager nodeManager; private final LocationFactory locationFactory; private final HttpClient httpClient; private final MBeanExporter exporter; private final JsonCodec<MemoryInfo> memoryInfoCodec; private final JsonCodec<MemoryPoolAssignmentsRequest> assignmentsRequestJsonCodec; private final DataSize maxQueryMemory; private final boolean enabled; private final String coordinatorId; private final AtomicLong memoryPoolAssignmentsVersion = new AtomicLong(); private final AtomicLong clusterMemoryUsageBytes = new AtomicLong(); private final AtomicLong clusterMemoryBytes = new AtomicLong(); private final Map<String, RemoteNodeMemory> nodes = new HashMap<>(); @GuardedBy("this") private final Map<MemoryPoolId, ClusterMemoryPool> pools = new HashMap<>(); @Inject public ClusterMemoryManager( @ForMemoryManager HttpClient httpClient, NodeManager nodeManager, LocationFactory locationFactory, MBeanExporter exporter, JsonCodec<MemoryInfo> memoryInfoCodec, JsonCodec<MemoryPoolAssignmentsRequest> assignmentsRequestJsonCodec, QueryIdGenerator queryIdGenerator, ServerConfig serverConfig, MemoryManagerConfig config) { requireNonNull(config, "config is null"); this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); this.locationFactory = requireNonNull(locationFactory, "locationFactory is null"); this.httpClient = requireNonNull(httpClient, "httpClient is null"); this.exporter = requireNonNull(exporter, "exporter is null"); this.memoryInfoCodec = requireNonNull(memoryInfoCodec, "memoryInfoCodec is null"); this.assignmentsRequestJsonCodec = requireNonNull(assignmentsRequestJsonCodec, "assignmentsRequestJsonCodec is null"); this.maxQueryMemory = config.getMaxQueryMemory(); this.coordinatorId = queryIdGenerator.getCoordinatorId(); this.enabled = config.isClusterMemoryManagerEnabled() && serverConfig.isCoordinator(); } public void process(Iterable<QueryExecution> queries) { if (!enabled) { return; } long totalBytes = 0; for (QueryExecution query : queries) { long bytes = query.getTotalMemoryReservation(); DataSize sessionMaxQueryMemory = getQueryMaxMemory(query.getSession()); long queryMemoryLimit = Math.min(maxQueryMemory.toBytes(), sessionMaxQueryMemory.toBytes()); totalBytes += bytes; if (bytes > queryMemoryLimit) { query.fail( new ExceededMemoryLimitException( "Query", DataSize.succinctDataSize(queryMemoryLimit, Unit.BYTE))); } } clusterMemoryUsageBytes.set(totalBytes); Map<MemoryPoolId, Integer> countByPool = new HashMap<>(); for (QueryExecution query : queries) { MemoryPoolId id = query.getMemoryPool().getId(); countByPool.put(id, countByPool.getOrDefault(id, 0) + 1); } updatePools(countByPool); updateNodes(updateAssignments(queries)); } @VisibleForTesting synchronized Map<MemoryPoolId, ClusterMemoryPool> getPools() { return ImmutableMap.copyOf(pools); } private MemoryPoolAssignmentsRequest updateAssignments(Iterable<QueryExecution> queries) { ClusterMemoryPool reservedPool = pools.get(RESERVED_POOL); ClusterMemoryPool generalPool = pools.get(GENERAL_POOL); long version = memoryPoolAssignmentsVersion.incrementAndGet(); // Check that all previous assignments have propagated to the visible nodes. This doesn't // account for temporary network issues, // and is more of a safety check than a guarantee if (reservedPool != null && generalPool != null && allAssignmentsHavePropagated(queries)) { if (reservedPool.getQueries() == 0 && generalPool.getBlockedNodes() > 0) { QueryExecution biggestQuery = null; long maxMemory = -1; for (QueryExecution queryExecution : queries) { long bytesUsed = queryExecution.getTotalMemoryReservation(); if (bytesUsed > maxMemory) { biggestQuery = queryExecution; maxMemory = bytesUsed; } } for (QueryExecution queryExecution : queries) { if (queryExecution.getQueryId().equals(biggestQuery.getQueryId())) { queryExecution.setMemoryPool(new VersionedMemoryPoolId(RESERVED_POOL, version)); } } } } ImmutableList.Builder<MemoryPoolAssignment> assignments = ImmutableList.builder(); for (QueryExecution queryExecution : queries) { assignments.add( new MemoryPoolAssignment( queryExecution.getQueryId(), queryExecution.getMemoryPool().getId())); } return new MemoryPoolAssignmentsRequest(coordinatorId, version, assignments.build()); } private boolean allAssignmentsHavePropagated(Iterable<QueryExecution> queries) { if (nodes.isEmpty()) { // Assignments can't have propagated, if there are no visible nodes. return false; } long newestAssignment = ImmutableList.copyOf(queries) .stream() .map(QueryExecution::getMemoryPool) .mapToLong(VersionedMemoryPoolId::getVersion) .min() .orElse(-1); long mostOutOfDateNode = nodes .values() .stream() .mapToLong(RemoteNodeMemory::getCurrentAssignmentVersion) .min() .orElse(Long.MAX_VALUE); return newestAssignment <= mostOutOfDateNode; } private void updateNodes(MemoryPoolAssignmentsRequest assignments) { Set<Node> activeNodes = nodeManager.getActiveNodes(); ImmutableSet<String> activeNodeIds = activeNodes.stream().map(Node::getNodeIdentifier).collect(toImmutableSet()); // Remove nodes that don't exist anymore // Make a copy to materialize the set difference Set<String> deadNodes = ImmutableSet.copyOf(difference(nodes.keySet(), activeNodeIds)); nodes.keySet().removeAll(deadNodes); // Add new nodes for (Node node : activeNodes) { if (!nodes.containsKey(node.getNodeIdentifier())) { nodes.put( node.getNodeIdentifier(), new RemoteNodeMemory( httpClient, memoryInfoCodec, assignmentsRequestJsonCodec, locationFactory.createMemoryInfoLocation(node))); } } // Schedule refresh for (RemoteNodeMemory node : nodes.values()) { node.asyncRefresh(assignments); } } private synchronized void updatePools(Map<MemoryPoolId, Integer> queryCounts) { // Update view of cluster memory and pools List<MemoryInfo> nodeMemoryInfos = nodes .values() .stream() .map(RemoteNodeMemory::getInfo) .filter(Optional::isPresent) .map(Optional::get) .collect(toImmutableList()); long totalClusterMemory = nodeMemoryInfos .stream() .map(MemoryInfo::getTotalNodeMemory) .mapToLong(DataSize::toBytes) .sum(); clusterMemoryBytes.set(totalClusterMemory); Set<MemoryPoolId> activePoolIds = nodeMemoryInfos .stream() .flatMap(info -> info.getPools().keySet().stream()) .collect(toImmutableSet()); // Make a copy to materialize the set difference Set<MemoryPoolId> removedPools = ImmutableSet.copyOf(difference(pools.keySet(), activePoolIds)); for (MemoryPoolId removed : removedPools) { unexport(pools.get(removed)); pools.remove(removed); } for (MemoryPoolId id : activePoolIds) { ClusterMemoryPool pool = pools.computeIfAbsent( id, poolId -> { ClusterMemoryPool newPool = new ClusterMemoryPool(poolId); String objectName = ObjectNames.builder(ClusterMemoryPool.class, newPool.getId().toString()) .build(); try { exporter.export(objectName, newPool); } catch (JmxException e) { log.error(e, "Error exporting memory pool %s", poolId); } return newPool; }); pool.update(nodeMemoryInfos, queryCounts.getOrDefault(pool.getId(), 0)); } } @PreDestroy public synchronized void destroy() { for (ClusterMemoryPool pool : pools.values()) { unexport(pool); } pools.clear(); } private void unexport(ClusterMemoryPool pool) { try { String objectName = ObjectNames.builder(ClusterMemoryPool.class, pool.getId().toString()).build(); exporter.unexport(objectName); } catch (JmxException e) { log.error(e, "Failed to unexport pool %s", pool.getId()); } } @Managed public long getClusterMemoryUsageBytes() { return clusterMemoryUsageBytes.get(); } @Managed public long getClusterMemoryBytes() { return clusterMemoryBytes.get(); } }
public class ShutdownUtil { private static final Logger log = Logger.get(ShutdownUtil.class); public static void shutdownChannelFactory( ChannelFactory channelFactory, ExecutorService bossExecutor, ExecutorService workerExecutor, ChannelGroup allChannels) { // Close all channels if (allChannels != null) { closeChannels(allChannels); } // Shutdown the channel factory if (channelFactory != null) { channelFactory.shutdown(); } // Stop boss threads if (bossExecutor != null) { shutdownExecutor(bossExecutor, "bossExecutor"); } // Finally stop I/O workers if (workerExecutor != null) { shutdownExecutor(workerExecutor, "workerExecutor"); } // Release any other resources netty might be holding onto via this channelFactory if (channelFactory != null) { channelFactory.releaseExternalResources(); } } public static void closeChannels(ChannelGroup allChannels) { if (allChannels.size() > 0) { // TODO : allow an option here to control if we need to drain connections and wait instead of // killing them all try { log.info("Closing %s open client connections", allChannels.size()); if (!allChannels.close().await(5, TimeUnit.SECONDS)) { log.warn("Failed to close all open client connections"); } } catch (InterruptedException e) { log.warn("Interrupted while closing client connections"); Thread.currentThread().interrupt(); } } } // TODO : make wait time configurable ? public static void shutdownExecutor(ExecutorService executor, final String name) { executor.shutdown(); try { log.info("Waiting for %s to shutdown", name); if (!executor.awaitTermination(5, TimeUnit.SECONDS)) { log.warn("%s did not shutdown properly", name); } } catch (InterruptedException e) { log.warn("Interrupted while waiting for %s to shutdown", name); Thread.currentThread().interrupt(); } } }
final class CodeCacheGcTrigger { private static final Logger log = Logger.get(CodeCacheGcTrigger.class); private static final AtomicBoolean installed = new AtomicBoolean(); private final Duration interval; private final int collectionThreshold; @Inject public CodeCacheGcTrigger(CodeCacheGcConfig config) { this.interval = config.getCodeCacheCheckInterval(); this.collectionThreshold = config.getCodeCacheCollectionThreshold(); } @PostConstruct public void start() { installCodeCacheGcTrigger(); } public void installCodeCacheGcTrigger() { if (installed.getAndSet(true)) { return; } // Hack to work around bugs in java 8 (8u45+) related to code cache management. // See // http://openjdk.5641.n7.nabble.com/JIT-stops-compiling-after-a-while-java-8u45-td259603.html // for more info. MemoryPoolMXBean codeCacheMbean = findCodeCacheMBean(); Thread gcThread = new Thread( () -> { while (!Thread.currentThread().isInterrupted()) { long used = codeCacheMbean.getUsage().getUsed(); long max = codeCacheMbean.getUsage().getMax(); if (used > 0.95 * max) { log.error("Code Cache is more than 95% full. JIT may stop working."); } if (used > (max * collectionThreshold) / 100) { // Due to some obscure bug in hotspot (java 8), once the code cache fills up the // JIT stops compiling // By forcing a GC, we let the code cache evictor make room before the cache fills // up. log.info("Triggering GC to avoid Code Cache eviction bugs"); System.gc(); } try { TimeUnit.MILLISECONDS.sleep(interval.toMillis()); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } } }); gcThread.setDaemon(true); gcThread.setName("Code-Cache-GC-Trigger"); gcThread.start(); } private static MemoryPoolMXBean findCodeCacheMBean() { for (MemoryPoolMXBean bean : ManagementFactory.getMemoryPoolMXBeans()) { if (bean.getName().equals("Code Cache")) { return bean; } } throw new RuntimeException("Could not obtain a reference to the 'Code Cache' MemoryPoolMXBean"); } }
public class StatusPrinter { private static final Logger log = Logger.get(StatusPrinter.class); private static final int CTRL_C = 3; private static final int CTRL_P = 16; private final long start = System.nanoTime(); private final StatementClient client; private final PrintStream out; private final ConsolePrinter console; private boolean debug; public StatusPrinter(StatementClient client, PrintStream out) { this.client = client; this.out = out; this.console = new ConsolePrinter(out); this.debug = client.isDebug(); } /* Query 16, RUNNING, 1 node, 855 splits http://my.server:8080/v1/query/16?pretty Splits: 646 queued, 34 running, 175 done CPU Time: 33.7s total, 191K rows/s, 16.6MB/s, 22% active Per Node: 2.5 parallelism, 473K rows/s, 41.1MB/s Parallelism: 2.5 0:13 [6.45M rows, 560MB] [ 473K rows/s, 41.1MB/s] [=========>> ] 20% STAGES ROWS ROWS/s BYTES BYTES/s PEND RUN DONE 0.........R 13.8M 336K 1.99G 49.5M 0 1 706 1.......R 666K 41.5K 82.1M 5.12M 563 65 79 2.....R 4.58M 234K 620M 31.6M 406 65 236 */ public void printInitialStatusUpdates() { long lastPrint = System.nanoTime(); try { while (client.isValid()) { try { // exit status loop if there is there is pending output if (client.current().getData() != null) { return; } // check if time to update screen boolean update = nanosSince(lastPrint).getValue(SECONDS) >= 0.5; // check for keyboard input int key = readKey(); if (key == CTRL_P) { partialCancel(); } else if (key == CTRL_C) { updateScreen(); update = false; client.close(); } else if (toUpperCase(key) == 'D') { debug = !debug; console.resetScreen(); update = true; } // update screen if (update) { updateScreen(); lastPrint = System.nanoTime(); } // fetch next results (server will wait for a while if no data) client.advance(); } catch (RuntimeException e) { log.debug(e, "error printing status"); if (debug) { e.printStackTrace(out); } } } } finally { console.resetScreen(); } } private void updateScreen() { console.repositionCursor(); printQueryInfo(client.current()); } public void printFinalInfo() { Duration wallTime = nanosSince(start); QueryResults results = client.finalResults(); StatementStats stats = results.getStats(); int nodes = stats.getNodes(); if ((nodes == 0) || (stats.getTotalSplits() == 0)) { return; } // blank line out.println(); // Query 12, FINISHED, 1 node String querySummary = String.format( "Query %s, %s, %,d %s", results.getId(), stats.getState(), nodes, pluralize("node", nodes)); out.println(querySummary); if (debug) { out.println(results.getInfoUri().toString()); } // Splits: 1000 total, 842 done (84.20%) String splitsSummary = String.format( "Splits: %,d total, %,d done (%.2f%%)", stats.getTotalSplits(), stats.getCompletedSplits(), percentage(stats.getCompletedSplits(), stats.getTotalSplits())); out.println(splitsSummary); if (debug) { // CPU Time: 565.2s total, 26K rows/s, 3.85MB/s Duration cpuTime = millis(stats.getCpuTimeMillis()); String cpuTimeSummary = String.format( "CPU Time: %.1fs total, %5s rows/s, %8s, %d%% active", cpuTime.getValue(SECONDS), formatCountRate(stats.getProcessedRows(), cpuTime, false), formatDataRate(bytes(stats.getProcessedBytes()), cpuTime, true), (int) percentage(stats.getCpuTimeMillis(), stats.getWallTimeMillis())); out.println(cpuTimeSummary); double parallelism = cpuTime.getValue(MILLISECONDS) / wallTime.getValue(MILLISECONDS); // Per Node: 3.5 parallelism, 83.3K rows/s, 0.7 MB/s String perNodeSummary = String.format( "Per Node: %.1f parallelism, %5s rows/s, %8s", parallelism / nodes, formatCountRate((double) stats.getProcessedRows() / nodes, wallTime, false), formatDataRate(bytes(stats.getProcessedBytes() / nodes), wallTime, true)); reprintLine(perNodeSummary); out.println(String.format("Parallelism: %.1f", parallelism)); } // 0:32 [2.12GB, 15M rows] [67MB/s, 463K rows/s] String statsLine = String.format( "%s [%s rows, %s] [%s rows/s, %s]", formatTime(wallTime), formatCount(stats.getProcessedRows()), formatDataSize(bytes(stats.getProcessedBytes()), true), formatCountRate(stats.getProcessedRows(), wallTime, false), formatDataRate(bytes(stats.getProcessedBytes()), wallTime, true)); out.println(statsLine); // blank line out.println(); } private void printQueryInfo(QueryResults results) { StatementStats stats = results.getStats(); Duration wallTime = nanosSince(start); // cap progress at 99%, otherwise it looks weird when the query is still running and it says // 100% int progressPercentage = (int) min(99, percentage(stats.getCompletedSplits(), stats.getTotalSplits())); if (console.isRealTerminal()) { // blank line reprintLine(""); int terminalWidth = console.getWidth(); if (terminalWidth < 75) { reprintLine("WARNING: Terminal"); reprintLine("must be at least"); reprintLine("80 characters wide"); reprintLine(""); reprintLine(stats.getState()); reprintLine(String.format("%s %d%%", formatTime(wallTime), progressPercentage)); return; } int nodes = stats.getNodes(); // Query 10, RUNNING, 1 node, 778 splits String querySummary = String.format( "Query %s, %s, %,d %s, %,d splits", results.getId(), stats.getState(), nodes, pluralize("node", nodes), stats.getTotalSplits()); reprintLine(querySummary); String url = results.getInfoUri().toString(); if (debug && (url.length() < terminalWidth)) { reprintLine(url); } if ((nodes == 0) || (stats.getTotalSplits() == 0)) { return; } if (debug) { // Splits: 620 queued, 34 running, 124 done String splitsSummary = String.format( "Splits: %,d queued, %,d running, %,d done", stats.getQueuedSplits(), stats.getRunningSplits(), stats.getCompletedSplits()); reprintLine(splitsSummary); // CPU Time: 56.5s total, 36.4K rows/s, 4.44MB/s, 60% active Duration cpuTime = millis(stats.getCpuTimeMillis()); String cpuTimeSummary = String.format( "CPU Time: %.1fs total, %5s rows/s, %8s, %d%% active", cpuTime.getValue(SECONDS), formatCountRate(stats.getProcessedRows(), cpuTime, false), formatDataRate(bytes(stats.getProcessedBytes()), cpuTime, true), (int) percentage(stats.getCpuTimeMillis(), stats.getWallTimeMillis())); reprintLine(cpuTimeSummary); double parallelism = cpuTime.getValue(MILLISECONDS) / wallTime.getValue(MILLISECONDS); // Per Node: 3.5 parallelism, 83.3K rows/s, 0.7 MB/s String perNodeSummary = String.format( "Per Node: %.1f parallelism, %5s rows/s, %8s", parallelism / nodes, formatCountRate((double) stats.getProcessedRows() / nodes, wallTime, false), formatDataRate(bytes(stats.getProcessedBytes() / nodes), wallTime, true)); reprintLine(perNodeSummary); reprintLine(String.format("Parallelism: %.1f", parallelism)); } assert terminalWidth >= 75; int progressWidth = (min(terminalWidth, 100) - 75) + 17; // progress bar is 17-42 characters wide if (stats.isScheduled()) { String progressBar = formatProgressBar( progressWidth, stats.getCompletedSplits(), max(0, stats.getRunningSplits()), stats.getTotalSplits()); // 0:17 [ 103MB, 802K rows] [5.74MB/s, 44.9K rows/s] [=====>> // ] 10% String progressLine = String.format( "%s [%5s rows, %6s] [%5s rows/s, %8s] [%s] %d%%", formatTime(wallTime), formatCount(stats.getProcessedRows()), formatDataSize(bytes(stats.getProcessedBytes()), true), formatCountRate(stats.getProcessedRows(), wallTime, false), formatDataRate(bytes(stats.getProcessedBytes()), wallTime, true), progressBar, progressPercentage); reprintLine(progressLine); } else { String progressBar = formatProgressBar( progressWidth, Ints.saturatedCast(nanosSince(start).roundTo(SECONDS))); // 0:17 [ 103MB, 802K rows] [5.74MB/s, 44.9K rows/s] [ <=> // ] String progressLine = String.format( "%s [%5s rows, %6s] [%5s rows/s, %8s] [%s]", formatTime(wallTime), formatCount(stats.getProcessedRows()), formatDataSize(bytes(stats.getProcessedBytes()), true), formatCountRate(stats.getProcessedRows(), wallTime, false), formatDataRate(bytes(stats.getProcessedBytes()), wallTime, true), progressBar); reprintLine(progressLine); } // todo Mem: 1949M shared, 7594M private // blank line reprintLine(""); // STAGE S ROWS RPS BYTES BPS QUEUED RUN DONE String stagesHeader = String.format( "%10s%1s %5s %6s %5s %7s %6s %5s %5s", "STAGE", "S", "ROWS", "ROWS/s", "BYTES", "BYTES/s", "QUEUED", "RUN", "DONE"); reprintLine(stagesHeader); printStageTree(stats.getRootStage(), "", new AtomicInteger()); } else { // Query 31 [S] i[2.7M 67.3MB 62.7MBps] o[35 6.1KB 1KBps] splits[252/16/380] String querySummary = String.format( "Query %s [%s] i[%s %s %s] o[%s %s %s] splits[%,d/%,d/%,d]", results.getId(), stats.getState(), formatCount(stats.getProcessedRows()), formatDataSize(bytes(stats.getProcessedBytes()), false), formatDataRate(bytes(stats.getProcessedBytes()), wallTime, false), formatCount(stats.getProcessedRows()), formatDataSize(bytes(stats.getProcessedBytes()), false), formatDataRate(bytes(stats.getProcessedBytes()), wallTime, false), stats.getQueuedSplits(), stats.getRunningSplits(), stats.getCompletedSplits()); reprintLine(querySummary); } } private void printStageTree(StageStats stage, String indent, AtomicInteger stageNumberCounter) { Duration elapsedTime = nanosSince(start); // STAGE S ROWS ROWS/s BYTES BYTES/s QUEUED RUN DONE // 0......Q 26M 9077M 9993G 9077M 9077M 9077M 9077M // 2....R 17K 627M 673M 627M 627M 627M 627M // 3..C 999 627M 673M 627M 627M 627M 627M // 4....R 26M 627M 673T 627M 627M 627M 627M // 5..F 29T 627M 673M 627M 627M 627M 627M String id = String.valueOf(stageNumberCounter.getAndIncrement()); String name = indent + id; name += Strings.repeat(".", max(0, 10 - name.length())); String bytesPerSecond; String rowsPerSecond; if (stage.isDone()) { bytesPerSecond = formatDataRate(new DataSize(0, BYTE), new Duration(0, SECONDS), false); rowsPerSecond = formatCountRate(0, new Duration(0, SECONDS), false); } else { bytesPerSecond = formatDataRate(bytes(stage.getProcessedBytes()), elapsedTime, false); rowsPerSecond = formatCountRate(stage.getProcessedRows(), elapsedTime, false); } String stageSummary = String.format( "%10s%1s %5s %6s %5s %7s %6s %5s %5s", name, stageStateCharacter(stage.getState()), formatCount(stage.getProcessedRows()), rowsPerSecond, formatDataSize(bytes(stage.getProcessedBytes()), false), bytesPerSecond, stage.getQueuedSplits(), stage.getRunningSplits(), stage.getCompletedSplits()); reprintLine(stageSummary); for (StageStats subStage : stage.getSubStages()) { printStageTree(subStage, indent + " ", stageNumberCounter); } } private void partialCancel() { try { client.cancelLeafStage(new Duration(1, SECONDS)); } catch (RuntimeException e) { log.debug(e, "error canceling leaf stage"); } } private void reprintLine(String line) { console.reprintLine(line); } private static char stageStateCharacter(String state) { return "FAILED".equals(state) ? 'X' : state.charAt(0); } private static Duration millis(long millis) { return new Duration(millis, MILLISECONDS); } private static DataSize bytes(long bytes) { return new DataSize(bytes, BYTE); } private static double percentage(double count, double total) { if (total == 0) { return 0; } return min(100, (count * 100.0) / total); } }
public class QueryMonitor { private static final Logger log = Logger.get(QueryMonitor.class); private final ObjectMapper objectMapper; private final EventClient eventClient; private final String environment; @Inject public QueryMonitor(ObjectMapper objectMapper, EventClient eventClient, NodeInfo nodeInfo) { this.objectMapper = checkNotNull(objectMapper, "objectMapper is null"); this.eventClient = checkNotNull(eventClient, "eventClient is null"); this.environment = checkNotNull(nodeInfo, "nodeInfo is null").getEnvironment(); } public void createdEvent(QueryInfo queryInfo) { eventClient.post( new QueryCreatedEvent( queryInfo.getQueryId(), queryInfo.getSession().getUser(), queryInfo.getSession().getSource(), environment, queryInfo.getSession().getCatalog(), queryInfo.getSession().getSchema(), queryInfo.getSession().getRemoteUserAddress(), queryInfo.getSession().getUserAgent(), queryInfo.getSelf(), queryInfo.getQuery(), queryInfo.getQueryStats().getCreateTime())); } public void completionEvent(QueryInfo queryInfo) { try { QueryStats queryStats = queryInfo.getQueryStats(); FailureInfo failureInfo = queryInfo.getFailureInfo(); String failureType = failureInfo == null ? null : failureInfo.getType(); String failureMessage = failureInfo == null ? null : failureInfo.getMessage(); eventClient.post( new QueryCompletionEvent( queryInfo.getQueryId(), queryInfo.getSession().getUser(), queryInfo.getSession().getSource(), environment, queryInfo.getSession().getCatalog(), queryInfo.getSession().getSchema(), queryInfo.getSession().getRemoteUserAddress(), queryInfo.getSession().getUserAgent(), queryInfo.getState(), queryInfo.getSelf(), queryInfo.getFieldNames(), queryInfo.getQuery(), queryStats.getCreateTime(), queryStats.getExecutionStartTime(), queryStats.getEndTime(), queryStats.getQueuedTime(), queryStats.getAnalysisTime(), queryStats.getDistributedPlanningTime(), queryStats.getTotalScheduledTime(), queryStats.getTotalCpuTime(), queryStats.getRawInputDataSize(), queryStats.getRawInputPositions(), queryStats.getTotalDrivers(), queryInfo.getErrorCode(), failureType, failureMessage, objectMapper.writeValueAsString(queryInfo.getOutputStage()), objectMapper.writeValueAsString(queryInfo.getFailureInfo()), objectMapper.writeValueAsString(queryInfo.getInputs()))); logQueryTimeline(queryInfo); } catch (JsonProcessingException e) { throw Throwables.propagate(e); } } private void logQueryTimeline(QueryInfo queryInfo) { try { QueryStats queryStats = queryInfo.getQueryStats(); DateTime queryStartTime = queryStats.getCreateTime(); DateTime queryEndTime = queryStats.getEndTime(); // query didn't finish cleanly if (queryStartTime == null || queryEndTime == null) { return; } // planning duration -- start to end of planning Duration planning = queryStats.getTotalPlanningTime(); if (planning == null) { planning = new Duration(0, MILLISECONDS); } List<StageInfo> stages = StageInfo.getAllStages(queryInfo.getOutputStage()); // long lastSchedulingCompletion = 0; long firstTaskStartTime = queryEndTime.getMillis(); long lastTaskStartTime = queryStartTime.getMillis() + planning.toMillis(); long lastTaskEndTime = queryStartTime.getMillis() + planning.toMillis(); for (StageInfo stage : stages) { // only consider leaf stages if (!stage.getSubStages().isEmpty()) { continue; } for (TaskInfo taskInfo : stage.getTasks()) { TaskStats taskStats = taskInfo.getStats(); DateTime firstStartTime = taskStats.getFirstStartTime(); if (firstStartTime != null) { firstTaskStartTime = Math.min(firstStartTime.getMillis(), firstTaskStartTime); } DateTime lastStartTime = taskStats.getLastStartTime(); if (lastStartTime != null) { lastTaskStartTime = Math.max(lastStartTime.getMillis(), lastTaskStartTime); } DateTime endTime = taskStats.getEndTime(); if (endTime != null) { lastTaskEndTime = Math.max(endTime.getMillis(), lastTaskEndTime); } } } Duration elapsed = millis(queryEndTime.getMillis() - queryStartTime.getMillis()); Duration scheduling = millis(firstTaskStartTime - queryStartTime.getMillis() - planning.toMillis()); Duration running = millis(lastTaskEndTime - firstTaskStartTime); Duration finishing = millis(queryEndTime.getMillis() - lastTaskEndTime); log.info( "TIMELINE: Query %s :: elapsed %s :: planning %s :: scheduling %s :: running %s :: finishing %s :: begin %s :: end %s", queryInfo.getQueryId(), elapsed, planning, scheduling, running, finishing, queryStartTime, queryEndTime); } catch (Exception e) { log.error(e, "Error logging query timeline"); } } public void splitCompletionEvent(TaskId taskId, DriverStats driverStats) { splitCompletionEvent(taskId, driverStats, null, null); } public void splitFailedEvent(TaskId taskId, DriverStats driverStats, Throwable cause) { splitCompletionEvent(taskId, driverStats, cause.getClass().getName(), cause.getMessage()); } private void splitCompletionEvent( TaskId taskId, DriverStats driverStats, @Nullable String failureType, @Nullable String failureMessage) { Duration timeToStart = null; if (driverStats.getStartTime() != null) { timeToStart = millis(driverStats.getStartTime().getMillis() - driverStats.getCreateTime().getMillis()); } Duration timeToEnd = null; if (driverStats.getEndTime() != null) { timeToEnd = millis(driverStats.getEndTime().getMillis() - driverStats.getCreateTime().getMillis()); } try { eventClient.post( new SplitCompletionEvent( taskId.getQueryId(), taskId.getStageId(), taskId, environment, driverStats.getQueuedTime(), driverStats.getStartTime(), timeToStart, timeToEnd, driverStats.getRawInputDataSize(), driverStats.getRawInputPositions(), driverStats.getRawInputReadTime(), driverStats.getElapsedTime(), driverStats.getTotalCpuTime(), driverStats.getTotalUserTime(), failureType, failureMessage, objectMapper.writeValueAsString(driverStats))); } catch (JsonProcessingException e) { log.error(e, "Error posting split completion event for task %s", taskId); } } private static Duration millis(long millis) { if (millis < 0) { millis = 0; } return new Duration(millis, MILLISECONDS); } }
@ThreadSafe class RequestErrorTracker { private static final Logger log = Logger.get(RequestErrorTracker.class); private final TaskId taskId; private final URI taskUri; private final ScheduledExecutorService scheduledExecutor; private final String jobDescription; private final Backoff backoff; private final Queue<Throwable> errorsSinceLastSuccess = new ConcurrentLinkedQueue<>(); public RequestErrorTracker( TaskId taskId, URI taskUri, Duration minErrorDuration, ScheduledExecutorService scheduledExecutor, String jobDescription) { this.taskId = taskId; this.taskUri = taskUri; this.scheduledExecutor = scheduledExecutor; this.backoff = new Backoff(minErrorDuration); this.jobDescription = jobDescription; } public ListenableFuture<?> acquireRequestPermit() { long delayNanos = backoff.getBackoffDelayNanos(); if (delayNanos == 0) { return Futures.immediateFuture(null); } ListenableFutureTask<Object> futureTask = ListenableFutureTask.create(() -> null); scheduledExecutor.schedule(futureTask, delayNanos, NANOSECONDS); return futureTask; } public void startRequest() { // before scheduling a new request clear the error timer // we consider a request to be "new" if there are no current failures if (backoff.getFailureCount() == 0) { requestSucceeded(); } } public void requestSucceeded() { backoff.success(); errorsSinceLastSuccess.clear(); } public void requestFailed(Throwable reason) throws PrestoException { // cancellation is not a failure if (reason instanceof CancellationException) { return; } if (reason instanceof RejectedExecutionException) { throw new PrestoException(REMOTE_TASK_ERROR, reason); } // log failure message if (isExpectedError(reason)) { // don't print a stack for a known errors log.warn("Error " + jobDescription + " %s: %s: %s", taskId, reason.getMessage(), taskUri); } else { log.warn(reason, "Error " + jobDescription + " %s: %s", taskId, taskUri); } // remember the first 10 errors if (errorsSinceLastSuccess.size() < 10) { errorsSinceLastSuccess.add(reason); } // fail the task, if we have more than X failures in a row and more than Y seconds have passed // since the last request if (backoff.failure()) { // it is weird to mark the task failed locally and then cancel the remote task, but there is // no way to tell a remote task that it is failed PrestoException exception = new PrestoException( TOO_MANY_REQUESTS_FAILED, format( "%s (%s %s - %s failures, time since last success %s)", WORKER_NODE_ERROR, jobDescription, taskUri, backoff.getFailureCount(), backoff.getTimeSinceLastSuccess().convertTo(SECONDS))); errorsSinceLastSuccess.forEach(exception::addSuppressed); throw exception; } } static void logError(Throwable t, String format, Object... args) { if (isExpectedError(t)) { log.error(format + ": %s", ObjectArrays.concat(args, t)); } else { log.error(t, format, args); } } private static boolean isExpectedError(Throwable t) { while (t != null) { if ((t instanceof SocketException) || (t instanceof SocketTimeoutException) || (t instanceof EOFException) || (t instanceof TimeoutException) || (t instanceof ServiceUnavailableException)) { return true; } t = t.getCause(); } return false; } }
public class HiveMetadata implements ConnectorMetadata { private static final Logger log = Logger.get(HiveMetadata.class); private final String connectorId; private final boolean allowDropTable; private final boolean allowRenameTable; private final boolean allowCorruptWritesForTesting; private final HiveMetastore metastore; private final HdfsEnvironment hdfsEnvironment; private final DateTimeZone timeZone; private final HiveStorageFormat hiveStorageFormat; private final TypeManager typeManager; @Inject @SuppressWarnings("deprecation") public HiveMetadata( HiveConnectorId connectorId, HiveClientConfig hiveClientConfig, HiveMetastore metastore, HdfsEnvironment hdfsEnvironment, @ForHiveClient ExecutorService executorService, TypeManager typeManager) { this( connectorId, metastore, hdfsEnvironment, DateTimeZone.forTimeZone(hiveClientConfig.getTimeZone()), hiveClientConfig.getAllowDropTable(), hiveClientConfig.getAllowRenameTable(), hiveClientConfig.getAllowCorruptWritesForTesting(), hiveClientConfig.getHiveStorageFormat(), typeManager); } public HiveMetadata( HiveConnectorId connectorId, HiveMetastore metastore, HdfsEnvironment hdfsEnvironment, DateTimeZone timeZone, boolean allowDropTable, boolean allowRenameTable, boolean allowCorruptWritesForTesting, HiveStorageFormat hiveStorageFormat, TypeManager typeManager) { this.connectorId = checkNotNull(connectorId, "connectorId is null").toString(); this.allowDropTable = allowDropTable; this.allowRenameTable = allowRenameTable; this.allowCorruptWritesForTesting = allowCorruptWritesForTesting; this.metastore = checkNotNull(metastore, "metastore is null"); this.hdfsEnvironment = checkNotNull(hdfsEnvironment, "hdfsEnvironment is null"); this.timeZone = checkNotNull(timeZone, "timeZone is null"); this.hiveStorageFormat = hiveStorageFormat; this.typeManager = checkNotNull(typeManager, "typeManager is null"); if (!allowCorruptWritesForTesting && !timeZone.equals(DateTimeZone.getDefault())) { log.warn( "Hive writes are disabled. " + "To write data to Hive, your JVM timezone must match the Hive storage timezone. " + "Add -Duser.timezone=%s to your JVM arguments", timeZone.getID()); } } public HiveMetastore getMetastore() { return metastore; } @Override public List<String> listSchemaNames(ConnectorSession session) { return metastore.getAllDatabases(); } @Override public HiveTableHandle getTableHandle(ConnectorSession session, SchemaTableName tableName) { checkNotNull(tableName, "tableName is null"); try { metastore.getTable(tableName.getSchemaName(), tableName.getTableName()); return new HiveTableHandle( connectorId, tableName.getSchemaName(), tableName.getTableName(), session); } catch (NoSuchObjectException e) { // table was not found return null; } } @Override public ConnectorTableMetadata getTableMetadata(ConnectorTableHandle tableHandle) { checkNotNull(tableHandle, "tableHandle is null"); SchemaTableName tableName = schemaTableName(tableHandle); return getTableMetadata(tableName); } private ConnectorTableMetadata getTableMetadata(SchemaTableName tableName) { try { Table table = metastore.getTable(tableName.getSchemaName(), tableName.getTableName()); if (table.getTableType().equals(TableType.VIRTUAL_VIEW.name())) { throw new TableNotFoundException(tableName); } List<HiveColumnHandle> handles = hiveColumnHandles(typeManager, connectorId, table, false); List<ColumnMetadata> columns = ImmutableList.copyOf(transform(handles, columnMetadataGetter(table, typeManager))); return new ConnectorTableMetadata(tableName, columns, table.getOwner()); } catch (NoSuchObjectException e) { throw new TableNotFoundException(tableName); } } @Override public List<SchemaTableName> listTables(ConnectorSession session, String schemaNameOrNull) { ImmutableList.Builder<SchemaTableName> tableNames = ImmutableList.builder(); for (String schemaName : listSchemas(session, schemaNameOrNull)) { try { for (String tableName : metastore.getAllTables(schemaName)) { tableNames.add(new SchemaTableName(schemaName, tableName)); } } catch (NoSuchObjectException e) { // schema disappeared during listing operation } } return tableNames.build(); } private List<String> listSchemas(ConnectorSession session, String schemaNameOrNull) { if (schemaNameOrNull == null) { return listSchemaNames(session); } return ImmutableList.of(schemaNameOrNull); } @Override public ColumnHandle getSampleWeightColumnHandle(ConnectorTableHandle tableHandle) { SchemaTableName tableName = schemaTableName(tableHandle); try { Table table = metastore.getTable(tableName.getSchemaName(), tableName.getTableName()); for (HiveColumnHandle columnHandle : hiveColumnHandles(typeManager, connectorId, table, true)) { if (columnHandle.getName().equals(SAMPLE_WEIGHT_COLUMN_NAME)) { return columnHandle; } } return null; } catch (NoSuchObjectException e) { throw new TableNotFoundException(tableName); } } @Override public boolean canCreateSampledTables(ConnectorSession session) { return true; } @Override public Map<String, ColumnHandle> getColumnHandles(ConnectorTableHandle tableHandle) { SchemaTableName tableName = schemaTableName(tableHandle); try { Table table = metastore.getTable(tableName.getSchemaName(), tableName.getTableName()); ImmutableMap.Builder<String, ColumnHandle> columnHandles = ImmutableMap.builder(); for (HiveColumnHandle columnHandle : hiveColumnHandles(typeManager, connectorId, table, false)) { columnHandles.put(columnHandle.getName(), columnHandle); } return columnHandles.build(); } catch (NoSuchObjectException e) { throw new TableNotFoundException(tableName); } } @Override public Map<SchemaTableName, List<ColumnMetadata>> listTableColumns( ConnectorSession session, SchemaTablePrefix prefix) { checkNotNull(prefix, "prefix is null"); ImmutableMap.Builder<SchemaTableName, List<ColumnMetadata>> columns = ImmutableMap.builder(); for (SchemaTableName tableName : listTables(session, prefix)) { try { columns.put(tableName, getTableMetadata(tableName).getColumns()); } catch (HiveViewNotSupportedException e) { // view is not supported } catch (TableNotFoundException e) { // table disappeared during listing operation } } return columns.build(); } private List<SchemaTableName> listTables(ConnectorSession session, SchemaTablePrefix prefix) { if (prefix.getSchemaName() == null || prefix.getTableName() == null) { return listTables(session, prefix.getSchemaName()); } return ImmutableList.of(new SchemaTableName(prefix.getSchemaName(), prefix.getTableName())); } /** NOTE: This method does not return column comment */ @Override public ColumnMetadata getColumnMetadata( ConnectorTableHandle tableHandle, ColumnHandle columnHandle) { checkType(tableHandle, HiveTableHandle.class, "tableHandle"); return checkType(columnHandle, HiveColumnHandle.class, "columnHandle") .getColumnMetadata(typeManager); } @Override public void createTable(ConnectorSession session, ConnectorTableMetadata tableMetadata) { checkArgument(!isNullOrEmpty(tableMetadata.getOwner()), "Table owner is null or empty"); SchemaTableName schemaTableName = tableMetadata.getTable(); String schemaName = schemaTableName.getSchemaName(); String tableName = schemaTableName.getTableName(); ImmutableList.Builder<String> columnNames = ImmutableList.builder(); ImmutableList.Builder<Type> columnTypes = ImmutableList.builder(); buildColumnInfo(tableMetadata, columnNames, columnTypes); ImmutableList.Builder<FieldSchema> partitionKeys = ImmutableList.builder(); ImmutableList.Builder<FieldSchema> columns = ImmutableList.builder(); List<String> names = columnNames.build(); List<String> typeNames = columnTypes .build() .stream() .map(HiveType::toHiveType) .map(HiveType::getHiveTypeName) .collect(toList()); for (int i = 0; i < names.size(); i++) { if (tableMetadata.getColumns().get(i).isPartitionKey()) { partitionKeys.add(new FieldSchema(names.get(i), typeNames.get(i), null)); } else { columns.add(new FieldSchema(names.get(i), typeNames.get(i), null)); } } Path targetPath = getTargetPath(schemaName, tableName, schemaTableName); HiveStorageFormat hiveStorageFormat = getHiveStorageFormat(session, this.hiveStorageFormat); SerDeInfo serdeInfo = new SerDeInfo(); serdeInfo.setName(tableName); serdeInfo.setSerializationLib(hiveStorageFormat.getSerDe()); StorageDescriptor sd = new StorageDescriptor(); sd.setLocation(targetPath.toString()); sd.setCols(columns.build()); sd.setSerdeInfo(serdeInfo); sd.setInputFormat(hiveStorageFormat.getInputFormat()); sd.setOutputFormat(hiveStorageFormat.getOutputFormat()); Table table = new Table(); table.setDbName(schemaName); table.setTableName(tableName); table.setOwner(tableMetadata.getOwner()); table.setTableType(TableType.MANAGED_TABLE.toString()); String tableComment = "Created by Presto"; table.setParameters(ImmutableMap.of("comment", tableComment)); table.setPartitionKeys(partitionKeys.build()); table.setSd(sd); metastore.createTable(table); } @Override public void renameTable(ConnectorTableHandle tableHandle, SchemaTableName newTableName) { if (!allowRenameTable) { throw new PrestoException( PERMISSION_DENIED, "Renaming tables is disabled in this Hive catalog"); } HiveTableHandle handle = checkType(tableHandle, HiveTableHandle.class, "tableHandle"); metastore.renameTable( handle.getSchemaName(), handle.getTableName(), newTableName.getSchemaName(), newTableName.getTableName()); } @Override public void dropTable(ConnectorTableHandle tableHandle) { HiveTableHandle handle = checkType(tableHandle, HiveTableHandle.class, "tableHandle"); SchemaTableName tableName = schemaTableName(tableHandle); if (!allowDropTable) { throw new PrestoException(PERMISSION_DENIED, "DROP TABLE is disabled in this Hive catalog"); } try { Table table = metastore.getTable(handle.getSchemaName(), handle.getTableName()); if (!handle.getSession().getUser().equals(table.getOwner())) { throw new PrestoException( PERMISSION_DENIED, format( "Unable to drop table '%s': owner of the table is different from session user", table)); } metastore.dropTable(handle.getSchemaName(), handle.getTableName()); } catch (NoSuchObjectException e) { throw new TableNotFoundException(tableName); } } @Override public HiveOutputTableHandle beginCreateTable( ConnectorSession session, ConnectorTableMetadata tableMetadata) { verifyJvmTimeZone(); checkArgument(!isNullOrEmpty(tableMetadata.getOwner()), "Table owner is null or empty"); HiveStorageFormat hiveStorageFormat = getHiveStorageFormat(session, this.hiveStorageFormat); ImmutableList.Builder<String> columnNames = ImmutableList.builder(); ImmutableList.Builder<Type> columnTypes = ImmutableList.builder(); // get the root directory for the database SchemaTableName schemaTableName = tableMetadata.getTable(); String schemaName = schemaTableName.getSchemaName(); String tableName = schemaTableName.getTableName(); buildColumnInfo(tableMetadata, columnNames, columnTypes); Path targetPath = getTargetPath(schemaName, tableName, schemaTableName); if (!useTemporaryDirectory(targetPath)) { return new HiveOutputTableHandle( connectorId, schemaName, tableName, columnNames.build(), columnTypes.build(), tableMetadata.getOwner(), targetPath.toString(), targetPath.toString(), session, hiveStorageFormat); } // use a per-user temporary directory to avoid permission problems // TODO: this should use Hadoop UserGroupInformation String temporaryPrefix = "/tmp/presto-" + StandardSystemProperty.USER_NAME.value(); // create a temporary directory on the same filesystem Path temporaryRoot = new Path(targetPath, temporaryPrefix); Path temporaryPath = new Path(temporaryRoot, randomUUID().toString()); createDirectories(temporaryPath); return new HiveOutputTableHandle( connectorId, schemaName, tableName, columnNames.build(), columnTypes.build(), tableMetadata.getOwner(), targetPath.toString(), temporaryPath.toString(), session, hiveStorageFormat); } @Override public void commitCreateTable( ConnectorOutputTableHandle tableHandle, Collection<Slice> fragments) { HiveOutputTableHandle handle = checkType(tableHandle, HiveOutputTableHandle.class, "tableHandle"); // verify no one raced us to create the target directory Path targetPath = new Path(handle.getTargetPath()); // rename if using a temporary directory if (handle.hasTemporaryPath()) { if (pathExists(targetPath)) { SchemaTableName table = new SchemaTableName(handle.getSchemaName(), handle.getTableName()); throw new PrestoException( HIVE_PATH_ALREADY_EXISTS, format( "Unable to commit creation of table '%s': target directory already exists: %s", table, targetPath)); } // rename the temporary directory to the target rename(new Path(handle.getTemporaryPath()), targetPath); } // create the table in the metastore List<String> types = handle .getColumnTypes() .stream() .map(HiveType::toHiveType) .map(HiveType::getHiveTypeName) .collect(toList()); boolean sampled = false; ImmutableList.Builder<FieldSchema> columns = ImmutableList.builder(); for (int i = 0; i < handle.getColumnNames().size(); i++) { String name = handle.getColumnNames().get(i); String type = types.get(i); if (name.equals(SAMPLE_WEIGHT_COLUMN_NAME)) { columns.add(new FieldSchema(name, type, "Presto sample weight column")); sampled = true; } else { columns.add(new FieldSchema(name, type, null)); } } HiveStorageFormat hiveStorageFormat = handle.getHiveStorageFormat(); SerDeInfo serdeInfo = new SerDeInfo(); serdeInfo.setName(handle.getTableName()); serdeInfo.setSerializationLib(hiveStorageFormat.getSerDe()); serdeInfo.setParameters(ImmutableMap.<String, String>of()); StorageDescriptor sd = new StorageDescriptor(); sd.setLocation(targetPath.toString()); sd.setCols(columns.build()); sd.setSerdeInfo(serdeInfo); sd.setInputFormat(hiveStorageFormat.getInputFormat()); sd.setOutputFormat(hiveStorageFormat.getOutputFormat()); sd.setParameters(ImmutableMap.<String, String>of()); Table table = new Table(); table.setDbName(handle.getSchemaName()); table.setTableName(handle.getTableName()); table.setOwner(handle.getTableOwner()); table.setTableType(TableType.MANAGED_TABLE.toString()); String tableComment = "Created by Presto"; if (sampled) { tableComment = "Sampled table created by Presto. Only query this table from Hive if you understand how Presto implements sampling."; } table.setParameters(ImmutableMap.of("comment", tableComment)); table.setPartitionKeys(ImmutableList.<FieldSchema>of()); table.setSd(sd); metastore.createTable(table); } private Path getTargetPath(String schemaName, String tableName, SchemaTableName schemaTableName) { String location = getDatabase(schemaName).getLocationUri(); if (isNullOrEmpty(location)) { throw new PrestoException( HIVE_DATABASE_LOCATION_ERROR, format("Database '%s' location is not set", schemaName)); } Path databasePath = new Path(location); if (!pathExists(databasePath)) { throw new PrestoException( HIVE_DATABASE_LOCATION_ERROR, format("Database '%s' location does not exist: %s", schemaName, databasePath)); } if (!isDirectory(databasePath)) { throw new PrestoException( HIVE_DATABASE_LOCATION_ERROR, format("Database '%s' location is not a directory: %s", schemaName, databasePath)); } // verify the target directory for the table Path targetPath = new Path(databasePath, tableName); if (pathExists(targetPath)) { throw new PrestoException( HIVE_PATH_ALREADY_EXISTS, format( "Target directory for table '%s' already exists: %s", schemaTableName, targetPath)); } return targetPath; } private Database getDatabase(String database) { try { return metastore.getDatabase(database); } catch (NoSuchObjectException e) { throw new SchemaNotFoundException(database); } } private boolean useTemporaryDirectory(Path path) { try { // skip using temporary directory for S3 return !(hdfsEnvironment.getFileSystem(path) instanceof PrestoS3FileSystem); } catch (IOException e) { throw new PrestoException(HIVE_FILESYSTEM_ERROR, "Failed checking path: " + path, e); } } private boolean pathExists(Path path) { try { return hdfsEnvironment.getFileSystem(path).exists(path); } catch (IOException e) { throw new PrestoException(HIVE_FILESYSTEM_ERROR, "Failed checking path: " + path, e); } } private boolean isDirectory(Path path) { try { return hdfsEnvironment.getFileSystem(path).isDirectory(path); } catch (IOException e) { throw new PrestoException(HIVE_FILESYSTEM_ERROR, "Failed checking path: " + path, e); } } private void createDirectories(Path path) { try { if (!hdfsEnvironment.getFileSystem(path).mkdirs(path)) { throw new IOException("mkdirs returned false"); } } catch (IOException e) { throw new PrestoException(HIVE_FILESYSTEM_ERROR, "Failed to create directory: " + path, e); } } private void rename(Path source, Path target) { try { if (!hdfsEnvironment.getFileSystem(source).rename(source, target)) { throw new IOException("rename returned false"); } } catch (IOException e) { throw new PrestoException( HIVE_FILESYSTEM_ERROR, format("Failed to rename %s to %s", source, target), e); } } @Override public void createView( ConnectorSession session, SchemaTableName viewName, String viewData, boolean replace) { if (replace) { try { dropView(session, viewName); } catch (ViewNotFoundException ignored) { } } Map<String, String> properties = ImmutableMap.<String, String>builder() .put("comment", "Presto View") .put(PRESTO_VIEW_FLAG, "true") .build(); FieldSchema dummyColumn = new FieldSchema("dummy", STRING_TYPE_NAME, null); StorageDescriptor sd = new StorageDescriptor(); sd.setCols(ImmutableList.of(dummyColumn)); sd.setSerdeInfo(new SerDeInfo()); Table table = new Table(); table.setDbName(viewName.getSchemaName()); table.setTableName(viewName.getTableName()); table.setOwner(session.getUser()); table.setTableType(TableType.VIRTUAL_VIEW.name()); table.setParameters(properties); table.setViewOriginalText(encodeViewData(viewData)); table.setViewExpandedText("/* Presto View */"); table.setSd(sd); try { metastore.createTable(table); } catch (TableAlreadyExistsException e) { throw new ViewAlreadyExistsException(e.getTableName()); } } @Override public void dropView(ConnectorSession session, SchemaTableName viewName) { String view = getViews(session, viewName.toSchemaTablePrefix()).get(viewName); if (view == null) { throw new ViewNotFoundException(viewName); } try { metastore.dropTable(viewName.getSchemaName(), viewName.getTableName()); } catch (TableNotFoundException e) { throw new ViewNotFoundException(e.getTableName()); } } @Override public List<SchemaTableName> listViews(ConnectorSession session, String schemaNameOrNull) { ImmutableList.Builder<SchemaTableName> tableNames = ImmutableList.builder(); for (String schemaName : listSchemas(session, schemaNameOrNull)) { try { for (String tableName : metastore.getAllViews(schemaName)) { tableNames.add(new SchemaTableName(schemaName, tableName)); } } catch (NoSuchObjectException e) { // schema disappeared during listing operation } } return tableNames.build(); } @Override public Map<SchemaTableName, String> getViews(ConnectorSession session, SchemaTablePrefix prefix) { ImmutableMap.Builder<SchemaTableName, String> views = ImmutableMap.builder(); List<SchemaTableName> tableNames; if (prefix.getTableName() != null) { tableNames = ImmutableList.of(new SchemaTableName(prefix.getSchemaName(), prefix.getTableName())); } else { tableNames = listViews(session, prefix.getSchemaName()); } for (SchemaTableName schemaTableName : tableNames) { try { Table table = metastore.getTable(schemaTableName.getSchemaName(), schemaTableName.getTableName()); if (HiveUtil.isPrestoView(table)) { views.put(schemaTableName, decodeViewData(table.getViewOriginalText())); } } catch (NoSuchObjectException ignored) { } } return views.build(); } @Override public ConnectorInsertTableHandle beginInsert( ConnectorSession session, ConnectorTableHandle tableHandle) { verifyJvmTimeZone(); throw new PrestoException(NOT_SUPPORTED, "INSERT not yet supported for Hive"); } @Override public void commitInsert(ConnectorInsertTableHandle insertHandle, Collection<Slice> fragments) { throw new PrestoException(NOT_SUPPORTED, "INSERT not yet supported for Hive"); } @Override public String toString() { return toStringHelper(this).add("clientId", connectorId).toString(); } private void verifyJvmTimeZone() { if (!allowCorruptWritesForTesting && !timeZone.equals(DateTimeZone.getDefault())) { throw new PrestoException( HIVE_TIMEZONE_MISMATCH, format( "To write Hive data, your JVM timezone must match the Hive storage timezone. Add -Duser.timezone=%s to your JVM arguments.", timeZone.getID())); } } private static void buildColumnInfo( ConnectorTableMetadata tableMetadata, ImmutableList.Builder<String> names, ImmutableList.Builder<Type> types) { for (ColumnMetadata column : tableMetadata.getColumns()) { // TODO: also verify that the OutputFormat supports the type if (!HiveRecordSink.isTypeSupported(column.getType())) { throw new PrestoException( NOT_SUPPORTED, format( "Cannot create table with unsupported type: %s", column.getType().getDisplayName())); } names.add(column.getName()); types.add(column.getType()); } if (tableMetadata.isSampled()) { names.add(SAMPLE_WEIGHT_COLUMN_NAME); types.add(BIGINT); } } private static Function<HiveColumnHandle, ColumnMetadata> columnMetadataGetter( Table table, final TypeManager typeManager) { ImmutableMap.Builder<String, String> builder = ImmutableMap.builder(); for (FieldSchema field : concat(table.getSd().getCols(), table.getPartitionKeys())) { if (field.getComment() != null) { builder.put(field.getName(), field.getComment()); } } final Map<String, String> columnComment = builder.build(); return input -> new ColumnMetadata( input.getName(), typeManager.getType(input.getTypeSignature()), input.isPartitionKey(), columnComment.get(input.getName()), false); } }
public class Query implements Closeable { private static final Logger log = Logger.get(Query.class); private static final Signal SIGINT = new Signal("INT"); private final AtomicBoolean ignoreUserInterrupt = new AtomicBoolean(); private final AtomicBoolean userAbortedQuery = new AtomicBoolean(); private final StatementClient client; public Query(StatementClient client) { this.client = requireNonNull(client, "client is null"); } public Map<String, String> getSetSessionProperties() { return client.getSetSessionProperties(); } public Set<String> getResetSessionProperties() { return client.getResetSessionProperties(); } public String getStartedTransactionId() { return client.getStartedtransactionId(); } public boolean isClearTransactionId() { return client.isClearTransactionId(); } public void renderOutput(PrintStream out, OutputFormat outputFormat, boolean interactive) { Thread clientThread = Thread.currentThread(); SignalHandler oldHandler = Signal.handle( SIGINT, signal -> { if (ignoreUserInterrupt.get() || client.isClosed()) { return; } userAbortedQuery.set(true); client.close(); clientThread.interrupt(); }); try { renderQueryOutput(out, outputFormat, interactive); } finally { Signal.handle(SIGINT, oldHandler); Thread.interrupted(); // clear interrupt status } } private void renderQueryOutput(PrintStream out, OutputFormat outputFormat, boolean interactive) { StatusPrinter statusPrinter = null; @SuppressWarnings("resource") PrintStream errorChannel = interactive ? out : System.err; if (interactive) { statusPrinter = new StatusPrinter(client, out); statusPrinter.printInitialStatusUpdates(); } else { waitForData(); } if ((!client.isFailed()) && (!client.isGone()) && (!client.isClosed())) { QueryResults results = client.isValid() ? client.current() : client.finalResults(); if (results.getUpdateType() != null) { renderUpdate(out, results); } else if (results.getColumns() == null) { errorChannel.printf("Query %s has no columns\n", results.getId()); return; } else { renderResults(out, outputFormat, interactive, results.getColumns()); } } if (statusPrinter != null) { statusPrinter.printFinalInfo(); } if (client.isClosed()) { errorChannel.println("Query aborted by user"); } else if (client.isGone()) { errorChannel.println("Query is gone (server restarted?)"); } else if (client.isFailed()) { renderFailure(errorChannel); } } private void waitForData() { while (client.isValid() && (client.current().getData() == null)) { client.advance(); } } private void renderUpdate(PrintStream out, QueryResults results) { String status = results.getUpdateType(); if (results.getUpdateCount() != null) { long count = results.getUpdateCount(); status += format(": %s row%s", count, (count != 1) ? "s" : ""); } out.println(status); discardResults(); } private void discardResults() { try (OutputHandler handler = new OutputHandler(new NullPrinter())) { handler.processRows(client); } catch (IOException e) { throw Throwables.propagate(e); } } private void renderResults( PrintStream out, OutputFormat outputFormat, boolean interactive, List<Column> columns) { try { doRenderResults(out, outputFormat, interactive, columns); } catch (QueryAbortedException e) { System.out.println("(query aborted by user)"); client.close(); } catch (IOException e) { throw Throwables.propagate(e); } } private void doRenderResults( PrintStream out, OutputFormat format, boolean interactive, List<Column> columns) throws IOException { List<String> fieldNames = Lists.transform(columns, Column::getName); if (interactive) { pageOutput(format, fieldNames); } else { sendOutput(out, format, fieldNames); } } private void pageOutput(OutputFormat format, List<String> fieldNames) throws IOException { try (Pager pager = Pager.create(); Writer writer = createWriter(pager); OutputHandler handler = createOutputHandler(format, writer, fieldNames)) { if (!pager.isNullPager()) { // ignore the user pressing ctrl-C while in the pager ignoreUserInterrupt.set(true); Thread clientThread = Thread.currentThread(); pager .getFinishFuture() .thenRun( () -> { userAbortedQuery.set(true); ignoreUserInterrupt.set(false); clientThread.interrupt(); }); } handler.processRows(client); } catch (RuntimeException | IOException e) { if (userAbortedQuery.get() && !(e instanceof QueryAbortedException)) { throw new QueryAbortedException(e); } throw e; } } private void sendOutput(PrintStream out, OutputFormat format, List<String> fieldNames) throws IOException { try (OutputHandler handler = createOutputHandler(format, createWriter(out), fieldNames)) { handler.processRows(client); } } private static OutputHandler createOutputHandler( OutputFormat format, Writer writer, List<String> fieldNames) { return new OutputHandler(createOutputPrinter(format, writer, fieldNames)); } private static OutputPrinter createOutputPrinter( OutputFormat format, Writer writer, List<String> fieldNames) { switch (format) { case ALIGNED: return new AlignedTablePrinter(fieldNames, writer); case VERTICAL: return new VerticalRecordPrinter(fieldNames, writer); case CSV: return new CsvPrinter(fieldNames, writer, false); case CSV_HEADER: return new CsvPrinter(fieldNames, writer, true); case TSV: return new TsvPrinter(fieldNames, writer, false); case TSV_HEADER: return new TsvPrinter(fieldNames, writer, true); case NULL: return new NullPrinter(); } throw new RuntimeException(format + " not supported"); } private static Writer createWriter(OutputStream out) { return new OutputStreamWriter(out, UTF_8); } @Override public void close() { client.close(); } public void renderFailure(PrintStream out) { QueryResults results = client.finalResults(); QueryError error = results.getError(); checkState(error != null); out.printf("Query %s failed: %s%n", results.getId(), error.getMessage()); if (client.isDebug() && (error.getFailureInfo() != null)) { error.getFailureInfo().toException().printStackTrace(out); } if (error.getErrorLocation() != null) { renderErrorLocation(client.getQuery(), error.getErrorLocation(), out); } out.println(); } private static void renderErrorLocation(String query, ErrorLocation location, PrintStream out) { List<String> lines = ImmutableList.copyOf(Splitter.on('\n').split(query).iterator()); String errorLine = lines.get(location.getLineNumber() - 1); String good = errorLine.substring(0, location.getColumnNumber() - 1); String bad = errorLine.substring(location.getColumnNumber() - 1); if ((location.getLineNumber() == lines.size()) && bad.trim().isEmpty()) { bad = " <EOF>"; } if (REAL_TERMINAL) { Ansi ansi = Ansi.ansi(); ansi.fg(Ansi.Color.CYAN); for (int i = 1; i < location.getLineNumber(); i++) { ansi.a(lines.get(i - 1)).newline(); } ansi.a(good); ansi.fg(Ansi.Color.RED); ansi.a(bad).newline(); for (int i = location.getLineNumber(); i < lines.size(); i++) { ansi.a(lines.get(i)).newline(); } ansi.reset(); out.print(ansi); } else { String prefix = format("LINE %s: ", location.getLineNumber()); String padding = Strings.repeat(" ", prefix.length() + (location.getColumnNumber() - 1)); out.println(prefix + errorLine); out.println(padding + "^"); } } }
public class ExpressionCompiler { private static final Logger log = Logger.get(ExpressionCompiler.class); private static final AtomicLong CLASS_ID = new AtomicLong(); private static final boolean DUMP_BYTE_CODE_TREE = false; private static final boolean DUMP_BYTE_CODE_RAW = false; private static final boolean RUN_ASM_VERIFIER = false; // verifier doesn't work right now private static final AtomicReference<String> DUMP_CLASS_FILES_TO = new AtomicReference<>(); private final Metadata metadata; private final LoadingCache<OperatorCacheKey, FilterAndProjectOperatorFactoryFactory> operatorFactories = CacheBuilder.newBuilder() .maximumSize(1000) .build( new CacheLoader<OperatorCacheKey, FilterAndProjectOperatorFactoryFactory>() { @Override public FilterAndProjectOperatorFactoryFactory load(OperatorCacheKey key) throws Exception { return internalCompileFilterAndProjectOperator( key.getFilter(), key.getProjections()); } }); private final LoadingCache<OperatorCacheKey, ScanFilterAndProjectOperatorFactoryFactory> sourceOperatorFactories = CacheBuilder.newBuilder() .maximumSize(1000) .build( new CacheLoader<OperatorCacheKey, ScanFilterAndProjectOperatorFactoryFactory>() { @Override public ScanFilterAndProjectOperatorFactoryFactory load(OperatorCacheKey key) throws Exception { return internalCompileScanFilterAndProjectOperator( key.getSourceId(), key.getFilter(), key.getProjections()); } }); private final AtomicLong generatedClasses = new AtomicLong(); @Inject public ExpressionCompiler(Metadata metadata) { this.metadata = metadata; } @Managed public long getGeneratedClasses() { return generatedClasses.get(); } @Managed public long getCachedFilterAndProjectOperators() { return operatorFactories.size(); } @Managed public long getCachedScanFilterAndProjectOperators() { return sourceOperatorFactories.size(); } public OperatorFactory compileFilterAndProjectOperator( int operatorId, RowExpression filter, List<RowExpression> projections) { return operatorFactories .getUnchecked(new OperatorCacheKey(filter, projections, null)) .create(operatorId); } private DynamicClassLoader createClassLoader() { return new DynamicClassLoader(getClass().getClassLoader()); } @VisibleForTesting public FilterAndProjectOperatorFactoryFactory internalCompileFilterAndProjectOperator( RowExpression filter, List<RowExpression> projections) { DynamicClassLoader classLoader = createClassLoader(); // create filter and project page iterator class TypedOperatorClass typedOperatorClass = compileFilterAndProjectOperator(filter, projections, classLoader); Constructor<? extends Operator> constructor; try { constructor = typedOperatorClass .getOperatorClass() .getConstructor(OperatorContext.class, Iterable.class); } catch (NoSuchMethodException e) { throw Throwables.propagate(e); } FilterAndProjectOperatorFactoryFactory operatorFactoryFactory = new FilterAndProjectOperatorFactoryFactory(constructor, typedOperatorClass.getTypes()); return operatorFactoryFactory; } private TypedOperatorClass compileFilterAndProjectOperator( RowExpression filter, List<RowExpression> projections, DynamicClassLoader classLoader) { CallSiteBinder callSiteBinder = new CallSiteBinder(); ClassDefinition classDefinition = new ClassDefinition( new CompilerContext(BOOTSTRAP_METHOD), a(PUBLIC, FINAL), typeFromPathName("FilterAndProjectOperator_" + CLASS_ID.incrementAndGet()), type(AbstractFilterAndProjectOperator.class)); // declare fields FieldDefinition sessionField = classDefinition.declareField(a(PRIVATE, FINAL), "session", ConnectorSession.class); classDefinition.declareField(a(PRIVATE, VOLATILE, STATIC), "callSites", Map.class); // constructor classDefinition .declareConstructor( new CompilerContext(BOOTSTRAP_METHOD), a(PUBLIC), arg("operatorContext", OperatorContext.class), arg("types", type(Iterable.class, Type.class))) .getBody() .comment("super(operatorContext, types);") .pushThis() .getVariable("operatorContext") .getVariable("types") .invokeConstructor( AbstractFilterAndProjectOperator.class, OperatorContext.class, Iterable.class) .comment("this.session = operatorContext.getSession();") .pushThis() .getVariable("operatorContext") .invokeVirtual(OperatorContext.class, "getSession", ConnectorSession.class) .putField(sessionField) .ret(); generateFilterAndProjectRowOriented(classDefinition, filter, projections); // // filter method // generateFilterMethod(callSiteBinder, classDefinition, filter, true); generateFilterMethod(callSiteBinder, classDefinition, filter, false); // // project methods // List<Type> types = new ArrayList<>(); int projectionIndex = 0; for (RowExpression projection : projections) { generateProjectMethod( callSiteBinder, classDefinition, "project_" + projectionIndex, projection, true); generateProjectMethod( callSiteBinder, classDefinition, "project_" + projectionIndex, projection, false); types.add(projection.getType()); projectionIndex++; } // // toString method // generateToString( classDefinition, toStringHelper(classDefinition.getType().getJavaClassName()) .add("filter", filter) .add("projections", projections) .toString()); Class<? extends Operator> filterAndProjectClass = defineClass(classDefinition, Operator.class, classLoader); setCallSitesField(filterAndProjectClass, callSiteBinder.getBindings()); return new TypedOperatorClass(filterAndProjectClass, types); } public SourceOperatorFactory compileScanFilterAndProjectOperator( int operatorId, PlanNodeId sourceId, DataStreamProvider dataStreamProvider, List<ColumnHandle> columns, RowExpression filter, List<RowExpression> projections) { OperatorCacheKey cacheKey = new OperatorCacheKey(filter, projections, sourceId); return sourceOperatorFactories .getUnchecked(cacheKey) .create(operatorId, dataStreamProvider, columns); } @VisibleForTesting public ScanFilterAndProjectOperatorFactoryFactory internalCompileScanFilterAndProjectOperator( PlanNodeId sourceId, RowExpression filter, List<RowExpression> projections) { DynamicClassLoader classLoader = createClassLoader(); // create filter and project page iterator class TypedOperatorClass typedOperatorClass = compileScanFilterAndProjectOperator(filter, projections, classLoader); Constructor<? extends SourceOperator> constructor; try { constructor = typedOperatorClass .getOperatorClass() .asSubclass(SourceOperator.class) .getConstructor( OperatorContext.class, PlanNodeId.class, DataStreamProvider.class, Iterable.class, Iterable.class); } catch (NoSuchMethodException e) { throw Throwables.propagate(e); } ScanFilterAndProjectOperatorFactoryFactory operatorFactoryFactory = new ScanFilterAndProjectOperatorFactoryFactory( constructor, sourceId, typedOperatorClass.getTypes()); return operatorFactoryFactory; } private TypedOperatorClass compileScanFilterAndProjectOperator( RowExpression filter, List<RowExpression> projections, DynamicClassLoader classLoader) { CallSiteBinder callSiteBinder = new CallSiteBinder(); ClassDefinition classDefinition = new ClassDefinition( new CompilerContext(BOOTSTRAP_METHOD), a(PUBLIC, FINAL), typeFromPathName("ScanFilterAndProjectOperator_" + CLASS_ID.incrementAndGet()), type(AbstractScanFilterAndProjectOperator.class)); // declare fields FieldDefinition sessionField = classDefinition.declareField(a(PRIVATE, FINAL), "session", ConnectorSession.class); classDefinition.declareField(a(PRIVATE, VOLATILE, STATIC), "callSites", Map.class); // constructor classDefinition .declareConstructor( new CompilerContext(BOOTSTRAP_METHOD), a(PUBLIC), arg("operatorContext", OperatorContext.class), arg("sourceId", PlanNodeId.class), arg("dataStreamProvider", DataStreamProvider.class), arg("columns", type(Iterable.class, ColumnHandle.class)), arg("types", type(Iterable.class, Type.class))) .getBody() .comment("super(operatorContext, sourceId, dataStreamProvider, columns, types);") .pushThis() .getVariable("operatorContext") .getVariable("sourceId") .getVariable("dataStreamProvider") .getVariable("columns") .getVariable("types") .invokeConstructor( AbstractScanFilterAndProjectOperator.class, OperatorContext.class, PlanNodeId.class, DataStreamProvider.class, Iterable.class, Iterable.class) .comment("this.session = operatorContext.getSession();") .pushThis() .getVariable("operatorContext") .invokeVirtual(OperatorContext.class, "getSession", ConnectorSession.class) .putField(sessionField) .ret(); generateFilterAndProjectRowOriented(classDefinition, filter, projections); generateFilterAndProjectCursorMethod(classDefinition, projections); // // filter method // generateFilterMethod(callSiteBinder, classDefinition, filter, true); generateFilterMethod(callSiteBinder, classDefinition, filter, false); // // project methods // List<Type> types = new ArrayList<>(); int projectionIndex = 0; for (RowExpression projection : projections) { generateProjectMethod( callSiteBinder, classDefinition, "project_" + projectionIndex, projection, true); generateProjectMethod( callSiteBinder, classDefinition, "project_" + projectionIndex, projection, false); types.add(projection.getType()); projectionIndex++; } // // toString method // generateToString( classDefinition, toStringHelper(classDefinition.getType().getJavaClassName()) .add("filter", filter) .add("projections", projections) .toString()); Class<? extends SourceOperator> filterAndProjectClass = defineClass(classDefinition, SourceOperator.class, classLoader); setCallSitesField(filterAndProjectClass, callSiteBinder.getBindings()); return new TypedOperatorClass(filterAndProjectClass, types); } private void generateToString(ClassDefinition classDefinition, String string) { // Constant strings can't be too large or the bytecode becomes invalid if (string.length() > 100) { string = string.substring(0, 100) + "..."; } classDefinition .declareMethod( new CompilerContext(BOOTSTRAP_METHOD), a(PUBLIC), "toString", type(String.class)) .getBody() .push(string) .retObject(); } private void generateFilterAndProjectRowOriented( ClassDefinition classDefinition, RowExpression filter, List<RowExpression> projections) { MethodDefinition filterAndProjectMethod = classDefinition.declareMethod( new CompilerContext(BOOTSTRAP_METHOD), a(PUBLIC), "filterAndProjectRowOriented", type(void.class), arg("page", com.facebook.presto.operator.Page.class), arg("pageBuilder", PageBuilder.class)); CompilerContext compilerContext = filterAndProjectMethod.getCompilerContext(); LocalVariableDefinition positionVariable = compilerContext.declareVariable(int.class, "position"); LocalVariableDefinition rowsVariable = compilerContext.declareVariable(int.class, "rows"); filterAndProjectMethod .getBody() .comment("int rows = page.getPositionCount();") .getVariable("page") .invokeVirtual(com.facebook.presto.operator.Page.class, "getPositionCount", int.class) .putVariable(rowsVariable); List<Integer> allInputChannels = getInputChannels(Iterables.concat(projections, ImmutableList.of(filter))); for (int channel : allInputChannels) { LocalVariableDefinition blockVariable = compilerContext.declareVariable( com.facebook.presto.spi.block.Block.class, "block_" + channel); filterAndProjectMethod .getBody() .comment("Block %s = page.getBlock(%s);", blockVariable.getName(), channel) .getVariable("page") .push(channel) .invokeVirtual( com.facebook.presto.operator.Page.class, "getBlock", com.facebook.presto.spi.block.Block.class, int.class) .putVariable(blockVariable); } // // for loop body // // for (position = 0; position < rows; position++) ForLoopBuilder forLoop = forLoopBuilder(compilerContext) .comment("for (position = 0; position < rows; position++)") .initialize(new Block(compilerContext).putVariable(positionVariable, 0)) .condition( new Block(compilerContext) .getVariable(positionVariable) .getVariable(rowsVariable) .invokeStatic( CompilerOperations.class, "lessThan", boolean.class, int.class, int.class)) .update(new Block(compilerContext).incrementVariable(positionVariable, (byte) 1)); Block forLoopBody = new Block(compilerContext); IfStatementBuilder ifStatement = new IfStatementBuilder(compilerContext).comment("if (filter(position, blocks...)"); Block condition = new Block(compilerContext); condition.pushThis(); condition.getVariable(positionVariable); List<Integer> filterInputChannels = getInputChannels(filter); for (int channel : filterInputChannels) { condition.getVariable("block_" + channel); } condition.invokeVirtual( classDefinition.getType(), "filter", type(boolean.class), ImmutableList.<ParameterizedType>builder() .add(type(int.class)) .addAll( nCopies( filterInputChannels.size(), type(com.facebook.presto.spi.block.Block.class))) .build()); ifStatement.condition(condition); Block trueBlock = new Block(compilerContext); if (projections.isEmpty()) { trueBlock .comment("pageBuilder.declarePosition()") .getVariable("pageBuilder") .invokeVirtual(PageBuilder.class, "declarePosition", void.class); } else { // pageBuilder.getBlockBuilder(0).append(block.getDouble(0); for (int projectionIndex = 0; projectionIndex < projections.size(); projectionIndex++) { trueBlock.comment( "project_%s(position, blocks..., pageBuilder.getBlockBuilder(%s))", projectionIndex, projectionIndex); trueBlock.pushThis(); List<Integer> projectionInputs = getInputChannels(projections.get(projectionIndex)); trueBlock.getVariable(positionVariable); for (int channel : projectionInputs) { trueBlock.getVariable("block_" + channel); } // pageBuilder.getBlockBuilder(0) trueBlock .getVariable("pageBuilder") .push(projectionIndex) .invokeVirtual(PageBuilder.class, "getBlockBuilder", BlockBuilder.class, int.class); // project(position, block_0, block_1, blockBuilder) trueBlock.invokeVirtual( classDefinition.getType(), "project_" + projectionIndex, type(void.class), ImmutableList.<ParameterizedType>builder() .add(type(int.class)) .addAll( nCopies( projectionInputs.size(), type(com.facebook.presto.spi.block.Block.class))) .add(type(BlockBuilder.class)) .build()); } } ifStatement.ifTrue(trueBlock); forLoopBody.append(ifStatement.build()); filterAndProjectMethod.getBody().append(forLoop.body(forLoopBody).build()); filterAndProjectMethod.getBody().ret(); } private void generateFilterAndProjectCursorMethod( ClassDefinition classDefinition, List<RowExpression> projections) { MethodDefinition filterAndProjectMethod = classDefinition.declareMethod( new CompilerContext(BOOTSTRAP_METHOD), a(PUBLIC), "filterAndProjectRowOriented", type(int.class), arg("cursor", RecordCursor.class), arg("pageBuilder", PageBuilder.class)); CompilerContext compilerContext = filterAndProjectMethod.getCompilerContext(); LocalVariableDefinition completedPositionsVariable = compilerContext.declareVariable(int.class, "completedPositions"); filterAndProjectMethod .getBody() .comment("int completedPositions = 0;") .putVariable(completedPositionsVariable, 0); // // for loop loop body // LabelNode done = new LabelNode("done"); ForLoopBuilder forLoop = ForLoop.forLoopBuilder(compilerContext) .initialize(NOP) .condition( new Block(compilerContext) .comment("completedPositions < 16384") .getVariable(completedPositionsVariable) .push(16384) .invokeStatic( CompilerOperations.class, "lessThan", boolean.class, int.class, int.class)) .update( new Block(compilerContext) .comment("completedPositions++") .incrementVariable(completedPositionsVariable, (byte) 1)); Block forLoopBody = new Block(compilerContext); forLoop.body(forLoopBody); forLoopBody .comment("if (pageBuilder.isFull()) break;") .append( new Block(compilerContext) .getVariable("pageBuilder") .invokeVirtual(PageBuilder.class, "isFull", boolean.class) .ifTrueGoto(done)); forLoopBody .comment("if (!cursor.advanceNextPosition()) break;") .append( new Block(compilerContext) .getVariable("cursor") .invokeInterface(RecordCursor.class, "advanceNextPosition", boolean.class) .ifFalseGoto(done)); // if (filter(cursor)) IfStatementBuilder ifStatement = new IfStatementBuilder(compilerContext); ifStatement.condition( new Block(compilerContext) .pushThis() .getVariable("cursor") .invokeVirtual( classDefinition.getType(), "filter", type(boolean.class), type(RecordCursor.class))); Block trueBlock = new Block(compilerContext); ifStatement.ifTrue(trueBlock); if (projections.isEmpty()) { // pageBuilder.declarePosition(); trueBlock .getVariable("pageBuilder") .invokeVirtual(PageBuilder.class, "declarePosition", void.class); } else { // project_43(block..., pageBuilder.getBlockBuilder(42))); for (int projectionIndex = 0; projectionIndex < projections.size(); projectionIndex++) { trueBlock.pushThis(); trueBlock.getVariable("cursor"); // pageBuilder.getBlockBuilder(0) trueBlock .getVariable("pageBuilder") .push(projectionIndex) .invokeVirtual(PageBuilder.class, "getBlockBuilder", BlockBuilder.class, int.class); // project(block..., blockBuilder) trueBlock.invokeVirtual( classDefinition.getType(), "project_" + projectionIndex, type(void.class), type(RecordCursor.class), type(BlockBuilder.class)); } } forLoopBody.append(ifStatement.build()); filterAndProjectMethod .getBody() .append(forLoop.build()) .visitLabel(done) .comment("return completedPositions;") .getVariable("completedPositions") .retInt(); } private void generateFilterMethod( CallSiteBinder callSiteBinder, ClassDefinition classDefinition, RowExpression filter, boolean sourceIsCursor) { MethodDefinition filterMethod; if (sourceIsCursor) { filterMethod = classDefinition.declareMethod( new CompilerContext(BOOTSTRAP_METHOD), a(PUBLIC), "filter", type(boolean.class), arg("cursor", RecordCursor.class)); } else { filterMethod = classDefinition.declareMethod( new CompilerContext(BOOTSTRAP_METHOD), a(PUBLIC), "filter", type(boolean.class), ImmutableList.<NamedParameterDefinition>builder() .add(arg("position", int.class)) .addAll(toBlockParameters(getInputChannels(filter))) .build()); } filterMethod.comment("Filter: %s", filter.toString()); filterMethod.getCompilerContext().declareVariable(type(boolean.class), "wasNull"); Block getSessionByteCode = new Block(filterMethod.getCompilerContext()) .pushThis() .getField(classDefinition.getType(), "session", type(ConnectorSession.class)); ByteCodeNode body = compileExpression( callSiteBinder, filter, sourceIsCursor, filterMethod.getCompilerContext(), getSessionByteCode); LabelNode end = new LabelNode("end"); filterMethod .getBody() .comment("boolean wasNull = false;") .putVariable("wasNull", false) .append(body) .getVariable("wasNull") .ifFalseGoto(end) .pop(boolean.class) .push(false) .visitLabel(end) .retBoolean(); } private ByteCodeNode compileExpression( CallSiteBinder callSiteBinder, RowExpression expression, boolean sourceIsCursor, CompilerContext context, Block getSessionByteCode) { ByteCodeExpressionVisitor visitor = new ByteCodeExpressionVisitor( callSiteBinder, getSessionByteCode, metadata.getFunctionRegistry(), sourceIsCursor); return expression.accept(visitor, context); } private Class<?> generateProjectMethod( CallSiteBinder callSiteBinder, ClassDefinition classDefinition, String methodName, RowExpression projection, boolean sourceIsCursor) { MethodDefinition projectionMethod; if (sourceIsCursor) { projectionMethod = classDefinition.declareMethod( new CompilerContext(BOOTSTRAP_METHOD), a(PUBLIC), methodName, type(void.class), arg("cursor", RecordCursor.class), arg("output", BlockBuilder.class)); } else { ImmutableList.Builder<NamedParameterDefinition> parameters = ImmutableList.builder(); parameters.add(arg("position", int.class)); parameters.addAll(toBlockParameters(getInputChannels(projection))); parameters.add(arg("output", BlockBuilder.class)); projectionMethod = classDefinition.declareMethod( new CompilerContext(BOOTSTRAP_METHOD), a(PUBLIC), methodName, type(void.class), parameters.build()); } projectionMethod.comment("Projection: %s", projection.toString()); // generate body code CompilerContext context = projectionMethod.getCompilerContext(); context.declareVariable(type(boolean.class), "wasNull"); Block getSessionByteCode = new Block(context) .pushThis() .getField(classDefinition.getType(), "session", type(ConnectorSession.class)); ByteCodeNode body = compileExpression(callSiteBinder, projection, sourceIsCursor, context, getSessionByteCode); projectionMethod .getBody() .comment("boolean wasNull = false;") .putVariable("wasNull", false) .getVariable("output") .append(body); Type projectionType = projection.getType(); Block notNullBlock = new Block(context); if (projectionType.getJavaType() == boolean.class) { notNullBlock .comment("output.append(<booleanStackValue>);") .invokeInterface(BlockBuilder.class, "appendBoolean", BlockBuilder.class, boolean.class) .pop(); } else if (projectionType.getJavaType() == long.class) { notNullBlock .comment("output.append(<longStackValue>);") .invokeInterface(BlockBuilder.class, "appendLong", BlockBuilder.class, long.class) .pop(); } else if (projectionType.getJavaType() == double.class) { notNullBlock .comment("output.append(<doubleStackValue>);") .invokeInterface(BlockBuilder.class, "appendDouble", BlockBuilder.class, double.class) .pop(); } else if (projectionType.getJavaType() == Slice.class) { notNullBlock .comment("output.append(<sliceStackValue>);") .invokeInterface(BlockBuilder.class, "appendSlice", BlockBuilder.class, Slice.class) .pop(); } else { throw new UnsupportedOperationException("Type " + projectionType + " can not be output yet"); } Block nullBlock = new Block(context) .comment("output.appendNull();") .pop(projectionType.getJavaType()) .invokeInterface(BlockBuilder.class, "appendNull", BlockBuilder.class) .pop(); projectionMethod .getBody() .comment("if the result was null, appendNull; otherwise append the value") .append( new IfStatement( context, new Block(context).getVariable("wasNull"), nullBlock, notNullBlock)) .ret(); return projectionType.getJavaType(); } private static List<Integer> getInputChannels(RowExpression expression) { return getInputChannels(ImmutableList.of(expression)); } private static List<Integer> getInputChannels(Iterable<RowExpression> expressions) { TreeSet<Integer> channels = new TreeSet<>(); for (RowExpression expression : Expressions.subExpressions(expressions)) { if (expression instanceof InputReferenceExpression) { channels.add(((InputReferenceExpression) expression).getField()); } } return ImmutableList.copyOf(channels); } private static class TypedOperatorClass { private final Class<? extends Operator> operatorClass; private final List<Type> types; private TypedOperatorClass(Class<? extends Operator> operatorClass, List<Type> types) { this.operatorClass = operatorClass; this.types = types; } private Class<? extends Operator> getOperatorClass() { return operatorClass; } private List<Type> getTypes() { return types; } } private static List<NamedParameterDefinition> toBlockParameters(List<Integer> inputChannels) { ImmutableList.Builder<NamedParameterDefinition> parameters = ImmutableList.builder(); for (int channel : inputChannels) { parameters.add(arg("block_" + channel, com.facebook.presto.spi.block.Block.class)); } return parameters.build(); } private <T> Class<? extends T> defineClass( ClassDefinition classDefinition, Class<T> superType, DynamicClassLoader classLoader) { Class<?> clazz = defineClasses(ImmutableList.of(classDefinition), classLoader).values().iterator().next(); return clazz.asSubclass(superType); } private Map<String, Class<?>> defineClasses( List<ClassDefinition> classDefinitions, DynamicClassLoader classLoader) { ClassInfoLoader classInfoLoader = ClassInfoLoader.createClassInfoLoader(classDefinitions, classLoader); if (DUMP_BYTE_CODE_TREE) { DumpByteCodeVisitor dumpByteCode = new DumpByteCodeVisitor(System.out); for (ClassDefinition classDefinition : classDefinitions) { dumpByteCode.visitClass(classDefinition); } } Map<String, byte[]> byteCodes = new LinkedHashMap<>(); for (ClassDefinition classDefinition : classDefinitions) { ClassWriter cw = new SmartClassWriter(classInfoLoader); classDefinition.visit(cw); byte[] byteCode = cw.toByteArray(); if (RUN_ASM_VERIFIER) { ClassReader reader = new ClassReader(byteCode); CheckClassAdapter.verify(reader, classLoader, true, new PrintWriter(System.out)); } byteCodes.put(classDefinition.getType().getJavaClassName(), byteCode); } String dumpClassPath = DUMP_CLASS_FILES_TO.get(); if (dumpClassPath != null) { for (Entry<String, byte[]> entry : byteCodes.entrySet()) { File file = new File( dumpClassPath, ParameterizedType.typeFromJavaClassName(entry.getKey()).getClassName() + ".class"); try { log.debug("ClassFile: " + file.getAbsolutePath()); Files.createParentDirs(file); Files.write(entry.getValue(), file); } catch (IOException e) { log.error(e, "Failed to write generated class file to: %s" + file.getAbsolutePath()); } } } if (DUMP_BYTE_CODE_RAW) { for (byte[] byteCode : byteCodes.values()) { ClassReader classReader = new ClassReader(byteCode); classReader.accept( new TraceClassVisitor(new PrintWriter(System.err)), ClassReader.SKIP_FRAMES); } } Map<String, Class<?>> classes = classLoader.defineClasses(byteCodes); generatedClasses.addAndGet(classes.size()); return classes; } private static void setCallSitesField(Class<?> clazz, Map<Long, MethodHandle> callSites) { try { Field field = clazz.getDeclaredField("callSites"); field.setAccessible(true); field.set(null, callSites); } catch (IllegalAccessException | NoSuchFieldException e) { throw Throwables.propagate(e); } } private static final class OperatorCacheKey { private final RowExpression filter; private final List<RowExpression> projections; private final PlanNodeId sourceId; private OperatorCacheKey( RowExpression filter, List<RowExpression> projections, PlanNodeId sourceId) { this.filter = filter; this.projections = ImmutableList.copyOf(projections); this.sourceId = sourceId; } private RowExpression getFilter() { return filter; } private List<RowExpression> getProjections() { return projections; } private PlanNodeId getSourceId() { return sourceId; } @Override public int hashCode() { return Objects.hashCode(filter, projections, sourceId); } @Override public boolean equals(Object obj) { if (this == obj) { return true; } if (obj == null || getClass() != obj.getClass()) { return false; } OperatorCacheKey other = (OperatorCacheKey) obj; return Objects.equal(this.filter, other.filter) && Objects.equal(this.sourceId, other.sourceId) && Objects.equal(this.projections, other.projections); } @Override public String toString() { return toStringHelper(this) .add("filter", filter) .add("projections", projections) .add("sourceId", sourceId) .toString(); } } private static class FilterAndProjectOperatorFactoryFactory { private final Constructor<? extends Operator> constructor; private final List<Type> types; public FilterAndProjectOperatorFactoryFactory( Constructor<? extends Operator> constructor, List<Type> types) { this.constructor = checkNotNull(constructor, "constructor is null"); this.types = ImmutableList.copyOf(checkNotNull(types, "types is null")); } public OperatorFactory create(int operatorId) { return new FilterAndProjectOperatorFactory(constructor, operatorId, types); } } private static class FilterAndProjectOperatorFactory implements OperatorFactory { private final Constructor<? extends Operator> constructor; private final int operatorId; private final List<Type> types; private boolean closed; public FilterAndProjectOperatorFactory( Constructor<? extends Operator> constructor, int operatorId, List<Type> types) { this.constructor = checkNotNull(constructor, "constructor is null"); this.operatorId = operatorId; this.types = ImmutableList.copyOf(checkNotNull(types, "types is null")); } @Override public List<Type> getTypes() { return types; } @Override public Operator createOperator(DriverContext driverContext) { checkState(!closed, "Factory is already closed"); OperatorContext operatorContext = driverContext.addOperatorContext( operatorId, constructor.getDeclaringClass().getSimpleName()); try { return constructor.newInstance(operatorContext, types); } catch (InvocationTargetException e) { throw Throwables.propagate(e.getCause()); } catch (ReflectiveOperationException e) { throw Throwables.propagate(e); } } @Override public void close() { closed = true; } } private static class ScanFilterAndProjectOperatorFactoryFactory { private final Constructor<? extends SourceOperator> constructor; private final PlanNodeId sourceId; private final List<Type> types; public ScanFilterAndProjectOperatorFactoryFactory( Constructor<? extends SourceOperator> constructor, PlanNodeId sourceId, List<Type> types) { this.sourceId = checkNotNull(sourceId, "sourceId is null"); this.constructor = checkNotNull(constructor, "constructor is null"); this.types = ImmutableList.copyOf(checkNotNull(types, "types is null")); } public SourceOperatorFactory create( int operatorId, DataStreamProvider dataStreamProvider, List<ColumnHandle> columns) { return new ScanFilterAndProjectOperatorFactory( constructor, operatorId, sourceId, dataStreamProvider, columns, types); } } private static class ScanFilterAndProjectOperatorFactory implements SourceOperatorFactory { private final Constructor<? extends SourceOperator> constructor; private final int operatorId; private final PlanNodeId sourceId; private final DataStreamProvider dataStreamProvider; private final List<ColumnHandle> columns; private final List<Type> types; private boolean closed; public ScanFilterAndProjectOperatorFactory( Constructor<? extends SourceOperator> constructor, int operatorId, PlanNodeId sourceId, DataStreamProvider dataStreamProvider, List<ColumnHandle> columns, List<Type> types) { this.constructor = checkNotNull(constructor, "constructor is null"); this.operatorId = operatorId; this.sourceId = checkNotNull(sourceId, "sourceId is null"); this.dataStreamProvider = checkNotNull(dataStreamProvider, "dataStreamProvider is null"); this.columns = ImmutableList.copyOf(checkNotNull(columns, "columns is null")); this.types = ImmutableList.copyOf(checkNotNull(types, "types is null")); } @Override public PlanNodeId getSourceId() { return sourceId; } @Override public List<Type> getTypes() { return types; } @Override public SourceOperator createOperator(DriverContext driverContext) { checkState(!closed, "Factory is already closed"); OperatorContext operatorContext = driverContext.addOperatorContext( operatorId, constructor.getDeclaringClass().getSimpleName()); try { return constructor.newInstance( operatorContext, sourceId, dataStreamProvider, columns, types); } catch (InvocationTargetException e) { throw Throwables.propagate(e.getCause()); } catch (ReflectiveOperationException e) { throw Throwables.propagate(e); } } @Override public void close() { closed = true; } } }
public class HttpRemoteTask implements RemoteTask { private static final Logger log = Logger.get(HttpRemoteTask.class); private static final Duration MAX_CLEANUP_RETRY_TIME = new Duration(2, TimeUnit.MINUTES); private final TaskId taskId; private final Session session; private final String nodeId; private final PlanFragment planFragment; private final AtomicLong nextSplitId = new AtomicLong(); private final StateMachine<TaskInfo> taskInfo; @GuardedBy("this") private Future<?> currentRequest; @GuardedBy("this") private long currentRequestStartNanos; @GuardedBy("this") private final SetMultimap<PlanNodeId, ScheduledSplit> pendingSplits = HashMultimap.create(); @GuardedBy("this") private volatile int pendingSourceSplitCount; @GuardedBy("this") private final Set<PlanNodeId> noMoreSplits = new HashSet<>(); @GuardedBy("this") private final AtomicReference<OutputBuffers> outputBuffers = new AtomicReference<>(); private final ContinuousTaskInfoFetcher continuousTaskInfoFetcher; private final HttpClient httpClient; private final Executor executor; private final ScheduledExecutorService errorScheduledExecutor; private final JsonCodec<TaskInfo> taskInfoCodec; private final JsonCodec<TaskUpdateRequest> taskUpdateRequestCodec; private final RequestErrorTracker updateErrorTracker; private final RequestErrorTracker getErrorTracker; private final AtomicBoolean needsUpdate = new AtomicBoolean(true); private final SplitCountChangeListener splitCountChangeListener; public HttpRemoteTask( Session session, TaskId taskId, String nodeId, URI location, PlanFragment planFragment, Multimap<PlanNodeId, Split> initialSplits, OutputBuffers outputBuffers, HttpClient httpClient, Executor executor, ScheduledExecutorService errorScheduledExecutor, Duration minErrorDuration, Duration refreshMaxWait, JsonCodec<TaskInfo> taskInfoCodec, JsonCodec<TaskUpdateRequest> taskUpdateRequestCodec, SplitCountChangeListener splitCountChangeListener) { requireNonNull(session, "session is null"); requireNonNull(taskId, "taskId is null"); requireNonNull(nodeId, "nodeId is null"); requireNonNull(location, "location is null"); requireNonNull(planFragment, "planFragment1 is null"); requireNonNull(outputBuffers, "outputBuffers is null"); requireNonNull(httpClient, "httpClient is null"); requireNonNull(executor, "executor is null"); requireNonNull(taskInfoCodec, "taskInfoCodec is null"); requireNonNull(taskUpdateRequestCodec, "taskUpdateRequestCodec is null"); requireNonNull(splitCountChangeListener, "splitCountChangeListener is null"); try (SetThreadName ignored = new SetThreadName("HttpRemoteTask-%s", taskId)) { this.taskId = taskId; this.session = session; this.nodeId = nodeId; this.planFragment = planFragment; this.outputBuffers.set(outputBuffers); this.httpClient = httpClient; this.executor = executor; this.errorScheduledExecutor = errorScheduledExecutor; this.taskInfoCodec = taskInfoCodec; this.taskUpdateRequestCodec = taskUpdateRequestCodec; this.updateErrorTracker = new RequestErrorTracker( taskId, location, minErrorDuration, errorScheduledExecutor, "updating task"); this.getErrorTracker = new RequestErrorTracker( taskId, location, minErrorDuration, errorScheduledExecutor, "getting info for task"); this.splitCountChangeListener = splitCountChangeListener; for (Entry<PlanNodeId, Split> entry : requireNonNull(initialSplits, "initialSplits is null").entries()) { ScheduledSplit scheduledSplit = new ScheduledSplit(nextSplitId.getAndIncrement(), entry.getValue()); pendingSplits.put(entry.getKey(), scheduledSplit); } if (initialSplits.containsKey(planFragment.getPartitionedSource())) { pendingSourceSplitCount = initialSplits.get(planFragment.getPartitionedSource()).size(); fireSplitCountChanged(pendingSourceSplitCount); } List<BufferInfo> bufferStates = outputBuffers .getBuffers() .keySet() .stream() .map(outputId -> new BufferInfo(outputId, false, 0, 0, PageBufferInfo.empty())) .collect(toImmutableList()); TaskStats taskStats = new TaskStats(DateTime.now(), null); taskInfo = new StateMachine<>( "task " + taskId, executor, new TaskInfo( taskId, Optional.empty(), TaskInfo.MIN_VERSION, TaskState.PLANNED, location, DateTime.now(), new SharedBufferInfo(BufferState.OPEN, true, true, 0, 0, 0, 0, bufferStates), ImmutableSet.<PlanNodeId>of(), taskStats, ImmutableList.<ExecutionFailureInfo>of())); continuousTaskInfoFetcher = new ContinuousTaskInfoFetcher(refreshMaxWait); } } @Override public TaskId getTaskId() { return taskId; } @Override public String getNodeId() { return nodeId; } @Override public TaskInfo getTaskInfo() { return taskInfo.get(); } @Override public void start() { try (SetThreadName ignored = new SetThreadName("HttpRemoteTask-%s", taskId)) { // to start we just need to trigger an update scheduleUpdate(); // begin the info fetcher continuousTaskInfoFetcher.start(); } } @Override public synchronized void addSplits(PlanNodeId sourceId, Iterable<Split> splits) { try (SetThreadName ignored = new SetThreadName("HttpRemoteTask-%s", taskId)) { requireNonNull(sourceId, "sourceId is null"); requireNonNull(splits, "splits is null"); checkState( !noMoreSplits.contains(sourceId), "noMoreSplits has already been set for %s", sourceId); // only add pending split if not done if (!getTaskInfo().getState().isDone()) { int added = 0; for (Split split : splits) { if (pendingSplits.put( sourceId, new ScheduledSplit(nextSplitId.getAndIncrement(), split))) { added++; } } if (sourceId.equals(planFragment.getPartitionedSource())) { pendingSourceSplitCount += added; fireSplitCountChanged(added); } needsUpdate.set(true); } scheduleUpdate(); } } @Override public synchronized void noMoreSplits(PlanNodeId sourceId) { if (noMoreSplits.add(sourceId)) { needsUpdate.set(true); scheduleUpdate(); } } @Override public synchronized void setOutputBuffers(OutputBuffers newOutputBuffers) { if (getTaskInfo().getState().isDone()) { return; } if (newOutputBuffers.getVersion() > outputBuffers.get().getVersion()) { outputBuffers.set(newOutputBuffers); needsUpdate.set(true); scheduleUpdate(); } } @Override public int getPartitionedSplitCount() { return pendingSourceSplitCount + taskInfo.get().getStats().getQueuedPartitionedDrivers() + taskInfo.get().getStats().getRunningPartitionedDrivers(); } @Override public int getQueuedPartitionedSplitCount() { return pendingSourceSplitCount + taskInfo.get().getStats().getQueuedPartitionedDrivers(); } @Override public void addStateChangeListener(StateChangeListener<TaskInfo> stateChangeListener) { try (SetThreadName ignored = new SetThreadName("HttpRemoteTask-%s", taskId)) { taskInfo.addStateChangeListener(stateChangeListener); } } @Override public CompletableFuture<TaskInfo> getStateChange(TaskInfo taskInfo) { return this.taskInfo.getStateChange(taskInfo); } private synchronized void updateTaskInfo(TaskInfo newValue) { updateTaskInfo(newValue, ImmutableList.of()); } private synchronized void updateTaskInfo(TaskInfo newValue, List<TaskSource> sources) { if (newValue.getState().isDone()) { // splits can be huge so clear the list pendingSplits.clear(); fireSplitCountChanged(-pendingSourceSplitCount); pendingSourceSplitCount = 0; } int oldPartitionedSplitCount = getPartitionedSplitCount(); // change to new value if old value is not changed and new value has a newer version AtomicBoolean workerRestarted = new AtomicBoolean(); boolean updated = taskInfo.setIf( newValue, oldValue -> { // did the worker restart if (oldValue.getNodeInstanceId().isPresent() && !oldValue.getNodeInstanceId().equals(newValue.getNodeInstanceId())) { workerRestarted.set(true); return false; } if (oldValue.getState().isDone()) { // never update if the task has reached a terminal state return false; } if (newValue.getVersion() < oldValue.getVersion()) { // don't update to an older version (same version is ok) return false; } return true; }); if (workerRestarted.get()) { PrestoException exception = new PrestoException( WORKER_RESTARTED, format("%s (%s)", WORKER_RESTARTED_ERROR, newValue.getSelf())); failTask(exception); abort(); } // remove acknowledged splits, which frees memory for (TaskSource source : sources) { PlanNodeId planNodeId = source.getPlanNodeId(); int removed = 0; for (ScheduledSplit split : source.getSplits()) { if (pendingSplits.remove(planNodeId, split)) { removed++; } } if (planNodeId.equals(planFragment.getPartitionedSource())) { pendingSourceSplitCount -= removed; } } if (updated) { if (getTaskInfo().getState().isDone()) { fireSplitCountChanged(-oldPartitionedSplitCount); } else { fireSplitCountChanged(getPartitionedSplitCount() - oldPartitionedSplitCount); } } } private void fireSplitCountChanged(int delta) { if (delta != 0) { splitCountChangeListener.splitCountChanged(delta); } } private synchronized void scheduleUpdate() { // don't update if the task hasn't been started yet or if it is already finished if (!needsUpdate.get() || taskInfo.get().getState().isDone()) { return; } // if we have an old request outstanding, cancel it if (currentRequest != null && Duration.nanosSince(currentRequestStartNanos).compareTo(new Duration(2, SECONDS)) >= 0) { needsUpdate.set(true); currentRequest.cancel(true); currentRequest = null; currentRequestStartNanos = 0; } // if there is a request already running, wait for it to complete if (this.currentRequest != null && !this.currentRequest.isDone()) { return; } // if throttled due to error, asynchronously wait for timeout and try again ListenableFuture<?> errorRateLimit = updateErrorTracker.acquireRequestPermit(); if (!errorRateLimit.isDone()) { errorRateLimit.addListener(this::scheduleUpdate, executor); return; } List<TaskSource> sources = getSources(); TaskUpdateRequest updateRequest = new TaskUpdateRequest( session.toSessionRepresentation(), planFragment, sources, outputBuffers.get()); Request request = preparePost() .setUri(uriBuilderFrom(taskInfo.get().getSelf()).addParameter("summarize").build()) .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.JSON_UTF_8.toString()) .setBodyGenerator(jsonBodyGenerator(taskUpdateRequestCodec, updateRequest)) .build(); updateErrorTracker.startRequest(); ListenableFuture<JsonResponse<TaskInfo>> future = httpClient.executeAsync(request, createFullJsonResponseHandler(taskInfoCodec)); currentRequest = future; currentRequestStartNanos = System.nanoTime(); // The needsUpdate flag needs to be set to false BEFORE adding the Future callback since // callback might change the flag value // and does so without grabbing the instance lock. needsUpdate.set(false); Futures.addCallback( future, new SimpleHttpResponseHandler<>(new UpdateResponseHandler(sources), request.getUri()), executor); } private synchronized List<TaskSource> getSources() { return Stream.concat( Stream.of(planFragment.getPartitionedSourceNode()), planFragment.getRemoteSourceNodes().stream()) .filter(Objects::nonNull) .map(PlanNode::getId) .map(this::getSource) .filter(Objects::nonNull) .collect(toImmutableList()); } private TaskSource getSource(PlanNodeId planNodeId) { Set<ScheduledSplit> splits = pendingSplits.get(planNodeId); boolean noMoreSplits = this.noMoreSplits.contains(planNodeId); TaskSource element = null; if (!splits.isEmpty() || noMoreSplits) { element = new TaskSource(planNodeId, splits, noMoreSplits); } return element; } @Override public synchronized void cancel() { try (SetThreadName ignored = new SetThreadName("HttpRemoteTask-%s", taskId)) { if (getTaskInfo().getState().isDone()) { return; } URI uri = getTaskInfo().getSelf(); if (uri == null) { return; } // send cancel to task and ignore response Request request = prepareDelete() .setUri( uriBuilderFrom(uri) .addParameter("abort", "false") .addParameter("summarize") .build()) .build(); scheduleAsyncCleanupRequest(new Backoff(MAX_CLEANUP_RETRY_TIME), request, "cancel"); } } @Override public synchronized void abort() { try (SetThreadName ignored = new SetThreadName("HttpRemoteTask-%s", taskId)) { // clear pending splits to free memory fireSplitCountChanged(-pendingSourceSplitCount); pendingSplits.clear(); pendingSourceSplitCount = 0; // cancel pending request if (currentRequest != null) { currentRequest.cancel(true); currentRequest = null; currentRequestStartNanos = 0; } // mark task as canceled (if not already done) TaskInfo taskInfo = getTaskInfo(); URI uri = taskInfo.getSelf(); updateTaskInfo( new TaskInfo( taskInfo.getTaskId(), taskInfo.getNodeInstanceId(), TaskInfo.MAX_VERSION, TaskState.ABORTED, uri, taskInfo.getLastHeartbeat(), taskInfo.getOutputBuffers(), taskInfo.getNoMoreSplits(), taskInfo.getStats(), ImmutableList.<ExecutionFailureInfo>of())); // send abort to task and ignore response Request request = prepareDelete().setUri(uriBuilderFrom(uri).addParameter("summarize").build()).build(); scheduleAsyncCleanupRequest(new Backoff(MAX_CLEANUP_RETRY_TIME), request, "abort"); } } private void scheduleAsyncCleanupRequest(Backoff cleanupBackoff, Request request, String action) { Futures.addCallback( httpClient.executeAsync(request, createStatusResponseHandler()), new FutureCallback<StatusResponse>() { @Override public void onSuccess(StatusResponse result) { // assume any response is good enough } @Override public void onFailure(Throwable t) { if (t instanceof RejectedExecutionException) { // client has been shutdown return; } // record failure if (cleanupBackoff.failure()) { logError(t, "Unable to %s task at %s", action, request.getUri()); return; } // reschedule long delayNanos = cleanupBackoff.getBackoffDelayNanos(); if (delayNanos == 0) { scheduleAsyncCleanupRequest(cleanupBackoff, request, action); } else { errorScheduledExecutor.schedule( () -> scheduleAsyncCleanupRequest(cleanupBackoff, request, action), delayNanos, NANOSECONDS); } } }, executor); } /** Move the task directly to the failed state */ private void failTask(Throwable cause) { TaskInfo taskInfo = getTaskInfo(); if (!taskInfo.getState().isDone()) { log.debug(cause, "Remote task failed: %s", taskInfo.getSelf()); } updateTaskInfo( new TaskInfo( taskInfo.getTaskId(), taskInfo.getNodeInstanceId(), TaskInfo.MAX_VERSION, TaskState.FAILED, taskInfo.getSelf(), taskInfo.getLastHeartbeat(), taskInfo.getOutputBuffers(), taskInfo.getNoMoreSplits(), taskInfo.getStats(), ImmutableList.of(toFailure(cause)))); } @Override public String toString() { return toStringHelper(this).addValue(getTaskInfo()).toString(); } private class UpdateResponseHandler implements SimpleHttpResponseCallback<TaskInfo> { private final List<TaskSource> sources; private UpdateResponseHandler(List<TaskSource> sources) { this.sources = ImmutableList.copyOf(requireNonNull(sources, "sources is null")); } @Override public void success(TaskInfo value) { try (SetThreadName ignored = new SetThreadName("UpdateResponseHandler-%s", taskId)) { try { synchronized (HttpRemoteTask.this) { currentRequest = null; } updateTaskInfo(value, sources); updateErrorTracker.requestSucceeded(); } finally { scheduleUpdate(); } } } @Override public void failed(Throwable cause) { try (SetThreadName ignored = new SetThreadName("UpdateResponseHandler-%s", taskId)) { try { synchronized (HttpRemoteTask.this) { currentRequest = null; } // on failure assume we need to update again needsUpdate.set(true); // if task not already done, record error TaskInfo taskInfo = getTaskInfo(); if (!taskInfo.getState().isDone()) { updateErrorTracker.requestFailed(cause); } } catch (Error e) { failTask(e); abort(); throw e; } catch (RuntimeException e) { failTask(e); abort(); } finally { scheduleUpdate(); } } } @Override public void fatal(Throwable cause) { try (SetThreadName ignored = new SetThreadName("UpdateResponseHandler-%s", taskId)) { failTask(cause); } } } /** * Continuous update loop for task info. Wait for a short period for task state to change, and if * it does not, return the current state of the task. This will cause stats to be updated at a * regular interval, and state changes will be immediately recorded. */ private class ContinuousTaskInfoFetcher implements SimpleHttpResponseCallback<TaskInfo> { private final Duration refreshMaxWait; @GuardedBy("this") private boolean running; @GuardedBy("this") private ListenableFuture<JsonResponse<TaskInfo>> future; public ContinuousTaskInfoFetcher(Duration refreshMaxWait) { this.refreshMaxWait = refreshMaxWait; } public synchronized void start() { if (running) { // already running return; } running = true; scheduleNextRequest(); } public synchronized void stop() { running = false; if (future != null) { future.cancel(true); future = null; } } private synchronized void scheduleNextRequest() { // stopped or done? TaskInfo taskInfo = HttpRemoteTask.this.taskInfo.get(); if (!running || taskInfo.getState().isDone()) { return; } // outstanding request? if (future != null && !future.isDone()) { // this should never happen log.error("Can not reschedule update because an update is already running"); return; } // if throttled due to error, asynchronously wait for timeout and try again ListenableFuture<?> errorRateLimit = getErrorTracker.acquireRequestPermit(); if (!errorRateLimit.isDone()) { errorRateLimit.addListener(this::scheduleNextRequest, executor); return; } Request request = prepareGet() .setUri(uriBuilderFrom(taskInfo.getSelf()).addParameter("summarize").build()) .setHeader(HttpHeaders.CONTENT_TYPE, MediaType.JSON_UTF_8.toString()) .setHeader(PrestoHeaders.PRESTO_CURRENT_STATE, taskInfo.getState().toString()) .setHeader(PrestoHeaders.PRESTO_MAX_WAIT, refreshMaxWait.toString()) .build(); getErrorTracker.startRequest(); future = httpClient.executeAsync(request, createFullJsonResponseHandler(taskInfoCodec)); Futures.addCallback( future, new SimpleHttpResponseHandler<>(this, request.getUri()), executor); } @Override public void success(TaskInfo value) { try (SetThreadName ignored = new SetThreadName("ContinuousTaskInfoFetcher-%s", taskId)) { synchronized (this) { future = null; } try { updateTaskInfo(value, ImmutableList.<TaskSource>of()); getErrorTracker.requestSucceeded(); } finally { scheduleNextRequest(); } } } @Override public void failed(Throwable cause) { try (SetThreadName ignored = new SetThreadName("ContinuousTaskInfoFetcher-%s", taskId)) { synchronized (this) { future = null; } try { // if task not already done, record error TaskInfo taskInfo = getTaskInfo(); if (!taskInfo.getState().isDone()) { getErrorTracker.requestFailed(cause); } } catch (Error e) { failTask(e); abort(); throw e; } catch (RuntimeException e) { failTask(e); abort(); } finally { // there is no back off here so we can get a lot of error messages when a server spins // down, but it typically goes away quickly because the queries get canceled scheduleNextRequest(); } } } @Override public void fatal(Throwable cause) { try (SetThreadName ignored = new SetThreadName("ContinuousTaskInfoFetcher-%s", taskId)) { synchronized (this) { future = null; } failTask(cause); } } } public static class SimpleHttpResponseHandler<T> implements FutureCallback<JsonResponse<T>> { private final SimpleHttpResponseCallback<T> callback; private final URI uri; public SimpleHttpResponseHandler(SimpleHttpResponseCallback<T> callback, URI uri) { this.callback = callback; this.uri = uri; } @Override public void onSuccess(JsonResponse<T> response) { try { if (response.getStatusCode() == HttpStatus.OK.code() && response.hasValue()) { callback.success(response.getValue()); } else if (response.getStatusCode() == HttpStatus.SERVICE_UNAVAILABLE.code()) { callback.failed(new ServiceUnavailableException(uri)); } else { // Something is broken in the server or the client, so fail the task immediately (includes // 500 errors) Exception cause = response.getException(); if (cause == null) { if (response.getStatusCode() == HttpStatus.OK.code()) { cause = new PrestoException( REMOTE_TASK_ERROR, format("Expected response from %s is empty", uri)); } else { cause = new PrestoException( REMOTE_TASK_ERROR, format( "Expected response code from %s to be %s, but was %s: %s%n%s", uri, HttpStatus.OK.code(), response.getStatusCode(), response.getStatusMessage(), response.getResponseBody())); } } callback.fatal(cause); } } catch (Throwable t) { // this should never happen callback.fatal(t); } } @Override public void onFailure(Throwable t) { callback.failed(t); } } @ThreadSafe private static class RequestErrorTracker { private final TaskId taskId; private final URI taskUri; private final ScheduledExecutorService scheduledExecutor; private final String jobDescription; private final Backoff backoff; private final Queue<Throwable> errorsSinceLastSuccess = new ConcurrentLinkedQueue<>(); public RequestErrorTracker( TaskId taskId, URI taskUri, Duration minErrorDuration, ScheduledExecutorService scheduledExecutor, String jobDescription) { this.taskId = taskId; this.taskUri = taskUri; this.scheduledExecutor = scheduledExecutor; this.backoff = new Backoff(minErrorDuration); this.jobDescription = jobDescription; } public ListenableFuture<?> acquireRequestPermit() { long delayNanos = backoff.getBackoffDelayNanos(); if (delayNanos == 0) { return Futures.immediateFuture(null); } ListenableFutureTask<Object> futureTask = ListenableFutureTask.create(() -> null); scheduledExecutor.schedule(futureTask, delayNanos, NANOSECONDS); return futureTask; } public void startRequest() { // before scheduling a new request clear the error timer // we consider a request to be "new" if there are no current failures if (backoff.getFailureCount() == 0) { requestSucceeded(); } } public void requestSucceeded() { backoff.success(); errorsSinceLastSuccess.clear(); } public void requestFailed(Throwable reason) throws PrestoException { // cancellation is not a failure if (reason instanceof CancellationException) { return; } if (reason instanceof RejectedExecutionException) { throw new PrestoException(REMOTE_TASK_ERROR, reason); } // log failure message if (isExpectedError(reason)) { // don't print a stack for a known errors log.warn("Error " + jobDescription + " %s: %s: %s", taskId, reason.getMessage(), taskUri); } else { log.warn(reason, "Error " + jobDescription + " %s: %s", taskId, taskUri); } // remember the first 10 errors if (errorsSinceLastSuccess.size() < 10) { errorsSinceLastSuccess.add(reason); } // fail the task, if we have more than X failures in a row and more than Y seconds have passed // since the last request if (backoff.failure()) { // it is weird to mark the task failed locally and then cancel the remote task, but there is // no way to tell a remote task that it is failed PrestoException exception = new PrestoException( TOO_MANY_REQUESTS_FAILED, format( "%s (%s - %s failures, time since last success %s)", WORKER_NODE_ERROR, taskUri, backoff.getFailureCount(), backoff.getTimeSinceLastSuccess().convertTo(SECONDS))); errorsSinceLastSuccess.forEach(exception::addSuppressed); throw exception; } } } public interface SimpleHttpResponseCallback<T> { void success(T value); void failed(Throwable cause); void fatal(Throwable cause); } private static void logError(Throwable t, String format, Object... args) { if (isExpectedError(t)) { log.error(format + ": %s", ObjectArrays.concat(args, t)); } else { log.error(t, format, args); } } private static boolean isExpectedError(Throwable t) { while (t != null) { if ((t instanceof SocketException) || (t instanceof SocketTimeoutException) || (t instanceof EOFException) || (t instanceof TimeoutException) || (t instanceof ServiceUnavailableException)) { return true; } t = t.getCause(); } return false; } private static class ServiceUnavailableException extends RuntimeException { public ServiceUnavailableException(URI uri) { super("Server returned SERVICE_UNAVAILABLE: " + uri); } } }
public class DirectoryDeploymentManager implements DeploymentManager { private static final Logger log = Logger.get(DirectoryDeploymentManager.class); private final JsonCodec<DeploymentRepresentation> jsonCodec = jsonCodec(DeploymentRepresentation.class); private final UUID slotId; private final String location; private final Duration tarTimeout; private final File baseDir; private final File deploymentFile; private Deployment deployment; public DirectoryDeploymentManager(File baseDir, String location, Duration tarTimeout) { Preconditions.checkNotNull(location, "location is null"); Preconditions.checkArgument(location.startsWith("/"), "location must start with /"); this.location = location; this.tarTimeout = tarTimeout; Preconditions.checkNotNull(baseDir, "baseDir is null"); baseDir.mkdirs(); Preconditions.checkArgument( baseDir.isDirectory(), "baseDir is not a directory: " + baseDir.getAbsolutePath()); this.baseDir = baseDir; // verify deployment file is readable and writable deploymentFile = new File(baseDir, "airship-deployment.json"); if (deploymentFile.exists()) { Preconditions.checkArgument( deploymentFile.canRead(), "Can not read slot-id file %s", deploymentFile.getAbsolutePath()); Preconditions.checkArgument( deploymentFile.canWrite(), "Can not write slot-id file %s", deploymentFile.getAbsolutePath()); } // load deployments if (deploymentFile.exists()) { try { Deployment deployment = load(deploymentFile); Preconditions.checkArgument( deployment.getDeploymentDir().isDirectory(), "Deployment directory is not a directory: %s", deployment.getDeploymentDir()); this.deployment = deployment; } catch (IOException e) { throw new IllegalArgumentException( "Invalid deployment file: " + deploymentFile.getAbsolutePath(), e); } } // load slot-id File slotIdFile = new File(baseDir, "airship-slot-id.txt"); UUID uuid = null; if (slotIdFile.exists()) { Preconditions.checkArgument( slotIdFile.canRead(), "can not read " + slotIdFile.getAbsolutePath()); try { String slotIdString = Files.toString(slotIdFile, UTF_8).trim(); try { uuid = UUID.fromString(slotIdString); } catch (IllegalArgumentException e) { } if (uuid == null) { log.warn( "Invalid slot id [" + slotIdString + "]: attempting to delete airship-slot-id.txt file and recreating a new one"); slotIdFile.delete(); } } catch (IOException e) { Preconditions.checkArgument( slotIdFile.canRead(), "can not read " + slotIdFile.getAbsolutePath()); } } if (uuid == null) { uuid = UUID.randomUUID(); try { Files.write(uuid.toString(), slotIdFile, UTF_8); } catch (IOException e) { Preconditions.checkArgument( slotIdFile.canRead(), "can not write " + slotIdFile.getAbsolutePath()); } } slotId = uuid; } @Override public UUID getSlotId() { return slotId; } public String getLocation() { return location; } @Override public Deployment install(Installation installation) { Preconditions.checkNotNull(installation, "installation is null"); File deploymentDir = new File(baseDir, "installation"); Assignment assignment = installation.getAssignment(); Deployment newDeployment = new Deployment( slotId, location, deploymentDir, getDataDir(), assignment, installation.getResources()); File tempDir = createTempDir(baseDir, "tmp-install"); try { // download the binary File binary = new File(tempDir, "airship-binary.tar.gz"); try { Files.copy(Resources.newInputStreamSupplier(installation.getBinaryFile().toURL()), binary); } catch (IOException e) { throw new RuntimeException( "Unable to download binary " + assignment.getBinary() + " from " + installation.getBinaryFile(), e); } // unpack the binary into a temp unpack dir File unpackDir = new File(tempDir, "unpack"); unpackDir.mkdirs(); try { extractTar(binary, unpackDir, tarTimeout); } catch (CommandFailedException e) { throw new RuntimeException( "Unable to extract tar file " + assignment.getBinary() + ": " + e.getMessage()); } // find the archive root dir (it should be the only file in the temp unpack dir) List<File> files = listFiles(unpackDir); if (files.size() != 1) { throw new RuntimeException( "Invalid tar file: file does not have a root directory " + assignment.getBinary()); } File binaryRootDir = files.get(0); // unpack config bundle try { URL url = installation.getConfigFile().toURL(); ConfigUtils.unpackConfig(Resources.newInputStreamSupplier(url), binaryRootDir); } catch (Exception e) { throw new RuntimeException( "Unable to extract config bundle " + assignment.getConfig() + ": " + e.getMessage()); } // installation is good, clear the current deployment if (this.deployment != null) { this.deploymentFile.delete(); deleteRecursively(this.deployment.getDeploymentDir()); this.deployment = null; } // save deployment versions file try { save(newDeployment); } catch (IOException e) { throw new RuntimeException("Unable to save deployment file", e); } // move the binary root directory to the final target try { Files.move(binaryRootDir, deploymentDir); } catch (IOException e) { throw new RuntimeException("Unable to move deployment to final location", e); } } finally { if (!deleteRecursively(tempDir)) { log.warn("Unable to delete temp directory: %s", tempDir.getAbsolutePath()); } } this.deployment = newDeployment; return newDeployment; } @Override public Deployment getDeployment() { return deployment; } @Override public void terminate() { deleteRecursively(baseDir); deployment = null; } @Override public File hackGetDataDir() { return getDataDir(); } public void save(Deployment deployment) throws IOException { String json = jsonCodec.toJson(DeploymentRepresentation.from(deployment)); Files.write(json, deploymentFile, UTF_8); } public Deployment load(File deploymentFile) throws IOException { String json = Files.toString(deploymentFile, UTF_8); DeploymentRepresentation data = jsonCodec.fromJson(json); File deploymentDir = new File(baseDir, "installation"); if (!deploymentDir.isDirectory()) { deploymentDir = new File(baseDir, "deployment"); } Deployment deployment = data.toDeployment(deploymentDir, getDataDir(), location); return deployment; } private File getDataDir() { File dataDir = new File(baseDir, "data"); dataDir.mkdirs(); if (!dataDir.isDirectory()) { throw new RuntimeException( String.format("Unable to create data dir %s", dataDir.getAbsolutePath())); } return dataDir; } }
public class PrestoS3FileSystem extends FileSystem { public static final String S3_SSL_ENABLED = "presto.s3.ssl.enabled"; public static final String S3_MAX_ERROR_RETRIES = "presto.s3.max-error-retries"; public static final String S3_MAX_CLIENT_RETRIES = "presto.s3.max-client-retries"; public static final String S3_MAX_BACKOFF_TIME = "presto.s3.max-backoff-time"; public static final String S3_MAX_RETRY_TIME = "presto.s3.max-retry-time"; public static final String S3_CONNECT_TIMEOUT = "presto.s3.connect-timeout"; public static final String S3_SOCKET_TIMEOUT = "presto.s3.socket-timeout"; public static final String S3_MAX_CONNECTIONS = "presto.s3.max-connections"; public static final String S3_STAGING_DIRECTORY = "presto.s3.staging-directory"; public static final String S3_MULTIPART_MIN_FILE_SIZE = "presto.s3.multipart.min-file-size"; public static final String S3_MULTIPART_MIN_PART_SIZE = "presto.s3.multipart.min-part-size"; private static final Logger log = Logger.get(PrestoS3FileSystem.class); private static final DataSize BLOCK_SIZE = new DataSize(32, MEGABYTE); private static final DataSize MAX_SKIP_SIZE = new DataSize(1, MEGABYTE); private final TransferManagerConfiguration transferConfig = new TransferManagerConfiguration(); private URI uri; private Path workingDirectory; private AmazonS3 s3; private File stagingDirectory; private int maxClientRetries; private Duration maxBackoffTime; private Duration maxRetryTime; @Override public void initialize(URI uri, Configuration conf) throws IOException { checkNotNull(uri, "uri is null"); checkNotNull(conf, "conf is null"); super.initialize(uri, conf); setConf(conf); this.uri = URI.create(uri.getScheme() + "://" + uri.getAuthority()); this.workingDirectory = new Path("/").makeQualified(this.uri, new Path("/")); HiveClientConfig defaults = new HiveClientConfig(); this.stagingDirectory = new File(conf.get(S3_STAGING_DIRECTORY, defaults.getS3StagingDirectory().toString())); this.maxClientRetries = conf.getInt(S3_MAX_CLIENT_RETRIES, defaults.getS3MaxClientRetries()); this.maxBackoffTime = Duration.valueOf(conf.get(S3_MAX_BACKOFF_TIME, defaults.getS3MaxBackoffTime().toString())); this.maxRetryTime = Duration.valueOf(conf.get(S3_MAX_RETRY_TIME, defaults.getS3MaxRetryTime().toString())); int maxErrorRetries = conf.getInt(S3_MAX_ERROR_RETRIES, defaults.getS3MaxErrorRetries()); boolean sslEnabled = conf.getBoolean(S3_SSL_ENABLED, defaults.isS3SslEnabled()); Duration connectTimeout = Duration.valueOf(conf.get(S3_CONNECT_TIMEOUT, defaults.getS3ConnectTimeout().toString())); Duration socketTimeout = Duration.valueOf(conf.get(S3_SOCKET_TIMEOUT, defaults.getS3SocketTimeout().toString())); int maxConnections = conf.getInt(S3_MAX_CONNECTIONS, defaults.getS3MaxConnections()); long minFileSize = conf.getLong(S3_MULTIPART_MIN_FILE_SIZE, defaults.getS3MultipartMinFileSize().toBytes()); long minPartSize = conf.getLong(S3_MULTIPART_MIN_PART_SIZE, defaults.getS3MultipartMinPartSize().toBytes()); ClientConfiguration configuration = new ClientConfiguration(); configuration.setMaxErrorRetry(maxErrorRetries); configuration.setProtocol(sslEnabled ? Protocol.HTTPS : Protocol.HTTP); configuration.setConnectionTimeout(Ints.checkedCast(connectTimeout.toMillis())); configuration.setSocketTimeout(Ints.checkedCast(socketTimeout.toMillis())); configuration.setMaxConnections(maxConnections); this.s3 = new AmazonS3Client(getAwsCredentials(uri, conf), configuration); transferConfig.setMultipartUploadThreshold(minFileSize); transferConfig.setMinimumUploadPartSize(minPartSize); } @Override public URI getUri() { return uri; } @Override public Path getWorkingDirectory() { return workingDirectory; } @Override public void setWorkingDirectory(Path path) { workingDirectory = path; } @Override public FileStatus[] listStatus(Path path) throws IOException { List<LocatedFileStatus> list = new ArrayList<>(); RemoteIterator<LocatedFileStatus> iterator = listLocatedStatus(path); while (iterator.hasNext()) { list.add(iterator.next()); } return toArray(list, LocatedFileStatus.class); } @Override public RemoteIterator<LocatedFileStatus> listLocatedStatus(final Path path) throws IOException { return new RemoteIterator<LocatedFileStatus>() { private final Iterator<LocatedFileStatus> iterator = listPrefix(path); @Override public boolean hasNext() throws IOException { try { return iterator.hasNext(); } catch (AmazonClientException e) { throw new IOException(e); } } @Override public LocatedFileStatus next() throws IOException { try { return iterator.next(); } catch (AmazonClientException e) { throw new IOException(e); } } }; } @Override public FileStatus getFileStatus(Path path) throws IOException { if (path.getName().isEmpty()) { // the bucket root requires special handling if (getS3ObjectMetadata(path) != null) { return new FileStatus(0, true, 1, 0, 0, qualifiedPath(path)); } throw new FileNotFoundException("File does not exist: " + path); } ObjectMetadata metadata = getS3ObjectMetadata(path); if (metadata == null) { // check if this path is a directory Iterator<LocatedFileStatus> iterator = listPrefix(path); if ((iterator != null) && iterator.hasNext()) { return new FileStatus(0, true, 1, 0, 0, qualifiedPath(path)); } throw new FileNotFoundException("File does not exist: " + path); } return new FileStatus( metadata.getContentLength(), false, 1, BLOCK_SIZE.toBytes(), lastModifiedTime(metadata), qualifiedPath(path)); } @Override public FSDataInputStream open(Path path, int bufferSize) throws IOException { return new FSDataInputStream( new BufferedFSInputStream( new PrestoS3InputStream( s3, uri.getHost(), path, maxClientRetries, maxBackoffTime, maxRetryTime), bufferSize)); } @Override public FSDataOutputStream create( Path path, FsPermission permission, boolean overwrite, int bufferSize, short replication, long blockSize, Progressable progress) throws IOException { if ((!overwrite) && exists(path)) { throw new IOException("File already exists:" + path); } createDirectories(stagingDirectory.toPath()); File tempFile = createTempFile(stagingDirectory.toPath(), "presto-s3-", ".tmp").toFile(); String key = keyFromPath(qualifiedPath(path)); return new FSDataOutputStream( new PrestoS3OutputStream(s3, transferConfig, uri.getHost(), key, tempFile), statistics); } @Override public FSDataOutputStream append(Path f, int bufferSize, Progressable progress) throws IOException { throw new UnsupportedOperationException("append"); } @Override public boolean rename(Path src, Path dst) throws IOException { throw new UnsupportedOperationException("rename"); } @Override public boolean delete(Path f, boolean recursive) throws IOException { throw new UnsupportedOperationException("delete"); } @Override public boolean mkdirs(Path f, FsPermission permission) throws IOException { // no need to do anything for S3 return true; } private Iterator<LocatedFileStatus> listPrefix(Path path) { String key = keyFromPath(path); if (!key.isEmpty()) { key += "/"; } ListObjectsRequest request = new ListObjectsRequest().withBucketName(uri.getHost()).withPrefix(key).withDelimiter("/"); Iterator<ObjectListing> listings = new AbstractSequentialIterator<ObjectListing>(s3.listObjects(request)) { @Override protected ObjectListing computeNext(ObjectListing previous) { if (!previous.isTruncated()) { return null; } return s3.listNextBatchOfObjects(previous); } }; return Iterators.concat(Iterators.transform(listings, this::statusFromListing)); } private Iterator<LocatedFileStatus> statusFromListing(ObjectListing listing) { return Iterators.concat( statusFromPrefixes(listing.getCommonPrefixes()), statusFromObjects(listing.getObjectSummaries())); } private Iterator<LocatedFileStatus> statusFromPrefixes(List<String> prefixes) { List<LocatedFileStatus> list = new ArrayList<>(); for (String prefix : prefixes) { Path path = qualifiedPath(new Path("/" + prefix)); FileStatus status = new FileStatus(0, true, 1, 0, 0, path); list.add(createLocatedFileStatus(status)); } return list.iterator(); } private Iterator<LocatedFileStatus> statusFromObjects(List<S3ObjectSummary> objects) { List<LocatedFileStatus> list = new ArrayList<>(); for (S3ObjectSummary object : objects) { if (!object.getKey().endsWith("/")) { FileStatus status = new FileStatus( object.getSize(), false, 1, BLOCK_SIZE.toBytes(), object.getLastModified().getTime(), qualifiedPath(new Path("/" + object.getKey()))); list.add(createLocatedFileStatus(status)); } } return list.iterator(); } /** * This exception is for stopping retries for S3 calls that shouldn't be retried. For example, * "Caused by: com.amazonaws.services.s3.model.AmazonS3Exception: Forbidden (Service: Amazon S3; * Status Code: 403 ..." */ private static class UnrecoverableS3OperationException extends Exception { public UnrecoverableS3OperationException(Throwable cause) { super(cause); } } private ObjectMetadata getS3ObjectMetadata(final Path path) throws IOException { try { return retry() .maxAttempts(maxClientRetries) .exponentialBackoff(new Duration(1, TimeUnit.SECONDS), maxBackoffTime, maxRetryTime, 2.0) .stopOn(InterruptedException.class, UnrecoverableS3OperationException.class) .run( "getS3ObjectMetadata", () -> { try { return s3.getObjectMetadata(uri.getHost(), keyFromPath(path)); } catch (AmazonS3Exception e) { if (e.getStatusCode() == SC_NOT_FOUND) { return null; } else if (e.getStatusCode() == SC_FORBIDDEN) { throw new UnrecoverableS3OperationException(e); } throw Throwables.propagate(e); } }); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw Throwables.propagate(e); } catch (Exception e) { Throwables.propagateIfInstanceOf(e, IOException.class); throw Throwables.propagate(e); } } private Path qualifiedPath(Path path) { return path.makeQualified(this.uri, getWorkingDirectory()); } private LocatedFileStatus createLocatedFileStatus(FileStatus status) { try { BlockLocation[] fakeLocation = getFileBlockLocations(status, 0, status.getLen()); return new LocatedFileStatus(status, fakeLocation); } catch (IOException e) { throw Throwables.propagate(e); } } private static long lastModifiedTime(ObjectMetadata metadata) { Date date = metadata.getLastModified(); return (date != null) ? date.getTime() : 0; } private static String keyFromPath(Path path) { checkArgument(path.isAbsolute(), "Path is not absolute: %s", path); String key = nullToEmpty(path.toUri().getPath()); if (key.startsWith("/")) { key = key.substring(1); } if (key.endsWith("/")) { key = key.substring(0, key.length() - 1); } return key; } private static AWSCredentials getAwsCredentials(URI uri, Configuration conf) { S3Credentials credentials = new S3Credentials(); credentials.initialize(uri, conf); return new BasicAWSCredentials(credentials.getAccessKey(), credentials.getSecretAccessKey()); } private static class PrestoS3InputStream extends FSInputStream { private final AmazonS3 s3; private final String host; private final Path path; private final int maxClientRetry; private final Duration maxBackoffTime; private final Duration maxRetryTime; private boolean closed; private S3ObjectInputStream in; private long position; public PrestoS3InputStream( AmazonS3 s3, String host, Path path, int maxClientRetry, Duration maxBackoffTime, Duration maxRetryTime) { this.s3 = checkNotNull(s3, "s3 is null"); this.host = checkNotNull(host, "host is null"); this.path = checkNotNull(path, "path is null"); checkArgument(maxClientRetry >= 0, "maxClientRetries cannot be negative"); this.maxClientRetry = maxClientRetry; this.maxBackoffTime = checkNotNull(maxBackoffTime, "maxBackoffTime is null"); this.maxRetryTime = checkNotNull(maxRetryTime, "maxRetryTime is null"); } @Override public void close() throws IOException { closed = true; closeStream(); } @Override public void seek(long pos) throws IOException { checkState(!closed, "already closed"); checkArgument(pos >= 0, "position is negative: %s", pos); if ((in != null) && (pos == position)) { // already at specified position return; } if ((in != null) && (pos > position)) { // seeking forwards long skip = pos - position; if (skip <= max(in.available(), MAX_SKIP_SIZE.toBytes())) { // already buffered or seek is small enough if (in.skip(skip) == skip) { position = pos; return; } } } // close the stream and open at desired position position = pos; closeStream(); openStream(); } @Override public long getPos() throws IOException { return position; } @Override public int read() throws IOException { // This stream is wrapped with BufferedInputStream, so this method should never be called throw new UnsupportedOperationException(); } @Override public int read(final byte[] buffer, final int offset, final int length) throws IOException { try { int bytesRead = retry() .maxAttempts(maxClientRetry) .exponentialBackoff( new Duration(1, TimeUnit.SECONDS), maxBackoffTime, maxRetryTime, 2.0) .stopOn(InterruptedException.class) .run( "readStream", () -> { openStream(); try { return in.read(buffer, offset, length); } catch (Exception e) { closeStream(); throw e; } }); if (bytesRead != -1) { position += bytesRead; } return bytesRead; } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw Throwables.propagate(e); } catch (Exception e) { Throwables.propagateIfInstanceOf(e, IOException.class); throw Throwables.propagate(e); } } @Override public boolean seekToNewSource(long targetPos) throws IOException { return false; } private S3Object getS3Object(final Path path, final long start) throws IOException { try { return retry() .maxAttempts(maxClientRetry) .exponentialBackoff( new Duration(1, TimeUnit.SECONDS), maxBackoffTime, maxRetryTime, 2.0) .stopOn(InterruptedException.class, UnrecoverableS3OperationException.class) .run( "getS3Object", () -> { try { return s3.getObject( new GetObjectRequest(host, keyFromPath(path)) .withRange(start, Long.MAX_VALUE)); } catch (AmazonServiceException e) { if (e.getStatusCode() == SC_FORBIDDEN) { throw new UnrecoverableS3OperationException(e); } throw Throwables.propagate(e); } }); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw Throwables.propagate(e); } catch (Exception e) { Throwables.propagateIfInstanceOf(e, IOException.class); throw Throwables.propagate(e); } } private void openStream() throws IOException { if (in == null) { in = getS3Object(path, position).getObjectContent(); } } private void closeStream() throws IOException { if (in != null) { try { in.abort(); } catch (AbortedException ignored) { // thrown if the current thread is in the interrupted state } in = null; } } } private static class PrestoS3OutputStream extends FilterOutputStream { private final TransferManager transferManager; private final String host; private final String key; private final File tempFile; private boolean closed; public PrestoS3OutputStream( AmazonS3 s3, TransferManagerConfiguration config, String host, String key, File tempFile) throws IOException { super( new BufferedOutputStream( new FileOutputStream(checkNotNull(tempFile, "tempFile is null")))); transferManager = new TransferManager(checkNotNull(s3, "s3 is null")); transferManager.setConfiguration(checkNotNull(config, "config is null")); this.host = checkNotNull(host, "host is null"); this.key = checkNotNull(key, "key is null"); this.tempFile = tempFile; log.debug("OutputStream for key '%s' using file: %s", key, tempFile); } @Override public void close() throws IOException { if (closed) { return; } closed = true; try { super.close(); uploadObject(); } finally { if (!tempFile.delete()) { log.warn("Could not delete temporary file: %s", tempFile); } // close transfer manager but keep underlying S3 client open transferManager.shutdownNow(false); } } private void uploadObject() throws IOException { try { log.debug( "Starting upload for host: %s, key: %s, file: %s, size: %s", host, key, tempFile, tempFile.length()); Upload upload = transferManager.upload(host, key, tempFile); if (log.isDebugEnabled()) { upload.addProgressListener(createProgressListener(upload)); } upload.waitForCompletion(); log.debug("Completed upload for host: %s, key: %s", host, key); } catch (AmazonClientException e) { throw new IOException(e); } catch (InterruptedException e) { Thread.currentThread().interrupt(); throw new InterruptedIOException(); } } private ProgressListener createProgressListener(final Transfer transfer) { return new ProgressListener() { private ProgressEventType previousType; private double previousTransferred; @Override public synchronized void progressChanged(ProgressEvent progressEvent) { ProgressEventType eventType = progressEvent.getEventType(); if (previousType != eventType) { log.debug("Upload progress event (%s/%s): %s", host, key, eventType); previousType = eventType; } double transferred = transfer.getProgress().getPercentTransferred(); if (transferred >= (previousTransferred + 10.0)) { log.debug("Upload percentage (%s/%s): %.0f%%", host, key, transferred); previousTransferred = transferred; } } }; } } }
@ThreadSafe public final class HttpPageBufferClient implements Closeable { private static final int INITIAL_DELAY_MILLIS = 1; private static final int MAX_DELAY_MILLIS = 100; private static final Logger log = Logger.get(HttpPageBufferClient.class); /** * For each request, the addPage method will be called zero or more times, followed by either * requestComplete or clientFinished (if buffer complete). If the client is closed, * requestComplete or bufferFinished may never be called. * * <p><b>NOTE:</b> Implementations of this interface are not allowed to perform blocking * operations. */ public interface ClientCallback { boolean addPages(HttpPageBufferClient client, List<Page> pages); void requestComplete(HttpPageBufferClient client); void clientFinished(HttpPageBufferClient client); void clientFailed(HttpPageBufferClient client, Throwable cause); } private final HttpClient httpClient; private final DataSize maxResponseSize; private final Duration minErrorDuration; private final URI location; private final ClientCallback clientCallback; private final BlockEncodingSerde blockEncodingSerde; private final ScheduledExecutorService executor; @GuardedBy("this") private final Stopwatch errorStopwatch; @GuardedBy("this") private boolean closed; @GuardedBy("this") private HttpResponseFuture<?> future; @GuardedBy("this") private DateTime lastUpdate = DateTime.now(); @GuardedBy("this") private long token; @GuardedBy("this") private boolean scheduled; @GuardedBy("this") private boolean completed; @GuardedBy("this") private long errorDelayMillis; @GuardedBy("this") private String taskInstanceId; private final AtomicLong rowsReceived = new AtomicLong(); private final AtomicInteger pagesReceived = new AtomicInteger(); private final AtomicLong rowsRejected = new AtomicLong(); private final AtomicInteger pagesRejected = new AtomicInteger(); private final AtomicInteger requestsScheduled = new AtomicInteger(); private final AtomicInteger requestsCompleted = new AtomicInteger(); private final AtomicInteger requestsFailed = new AtomicInteger(); public HttpPageBufferClient( HttpClient httpClient, DataSize maxResponseSize, Duration minErrorDuration, URI location, ClientCallback clientCallback, BlockEncodingSerde blockEncodingSerde, ScheduledExecutorService executor) { this( httpClient, maxResponseSize, minErrorDuration, location, clientCallback, blockEncodingSerde, executor, Stopwatch.createUnstarted()); } public HttpPageBufferClient( HttpClient httpClient, DataSize maxResponseSize, Duration minErrorDuration, URI location, ClientCallback clientCallback, BlockEncodingSerde blockEncodingSerde, ScheduledExecutorService executor, Stopwatch errorStopwatch) { this.httpClient = requireNonNull(httpClient, "httpClient is null"); this.maxResponseSize = requireNonNull(maxResponseSize, "maxResponseSize is null"); this.minErrorDuration = requireNonNull(minErrorDuration, "minErrorDuration is null"); this.location = requireNonNull(location, "location is null"); this.clientCallback = requireNonNull(clientCallback, "clientCallback is null"); this.blockEncodingSerde = requireNonNull(blockEncodingSerde, "blockEncodingManager is null"); this.executor = requireNonNull(executor, "executor is null"); this.errorStopwatch = requireNonNull(errorStopwatch, "errorStopwatch is null").reset(); } public synchronized PageBufferClientStatus getStatus() { String state; if (closed) { state = "closed"; } else if (future != null) { state = "running"; } else if (scheduled) { state = "scheduled"; } else if (completed) { state = "completed"; } else { state = "queued"; } String httpRequestState = "not scheduled"; if (future != null) { httpRequestState = future.getState(); } long rejectedRows = rowsRejected.get(); int rejectedPages = pagesRejected.get(); return new PageBufferClientStatus( location, state, lastUpdate, rowsReceived.get(), pagesReceived.get(), rejectedRows == 0 ? OptionalLong.empty() : OptionalLong.of(rejectedRows), rejectedPages == 0 ? OptionalInt.empty() : OptionalInt.of(rejectedPages), requestsScheduled.get(), requestsCompleted.get(), requestsFailed.get(), httpRequestState); } public synchronized boolean isRunning() { return future != null; } @Override public void close() { boolean shouldSendDelete; Future<?> future; synchronized (this) { shouldSendDelete = !closed; closed = true; future = this.future; this.future = null; lastUpdate = DateTime.now(); } if (future != null && !future.isDone()) { future.cancel(true); } // abort the output buffer on the remote node; response of delete is ignored if (shouldSendDelete) { sendDelete(); } } public synchronized void scheduleRequest() { if (closed || (future != null) || scheduled) { return; } scheduled = true; // start before scheduling to include error delay errorStopwatch.start(); executor.schedule( () -> { try { initiateRequest(); } catch (Throwable t) { // should not happen, but be safe and fail the operator clientCallback.clientFailed(HttpPageBufferClient.this, t); } }, errorDelayMillis, TimeUnit.MILLISECONDS); lastUpdate = DateTime.now(); requestsScheduled.incrementAndGet(); } private synchronized void initiateRequest() { scheduled = false; if (closed || (future != null)) { return; } if (completed) { sendDelete(); } else { sendGetResults(); } lastUpdate = DateTime.now(); } private synchronized void sendGetResults() { URI uri = HttpUriBuilder.uriBuilderFrom(location).appendPath(String.valueOf(token)).build(); HttpResponseFuture<PagesResponse> resultFuture = httpClient.executeAsync( prepareGet().setHeader(PRESTO_MAX_SIZE, maxResponseSize.toString()).setUri(uri).build(), new PageResponseHandler(blockEncodingSerde)); future = resultFuture; Futures.addCallback( resultFuture, new FutureCallback<PagesResponse>() { @Override public void onSuccess(PagesResponse result) { checkNotHoldsLock(); resetErrors(); List<Page> pages; try { synchronized (HttpPageBufferClient.this) { if (taskInstanceId == null) { taskInstanceId = result.getTaskInstanceId(); } if (!isNullOrEmpty(taskInstanceId) && !result.getTaskInstanceId().equals(taskInstanceId)) { // TODO: update error message throw new PrestoException(REMOTE_TASK_MISMATCH, REMOTE_TASK_MISMATCH_ERROR); } if (result.getToken() == token) { pages = result.getPages(); token = result.getNextToken(); } else { pages = ImmutableList.of(); } } } catch (PrestoException e) { handleFailure(e, resultFuture); return; } // add pages if (clientCallback.addPages(HttpPageBufferClient.this, pages)) { pagesReceived.addAndGet(pages.size()); rowsReceived.addAndGet(pages.stream().mapToLong(Page::getPositionCount).sum()); } else { pagesRejected.addAndGet(pages.size()); rowsRejected.addAndGet(pages.stream().mapToLong(Page::getPositionCount).sum()); } synchronized (HttpPageBufferClient.this) { // client is complete, acknowledge it by sending it a delete in the next request if (result.isClientComplete()) { completed = true; } if (future == resultFuture) { future = null; errorDelayMillis = 0; } lastUpdate = DateTime.now(); } requestsCompleted.incrementAndGet(); clientCallback.requestComplete(HttpPageBufferClient.this); } @Override public void onFailure(Throwable t) { log.debug("Request to %s failed %s", uri, t); checkNotHoldsLock(); Duration errorDuration = elapsedErrorDuration(); t = rewriteException(t); if (!(t instanceof PrestoException) && errorDuration.compareTo(minErrorDuration) > 0) { String message = format("%s (%s - requests failed for %s)", WORKER_NODE_ERROR, uri, errorDuration); t = new PageTransportTimeoutException(message, t); } handleFailure(t, resultFuture); } }, executor); } private synchronized void sendDelete() { HttpResponseFuture<StatusResponse> resultFuture = httpClient.executeAsync( prepareDelete().setUri(location).build(), createStatusResponseHandler()); future = resultFuture; Futures.addCallback( resultFuture, new FutureCallback<StatusResponse>() { @Override public void onSuccess(@Nullable StatusResponse result) { checkNotHoldsLock(); synchronized (HttpPageBufferClient.this) { closed = true; if (future == resultFuture) { future = null; errorDelayMillis = 0; } lastUpdate = DateTime.now(); } requestsCompleted.incrementAndGet(); clientCallback.clientFinished(HttpPageBufferClient.this); } @Override public void onFailure(Throwable t) { checkNotHoldsLock(); log.error("Request to delete %s failed %s", location, t); Duration errorDuration = elapsedErrorDuration(); if (!(t instanceof PrestoException) && errorDuration.compareTo(minErrorDuration) > 0) { String message = format( "Error closing remote buffer (%s - requests failed for %s)", location, errorDuration); t = new PrestoException(REMOTE_BUFFER_CLOSE_FAILED, message, t); } handleFailure(t, resultFuture); } }, executor); } private void checkNotHoldsLock() { if (Thread.holdsLock(HttpPageBufferClient.this)) { log.error("Can not handle callback while holding a lock on this"); } } private void handleFailure(Throwable t, HttpResponseFuture<?> expectedFuture) { // Can not delegate to other callback while holding a lock on this checkNotHoldsLock(); requestsFailed.incrementAndGet(); requestsCompleted.incrementAndGet(); if (t instanceof PrestoException) { clientCallback.clientFailed(HttpPageBufferClient.this, t); } synchronized (HttpPageBufferClient.this) { increaseErrorDelay(); if (future == expectedFuture) { future = null; } lastUpdate = DateTime.now(); } clientCallback.requestComplete(HttpPageBufferClient.this); } @Override public boolean equals(Object o) { if (this == o) { return true; } if (o == null || getClass() != o.getClass()) { return false; } HttpPageBufferClient that = (HttpPageBufferClient) o; if (!location.equals(that.location)) { return false; } return true; } @Override public int hashCode() { return location.hashCode(); } @Override public String toString() { String state; synchronized (this) { if (closed) { state = "CLOSED"; } else if (future != null) { state = "RUNNING"; } else { state = "QUEUED"; } } return toStringHelper(this).add("location", location).addValue(state).toString(); } private static Throwable rewriteException(Throwable t) { if (t instanceof ResponseTooLargeException) { return new PageTooLargeException(); } return t; } private synchronized Duration elapsedErrorDuration() { if (errorStopwatch.isRunning()) { errorStopwatch.stop(); } long nanos = errorStopwatch.elapsed(TimeUnit.NANOSECONDS); return new Duration(nanos, TimeUnit.NANOSECONDS).convertTo(TimeUnit.SECONDS); } private synchronized void increaseErrorDelay() { if (errorDelayMillis == 0) { errorDelayMillis = INITIAL_DELAY_MILLIS; } else { errorDelayMillis = min(errorDelayMillis * 2, MAX_DELAY_MILLIS); } } private synchronized void resetErrors() { errorStopwatch.reset(); } public static class PageResponseHandler implements ResponseHandler<PagesResponse, RuntimeException> { private final BlockEncodingSerde blockEncodingSerde; public PageResponseHandler(BlockEncodingSerde blockEncodingSerde) { this.blockEncodingSerde = blockEncodingSerde; } @Override public PagesResponse handleException(Request request, Exception exception) { throw propagate(request, exception); } @Override public PagesResponse handle(Request request, Response response) { // no content means no content was created within the wait period, but query is still ok // if job is finished, complete is set in the response if (response.getStatusCode() == HttpStatus.NO_CONTENT.code()) { return createEmptyPagesResponse( getTaskInstanceId(response), getToken(response), getNextToken(response), getComplete(response)); } // otherwise we must have gotten an OK response, everything else is considered fatal if (response.getStatusCode() != HttpStatus.OK.code()) { StringBuilder body = new StringBuilder(); try (BufferedReader reader = new BufferedReader(new InputStreamReader(response.getInputStream()))) { // Get up to 1000 lines for debugging for (int i = 0; i < 1000; i++) { String line = reader.readLine(); // Don't output more than 100KB if (line == null || body.length() + line.length() > 100 * 1024) { break; } body.append(line + "\n"); } } catch (RuntimeException | IOException e) { // Ignored. Just return whatever message we were able to decode } throw new PageTransportErrorException( format( "Expected response code to be 200, but was %s %s: %s%n%s", response.getStatusCode(), response.getStatusMessage(), request.getUri(), body.toString())); } // invalid content type can happen when an error page is returned, but is unlikely given the // above 200 String contentType = response.getHeader(CONTENT_TYPE); if (contentType == null) { throw new PageTransportErrorException( format("%s header is not set: %s %s", CONTENT_TYPE, request.getUri(), response)); } if (!mediaTypeMatches(contentType, PRESTO_PAGES_TYPE)) { throw new PageTransportErrorException( format( "Expected %s response from server but got %s: %s", PRESTO_PAGES_TYPE, contentType, request.getUri())); } String taskInstanceId = getTaskInstanceId(response); long token = getToken(response); long nextToken = getNextToken(response); boolean complete = getComplete(response); try (SliceInput input = new InputStreamSliceInput(response.getInputStream())) { List<Page> pages = ImmutableList.copyOf(readPages(blockEncodingSerde, input)); return createPagesResponse(taskInstanceId, token, nextToken, pages, complete); } catch (IOException e) { throw Throwables.propagate(e); } } private static String getTaskInstanceId(Response response) { String taskInstanceId = response.getHeader(PRESTO_TASK_INSTANCE_ID); if (taskInstanceId == null) { throw new PageTransportErrorException( format("Expected %s header", PRESTO_TASK_INSTANCE_ID)); } return taskInstanceId; } private static long getToken(Response response) { String tokenHeader = response.getHeader(PRESTO_PAGE_TOKEN); if (tokenHeader == null) { throw new PageTransportErrorException(format("Expected %s header", PRESTO_PAGE_TOKEN)); } return Long.parseLong(tokenHeader); } private static long getNextToken(Response response) { String nextTokenHeader = response.getHeader(PRESTO_PAGE_NEXT_TOKEN); if (nextTokenHeader == null) { throw new PageTransportErrorException(format("Expected %s header", PRESTO_PAGE_NEXT_TOKEN)); } return Long.parseLong(nextTokenHeader); } private static boolean getComplete(Response response) { String bufferComplete = response.getHeader(PRESTO_BUFFER_COMPLETE); if (bufferComplete == null) { throw new PageTransportErrorException(format("Expected %s header", PRESTO_BUFFER_COMPLETE)); } return Boolean.parseBoolean(bufferComplete); } private static boolean mediaTypeMatches(String value, MediaType range) { try { return MediaType.parse(value).is(range); } catch (IllegalArgumentException | IllegalStateException e) { return false; } } } public static class PagesResponse { public static PagesResponse createPagesResponse( String taskInstanceId, long token, long nextToken, Iterable<Page> pages, boolean complete) { return new PagesResponse(taskInstanceId, token, nextToken, pages, complete); } public static PagesResponse createEmptyPagesResponse( String taskInstanceId, long token, long nextToken, boolean complete) { return new PagesResponse( taskInstanceId, token, nextToken, ImmutableList.<Page>of(), complete); } private final String taskInstanceId; private final long token; private final long nextToken; private final List<Page> pages; private final boolean clientComplete; private PagesResponse( String taskInstanceId, long token, long nextToken, Iterable<Page> pages, boolean clientComplete) { this.taskInstanceId = taskInstanceId; this.token = token; this.nextToken = nextToken; this.pages = ImmutableList.copyOf(pages); this.clientComplete = clientComplete; } public long getToken() { return token; } public long getNextToken() { return nextToken; } public List<Page> getPages() { return pages; } public boolean isClientComplete() { return clientComplete; } public String getTaskInstanceId() { return taskInstanceId; } @Override public String toString() { return toStringHelper(this) .add("token", token) .add("nextToken", nextToken) .add("pagesSize", pages.size()) .add("clientComplete", clientComplete) .toString(); } } }
@ThreadSafe public class SqlQueryManager implements QueryManager { private static final Logger log = Logger.get(SqlQueryManager.class); private final ExecutorService queryExecutor; private final ThreadPoolExecutorMBean queryExecutorMBean; private final int maxQueryHistory; private final Duration maxQueryAge; private final ConcurrentMap<QueryId, QueryExecution> queries = new ConcurrentHashMap<>(); private final Duration clientTimeout; private final ScheduledExecutorService queryManagementExecutor; private final ThreadPoolExecutorMBean queryManagementExecutorMBean; private final QueryMonitor queryMonitor; private final LocationFactory locationFactory; private final QueryIdGenerator queryIdGenerator; private final Map<Class<? extends Statement>, QueryExecutionFactory<?>> executionFactories; private final SqlQueryManagerStats stats = new SqlQueryManagerStats(); @Inject public SqlQueryManager( QueryManagerConfig config, QueryMonitor queryMonitor, QueryIdGenerator queryIdGenerator, LocationFactory locationFactory, Map<Class<? extends Statement>, QueryExecutionFactory<?>> executionFactories) { checkNotNull(config, "config is null"); this.executionFactories = checkNotNull(executionFactories, "executionFactories is null"); this.queryExecutor = Executors.newCachedThreadPool(threadsNamed("query-scheduler-%d")); this.queryExecutorMBean = new ThreadPoolExecutorMBean((ThreadPoolExecutor) queryExecutor); this.queryMonitor = checkNotNull(queryMonitor, "queryMonitor is null"); this.locationFactory = checkNotNull(locationFactory, "locationFactory is null"); this.queryIdGenerator = checkNotNull(queryIdGenerator, "queryIdGenerator is null"); this.maxQueryAge = config.getMaxQueryAge(); this.maxQueryHistory = config.getMaxQueryHistory(); this.clientTimeout = config.getClientTimeout(); queryManagementExecutor = Executors.newScheduledThreadPool( config.getQueryManagerExecutorPoolSize(), threadsNamed("query-management-%d")); queryManagementExecutorMBean = new ThreadPoolExecutorMBean((ThreadPoolExecutor) queryManagementExecutor); queryManagementExecutor.scheduleAtFixedRate( new Runnable() { @Override public void run() { try { removeExpiredQueries(); } catch (Throwable e) { log.warn(e, "Error removing old queries"); } try { failAbandonedQueries(); } catch (Throwable e) { log.warn(e, "Error removing old queries"); } } }, 200, 200, TimeUnit.MILLISECONDS); } @PreDestroy public void stop() { queryManagementExecutor.shutdownNow(); queryExecutor.shutdownNow(); } @Override public List<QueryInfo> getAllQueryInfo() { return ImmutableList.copyOf( filter( transform( queries.values(), new Function<QueryExecution, QueryInfo>() { @Override public QueryInfo apply(QueryExecution queryExecution) { try { return queryExecution.getQueryInfo(); } catch (RuntimeException ignored) { return null; } } }), Predicates.notNull())); } @Override public Duration waitForStateChange(QueryId queryId, QueryState currentState, Duration maxWait) throws InterruptedException { Preconditions.checkNotNull(queryId, "queryId is null"); Preconditions.checkNotNull(maxWait, "maxWait is null"); QueryExecution query = queries.get(queryId); if (query == null) { return maxWait; } query.recordHeartbeat(); return query.waitForStateChange(currentState, maxWait); } @Override public QueryInfo getQueryInfo(QueryId queryId) { checkNotNull(queryId, "queryId is null"); QueryExecution query = queries.get(queryId); if (query == null) { throw new NoSuchElementException(); } query.recordHeartbeat(); return query.getQueryInfo(); } @Override public QueryInfo createQuery(Session session, String query) { checkNotNull(query, "query is null"); Preconditions.checkArgument(!query.isEmpty(), "query must not be empty string"); QueryId queryId = queryIdGenerator.createNextQueryId(); Statement statement; try { statement = SqlParser.createStatement(query); } catch (ParsingException e) { return createFailedQuery(session, query, queryId, e); } QueryExecutionFactory<?> queryExecutionFactory = executionFactories.get(statement.getClass()); Preconditions.checkState( queryExecutionFactory != null, "Unsupported statement type %s", statement.getClass().getName()); final QueryExecution queryExecution = queryExecutionFactory.createQueryExecution(queryId, query, session, statement); queryMonitor.createdEvent(queryExecution.getQueryInfo()); queryExecution.addStateChangeListener( new StateChangeListener<QueryState>() { @Override public void stateChanged(QueryState newValue) { if (newValue.isDone()) { QueryInfo info = queryExecution.getQueryInfo(); stats.queryFinished(info); queryMonitor.completionEvent(info); } } }); queries.put(queryId, queryExecution); // start the query in the background queryExecutor.submit(new QueryStarter(queryExecution, stats)); return queryExecution.getQueryInfo(); } @Override public void cancelQuery(QueryId queryId) { checkNotNull(queryId, "queryId is null"); log.debug("Cancel query %s", queryId); QueryExecution query = queries.get(queryId); if (query != null) { query.cancel(); } } @Override public void cancelStage(StageId stageId) { Preconditions.checkNotNull(stageId, "stageId is null"); log.debug("Cancel stage %s", stageId); QueryExecution query = queries.get(stageId.getQueryId()); if (query != null) { query.cancelStage(stageId); } } @Managed @Flatten public SqlQueryManagerStats getStats() { return stats; } @Managed(description = "Query scheduler executor") @Nested public ThreadPoolExecutorMBean getExecutor() { return queryExecutorMBean; } @Managed(description = "Query garbage collector executor") @Nested public ThreadPoolExecutorMBean getManagementExecutor() { return queryManagementExecutorMBean; } public void removeQuery(QueryId queryId) { Preconditions.checkNotNull(queryId, "queryId is null"); log.debug("Remove query %s", queryId); QueryExecution query = queries.remove(queryId); if (query != null) { query.cancel(); } } /** Remove completed queries after a waiting period */ public void removeExpiredQueries() { List<QueryExecution> sortedQueries = IterableTransformer.on(queries.values()) .select(compose(not(isNull()), endTimeGetter())) .orderBy(Ordering.natural().onResultOf(endTimeGetter())) .list(); int toRemove = Math.max(sortedQueries.size() - maxQueryHistory, 0); DateTime oldestAllowedQuery = DateTime.now().minus(maxQueryAge.toMillis()); for (QueryExecution queryExecution : sortedQueries) { try { DateTime endTime = queryExecution.getQueryInfo().getQueryStats().getEndTime(); if ((endTime.isBefore(oldestAllowedQuery) || toRemove > 0) && isAbandoned(queryExecution)) { removeQuery(queryExecution.getQueryInfo().getQueryId()); --toRemove; } } catch (RuntimeException e) { log.warn( e, "Error while inspecting age of query %s", queryExecution.getQueryInfo().getQueryId()); } } } public void failAbandonedQueries() { for (QueryExecution queryExecution : queries.values()) { try { QueryInfo queryInfo = queryExecution.getQueryInfo(); if (queryInfo.getState().isDone()) { continue; } if (isAbandoned(queryExecution)) { log.info("Failing abandoned query %s", queryExecution.getQueryInfo().getQueryId()); queryExecution.fail( new AbandonedException( "Query " + queryInfo.getQueryId(), queryInfo.getQueryStats().getLastHeartbeat(), DateTime.now())); } } catch (RuntimeException e) { log.warn( e, "Error while inspecting age of query %s", queryExecution.getQueryInfo().getQueryId()); } } } private boolean isAbandoned(QueryExecution query) { DateTime oldestAllowedHeartbeat = DateTime.now().minus(clientTimeout.toMillis()); DateTime lastHeartbeat = query.getQueryInfo().getQueryStats().getLastHeartbeat(); return lastHeartbeat != null && lastHeartbeat.isBefore(oldestAllowedHeartbeat); } private QueryInfo createFailedQuery( Session session, String query, QueryId queryId, Throwable cause) { URI self = locationFactory.createQueryLocation(queryId); QueryExecution execution = new FailedQueryExecution(queryId, query, session, self, queryExecutor, cause); queries.put(queryId, execution); queryMonitor.createdEvent(execution.getQueryInfo()); queryMonitor.completionEvent(execution.getQueryInfo()); stats.queryFinished(execution.getQueryInfo()); return execution.getQueryInfo(); } private static Function<QueryExecution, DateTime> endTimeGetter() { return new Function<QueryExecution, DateTime>() { @Nullable @Override public DateTime apply(QueryExecution input) { return input.getQueryInfo().getQueryStats().getEndTime(); } }; } private static class QueryStarter implements Runnable { private final QueryExecution queryExecution; private final SqlQueryManagerStats stats; public QueryStarter(QueryExecution queryExecution, SqlQueryManagerStats stats) { this.queryExecution = queryExecution; this.stats = stats; } @Override public void run() { try (SetThreadName setThreadName = new SetThreadName("Query-%s", queryExecution.getQueryInfo().getQueryId())) { stats.queryStarted(); queryExecution.start(); } } } }
public final class CompilerUtils { private static final Logger log = Logger.get(CompilerUtils.class); private static final boolean DUMP_BYTE_CODE_TREE = false; private static final boolean DUMP_BYTE_CODE_RAW = false; private static final boolean RUN_ASM_VERIFIER = false; // verifier doesn't work right now private static final AtomicReference<String> DUMP_CLASS_FILES_TO = new AtomicReference<>(); private static final AtomicLong CLASS_ID = new AtomicLong(); private CompilerUtils() {} public static ParameterizedType makeClassName(String baseName) { String className = "com.facebook.presto.$gen." + baseName + "_" + CLASS_ID.incrementAndGet(); String javaClassName = toJavaIdentifierString(className); return ParameterizedType.typeFromJavaClassName(javaClassName); } public static String toJavaIdentifierString(String className) { // replace invalid characters with '_' int[] codePoints = className.codePoints().map(c -> Character.isJavaIdentifierPart(c) ? c : '_').toArray(); return new String(codePoints, 0, codePoints.length); } public static <T> Class<? extends T> defineClass( ClassDefinition classDefinition, Class<T> superType, DynamicClassLoader classLoader) { Class<?> clazz = defineClasses(ImmutableList.of(classDefinition), classLoader).values().iterator().next(); return clazz.asSubclass(superType); } public static <T> Class<? extends T> defineClass( ClassDefinition classDefinition, Class<T> superType, Map<Long, MethodHandle> callSiteBindings, ClassLoader parentClassLoader) { Class<?> clazz = defineClass( classDefinition, superType, new DynamicClassLoader(parentClassLoader, callSiteBindings)); return clazz.asSubclass(superType); } private static Map<String, Class<?>> defineClasses( List<ClassDefinition> classDefinitions, DynamicClassLoader classLoader) { ClassInfoLoader classInfoLoader = ClassInfoLoader.createClassInfoLoader(classDefinitions, classLoader); if (DUMP_BYTE_CODE_TREE) { ByteArrayOutputStream out = new ByteArrayOutputStream(); DumpByteCodeVisitor dumpByteCode = new DumpByteCodeVisitor(new PrintStream(out)); for (ClassDefinition classDefinition : classDefinitions) { dumpByteCode.visitClass(classDefinition); } System.out.println(new String(out.toByteArray(), StandardCharsets.UTF_8)); } Map<String, byte[]> byteCodes = new LinkedHashMap<>(); for (ClassDefinition classDefinition : classDefinitions) { ClassWriter cw = new SmartClassWriter(classInfoLoader); classDefinition.visit(cw); byte[] byteCode = cw.toByteArray(); if (RUN_ASM_VERIFIER) { ClassReader reader = new ClassReader(byteCode); CheckClassAdapter.verify(reader, classLoader, true, new PrintWriter(System.out)); } byteCodes.put(classDefinition.getType().getJavaClassName(), byteCode); } String dumpClassPath = DUMP_CLASS_FILES_TO.get(); if (dumpClassPath != null) { for (Map.Entry<String, byte[]> entry : byteCodes.entrySet()) { File file = new File( dumpClassPath, ParameterizedType.typeFromJavaClassName(entry.getKey()).getClassName() + ".class"); try { log.debug("ClassFile: " + file.getAbsolutePath()); Files.createParentDirs(file); Files.write(entry.getValue(), file); } catch (IOException e) { log.error(e, "Failed to write generated class file to: %s" + file.getAbsolutePath()); } } } if (DUMP_BYTE_CODE_RAW) { for (byte[] byteCode : byteCodes.values()) { ClassReader classReader = new ClassReader(byteCode); classReader.accept( new TraceClassVisitor(new PrintWriter(System.err)), ClassReader.SKIP_FRAMES); } } Map<String, Class<?>> classes = classLoader.defineClasses(byteCodes); try { for (Class<?> clazz : classes.values()) { Reflection.initialize(clazz); } } catch (VerifyError e) { throw new RuntimeException(e); } return classes; } }
public class BenchmarkSuite { private static final Logger LOGGER = Logger.get(BenchmarkSuite.class); public static List<AbstractBenchmark> createBenchmarks(ExecutorService executor) { LocalQueryRunner localQueryRunner = createLocalQueryRunner(executor); LocalQueryRunner localSampledQueryRunner = createLocalSampledQueryRunner(executor); return ImmutableList.<AbstractBenchmark>of( // hand built benchmarks new CountAggregationBenchmark(localQueryRunner), new DoubleSumAggregationBenchmark(localQueryRunner), new HashAggregationBenchmark(localQueryRunner), new PredicateFilterBenchmark(localQueryRunner), new RawStreamingBenchmark(localQueryRunner), new Top100Benchmark(localQueryRunner), new OrderByBenchmark(localQueryRunner), new HashBuildBenchmark(localQueryRunner), new HashJoinBenchmark(localQueryRunner), new HashBuildAndJoinBenchmark(localQueryRunner), new HandTpchQuery1(localQueryRunner), new HandTpchQuery6(localQueryRunner), // sql benchmarks new GroupBySumWithArithmeticSqlBenchmark(localQueryRunner), new CountAggregationSqlBenchmark(localQueryRunner), new SqlDoubleSumAggregationBenchmark(localQueryRunner), new CountWithFilterSqlBenchmark(localQueryRunner), new GroupByAggregationSqlBenchmark(localQueryRunner), new PredicateFilterSqlBenchmark(localQueryRunner), new RawStreamingSqlBenchmark(localQueryRunner), new Top100SqlBenchmark(localQueryRunner), new SqlHashJoinBenchmark(localQueryRunner), new SqlJoinWithPredicateBenchmark(localQueryRunner), new VarBinaryMaxAggregationSqlBenchmark(localQueryRunner), new SqlDistinctMultipleFields(localQueryRunner), new SqlDistinctSingleField(localQueryRunner), new SqlTpchQuery1(localQueryRunner), new SqlTpchQuery6(localQueryRunner), new SqlLikeBenchmark(localQueryRunner), new SqlInBenchmark(localQueryRunner), new SqlSemiJoinInPredicateBenchmark(localQueryRunner), new SqlRegexpLikeBenchmark(localQueryRunner), new SqlApproximatePercentileBenchmark(localQueryRunner), new SqlBetweenBenchmark(localQueryRunner), // Sampled sql benchmarks new RenamingBenchmark( "sampled_", new GroupBySumWithArithmeticSqlBenchmark(localSampledQueryRunner)), new RenamingBenchmark( "sampled_", new CountAggregationSqlBenchmark(localSampledQueryRunner)), new RenamingBenchmark( "sampled_", new SqlJoinWithPredicateBenchmark(localSampledQueryRunner)), new RenamingBenchmark( "sampled_", new SqlDoubleSumAggregationBenchmark(localSampledQueryRunner)), // statistics benchmarks new StatisticsBenchmark.LongVarianceBenchmark(localQueryRunner), new StatisticsBenchmark.LongVariancePopBenchmark(localQueryRunner), new StatisticsBenchmark.DoubleVarianceBenchmark(localQueryRunner), new StatisticsBenchmark.DoubleVariancePopBenchmark(localQueryRunner), new StatisticsBenchmark.LongStdDevBenchmark(localQueryRunner), new StatisticsBenchmark.LongStdDevPopBenchmark(localQueryRunner), new StatisticsBenchmark.DoubleStdDevBenchmark(localQueryRunner), new StatisticsBenchmark.DoubleStdDevPopBenchmark(localQueryRunner), new SqlApproximateCountDistinctLongBenchmark(localQueryRunner), new SqlApproximateCountDistinctDoubleBenchmark(localQueryRunner), new SqlApproximateCountDistinctVarBinaryBenchmark(localQueryRunner)); } private final String outputDirectory; public BenchmarkSuite(String outputDirectory) { this.outputDirectory = checkNotNull(outputDirectory, "outputDirectory is null"); } private File createOutputFile(String fileName) throws IOException { File outputFile = new File(fileName); Files.createParentDirs(outputFile); return outputFile; } public void runAllBenchmarks() throws IOException { ExecutorService executor = newCachedThreadPool(daemonThreadsNamed("test")); try { List<AbstractBenchmark> benchmarks = createBenchmarks(executor); LOGGER.info("=== Pre-running all benchmarks for JVM warmup ==="); for (AbstractBenchmark benchmark : benchmarks) { benchmark.runBenchmark(); } LOGGER.info("=== Actually running benchmarks for metrics ==="); for (AbstractBenchmark benchmark : benchmarks) { try (OutputStream jsonOut = new FileOutputStream( createOutputFile( String.format( "%s/json/%s.json", outputDirectory, benchmark.getBenchmarkName()))); OutputStream jsonAvgOut = new FileOutputStream( createOutputFile( String.format( "%s/json-avg/%s.json", outputDirectory, benchmark.getBenchmarkName()))); OutputStream csvOut = new FileOutputStream( createOutputFile( String.format( "%s/csv/%s.csv", outputDirectory, benchmark.getBenchmarkName()))); OutputStream odsOut = new FileOutputStream( createOutputFile( String.format( "%s/ods/%s.json", outputDirectory, benchmark.getBenchmarkName())))) { benchmark.runBenchmark( new ForwardingBenchmarkResultWriter( ImmutableList.of( new JsonBenchmarkResultWriter(jsonOut), new JsonAvgBenchmarkResultWriter(jsonAvgOut), new SimpleLineBenchmarkResultWriter(csvOut), new OdsBenchmarkResultWriter( "presto.benchmark." + benchmark.getBenchmarkName(), odsOut)))); } } } finally { executor.shutdownNow(); } } private static class ForwardingBenchmarkResultWriter implements BenchmarkResultHook { private final List<BenchmarkResultHook> benchmarkResultHooks; private ForwardingBenchmarkResultWriter(List<BenchmarkResultHook> benchmarkResultHooks) { checkNotNull(benchmarkResultHooks, "benchmarkResultWriters is null"); this.benchmarkResultHooks = ImmutableList.copyOf(benchmarkResultHooks); } @Override public BenchmarkResultHook addResults(Map<String, Long> results) { checkNotNull(results, "results is null"); for (BenchmarkResultHook benchmarkResultHook : benchmarkResultHooks) { benchmarkResultHook.addResults(results); } return this; } @Override public void finished() { for (BenchmarkResultHook benchmarkResultHook : benchmarkResultHooks) { benchmarkResultHook.finished(); } } } public static void main(String[] args) throws IOException { String outputDirectory = checkNotNull(System.getProperty("outputDirectory"), "Must specify -DoutputDirectory=..."); new BenchmarkSuite(outputDirectory).runAllBenchmarks(); } }