Source code for scannerpy.client

import os
import os.path
import sys
import grpc
import imp
import socket
import time
import ipaddress
import pickle
import struct
import signal
import copy
import collections
import subprocess
import cloudpickle
import types
from tqdm import tqdm
from typing import Sequence, List, Union, Tuple, Optional, Callable

if sys.platform == 'linux' or sys.platform == 'linux2':
    import prctl

from multiprocessing import Process, cpu_count, Event
import multiprocessing as mp
from subprocess import Popen, PIPE
from random import choice
from string import ascii_uppercase

from scannerpy.common import *
from scannerpy.profiler import Profile
from scannerpy.config import Config
from scannerpy.op import OpGenerator, Op, OpColumn, SliceList
from scannerpy.source import SourceGenerator, Source
from scannerpy.sink import SinkGenerator, Sink
from scannerpy.streams import StreamsGenerator
from scannerpy.partitioner import TaskPartitioner
from scannerpy.table import Table
from scannerpy.column import Column
from scannerpy.protobufs import protobufs, python_to_proto
from scannerpy.job import Job
from scannerpy.kernel import Kernel
from scannerpy import types as scannertypes
from scannerpy.storage import StorageBackend, StoredStream, NamedVideoStream
from scannerpy.io import IOGenerator

import scannerpy._python as bindings
import scanner.metadata_pb2 as metadata_types
import scanner.engine.rpc_pb2 as rpc_types
import scanner.engine.rpc_pb2_grpc as grpc_types
import scanner.types_pb2 as misc_types



SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))



