Source code for punchpipe.speedster
import os
import time
import argparse
import warnings
import traceback
import multiprocessing
from datetime import datetime
from collections import defaultdict
import yaml
from prefect.logging import disable_run_logger
from sqlalchemy import update
from tqdm.auto import tqdm
from yaml.loader import FullLoader
from punchpipe.cli import find_flow
from punchpipe.control.db import Flow
from punchpipe.control.util import get_database_session
[docs]
def load_pipeline_configuration(path: str = None) -> dict:
with open(path) as f:
config = yaml.load(f, Loader=FullLoader)
# TODO: add validation
return config
[docs]
def load_enabled_flows(pipeline_config):
enabled_flows = []
for flow_type in pipeline_config["flows"]:
if pipeline_config["flows"][flow_type].get("enabled", True) == "speedy":
enabled_flows.append(flow_type)
return enabled_flows
[docs]
def gather_planned_flows(session, enabled_flows, max_n=None):
flows = (session.query(Flow)
.where(Flow.state == "planned")
.where(Flow.flow_type.in_(enabled_flows))
.order_by(Flow.is_backprocessing.asc(), Flow.priority.desc(), Flow.creation_time.asc())
.limit(max_n).all())
count_per_type = defaultdict(lambda: 0)
flow_ids = []
types = []
for flow in flows:
types.append(flow.flow_type)
count_per_type[flow.flow_type] += 1
flow_ids.append(flow.flow_id)
return flow_ids, types, count_per_type
[docs]
def worker_init(config_path):
global session, flow_type_to_runner, path_to_config
with disable_run_logger(), warnings.catch_warnings():
# Otherwise warning spam will hide any progress messages
warnings.simplefilter('ignore')
session = get_database_session()
flow_type_to_runner = dict()
path_to_config = config_path
[docs]
def worker_run_flow(inputs):
flow_id, flow_type, delay = inputs
global flow_type_to_runner, session, path_to_config
if flow_type not in flow_type_to_runner:
runner = find_flow(flow_type + "_process_flow").fn
flow_type_to_runner[flow_type] = runner
else:
runner = flow_type_to_runner[flow_type]
session.execute(update(Flow).where(Flow.flow_id == flow_id).values(
state='launched', flow_run_name='speedster', launch_time=datetime.now()))
with disable_run_logger(), warnings.catch_warnings():
# Otherwise warning spam will hide any progress messages
warnings.simplefilter('ignore')
try:
time.sleep(delay)
runner(flow_id, path_to_config, session)
except KeyboardInterrupt:
session.execute(
update(Flow).where(Flow.flow_id == flow_id).values(state='revivable'))
session.commit()
print(f"Keyboard interrupt in flow {flow_id}; marked as revivable")
except: # noqa: E722
print(f"Exception in flow {flow_id}")
traceback.print_exc()
if __name__ == "__main__":
multiprocessing.set_start_method('forkserver')
parser = argparse.ArgumentParser(prog='speedster')
parser.add_argument("config", type=str, help="Path to config.")
parser.add_argument("-f", "--flows-per-batch", type=int, help="Max number of flows per batch.")
parser.add_argument("-b", "--n-batches", type=int, help="Number of batches.")
parser.add_argument("-w", "--n-workers", type=int, help="Number of workers")
args = parser.parse_args()
config_path = args.config
pipeline_config = load_pipeline_configuration(config_path)
enabled_flows = load_enabled_flows(pipeline_config)
session = get_database_session(engine_kwargs=dict(isolation_level="READ COMMITTED"))
if args.n_workers is None:
args.n_workers = os.cpu_count()
if args.flows_per_batch is None:
n_cores = args.n_workers
else:
n_cores = min(args.n_workers, args.flows_per_batch)
n_batches_run = 0
with multiprocessing.Pool(n_cores, initializer=worker_init, initargs=(config_path,)) as p:
print("Beginning fetch-run loop; press Ctrl-C to exit and allow time for cleanup")
if args.flows_per_batch:
print(f"Will cap at {args.flows_per_batch} flows per batch")
if args.n_batches:
print(f"Will stop after {args.n_batches} batches")
while True:
batch_of_flows, batch_types, count_per_type = gather_planned_flows(
session, enabled_flows, args.flows_per_batch)
if len(batch_of_flows) == 0:
print("No pending flows found---will wait two minutes and try again")
try:
time.sleep(60*2)
except KeyboardInterrupt:
break
else:
print("Batch contents: ", end='')
count_report = []
for type in sorted(count_per_type.keys()):
print(f"{count_per_type[type]} of {type}, ", end='')
print()
with tqdm(total=len(batch_of_flows)) as pbar:
# Stagger the launches which may give less DB and IO contention
delays = [i / 6 if i < n_cores else 0 for i in range(len(batch_of_flows))]
try:
for _ in p.imap_unordered(worker_run_flow, zip(batch_of_flows, batch_types, delays)):
pbar.update()
except KeyboardInterrupt:
print("Halting")
break
n_batches_run += 1
if args.n_batches and n_batches_run >= args.n_batches:
break