Source code for punchpipe.speedster

import os
import time
import argparse
import warnings
import traceback
import multiprocessing
from datetime import datetime
from collections import defaultdict

import yaml
from prefect.logging import disable_run_logger
from sqlalchemy import update
from tqdm.auto import tqdm
from yaml.loader import FullLoader

from punchpipe.cli import find_flow
from punchpipe.control.db import Flow
from punchpipe.control.util import get_database_session


[docs] def load_pipeline_configuration(path: str = None) -> dict: with open(path) as f: config = yaml.load(f, Loader=FullLoader) # TODO: add validation return config
[docs] def load_enabled_flows(pipeline_config): enabled_flows = [] for flow_type in pipeline_config["flows"]: if pipeline_config["flows"][flow_type].get("enabled", True) == "speedy": enabled_flows.append(flow_type) return enabled_flows
[docs] def gather_planned_flows(session, enabled_flows, max_n=None): flows = (session.query(Flow) .where(Flow.state == "planned") .where(Flow.flow_type.in_(enabled_flows)) .order_by(Flow.is_backprocessing.asc(), Flow.priority.desc(), Flow.creation_time.asc()) .limit(max_n).all()) count_per_type = defaultdict(lambda: 0) flow_ids = [] types = [] for flow in flows: types.append(flow.flow_type) count_per_type[flow.flow_type] += 1 flow_ids.append(flow.flow_id) return flow_ids, types, count_per_type
[docs] def worker_init(config_path): global session, flow_type_to_runner, path_to_config with disable_run_logger(), warnings.catch_warnings(): # Otherwise warning spam will hide any progress messages warnings.simplefilter('ignore') session = get_database_session() flow_type_to_runner = dict() path_to_config = config_path
[docs] def worker_run_flow(inputs): flow_id, flow_type, delay = inputs global flow_type_to_runner, session, path_to_config if flow_type not in flow_type_to_runner: runner = find_flow(flow_type + "_process_flow").fn flow_type_to_runner[flow_type] = runner else: runner = flow_type_to_runner[flow_type] session.execute(update(Flow).where(Flow.flow_id == flow_id).values( state='launched', flow_run_name='speedster', launch_time=datetime.now())) with disable_run_logger(), warnings.catch_warnings(): # Otherwise warning spam will hide any progress messages warnings.simplefilter('ignore') try: time.sleep(delay) runner(flow_id, path_to_config, session) except KeyboardInterrupt: session.execute( update(Flow).where(Flow.flow_id == flow_id).values(state='revivable')) session.commit() print(f"Keyboard interrupt in flow {flow_id}; marked as revivable") except: # noqa: E722 print(f"Exception in flow {flow_id}") traceback.print_exc()
if __name__ == "__main__": multiprocessing.set_start_method('forkserver') parser = argparse.ArgumentParser(prog='speedster') parser.add_argument("config", type=str, help="Path to config.") parser.add_argument("-f", "--flows-per-batch", type=int, help="Max number of flows per batch.") parser.add_argument("-b", "--n-batches", type=int, help="Number of batches.") parser.add_argument("-w", "--n-workers", type=int, help="Number of workers") args = parser.parse_args() config_path = args.config pipeline_config = load_pipeline_configuration(config_path) enabled_flows = load_enabled_flows(pipeline_config) session = get_database_session(engine_kwargs=dict(isolation_level="READ COMMITTED")) if args.n_workers is None: args.n_workers = os.cpu_count() if args.flows_per_batch is None: n_cores = args.n_workers else: n_cores = min(args.n_workers, args.flows_per_batch) n_batches_run = 0 with multiprocessing.Pool(n_cores, initializer=worker_init, initargs=(config_path,)) as p: print("Beginning fetch-run loop; press Ctrl-C to exit and allow time for cleanup") if args.flows_per_batch: print(f"Will cap at {args.flows_per_batch} flows per batch") if args.n_batches: print(f"Will stop after {args.n_batches} batches") while True: batch_of_flows, batch_types, count_per_type = gather_planned_flows( session, enabled_flows, args.flows_per_batch) if len(batch_of_flows) == 0: print("No pending flows found---will wait two minutes and try again") try: time.sleep(60*2) except KeyboardInterrupt: break else: print("Batch contents: ", end='') count_report = [] for type in sorted(count_per_type.keys()): print(f"{count_per_type[type]} of {type}, ", end='') print() with tqdm(total=len(batch_of_flows)) as pbar: # Stagger the launches which may give less DB and IO contention delays = [i / 6 if i < n_cores else 0 for i in range(len(batch_of_flows))] try: for _ in p.imap_unordered(worker_run_flow, zip(batch_of_flows, batch_types, delays)): pbar.update() except KeyboardInterrupt: print("Halting") break n_batches_run += 1 if args.n_batches and n_batches_run >= args.n_batches: break