[docs]class Client(object): r"""Entrypoint for all Scanner operations. Parameters ---------- master The address of the master process. The addresses should be formatted as 'ip:port'. If the `start_cluster` flag is specified, the Client object will ssh into the provided address and start a master process. You should have ssh access to the target machine and scannerpy should be installed. workers The list of addresses to spawn worker processes on. The addresses should be formatted as 'ip:port'. Like with `master`, you should have ssh access to the target machine and scannerpy should be installed. If `start_cluster` is false, this parameter has no effect. start_cluster If true, a master process and worker processes will be spawned at the addresses specified by `master` and `workers`, respectively. config_path Path to a Scanner configuration TOML, by default assumed to be '~/.scanner/config.toml'. config The scanner Config to use. If specified, config_path is ignored. debug This flag is only relevant when `start_cluster == True`. If true, the master and worker servers are spawned in the same process as the invoking python code, enabling easy gdb-based debugging. Other Parameters ---------------- prefetch_table_metadata no_workers_timeout grpc_timeout Attributes ---------- config : Config The Config object used to initialize this Client. ops : OpGenerator Represents the set of available Ops. Ops can be created like so: :code:`output = cl.ops.ExampleOp(arg='example')` For a more detailed description, see :class:`~scannerpy.op.OpGenerator` sources : SourceGenerator Represents the set of available Sources. Sources are created just like Ops. See :class:`~scannerpy.op.SourceGenerator` sinks : SinkGenerator Represents the set of available Sinks. Sinks are created just like Ops. See :class:`~scannerpy.op.SinkGenerator` streams : StreamsGenerator Used to specify which elements to sample from a sequence. See :class:`~scannerpy.streams.StreamsGenerator` partitioner : TaskPartitioner Used to specify how to split the elements in a sequence when performing a slice operation. See :class:`~scannerpy.partitioner.TaskPartitioner`. protobufs : ProtobufGenerator Used to construct protobuf objects that handle serialization/deserialization of the outputs of Ops. """ def __init__(self, master: str = None, workers: List[str] = None, start_cluster: bool = True, config_path: str = None, config: Config = None, debug: bool = None, enable_watchdog: bool = True, prefetch_table_metadata: bool = True, no_workers_timeout: float = 30, grpc_timeout: float = 30, new_job_retries_limit: int = 5, machine_params = None): if config: self.config = config else: self.config = Config(config_path) self._start_cluster = start_cluster self._workers_started = False self._enable_watchdog = enable_watchdog # Worker always has watchdog enabled to determine when master # connection has failed self._enable_worker_watchdog = True self._prefetch_table_metadata = prefetch_table_metadata self._no_workers_timeout = no_workers_timeout self._debug = debug self._grpc_timeout = grpc_timeout self._new_job_retries_limit = new_job_retries_limit self._machine_params = machine_params if debug is None: self._debug = (master is None and workers is None) self._master = None self._bindings = bindings # Setup Client metadata self._db_path = self.config.db_path self._storage = self.config.storage self._cached_db_metadata = None self._png_dump_prefix = '__png_dump_{:s}_{:s}' self.ops = OpGenerator(self) self.sources = SourceGenerator(self) self.sinks = SinkGenerator(self) self.streams = StreamsGenerator(self) self.partitioner = TaskPartitioner(self) self.io = IOGenerator(self) self._op_cache = {} self._python_ops = set() self._modules = set() self._enumerator_info_cache = {} self._sink_info_cache = {} self._master_conn = None self._workers = {} self._worker_conns = None self._worker_paths = workers self.start_master(master) def __del__(self): # Client crashed during config creation if this attr is missing if hasattr(self, '_start_cluster') and self._start_cluster: self._stop_heartbeat() self.stop_cluster() def __enter__(self): return self def __exit__(self, exception_type, exception_val, exception_tb): self._stop_heartbeat() self.stop_cluster() del self._db def _load_descriptor(self, descriptor, path): d = descriptor() path = '{}/{}'.format(self._db_path, path) try: d.ParseFromString(self._storage.read(path)) except UserWarning: raise ScannerException( 'Internal error. Missing file {}'.format(path)) return d def _save_descriptor(self, descriptor, path): self._storage.write(('{}/{}'.format(self._db_path, path)), descriptor.SerializeToString()) def _load_table_metadata(self, table_names): NUM_TABLES_TO_READ = 100000 tables = [] for i in range(0, len(table_names), NUM_TABLES_TO_READ): get_tables_params = protobufs.GetTablesParams() for table_name in table_names[i:i + NUM_TABLES_TO_READ]: get_tables_params.tables.append(table_name) get_tables_result = self._try_rpc( lambda: self._master.GetTables(get_tables_params)) if not get_tables_result.result.success: raise ScannerException( 'Internal error: GetTables returned error: {}'.format( get_tables_result.result.msg)) tables.extend(get_tables_result.tables) return tables def _load_db_metadata(self): if self._cached_db_metadata is None: desc = self._load_descriptor(protobufs.DatabaseDescriptor, 'db_metadata.bin') self._cached_db_metadata = desc # table id cache self._table_id = {} self._table_name = {} self._table_committed = {} for i, table in enumerate(self._cached_db_metadata.tables): if table.name in self._table_name: raise ScannerException( 'Internal error: multiple tables with same name: {}'. format(table.name)) self._table_id[table.id] = i self._table_name[table.name] = i self._table_committed[table.id] = table.committed if self._prefetch_table_metadata: self._table_descriptor = {} # Read all table descriptors from client table_names = list(self._table_name.keys()) tables = self._load_table_metadata(table_names) for table in tables: self._table_descriptor[table.id] = table return self._cached_db_metadata def _make_grpc_channel(self, address): max_message_length = 1024 * 1024 * 1024 return grpc.insecure_channel( address, options=[('grpc.max_send_message_length', max_message_length), ('grpc.max_receive_message_length', max_message_length)]) def _connect_to_worker(self, address): channel = self._make_grpc_channel(address) worker = protobufs.WorkerStub(channel) try: self._worker.Ping( protobufs.Empty(), timeout=self._grpc_timeout) return worker except grpc.RpcError as e: status = e.code() if status == grpc.StatusCode.UNAVAILABLE: pass else: raise ScannerException('Master ping errored with status: {}' .format(status)) return None def _connect_to_master(self): channel = self._make_grpc_channel(self._master_address) self._master = protobufs.MasterStub(channel) result = False try: self._master.Ping( protobufs.Empty(), timeout=self._grpc_timeout) result = True except grpc.RpcError as e: status = e.code() if status == grpc.StatusCode.UNAVAILABLE: pass elif status == grpc.StatusCode.OK: result = True else: raise ScannerException('Master ping errored with status: {}' .format(status)) return result def _run_remote_cmd(self, host, cmd, nohup=False): host_name, _, _ = host.partition(':') host_ip = socket.gethostbyname(host_name) if (ipaddress.ip_address(host_ip).is_loopback or host_name == 'localhost'): return Popen(cmd, shell=True) else: cmd = cmd.replace('"', '\\"') return Popen( "ssh {} \"cd {} && {} {} {}\"".format(host_name, os.getcwd(), '' if nohup else '', cmd, '' if nohup else ''), shell=True) def _start_heartbeat(self): # Start up heartbeat to keep master alive def heartbeat_task(stop_event, master_address, ppid): if sys.platform == 'linux' or sys.platform == 'linux2': prctl.set_pdeathsig(signal.SIGTERM) channel = self._make_grpc_channel(master_address) master = grpc_types.MasterStub(channel) while not stop_event.is_set(): if os.getppid() != ppid: return try: master.PokeWatchdog( rpc_types.Empty(), timeout=self._grpc_timeout) except grpc.RpcError as e: pass time.sleep(1) self._heartbeat_stop_event = Event() pid = os.getpid() self._heartbeat_process = Process( target=heartbeat_task, args=(self._heartbeat_stop_event, self._master_address, pid)) self._heartbeat_process.daemon = True self._heartbeat_process.start() def _stop_heartbeat(self): if (self._enable_watchdog and self._heartbeat_stop_event): self._heartbeat_stop_event.set() def _handle_signal(self, signum, frame): if (signum == signal.SIGINT or signum == signal.SIGTERM or signum == signal.SIGSEGV or signum == signal.SIGABRT): self._stop_heartbeat() self.stop_cluster() if signum == signal.SIGINT: sys.exit(0) else: sys.exit(1) def _try_rpc(self, fn): try: result = fn() except grpc.RpcError as e: raise ScannerException(e) if isinstance(result, protobufs.Result): if not result.success: raise ScannerException(result.msg) return result def _get_source_info(self, source_name): source_info_args = protobufs.SourceInfoArgs() source_info_args.source_name = source_name source_info = self._try_rpc( lambda: self._master.GetSourceInfo(source_info_args)) if not source_info.result.success: raise ScannerException(source_info.result.msg) return source_info def _get_enumerator_info(self, enumerator_name): if enumerator_name in self._enumerator_info_cache: return self._enumerator_info_cache[enumerator_name] enumerator_info_args = protobufs.EnumeratorInfoArgs() enumerator_info_args.enumerator_name = enumerator_name enumerator_info = self._try_rpc( lambda: self._master.GetEnumeratorInfo(enumerator_info_args)) if not enumerator_info.result.success: raise ScannerException(enumerator_info.result.msg) self._enumerator_info_cache[enumerator_name] = enumerator_info return enumerator_info def _get_sink_info(self, sink_name): if sink_name in self._sink_info_cache: return self._sink_info_cache[sink_name] sink_info_args = protobufs.SinkInfoArgs() sink_info_args.sink_name = sink_name sink_info = self._try_rpc( lambda: self._master.GetSinkInfo(sink_info_args)) if not sink_info.result.success: raise ScannerException(sink_info.result.msg) self._sink_info_cache[sink_name] = sink_info return sink_info def _get_op_info(self, op_name): if op_name in self._op_cache: op_info = self._op_cache[op_name] else: op_info_args = protobufs.OpInfoArgs() op_info_args.op_name = op_name op_info = self._try_rpc( lambda: self._master.GetOpInfo(op_info_args, self._grpc_timeout) ) if not op_info.result.success: raise ScannerException(op_info.result.msg) self._op_cache[op_name] = op_info return op_info def _check_has_op(self, op_name): self._get_op_info(op_name) def _get_input_columns(self, op_name): return self._get_op_info(op_name).input_columns def _get_output_columns(self, op_name): return self._get_op_info(op_name).output_columns def _toposort(self, outputs): # Perform DFS on modified graph edges = defaultdict(list) in_edges_left = defaultdict(int) source_nodes = [] explored_nodes = set() stack = list(outputs) while len(stack) > 0: c = stack.pop() if c in explored_nodes: continue explored_nodes.add(c) if isinstance(c, Source): source_nodes.append(c) continue for input in c._inputs: edges[input._op].append(c) in_edges_left[c] += 1 if input._op not in explored_nodes: stack.append(input._op) # Keep track of position of input ops and sampling/slicing ops # to use for associating job args to source_ops = {} stream_ops = {} output_ops = {} # Compute sorted list eval_sorted = [] eval_index = {} stack = source_nodes[:] while len(stack) > 0: c = stack.pop() eval_sorted.append(c) op_idx = len(eval_sorted) - 1 eval_index[c] = op_idx for child in edges[c]: in_edges_left[child] -= 1 if in_edges_left[child] == 0: stack.append(child) if isinstance(c, Source): source_ops[c] = op_idx elif (c._name == "Sample" or c._name == "Space" or c._name == "Slice" or c._name == "Unslice"): stream_ops[c] = op_idx elif isinstance(c, Sink): output_ops[c] = op_idx return eval_sorted, \ eval_index, \ source_ops, \ stream_ops, \ output_ops def _parse_size_string(self, s): (prefix, suffix) = (s[:-1], s[-1]) mults = {'G': 1024**3, 'M': 1024**2, 'K': 1024**1} suffix = suffix.upper() if suffix not in mults: raise ScannerException('Invalid size suffix in "{}"'.format(s)) return int(prefix) * mults[suffix]
[docs] def load_op(self, so_path: str, proto_path: str = None): r"""Loads a custom op into the Scanner runtime. Parameters ---------- so_path Path to the custom op's shared library (.so). proto_path Path to the custom op's arguments protobuf if one exists. Raises ------ ScannerException Raised when the master fails to load the op. """ if proto_path is not None: protobufs.add_module(proto_path) op_path = protobufs.OpPath() op_path.path = so_path self._try_rpc( lambda: self._master.LoadOp(op_path, timeout=self._grpc_timeout)) self._modules.add((so_path, proto_path))
[docs] def has_gpu(self): try: with open(os.devnull, 'w') as f: subprocess.check_call(['nvidia-smi'], stdout=f, stderr=f) return True except: pass return False
[docs] def summarize(self) -> str: r"""Returns a human-readable summarization of the client state. """ summary = '' db_meta = self._load_db_metadata() if len(db_meta.tables) == 0: return 'summarize: your client is empty!' tables = [ ('TABLES', [ ('ID', [str(t.id) for t in db_meta.tables]), ('Name', [t.name for t in db_meta.tables]), ('# rows', [str(self.table(t.id).num_rows()) for t in db_meta.tables]), ('Columns', [ ', '.join(self.table(t.id).column_names()) for t in db_meta.tables ]), ('Committed', [ 'true' if self.table(t.id).committed() else 'false' for t in db_meta.tables ]), ]), ] for table_idx, (label, cols) in enumerate(tables): if table_idx > 0: summary += '\n\n' num_cols = len(cols) max_col_lens = [ max(max([len(s) for s in c] or [0]), len(name)) for name, c in cols ] table_width = sum(max_col_lens) + 3 * (num_cols - 1) label = '** {} **'.format(label) summary += ' ' * int( table_width / 2 - len(label) / 2) + label + '\n' summary += '-' * table_width + '\n' col_name_fmt = ' | '.join(['{{:{}}}' for _ in range(num_cols)]) col_name_fmt = col_name_fmt.format(*max_col_lens) summary += col_name_fmt.format(*[s for s, _ in cols]) + '\n' summary += '-' * table_width + '\n' row_fmt = ' | '.join(['{{:{}}}' for _ in range(num_cols)]) row_fmt = row_fmt.format(*max_col_lens) for i in range(len(cols[0][1])): summary += row_fmt.format(*[c[i] for _, c in cols]) + '\n' return summary
[docs] def start_master(self, master: str): r"""Starts a Scanner master. Parameters ---------- master ssh-able address of the master node. """ if master is None: self._master_address = ( self.config.master_address + ':' + self.config.master_port) else: self._master_address = master if ':' not in self._master_address: raise ScannerException( ('Did you forget to specify the master port number? ' 'Specified address is {s:s}. It should look like {s:s}:5001') .format(s=self._master_address)) # Start up heartbeat to keep master alive # NOTE(apoms): This MUST BE before any grpc channel is created, since it # forks a process and forking after channel creation causes hangs in the # forked process under grpc # https://github.com/grpc/grpc/issues/13873#issuecomment-358476408 if self._enable_watchdog: self._start_heartbeat() # Boot up C++ database bindings self._db = self._bindings.Database(self.config.storage_config, str(self._db_path), str(self._master_address)) if self._start_cluster: # Set handler to shutdown cluster on signal # TODO(apoms): we should clear these handlers when stopping # the cluster signal.signal(signal.SIGINT, self._handle_signal) signal.signal(signal.SIGTERM, self._handle_signal) signal.signal(signal.SIGSEGV, self._handle_signal) signal.signal(signal.SIGABRT, self._handle_signal) if self._debug: self._master_conn = None res = self._bindings.start_master( self._db, self.config.master_port, SCRIPT_DIR, self._enable_watchdog, self._no_workers_timeout, self._new_job_retries_limit).success assert res res = self._connect_to_master() if not res: raise ScannerException( 'Failed to connect to local master process on port ' '{:s}. (Is there another process that is bound to that ' 'port already?)'.format(self.config.master_port)) else: master_port = self._master_address.partition(':')[2] # https://stackoverflow.com/questions/30469575/how-to-pickle-and-unpickle-to-portable-string-in-python-3 pickled_config = pickle.dumps(self.config, 0).decode() master_cmd = ( 'python3 -c ' + '\"from scannerpy import start_master\n' + 'import pickle\n' + 'config=pickle.loads(bytes(\'\'\'{config:s}\'\'\', \'utf8\'))\n' + 'start_master(port=\'{master_port:s}\', block=True,\n' + ' config=config, watchdog={watchdog},\n' + ' no_workers_timeout={no_workers})\" ' + '').format( master_port=master_port, config=pickled_config, watchdog=self._enable_watchdog, no_workers=self._no_workers_timeout) self._master_conn = self._run_remote_cmd( self._master_address, master_cmd, nohup=True) # Wait for master to start slept_so_far = 0 sleep_time = 60 while slept_so_far < sleep_time: if self._connect_to_master(): break time.sleep(0.3) slept_so_far += 0.3 if slept_so_far >= sleep_time: self._master_conn.kill() self._master_conn = None raise ScannerException( 'Timed out waiting to connect to master') else: self._master_conn = None self._worker_conns = None # Wait for master to start slept_so_far = 0 sleep_time = 20 while slept_so_far < sleep_time: if self._connect_to_master(): break time.sleep(0.3) slept_so_far += 0.3 if slept_so_far >= sleep_time: raise ScannerException( 'Timed out waiting to connect to master')
[docs] def start_workers(self, workers: List[str]): r"""Starts Scanner workers. Parameters ---------- workers list of ssh-able addresses of the worker nodes. """ if workers is None: self._worker_addresses = [ self.config.master_address + ':' + self.config.worker_port ] else: self._worker_addresses = workers if self._debug: self._worker_conns = None machine_params = self._machine_params or self._bindings.default_machine_params() for i in range(len(self._worker_addresses)): start_worker( self._master_address, port=str(int(self.config.worker_port) + i), config=self.config, db=self._db, watchdog=self._enable_worker_watchdog, machine_params=machine_params) else: pickled_config = pickle.dumps(self.config, 0).decode() worker_cmd = ( 'python3 -c ' + '\"from scannerpy import start_worker\n' + 'import pickle\n' + 'config=pickle.loads(bytes(\'\'\'{config:s}\'\'\', \'utf8\'))\n' + 'start_worker(\'{master:s}\', port=\'{worker_port:s}\',\n' + ' block=True,\n' + ' watchdog={watchdog},' + ' config=config)\" ' + '') # Start workers now that master is ready self._worker_conns = [] ignored_nodes = 0 for w in self._worker_addresses: try: self._worker_conns.append( self._run_remote_cmd( w, worker_cmd.format( master=self._master_address, config=pickled_config, watchdog=self._enable_worker_watchdog, worker_port=w.partition(':')[2]), nohup=True)) except Exception as e: print( 'WARNING: Failed to ssh into {:s} because: {:s}'.format( w, repr(e))) ignored_nodes += 1 slept_so_far = 0 # Has to be this long for GCS sleep_time = 60 while slept_so_far < sleep_time: active_workers = self._master.ActiveWorkers( protobufs.Empty(), timeout=self._grpc_timeout) if (len(active_workers.workers) > len(self._worker_conns)): raise ScannerException( ('Master has more workers than requested ' + '({:d} vs {:d})').format( len(active_workers.workers), len(self._worker_conns))) if (len(active_workers.workers) == len(self._worker_conns)): break time.sleep(0.3) slept_so_far += 0.3 if slept_so_far >= sleep_time: self.stop_cluster() raise ScannerException( 'Timed out waiting for workers to connect to master') if ignored_nodes > 0: print( 'Ignored {:d} nodes during startup.'.format(ignored_nodes)) self._workers_started = True
[docs] def stop_cluster(self): r"""Stops the Scanner master and workers. """ if self._start_cluster: if self._master: # Stop heartbeat self._stop_heartbeat() try: self._try_rpc( lambda: self._master.Shutdown( protobufs.Empty(), timeout=self._grpc_timeout)) except: pass self._master = None if self._master_conn: self._master_conn.kill() self._master_conn = None if self._worker_conns: for wc in self._worker_conns: wc.kill() self._worker_conns = None
[docs] def register_op(self, name: str, input_columns: List[Union[str, Tuple[str, ColumnType]]], output_columns: List[Union[str, Tuple[str, ColumnType]]], variadic_inputs: bool = False, stencil: List[int] = None, unbounded_state: bool = False, bounded_state: int = None, proto_path: str = None): r"""Register a new Op with the Scanner master. Parameters ---------- name Name of the Op. input_columns A list of the inputs for this Op. Can be either the name of the input as a string or a tuple of ('name', ColumnType). If only the name is specified as a string, the ColumnType is assumed to be ColumnType.Blob. output_columns A list of the outputs for this Op. Can be either the name of the output as a string or a tuple of ('name', ColumnType). If only the name is specified as a string, the ColumnType is assumed to be ColumnType.Blob. variadic_inputs If true, this Op may take a variable number of inputs and `input_columns` is ignored. Variadic inputs are specified as positional arguments when invoking the Op, instead of keyword arguments. stencil Specifies the default stencil to use for the Op. If none, indicates that the the Op does not have the ability to stencil. A stencil of [0] should be specified if the Op can stencil but should not by default. unbounded_state If true, indicates that the Op needs to see all previous elements of its input sequences before it can compute a given element. For example, to compute output element at index 100, the Op must have already produced elements 0-99. This option is mutually exclusive with `bounded_state`. bounded_state If true, indicates that the Op needs to see all previous elements of its input sequences before it can compute a given element. For example, to compute output element at index 100, the Op must have already produced elements 0-99. This option is mutually exclusive with `bounded_state`. proto_path Optional path to the proto file that describes the configuration arguments to this Op. Raises ------ ScannerException Raised when the master fails to register the Op. """ op_registration = protobufs.OpRegistration() op_registration.name = name op_registration.variadic_inputs = variadic_inputs op_registration.has_unbounded_state = unbounded_state def add_col(columns, col): if isinstance(col, str): c = columns.add() c.name = col c.type = protobufs.Bytes elif isinstance(col, collections.Iterable): c = columns.add() c.name = col[0] c.type = ColumnType.to_proto(protobufs, col[1]) c.type_name = col[2].cpp_name else: raise ScannerException( 'Column ' + col + ' must be a string name or a tuple of ' '(name, column_type)') for in_col in input_columns: add_col(op_registration.input_columns, in_col) for out_col in output_columns: add_col(op_registration.output_columns, out_col) if stencil is None: op_registration.can_stencil = False else: op_registration.can_stencil = True op_registration.preferred_stencil.extend(stencil) if bounded_state is not None: assert isinstance(bounded_state, int) op_registration.has_bounded_state = True op_registration.warmup = bounded_state if proto_path is not None: protobufs.add_module(proto_path) self._try_rpc(lambda: self._master.RegisterOp( op_registration, timeout=self._grpc_timeout))
[docs] def register_python_kernel( self, op_name: str, device_type: DeviceType, kernel: Union[types.FunctionType, types.BuiltinFunctionType, Kernel], batch: int = 1): r"""Register a Python Kernel with the Scanner master. Parameters ---------- op_name Name of the Op. device_type The device type of the resource this kernel uses. kernel The class or function that implements the kernel. batch Specifies a default for how many elements this kernel should batch over. If `batch == 1`, kernel is assume to not be able to batch. Raises ------ ScannerException Raised when the master fails to register the kernel. """ if isinstance(kernel, types.FunctionType) or isinstance( kernel, types.BuiltinFunctionType): class KernelWrapper(Kernel): def __init__(self, config, **kwargs): self._config = config def execute(self, columns): return kernel(self._config, columns) kernel_cls = KernelWrapper else: kernel_cls = kernel py_registration = protobufs.PythonKernelRegistration() py_registration.op_name = op_name py_registration.device_type = DeviceType.to_proto( protobufs, device_type) py_registration.kernel_code = cloudpickle.dumps(kernel_cls) py_registration.batch_size = batch self._try_rpc( lambda: self._master.RegisterPythonKernel( py_registration, timeout=self._grpc_timeout)) self._python_ops.add(op_name)
[docs] def ingest_videos( self, videos: List[Tuple[str, str]], inplace: bool = False, force: bool = False) -> Tuple[List[Table], List[Tuple[str, str]]]: r"""Creates tables from videos. Parameters ---------- videos The list of videos to ingest into the client. Each element in the list should be ('table_name', 'path/to/video'). inplace If true, ingests the videos without copying them into the client. Currently only supported for mp4 containers. force If true, deletes existing tables with the same names. Returns ------- tables: List[Table] List of table objects for the ingested videos. failures: List[Tuple[str, str]] List of ('path/to/video', 'reason for failure') tuples for each video which failed to ingest. """ if len(videos) == 0: raise ScannerException('Must ingest at least one video.') [table_names, paths] = list(zip(*videos)) to_delete = [] for table_name in table_names: if self.has_table(table_name): if force is True: to_delete.append(table_name) else: raise ScannerException( 'Attempted to ingest over existing table {}' .format(table_name)) self.delete_tables(to_delete) ingest_params = protobufs.IngestParameters() ingest_params.table_names.extend(table_names) ingest_params.video_paths.extend(paths) ingest_params.inplace = inplace ingest_result = self._try_rpc( lambda: self._master.IngestVideos(ingest_params)) if not ingest_result.result.success: raise ScannerException(ingest_result.result.msg) failures = list( zip(ingest_result.failed_paths, ingest_result.failed_messages)) self._cached_db_metadata = None return ([ self.table(t) for (t, p) in videos if p not in ingest_result.failed_paths ], failures)
[docs] def has_table(self, name: str) -> bool: r"""Checks if a table exists in the database. Parameters ---------- name The name of the table to check for. Returns ------- bool True if the table exists, false otherwise. """ db_meta = self._load_db_metadata() if name in self._table_name: return True return False
[docs] def delete_tables(self, names: List[str]): r"""Deletes tables from the database. Parameters ---------- names The names of the tables to delete. """ delete_tables_params = protobufs.DeleteTablesParams() for name in names: delete_tables_params.tables.append(name) self._try_rpc(lambda: self._master.DeleteTables(delete_tables_params)) self._cached_db_metadata = None
[docs] def delete_table(self, name: str): r"""Deletes a table from the database. Parameters ---------- name The name of the table to delete. """ self.delete_tables([name])
[docs] def new_table(self, name: str, columns: List[str], rows: List[List[bytes]], fns=None, force: bool = False) -> Table: r"""Creates a new table from a list of rows. Parameters ---------- name String name of the table to create columns List of names of table columns rows List of rows with each row a list of elements corresponding to the specified columns. Elements must be strings of serialized representations of the data. fns force Returns ------- Table The new table object. """ if self.has_table(name): if force: self.delete_table(name) else: raise ScannerException( 'Attempted to create table with existing ' 'name {}'.format(name)) if fns is not None: rows = [[fn(col, protobufs) for fn, col in zip(fns, row)] for row in rows] params = protobufs.NewTableParams() params.table_name = name params.columns[:] = columns for i, row in enumerate(rows): row_proto = params.rows.add() row_proto.columns[:] = row self._try_rpc(lambda: self._master.NewTable(params)) self._cached_db_metadata = None return self.table(name)
[docs] def table(self, name: str) -> Table: r"""Retrieves a Table. Parameters ---------- name Name of the table to retrieve. Returns ------- Table The table object. """ db_meta = self._load_db_metadata() table_name = None table_id = None if isinstance(name, str): table_name = name if name in self._table_name: table_id = db_meta.tables[self._table_name[name]].id elif isinstance(name, int): table_id = name if name in self._table_id: table_name = db_meta.tables[self._table_id[name]].name else: raise ScannerException('Invalid table identifier') table = Table(self, table_name, table_id) if self._prefetch_table_metadata and table_id in self._table_descriptor: table._descriptor = self._table_descriptor[table_id] return table
[docs] def sequence(self, name: str) -> Column: t = self.table(name) if t.committed(): column_names = t.column_names() return t.column('frame' if 'frame' in column_names else 'column') else: return t.column('column')
[docs] def get_profile(self, job_name, **kwargs): db_meta = self._load_db_metadata() if isinstance(job_name, str): job_id = None for job in db_meta.bulk_jobs: if job.name == job_name: job_id = job.id break if job_id is None: raise ScannerException( 'Job name {} does not exist'.format(job_name)) else: job_id = job_name return Profile(self, job_id, **kwargs)
[docs] def get_active_jobs(self): req = protobufs.GetJobsRequest() reply = self._try_rpc(lambda: self._master.GetJobs( req, timeout=self._grpc_timeout)) return [x for x in reply.active_bulk_jobs]
[docs] def wait_on_job_gen(self, bulk_job_id, show_progress=True): pbar = None total_tasks = None last_task_count = 0 last_jobs_failed = 0 last_failed_workers = 0 while True: try: req = protobufs.GetJobStatusRequest() req.bulk_job_id = bulk_job_id job_status = self._master.GetJobStatus( req, timeout=self._grpc_timeout) if show_progress and pbar is None and job_status.total_jobs != 0 \ and job_status.total_tasks != 0: total_tasks = job_status.total_tasks # Lower smoothing provides more accurate ETAs over long jobs. # See: https://tqdm.github.io/docs/tqdm/ pbar = tqdm(total=total_tasks, smoothing=0.05) except grpc.RpcError as e: raise ScannerException(e) if job_status.finished: break if pbar is not None: tasks_completed = job_status.tasks_done if tasks_completed - last_task_count > 0: pbar.update(tasks_completed - last_task_count) last_task_count = tasks_completed pbar.set_postfix({ 'jobs': job_status.total_jobs - job_status.jobs_done, 'tasks': job_status.total_tasks - job_status.tasks_done, 'workers': job_status.num_workers, }) time_str = time.strftime('%l:%M%p %z on %b %d, %Y') if last_jobs_failed < job_status.jobs_failed: num_jobs_failed = job_status.jobs_failed - last_jobs_failed pbar.write('{:d} {:s} failed at {:s}'.format( num_jobs_failed, 'job' if num_jobs < 2 else 'jobs', time_str)) if last_failed_workers < job_status.failed_workers: num_workers_failed = job_status.failed_workers - last_failed_workers pbar.write('{:d} {:s} failed at {:s}'.format( num_workers_failed, 'worker' if num_workers_failed < 2 else 'workers', time_str)) last_jobs_failed = job_status.jobs_failed last_failed_workers = job_status.failed_workers yield if pbar is not None: pbar.update(total_tasks - last_task_count) pbar.close() yield job_status return
[docs] def wait_on_job(self, *args, **kwargs): sleep_schedule = [] for i in range(15): sleep_schedule += [0.02 * (2 ** i)] * (50 // (2 ** i)) gen = self.wait_on_job_gen(*args, **kwargs) sleep_attempt = 0 while True: try: job_status = next(gen) if sleep_attempt < len(sleep_schedule): time.sleep(sleep_schedule[sleep_attempt]) else: time.sleep(1.0) sleep_attempt += 1 except StopIteration: break return job_status
[docs] def bulk_fetch_video_metadata(self, tables): params = protobufs.GetVideoMetadataParams( tables=[t.name() for t in tables]) result = self._try_rpc(lambda: self._master.GetVideoMetadata( params, timeout=self._grpc_timeout)) return result.videos
[docs] def batch_load(self, tables, column, callback, batch_size=1000, workers=16): def batch(l, n): for i in range(0, len(l), n): yield l[i:i+n] with mp.get_context('spawn').Pool(processes=workers) as pool: for _ in tqdm(pool.imap_unordered( _batch_load_column, [(l, column, callback) for l in batch([t.name() for t in tables], batch_size)])): pass
[docs] def run(self, outputs: Union[Sink, List[Sink]], perf_params: Callable[[], PerfParams], cache_mode: CacheMode = CacheMode.Error, show_progress: bool = True, profiling: bool = False, task_timeout: int = 0, checkpoint_frequency: int = 10, detach: bool = False, profiler_level: int = 1): r"""Runs a collection of jobs. Parameters ---------- outputs The Sink or Sinks that should be processed. perf_params Performance-related parameters. These options should be tuned to improve the performance of executing a computation graph. cache_mode Determines whether to overwrite, ignore, or raise an error when running a job on existing outputs. show_progress If true, will display an ASCII progress bar measuring job status. profiling Other Parameters ---------------- task_timeout checkpoint_frequency Returns ------- int The job id. """ if not isinstance(outputs, list): outputs = [outputs] sorted_ops, op_index, source_ops, stream_ops, output_ops = ( self._toposort(outputs)) to_ingest = defaultdict(list) for op in source_ops.keys(): streams = op._outputs[0]._streams if isinstance(streams[0], NamedVideoStream): to_ingest[streams[0].storage()] += streams for storage, streams in to_ingest.items(): storage.ingest(self, streams) # Collect compression annotations to add to job output_column_names = [] compression_options = [] for op in sorted_ops: if op in outputs: for out_col in op.inputs(): opts = protobufs.OutputColumnCompression() opts.codec = 'default' if out_col._type == protobufs.Video: for k, v in out_col._encode_options.items(): if k == 'codec': opts.codec = v else: opts.options[k] = str(v) compression_options.append(opts) # Get output columns output_column_names += op._output_names job_params = protobufs.BulkJobParameters() job_name = ''.join(choice(ascii_uppercase) for _ in range(12)) job_params.job_name = job_name job_params.ops.extend([e.to_proto(op_index) for e in sorted_ops]) job_params.output_column_names.extend(output_column_names) N = None for op in sorted_ops: n = None if op._job_args is not None: n = len(op._job_args) elif op._extra is not None and 'job_args' in op._extra: if not isinstance(op._extra['job_args'], list): op_name = op._name if op_name == 'Sample': op_name = op._extra['type'] raise ScannerException( "The arguments to op `{}` are stream config arguments and must be lists." .format(op._name)) n = len(op._extra['job_args']) else: continue if N is None: N = n elif n != N: raise ScannerException("Op `{}` had {} per-stream arguments, but expected {}" \ .format(op._name, n, N)) assert N is not None # Collect set of existing stored streams for each output output_ops_list = list(output_ops.keys()) to_delete = [] for op in output_ops_list: streams = op._streams storage = streams[0].storage() to_delete.append(set([i for i, s in enumerate(streams) if s.exists()])) # Check that all outputs have matching existing stored streams if cache_mode == CacheMode.Ignore: for i, s in enumerate(to_delete): for j, t in enumerate(to_delete): if i == j: continue diff = s - t if len(diff) > 0 : raise Exception( ("Output for stream {} exists in {} but does not exist in {}. Either both or " "neither should exist.").format( next(iter(diff)), output_ops_list[i]._name, output_ops_list[j]._name)) to_cache = set() for output_idx, per_output_to_delete in enumerate(to_delete): if len(per_output_to_delete) == 0: continue op = output_ops_list[output_idx] if cache_mode == CacheMode.Error: stream_idx = next(iter(per_output_to_delete)) stream = op._streams[stream_idx] stored_stream_type_name = stream.__class__.__name__ stored_stream_instance_name = stream.name() raise ScannerException( ("Running this job would overwrite the {:s} `{}`. You can " "change this behavior using cl.run(cache_mode=sp.CacheMode.Ignore) to " "not run for outputs that already exist, or sp.CacheMode.Overwrite to " "run anyway and overwrite them.").format( stored_stream_type_name, stored_stream_instance_name)) elif cache_mode == CacheMode.Overwrite: td = per_output_to_delete streams = op._streams storage = streams[0].storage() storage.delete(self, [streams[i] for i in td]) elif cache_mode == CacheMode.Ignore: to_cache |= per_output_to_delete # Struct of arrays to array of structs conversion jobs = [] for i in range(N): if i in to_cache: continue op_args = {} for op in sorted_ops: if op._job_args is not None: op_args[op] = op._job_args[i] elif op._extra is not None and 'job_args' in op._extra: op_args[op] = op._extra['job_args'][i] jobs.append(Job(op_args=op_args)) N -= len(to_cache) if len(jobs) == 0: return None for job in jobs: j = job_params.jobs.add() op_to_op_args = job.op_args() for op in sorted_ops: if op in source_ops: op_idx = source_ops[op] if not op in op_to_op_args: raise ScannerException( 'No arguments bound to source {:s}.'.format( op._name)) else: args = op_to_op_args[op] args = op_to_op_args[op] source_input = j.inputs.add() source_input.op_index = op_idx source_input.enumerator_args = args elif op in stream_ops: op_idx = stream_ops[op] # If this is an Unslice Op, ignore it since it has no args if op._name == 'Unslice': continue args = op_to_op_args[op] saa = j.sampling_args_assignment.add() saa.op_index = op_idx if not isinstance(args, SliceList): args = SliceList([args]) arg_builder = op._extra['arg_builder'] for arg in args: # Construct the sampling_args using the arg_builder if isinstance(arg, tuple): # Positional arguments sargs = arg_builder(*arg) elif isinstance(arg, dict): # Keyword arguments sargs = arg_builder(**arg) else: # Single argument sargs = arg_builder(arg) sa = saa.sampling_args.add() sa.CopyFrom(sargs) elif op in output_ops: op_idx = output_ops[op] sink_args = j.outputs.add() sink_args.op_index = op_idx args = op_to_op_args[op] sink_args.args = args else: # Regular old Op op_idx = op_index[op] oargs = j.op_args.add() oargs.op_index = op_idx if not op in op_to_op_args: continue else: args = op_to_op_args[op] if not isinstance(args, SliceList): args = SliceList([args]) for arg in args: oargs.op_args.append(arg) #oargs.op_args.append(serialize_args(arg)) perf_params = perf_params( inputs=[op._outputs[0]._streams for op in source_ops.keys()], ops=sorted_ops) job_params.compression.extend(compression_options) job_params.pipeline_instances_per_node = (perf_params.pipeline_instances_per_node or -1) job_params.work_packet_size = perf_params.work_packet_size job_params.io_packet_size = perf_params.io_packet_size job_params.profiling = profiling job_params.tasks_in_queue_per_pu = perf_params.queue_size_per_pipeline job_params.load_sparsity_threshold = perf_params.load_sparsity_threshold job_params.boundary_condition = ( protobufs.BulkJobParameters.REPEAT_EDGE) job_params.task_timeout = task_timeout job_params.checkpoint_frequency = checkpoint_frequency job_params.profiler_level = profiler_level cpu_pool = perf_params.cpu_pool gpu_pool = perf_params.gpu_pool job_params.memory_pool_config.pinned_cpu = False if cpu_pool is not None: job_params.memory_pool_config.cpu.use_pool = True if cpu_pool[0] == 'p': job_params.memory_pool_config.pinned_cpu = True cpu_pool = cpu_pool[1:] size = self._parse_size_string(cpu_pool) job_params.memory_pool_config.cpu.free_space = size else: job_params.memory_pool_config.cpu.use_pool = False if gpu_pool is not None: job_params.memory_pool_config.gpu.use_pool = True size = self._parse_size_string(gpu_pool) job_params.memory_pool_config.gpu.free_space = size else: job_params.memory_pool_config.gpu.use_pool = False if not self._workers_started and self._start_cluster: self.start_workers(self._worker_paths) # Invalidate db metadata because of job run self._cached_db_metadata = None # Run the job result = self._try_rpc(lambda: self._master.NewJob( job_params, timeout=self._grpc_timeout)) bulk_job_id = result.bulk_job_id if detach: return bulk_job_id job_status = self.wait_on_job(bulk_job_id, show_progress) if not job_status.result.success: raise ScannerException(job_status.result.msg) db_meta = self._load_db_metadata() return bulk_job_id
[docs]def start_master(port: int = None, config: Config = None, config_path: str = None, block: bool = False, watchdog: bool = True, no_workers_timeout: float = 30, new_job_retries_limit: int = 5): r""" Start a master server instance on this node. Parameters ---------- port The port number to start the master on. If unspecified, it will be read from the provided Config. config The scanner Config to use. If specified, config_path is ignored. config_path : optional Path to a Scanner configuration TOML, by default assumed to be `~/.scanner/config.toml`. block : optional If true, will wait until the server is shutdown. Server will not shutdown currently unless wait_for_server_shutdown is eventually called. watchdog : optional If true, the master will shutdown after a time interval if PokeWatchdog is not called. no_workers_timeout : optional The interval after which the master will consider a job to have failed if it has no workers connected to it. Returns ------- Database A cpp database instance. """ config = config or Config(config_path) port = port or config.master_port # Load all protobuf types db = bindings.Database(config.storage_config, config.db_path, (config.master_address + ':' + port)) result = bindings.start_master(db, port, SCRIPT_DIR, watchdog, no_workers_timeout, new_job_retries_limit) if not result.success(): raise ScannerException('Failed to start master: {}'.format( result.msg())) if block: bindings.wait_for_server_shutdown(db) return db
[docs]def start_worker(master_address: str, machine_params=None, port: int = None, config: Config = None, config_path: str = None, block: bool = False, watchdog: bool = True, db: Client = None): r"""Starts a worker instance on this node. Parameters ---------- master_address The address of the master server to connect this worker to. The expected format is '0.0.0.0:5000' (ip:port). machine_params Describes the resources of the machine that the worker should manage. If left unspecified, the machine resources will be inferred. config The Config object to use in creating the worker. If specified, config_path is ignored. config_path Path to a Scanner configuration TOML, by default assumed to be `~/.scanner/config.toml`. block If true, will wait until the server is shutdown. Server will not shutdown currently unless wait_for_server_shutdown is eventually called. watchdog If true, the worker will shutdown after a time interval if PokeWatchdog is not called. Other Parameters ---------------- db This is for internal usage only. Returns ------- Database A cpp database instance. """ # Worker always has watchdog enabled to determine when master connection # has failed if not watchdog: print(('Forcing worker to enable watchdog. Watchdog must be enabled to ' 'detect if worker has disconnected from master.'), file=sys.stderr) watchdog = True config = config or Config(config_path) port = port or config.worker_port # Load all protobuf types db = db or bindings.Database( config.storage_config, config.db_path, master_address) machine_params = machine_params or bindings.default_machine_params() result = bindings.start_worker(db, machine_params, str(port), SCRIPT_DIR, watchdog) if not result.success(): raise ScannerException('Failed to start worker: {}'.format( result.msg())) if block: bindings.wait_for_server_shutdown(db) return result
def _batch_load_column(arg): (tables, column, callback) = arg sc = Client(start_cluster=False, enable_watchdog=False) for t in tables: callback(t, list(sc.table(t).column(column).load(workers=1)))