diff --git a/micall/__main__.py b/micall/__main__.py index 275a69a2a..3fc03b59e 100644 --- a/micall/__main__.py +++ b/micall/__main__.py @@ -101,6 +101,7 @@ "micall/utils/check_sample_sheet.py", "micall/utils/cache.py", "micall/monitor/run_completion_watcher.py", + "micall/utils/kive_collate.py", ] diff --git a/micall/monitor/kive_watcher.py b/micall/monitor/kive_watcher.py index edc74f82c..1bb027bda 100644 --- a/micall/monitor/kive_watcher.py +++ b/micall/monitor/kive_watcher.py @@ -11,7 +11,7 @@ from itertools import count from pathlib import Path from queue import Full, Queue -from typing import IO, Callable, Mapping, Optional, Sequence, Iterable, TextIO, TypeVar +from typing import IO, Callable, Mapping, Optional, Sequence, Iterable, TypeVar from io import StringIO, BytesIO from time import sleep @@ -97,6 +97,8 @@ urllib3.connectionpool.log.setLevel(logging.ERROR) T = TypeVar('T') +COLLATION_INPUT_NAME = 'inputs' +COLLATION_OUTPUT_NAME = 'output' def open_kive(server_url: str) -> KiveAPI: session = KiveAPI(server_url) @@ -309,10 +311,6 @@ def trim_run_name(run_name: str) -> str: return run_name -def get_output_filename(output_name: str) -> str: - return '.'.join(output_name.rsplit('_', 1)) - - def wait_for_retry(attempt_count: int, start_time: datetime) -> None: """Wait with exponential backoff, logging warnings after 1 hour, info messages before.""" delay = calculate_retry_wait(MINIMUM_RETRY_WAIT, @@ -337,20 +335,6 @@ def calculate_retry_wait(min_wait: timedelta, max_wait: timedelta, attempt_count return timedelta(seconds=seconds) -def get_scratch_path(results_path: Path, pipeline_group: PipelineType) -> Path: - if pipeline_group == PipelineType.MAIN: - scratch_name = "scratch" - elif pipeline_group == PipelineType.DENOVO_MAIN: - scratch_name = "scratch_denovo" - elif pipeline_group == PipelineType.PROVIRAL: - scratch_name = "scratch_proviral" - else: - assert pipeline_group == PipelineType.MIXED_HCV_MAIN - scratch_name = "scratch_mixed_hcv" - scratch_path = results_path / scratch_name - return scratch_path - - def get_collated_path(results_path: Path, pipeline_group: PipelineType) -> Path: if pipeline_group == PipelineType.MAIN: target_path = results_path @@ -380,8 +364,10 @@ def __init__(self, self.folder_watchers: dict[str, FolderWatcher] = {} # {base_calls_folder: FolderWatcher} self.app_urls: dict[str | int, str] = {} # {app_id: app_url} self.app_args: dict[str | int, dict[str, str]] = {} # {app_id: {arg_name: arg_url}} + self.app_argument_lists: dict[str | int, list[dict[str, object]]] = {} # {app_id: [argument_dict]} self.external_directory_path: Optional[Path] = None self.external_directory_name: Optional[str] = None + self.collation_runs: dict[tuple[str, PipelineType], Run] = {} def is_full(self) -> bool: if self.config is None: @@ -409,9 +395,50 @@ def get_kive_arguments(self, app_id: str | int) -> dict[str, str]: for argument in arguments if argument['type'] == 'I'} self.app_args[app_id] = kive_app + self.app_argument_lists[app_id] = list(arguments) self.app_urls[app_id] = arguments[0]['app'] return kive_app + def validate_collation_app_signature(self, app_id: int) -> None: + """Ensure the collation app has the expected optional-multiple input signature. + + Expected KIVE signature: + - Input argument named 'inputs' with type='I', position=None, allow_multiple=True + - Output argument named 'output' with type='O' + """ + # Populate cached argument metadata when possible. + self.get_kive_arguments(app_id) + arguments = self.app_argument_lists.get(app_id) + if not arguments: + # Tests can pre-seed app_args/app_urls without argument metadata. + logger.debug('Skipping collation app signature validation for app %r: argument metadata unavailable.', app_id) + return + + input_argument = next( + (arg for arg in arguments + if arg.get('name') == COLLATION_INPUT_NAME and arg.get('type') == 'I'), + None) + if input_argument is None: + raise RuntimeError( + f'Collation app {app_id} is missing required input argument {COLLATION_INPUT_NAME!r}.') + + is_optional_multiple = (input_argument.get('position') is None and + bool(input_argument.get('allow_multiple'))) + if not is_optional_multiple: + raise RuntimeError( + f'Collation app {app_id} argument {COLLATION_INPUT_NAME!r} must be optional multiple ' + f'(formatted as --{COLLATION_INPUT_NAME}* in KIVE_INPUTS). ' + f'Got position={input_argument.get("position")!r}, ' + f'allow_multiple={input_argument.get("allow_multiple")!r}.') + + output_argument = next( + (arg for arg in arguments + if arg.get('name') == COLLATION_OUTPUT_NAME and arg.get('type') == 'O'), + None) + if output_argument is None: + raise RuntimeError( + f'Collation app {app_id} is missing required output argument {COLLATION_OUTPUT_NAME!r}.') + def get_kive_container_name(self, app_id: str | int) -> str: """ Get the container name for a container app. """ self.check_session() @@ -651,247 +678,206 @@ def check_completed_folders(self) -> None: for pipeline_group in list(folder_watcher.active_pipeline_groups): if not folder_watcher.is_pipeline_group_finished(pipeline_group): continue - results_path = self.collate_folder(folder_watcher, - pipeline_group) - folder_watcher.active_pipeline_groups.remove(pipeline_group) - if results_path is not None: - if not folder_watcher.active_pipeline_groups: - disk_operations.touch(results_path / "done_all_processing") - self.folder_watchers.pop(folder) - if not self.folder_watchers: - logger.info('No more folders to process.') - - def collate_folder(self, folder_watcher: FolderWatcher, pipeline_group: PipelineType) -> Optional[Path]: - """ Collate scratch files for a run folder. - - :param FolderWatcher folder_watcher: holds details about the run folder - :param PipelineType pipeline_group: the group of runs to collate - """ - results_path = self.get_results_path(folder_watcher) - - error_message = None - if folder_watcher.is_folder_failed: - error_message = 'Filter quality failed in Kive.' - else: - failed_sample_names = [ - sample_watcher.sample_group.enum - for sample_watcher in folder_watcher.sample_watchers - if sample_watcher.is_failed] - if failed_sample_names: - error_message = 'Samples failed in Kive: {}.'.format( - ', '.join(failed_sample_names)) - if error_message is not None: - run_path = (results_path / "../..").resolve() - disk_operations.write_text(run_path / 'errorprocessing', error_message + '\n') - logger.error('Error in folder %s: %s', run_path, error_message) - return None - if pipeline_group == PipelineType.FILTER_QUALITY: - return results_path - scratch_path = get_scratch_path(results_path, pipeline_group) - target_path = get_collated_path(results_path, pipeline_group) - logger.info('Collating results in %s', target_path) - self.copy_outputs(folder_watcher, scratch_path, target_path) - disk_operations.rmtree(scratch_path) - disk_operations.touch(target_path / 'doneprocessing') - return results_path - - def copy_outputs(self, - folder_watcher: FolderWatcher, - scratch_path: Path, - results_path: Path) -> None: - disk_operations.mkdir_p(results_path, exist_ok=True) - for output_name in DOWNLOADED_RESULTS: - if output_name == 'coverage_maps_tar': - self.extract_coverage_maps(folder_watcher, - scratch_path, - results_path) - continue - if output_name.endswith('_tar'): - self.extract_archive(folder_watcher, - scratch_path, - results_path, - output_name) - continue - if output_name == 'alignment_svg': - self.move_alignment_plot(folder_watcher, - '.svg', - scratch_path, - results_path) - continue - if output_name == 'alignment_png': - self.move_alignment_plot(folder_watcher, - '.png', - scratch_path, - results_path) - continue - if output_name == 'genome_coverage_svg': - self.move_genome_coverage(folder_watcher, - scratch_path, - results_path) - continue - if output_name == 'stitcher_plot_svg': - self.move_stitcher_plot(folder_watcher, - scratch_path, - results_path) - continue - source_count = 0 - filename = get_output_filename(output_name) - target_path = results_path / filename - with disk_operations.disk_file_operation(target_path, 'w') as target: - for sample_name in folder_watcher.all_samples: - sample_name = trim_name(sample_name) - source_path = scratch_path / sample_name / filename - if not source_path.exists(): - logger.debug('Source file %s does not exist, skipping.', - source_path) - continue + if pipeline_group == PipelineType.FILTER_QUALITY: + folder_watcher.active_pipeline_groups.remove(pipeline_group) + if folder_watcher.is_folder_failed: + disk_operations.write_text( + folder_watcher.run_folder / 'errorprocessing', + 'Filter quality failed in Kive.\n') + self._remove_folder_state(folder) + else: + self._mark_pipeline_group_complete(folder, + folder_watcher, + pipeline_group, + self.get_results_path(folder_watcher)) + continue - with disk_operations.disk_file_operation(source_path, 'r') as source: - if output_name.endswith('_fasta'): - source_count += self.extract_fasta(source, target, sample_name) - else: - source_count += self.extract_csv(source, - target, - sample_name, - source_count) + failed_samples = sorted( + sw.sample_group.enum + for sw in folder_watcher.sample_watchers + if sw.is_failed) + if failed_samples: + disk_operations.write_text( + folder_watcher.run_folder / 'errorprocessing', + f'Samples failed in Kive: {", ".join(failed_samples)}.\n') + self._remove_folder_state(folder) + break - if not source_count: - disk_operations.unlink(target_path) + collation_key = (folder, pipeline_group) + collation_run = self.collation_runs.get(collation_key) + + if collation_run is None: + collation_run = self.run_collation_pipeline(folder_watcher, + pipeline_group) + if collation_run is None: + raise RuntimeError( + f'Failed to launch Kive collation for {folder_watcher.run_name} ' + f'pipeline group {pipeline_group.name}.') + self.collation_runs[collation_key] = collation_run + continue - @staticmethod - def extract_csv(source: TextIO, target: TextIO, sample_name: str, source_count: int) -> int: - reader = DictReader(source) - fieldnames = reader.fieldnames - if fieldnames is None: - # Empty file, nothing to copy. - return 0 - fieldnames = list(fieldnames) - has_sample = 'sample' in fieldnames - if not has_sample: - fieldnames.insert(0, 'sample') - writer = DictWriter(target, fieldnames, lineterminator=os.linesep) - if source_count == 0: - # First source file, copy header. - writer.writeheader() - for row in reader: - if not has_sample: - row['sample'] = sample_name - writer.writerow(row) - return 1 + collation_result = self.poll_collation_run(folder_watcher, + pipeline_group, + collation_run) + if collation_result is None: + continue + self.collation_runs.pop(collation_key, None) + results_path, is_successful = collation_result - @staticmethod - def extract_fasta(source: TextIO, target: TextIO, sample_name: str) -> int: - for line in source: - if line.startswith('>'): - target.write(f'>{sample_name},{line[1:]}') - else: - target.write(line) - return 1 + if not is_successful: + self._remove_folder_state(folder) + break - @staticmethod - def extract_coverage_maps(folder_watcher: FolderWatcher, scratch_path: Path, results_path: Path) -> None: - coverage_path: Path = results_path / "coverage_maps" - disk_operations.mkdir_p(coverage_path, exist_ok=True) - for sample_name in folder_watcher.all_samples: - sample_name = trim_name(sample_name) - source_path = scratch_path / sample_name / 'coverage_maps.tar' - try: - with tarfile.open(source_path) as f: - for source_info in f: - filename = os.path.basename(source_info.name) - target_path = coverage_path / (sample_name + '.' + filename) - source = f.extractfile(source_info) - if source is None: - raise RuntimeError(f"Failed to extract {source_info.name} from {source_path}") - with source, open(target_path, 'wb') as target: - shutil.copyfileobj(source, target) - except FileNotFoundError: - pass - disk_operations.remove_empty_directory(coverage_path) + folder_watcher.active_pipeline_groups.remove(pipeline_group) + self._mark_pipeline_group_complete(folder, + folder_watcher, + pipeline_group, + results_path) + + def _remove_folder_state(self, folder: str) -> None: + self.folder_watchers.pop(folder, None) + self.loaded_folders.discard(folder) + for key in list(self.collation_runs): + if key[0] == folder: + self.collation_runs.pop(key, None) + if not self.folder_watchers: + logger.info('No more folders to process.') + + def _mark_pipeline_group_complete(self, + folder: str, + folder_watcher: FolderWatcher, + pipeline_group: PipelineType, + results_path: Optional[Path]) -> None: + if results_path is not None: + if not folder_watcher.active_pipeline_groups: + disk_operations.touch(results_path / "done_all_processing") + self._remove_folder_state(folder) + return + if not self.folder_watchers: + logger.info('No more folders to process.') + + def run_collation_pipeline(self, + folder_watcher: FolderWatcher, + pipeline_group: PipelineType) -> Optional[Run]: + if self.config.micall_collation_pipeline_id is None: + raise RuntimeError('Collation pipeline ID is not configured.') + + self.validate_collation_app_signature(self.config.micall_collation_pipeline_id) + + input_datasets: list[RunDataset] = [] + manifest_rows: list[dict[str, str]] = [] + for sample_watcher in folder_watcher.sample_watchers: + for pipeline_type, run in sample_watcher.runs.items(): + if PIPELINE_GROUPS[pipeline_type] != pipeline_group: + continue + sample_name = self.get_sample_name_for_pipeline(sample_watcher, + pipeline_type) + if sample_name is None: + continue + for run_dataset in run.get('datasets', []): + if run_dataset['argument_type'] != 'O': + continue + output_name = run_dataset['argument_name'] + if output_name not in DOWNLOADED_RESULTS: + continue + dataset_url: str = run_dataset['dataset'] + dataset: RunDataset = self.kive_retry( + lambda: self.session.get(dataset_url).json()) + input_index = len(input_datasets) + input_datasets.append(dataset) + manifest_rows.append(dict(index=str(input_index), + sample=trim_name(sample_name), + output_name=output_name)) + + manifest_text = StringIO() + writer = DictWriter(manifest_text, + ['index', 'sample', 'output_name'], + lineterminator='\n') + writer.writeheader() + writer.writerows(manifest_rows) + manifest_dataset = self.find_or_upload_dataset( + BytesIO(manifest_text.getvalue().encode('utf8')), + f'{folder_watcher.run_name}_{pipeline_group.name.lower()}_collation_metadata.csv', + f'Collation metadata for {folder_watcher.run_name} {pipeline_group.name}.') + if manifest_dataset is None: + return None - @staticmethod - def extract_archive(folder_watcher: FolderWatcher, - scratch_path: Path, - results_path: Path, - output_name: str) -> None: - """ Extract contents of tar files. - - There will be a folder named after the output name, with a subfolder - for each sample. - :param folder_watcher: holds a list of all samples to extract from - :param scratch_path: parent folder of all sample working files - :param results_path: parent folder to extract into - :param output_name: the name of tar files to look for with "_tar" - instead of ".tar". - """ - assert output_name.endswith('_tar'), output_name - archive_name = output_name[:-4] - output_path: Path = results_path / archive_name - disk_operations.mkdir_p(output_path, exist_ok=True) - for sample_name in folder_watcher.all_samples: - sample_name = trim_name(sample_name) - source_path = scratch_path / sample_name / (archive_name + '.tar') - try: - with tarfile.open(source_path) as f: - sample_target_path = output_path / sample_name - disk_operations.mkdir_p(sample_target_path, exist_ok=True) - for source_info in f: - filename = os.path.basename(source_info.name) - target_path = sample_target_path / filename - assert not target_path.exists(), target_path - source = f.extractfile(source_info) - if source is None: - raise RuntimeError(f"Failed to extract {source_info.name} from {source_path}") - with source, open(target_path, 'wb') as target: - shutil.copyfileobj(source, target) - disk_operations.remove_empty_directory(sample_target_path) - except FileNotFoundError: - pass - disk_operations.remove_empty_directory(output_path) + return self.find_or_launch_run( + self.config.micall_collation_pipeline_id, + { + COLLATION_INPUT_NAME: [*input_datasets, manifest_dataset], + "some_other_input": manifest_dataset, # Dummy input to ensure manifest is included in run datasets for easier debugging. Not used by the app. + }, + f'MiCall collation {pipeline_group.name.lower()} on {folder_watcher.run_name}', + folder_watcher.batch) @staticmethod - def move_alignment_plot(folder_watcher: FolderWatcher, - extension: str, - scratch_path: Path, - results_path: Path) -> None: - alignment_path: Path = results_path / "alignment" - disk_operations.mkdir_p(alignment_path, exist_ok=True) - for sample_name in folder_watcher.all_samples: - sample_name = trim_name(sample_name) - source_path = scratch_path / sample_name / f'alignment{extension}' - target_path = alignment_path / f"{sample_name}_alignment{extension}" - if source_path.exists(): - disk_operations.rename(source_path, target_path) - disk_operations.remove_empty_directory(alignment_path) + def get_sample_name_for_pipeline(sample_watcher: SampleWatcher, + pipeline_type: PipelineType) -> Optional[str]: + if pipeline_type in (PipelineType.MIDI, + PipelineType.MIXED_HCV_MIDI, + PipelineType.DENOVO_MIDI): + return sample_watcher.sample_group.names[1] + return sample_watcher.sample_group.names[0] @staticmethod - def move_genome_coverage(folder_watcher: FolderWatcher, scratch_path: Path, results_path: Path) -> None: - plots_path = results_path / "genome_coverage" + def move_stitcher_plot(folder_watcher: FolderWatcher, + scratch_path: Path, + results_path: Path) -> None: + """Move stitcher plot SVG files from per-sample scratch folders to results.""" + plots_path = results_path / 'stitcher_plots' disk_operations.mkdir_p(plots_path, exist_ok=True) - for sample_name in folder_watcher.all_samples: - sample_name = trim_name(sample_name) - source_path = scratch_path / sample_name / 'genome_coverage.svg' - target_path = plots_path / f"{sample_name}_genome_coverage.svg" - if source_path.exists(): - disk_operations.rename(source_path, target_path) - concordance_path = scratch_path / sample_name / 'genome_concordance.svg' - target_concordance_path = plots_path / f"{sample_name}_genome_concordance.svg" - if concordance_path.exists(): - disk_operations.rename(concordance_path, target_concordance_path) - disk_operations.remove_empty_directory(plots_path) - @staticmethod - def move_stitcher_plot(folder_watcher: FolderWatcher, scratch_path: Path, results_path: Path) -> None: - plots_path = results_path / "stitcher_plots" - disk_operations.mkdir_p(plots_path, exist_ok=True) - for sample_name in folder_watcher.all_samples: - sample_name = trim_name(sample_name) + sample_names = {trim_name(sample_name) for sample_name in folder_watcher.all_samples} + for sample_name in sample_names: source_path = scratch_path / sample_name / 'stitcher_plot.svg' - target_path = plots_path / f"{sample_name}_stitcher_plot.svg" + target_path = plots_path / f'{sample_name}_stitcher_plot.svg' if source_path.exists(): disk_operations.rename(source_path, target_path) + disk_operations.remove_empty_directory(plots_path) + def poll_collation_run(self, + folder_watcher: FolderWatcher, + pipeline_group: PipelineType, + run: Optional[Run]) -> Optional[tuple[Path, bool]]: + if run is None: + return None + run_status: Run = self.kive_retry(lambda: self.session.endpoints.containerruns.get(run['id'])) + state = run_status['state'] + if state in ('N', 'L', 'R'): + return None + if state in ('X', 'F'): + results_path = self.get_results_path(folder_watcher) + run_path = (results_path / "../..").resolve() + disk_operations.write_text( + run_path / 'errorprocessing', + f'Kive collation failed for {pipeline_group.name} (state: {state}).\n') + return results_path, False + + run_datasets: Sequence[RunDataset] = self.kive_retry( + lambda: self.session.endpoints.containerruns.get(f"{run['id']}/dataset_list/")) + matches = [ + run_dataset + for run_dataset in run_datasets + if run_dataset['argument_name'] == COLLATION_OUTPUT_NAME and run_dataset['argument_type'] == 'O' + ] + if not matches: + raise RuntimeError('Collation run completed without collated output dataset.') + + results_path = self.get_results_path(folder_watcher) + target_path = get_collated_path(results_path, pipeline_group) + disk_operations.mkdir_p(target_path, exist_ok=True) + collated_tar_path = target_path / 'collated_results.tar' + + dataset_url = matches[0]['dataset'] + self.kive_retry(lambda: self.download_file(dataset_url + 'download/', collated_tar_path)) + with tarfile.open(collated_tar_path) as tar_file: + tar_file.extractall(target_path, filter='data') + disk_operations.unlink(collated_tar_path, missing_ok=True) + disk_operations.touch(target_path / 'doneprocessing') + return results_path, True + def run_filter_quality_pipeline(self, folder_watcher: FolderWatcher) -> Optional[Run]: if folder_watcher.quality_dataset is None: raise RuntimeError('Quality dataset not available') @@ -1180,7 +1166,7 @@ def run_proviral_pipeline(self, sample_watcher: SampleWatcher, folder_watcher: F def find_or_launch_run(self, pipeline_id: int, - inputs: Mapping[str, RunDataset], + inputs: Mapping[str, RunDataset | Sequence[RunDataset]], run_name: str, run_batch: Optional[Batch]) -> Optional[Run]: """ Look for a matching container run, or start a new one. @@ -1192,8 +1178,13 @@ def find_or_launch_run(self, run_name = trim_run_name(run_name) filters = ['name', run_name, 'app_id', pipeline_id, 'states', 'NLRSC'] for arg in inputs.values(): - filters.append('input_id') - filters.append(arg['id']) + if isinstance(arg, dict): + filters.append('input_id') + filters.append(arg['id']) + else: + for item in arg: + filters.append('input_id') + filters.append(item['id']) old_runs = self.session.endpoints.containerruns.filter(*filters) untyped_run = self.find_name_and_permissions_match(old_runs, run_name, 'container run') @@ -1209,9 +1200,16 @@ def find_or_launch_run(self, run = None if run is None: try: - run_datasets = [dict(argument=app_arg, - dataset=inputs[name]['url']) - for name, app_arg in app_args.items()] + run_dataset_specs: list[dict[str, str]] = [] + for name, app_arg in app_args.items(): + input_value = inputs[name] + if isinstance(input_value, dict): + run_dataset_specs.append( + dict(argument=app_arg, dataset=input_value['url'])) + else: + run_dataset_specs.extend( + dict(argument=app_arg, dataset=dataset['url']) + for dataset in input_value) except KeyError as e: raise ValueError(f"Pipeline input error: {repr(e)}." f" The specified app with id {pipeline_id} appears to expect a different set of inputs." @@ -1222,7 +1220,7 @@ def find_or_launch_run(self, batch=run_batch['url'], groups_allowed=ALLOWED_GROUPS, app=app_url, - datasets=run_datasets) + datasets=run_dataset_specs) try: run = self.session.endpoints.containerruns.post(json=run_params) except Exception as ex: @@ -1272,28 +1270,6 @@ def fetch_run_status(self, run: Run, folder_watcher: FolderWatcher, pipeline_typ if other_run['id'] == run['id']: other_run['datasets'] = run_datasets break - sample_name = (sample_watcher.sample_group.names[1] - if pipeline_type in (PipelineType.MIDI, - PipelineType.MIXED_HCV_MIDI, - PipelineType.DENOVO_MIDI) - else sample_watcher.sample_group.names[0]) - results_path = self.get_results_path(folder_watcher) - pipeline_group = PIPELINE_GROUPS[pipeline_type] - scratch_path = get_scratch_path(results_path, pipeline_group) - scratch_path /= trim_name(sample_name if sample_name is not None else 'unknown') - disk_operations.mkdir_p(scratch_path, parents=True, exist_ok=True) - for output_name in DOWNLOADED_RESULTS: - matches = [run_dataset - for run_dataset in run_datasets - if (run_dataset['argument_name'] == output_name and - run_dataset['argument_type'] == 'O')] - if not matches: - continue - filename = get_output_filename(output_name) - dataset_url, = [match['dataset'] for match in matches] - self.kive_retry( - lambda: self.download_file(dataset_url + 'download/', - scratch_path / filename)) if is_complete: return None diff --git a/micall/monitor/types.py b/micall/monitor/types.py index ae9d43f1f..b32b49695 100644 --- a/micall/monitor/types.py +++ b/micall/monitor/types.py @@ -32,6 +32,7 @@ class ConfigInterface(Protocol): denovo_main_pipeline_id: Optional[int] micall_filter_quality_pipeline_id: Optional[int] micall_resistance_pipeline_id: Optional[int] + micall_collation_pipeline_id: Optional[int] proviral_pipeline_id: Optional[int] max_active: int pipeline_version: str diff --git a/micall/monitor/watcher.py b/micall/monitor/watcher.py index 689adbc98..d8849c457 100644 --- a/micall/monitor/watcher.py +++ b/micall/monitor/watcher.py @@ -39,6 +39,11 @@ def parse_args(argv=None): type=int, default=os.environ.get('MICALL_RESISTANCE_PIPELINE_ID', None), help="id of resistance pipeline's container app") + parser.add_argument( + '--micall_collation_pipeline_id', + type=int, + default=os.environ.get('MICALL_COLLATION_PIPELINE_ID', None), + help="id of collation pipeline's container app") parser.add_argument( '--mixed_hcv_pipeline_id', type=int, @@ -118,6 +123,10 @@ def parse_args(argv=None): else: parser.error(f"No arguments or environment variables set for main " f"pipeline ids ({', '.join(main_pipeline_ids)}).") + if args.micall_collation_pipeline_id is None: + parser.error("Argument --micall_collation_pipeline_id not set and " + "$MICALL_COLLATION_PIPELINE_ID environment variable " + "not set.") return args diff --git a/micall/tests/test_kive_collate.py b/micall/tests/test_kive_collate.py new file mode 100644 index 000000000..83c20d3c3 --- /dev/null +++ b/micall/tests/test_kive_collate.py @@ -0,0 +1,155 @@ +import sys +import tarfile +from pathlib import Path + +import pytest + +from micall.utils import kive_collate + + +def test_parse_args_with_optional_multiple_and_separator(monkeypatch, tmp_path): + metadata_path = tmp_path / 'metadata.csv' + metadata_path.write_text('index,sample,output_name\n') + output_path = tmp_path / 'out.tar' + monkeypatch.setattr( + sys, + 'argv', + ['kive_collate', '--inputs', 'a.csv', 'b.csv', str(metadata_path), '--', str(output_path)]) + + args = kive_collate.parse_args() + + assert args.inputs == [Path('a.csv'), Path('b.csv'), metadata_path] + assert args.output == output_path + assert not args.verbose + assert not args.debug + assert not args.quiet + + +def test_parse_args_with_debug_flag(monkeypatch, tmp_path): + metadata_path = tmp_path / 'metadata.csv' + metadata_path.write_text('index,sample,output_name\n') + output_path = tmp_path / 'out.tar' + + monkeypatch.setattr( + sys, + 'argv', + ['kive_collate', '--debug', '--inputs', 'a.csv', str(metadata_path), '--', str(output_path)]) + + args = kive_collate.parse_args() + + assert args.debug + assert not args.verbose + assert not args.quiet + + +def test_main_collates_csv_and_fasta_from_multiple_samples(monkeypatch, tmp_path): + sample1_cascade = tmp_path / 'sample1_cascade.csv' + sample1_fasta = tmp_path / 'sample1_wg.fasta' + sample2_cascade = tmp_path / 'sample2_cascade.csv' + sample2_fasta = tmp_path / 'sample2_wg.fasta' + sample1_cascade.write_text('x,y\n1,2\n') + sample2_cascade.write_text('x,y\n3,4\n') + sample1_fasta.write_text('>seed\nACTG\n') + sample2_fasta.write_text('>seed\nACTG\n') + metadata_path = tmp_path / 'metadata.csv' + metadata_path.write_text( + 'index,sample,output_name\n' + '0,E11111,cascade_csv\n' + '1,E11111,wg_fasta\n' + '2,E22222,cascade_csv\n' + '3,E22222,wg_fasta\n') + output_path = tmp_path / 'collated.tar' + + monkeypatch.setattr( + sys, + 'argv', + ['kive_collate', '--inputs', + str(sample1_cascade), + str(sample1_fasta), + str(sample2_cascade), + str(sample2_fasta), + str(metadata_path), + '--', + str(output_path)]) + + kive_collate.main() + + extract_path = tmp_path / 'extract' + extract_path.mkdir() + with tarfile.open(output_path) as output_tar: + output_tar.extractall(extract_path, filter='data') + + cascade_text = (extract_path / 'cascade.csv').read_text() + assert cascade_text == ( + 'sample,x,y\n' + 'E11111,1,2\n' + 'E22222,3,4\n' + ) + + fasta_text = (extract_path / 'wg.fasta').read_text() + assert fasta_text == ( + '>E11111,seed\n' + 'ACTG\n' + '>E22222,seed\n' + 'ACTG\n' + ) + + +def test_stage_inputs_by_sample_rejects_invalid_index(tmp_path): + metadata_path = tmp_path / 'metadata.csv' + metadata_path.write_text('index,sample,output_name\nabc,E11111,cascade_csv\n') + run_outputs = [tmp_path / 'cascade.csv'] + run_outputs[0].write_text('x,y\n1,2\n') + + with pytest.raises(ValueError, match='invalid index'): + kive_collate.stage_inputs_by_sample(run_outputs, metadata_path, tmp_path / 'scratch') + + +def test_stage_inputs_by_sample_rejects_invalid_sample_name(tmp_path): + metadata_path = tmp_path / 'metadata.csv' + metadata_path.write_text('index,sample,output_name\n0,../escape,cascade_csv\n') + run_outputs = [tmp_path / 'cascade.csv'] + run_outputs[0].write_text('x,y\n1,2\n') + + with pytest.raises(ValueError, match='invalid sample name'): + kive_collate.stage_inputs_by_sample(run_outputs, metadata_path, tmp_path / 'scratch') + + +def test_stage_inputs_by_sample_rejects_missing_required_columns(tmp_path): + metadata_path = tmp_path / 'metadata.csv' + metadata_path.write_text('index,sample\n0,E11111\n') + run_outputs = [tmp_path / 'cascade.csv'] + run_outputs[0].write_text('x,y\n1,2\n') + + with pytest.raises(ValueError, match='missing required columns'): + kive_collate.stage_inputs_by_sample(run_outputs, metadata_path, tmp_path / 'scratch') + + +def test_stage_inputs_by_sample_rejects_duplicate_output_for_sample(tmp_path): + metadata_path = tmp_path / 'metadata.csv' + metadata_path.write_text( + 'index,sample,output_name\n' + '0,E11111,cascade_csv\n' + '1,E11111,cascade_csv\n' + ) + run_outputs = [tmp_path / 'cascade1.csv', tmp_path / 'cascade2.csv'] + run_outputs[0].write_text('x,y\n1,2\n') + run_outputs[1].write_text('x,y\n3,4\n') + + with pytest.raises(ValueError, match='duplicates output'): + kive_collate.stage_inputs_by_sample(run_outputs, metadata_path, tmp_path / 'scratch') + + +def test_parse_args_with_explicit_output_separator(monkeypatch, tmp_path): + metadata_path = tmp_path / 'metadata.csv' + metadata_path.write_text('index,sample,output_name\n') + output_path = tmp_path / 'out.tar' + monkeypatch.setattr( + sys, + 'argv', + ['kive_collate', '--inputs', 'a.csv', 'b.csv', str(metadata_path), '--', str(output_path)]) + + args = kive_collate.parse_args() + + assert args.inputs == [Path('a.csv'), Path('b.csv'), metadata_path] + assert args.output == output_path diff --git a/micall/tests/test_kive_watcher.py b/micall/tests/test_kive_watcher.py index f4a9b57df..3ea3dc960 100644 --- a/micall/tests/test_kive_watcher.py +++ b/micall/tests/test_kive_watcher.py @@ -1,10 +1,7 @@ import shutil -import tarfile from gzip import GzipFile -from io import BytesIO, StringIO from pathlib import Path from queue import Full -from tarfile import TarInfo from unittest.mock import patch, ANY, Mock, call from zipfile import ZipFile import tempfile @@ -109,14 +106,6 @@ def mock_session_download_file(file, url): file.write(f'{url},{i}\n'.encode()) -def mock_session_download_fasta(file, url): - for i in range(2): - file.write(f'>{url},{i}\n'.encode('UTF8')) - for j in range(3): - file.write('ACTGTCA'[i+j:].encode()) - file.write(b'\n') - - def mock_failures(failure_count, success_callable): """ Simulate a number of failures, followed by successes. """ def mocked(*args, **kwargs): @@ -193,6 +182,7 @@ def create_mock_wait(): def create_default_config(): default_config = parse_args(argv=['--micall_filter_quality_pipeline_id', '42', '--micall_main_pipeline_id', '43', + '--micall_collation_pipeline_id', '44', '--micall_resistance_pipeline_id', '494']) yield default_config @@ -289,7 +279,8 @@ def test_filter_pipeline_not_set(capsys, monkeypatch): def test_pipeline_set(): args = parse_args(['--micall_filter_quality_pipeline_id', '402', - '--micall_main_pipeline_id', '403']) + '--micall_main_pipeline_id', '403', + '--micall_collation_pipeline_id', '404']) assert args.micall_filter_quality_pipeline_id == 402 @@ -297,6 +288,7 @@ def test_pipeline_set(): def test_pipeline_set_with_environment_variable(monkeypatch): monkeypatch.setenv('MICALL_FILTER_QUALITY_PIPELINE_ID', '99') monkeypatch.setenv('MICALL_MAIN_PIPELINE_ID', '99') + monkeypatch.setenv('MICALL_COLLATION_PIPELINE_ID', '99') args = parse_args([]) assert args.micall_filter_quality_pipeline_id == 99 @@ -2188,6 +2180,7 @@ def test_launch_resistance_run(raw_data_with_two_samples, mock_open_kive, pipeli def test_launch_proviral_run(raw_data_with_two_samples, mock_open_kive): pipelines_config = parse_args(argv=['--micall_filter_quality_pipeline_id', '42', + '--micall_collation_pipeline_id', '44', '--denovo_main_pipeline_id', '43', '--proviral_pipeline_id', '145']) @@ -2278,6 +2271,7 @@ def get_container_app(path): def test_proviral_pipeline_chains_versions(raw_data_with_two_samples, mock_open_kive): """Test that proviral pipeline chains filter_quality and denovo versions.""" pipelines_config = parse_args(argv=['--micall_filter_quality_pipeline_id', '42', + '--micall_collation_pipeline_id', '44', '--denovo_main_pipeline_id', '43', '--proviral_pipeline_id', '145']) @@ -2421,13 +2415,13 @@ def test_skip_resistance_run(raw_data_with_two_samples, mock_open_kive, pipeline dict(dataset='/datasets/112/', argument_type='O', argument_name='nuc_csv')]] # run datasets - mock_session.get.return_value.json.side_effect = [ - dict(url='/datasets/111/', id=111)] - kive_watcher.poll_runs() + with patch.object(kive_watcher, 'run_collation_pipeline', + return_value=dict(id='702', state='N')) as mock_collate: + kive_watcher.poll_runs() mock_session.endpoints.containerruns.post.assert_not_called() - assert kive_watcher.is_idle() + mock_collate.assert_called_once() def test_resistance_run_missing_input(raw_data_with_two_samples, @@ -2816,7 +2810,6 @@ def test_fetch_run_status_filter_quality(raw_data_with_two_runs, def test_fetch_run_status_main(raw_data_with_two_runs, mock_open_kive, pipelines_config): - mock_session = mock_open_kive.return_value base_calls = (raw_data_with_two_runs / "MiSeq/runs/140101_M01234/Data/Intensities/BaseCalls") folder_watcher = FolderWatcher(base_calls, None) @@ -2826,17 +2819,16 @@ def test_fetch_run_status_main(raw_data_with_two_runs, None), ('V3LOOP', None))) mock_run = dict(id=123) - mock_session.endpoints.containerruns.get.side_effect = [ + sample_watcher.runs[PipelineType.MAIN] = mock_run + expected_datasets = [dict(argument_name='insertions_csv', + argument_type='O', + dataset='/datasets/110/'), + dict(argument_name='nuc_csv', + argument_type='O', + dataset='/datasets/111/')] + mock_open_kive.return_value.endpoints.containerruns.get.side_effect = [ dict(state='C'), # run state refresh - [dict(argument_name='insertions_csv', - argument_type='O', - dataset='/datasets/110/'), - dict(argument_name='nuc_csv', - argument_type='O', - dataset='/datasets/111/')]] # run datasets - expected_scratch = base_calls / "../../../Results/version_0-dev/scratch" - expected_insertion_path = expected_scratch / "2000A-V3LOOP_S2/insertions.csv" - expected_nuc_path = expected_scratch / "2000A-V3LOOP_S2/nuc.csv" + expected_datasets] # run datasets kive_watcher = KiveWatcher(pipelines_config) @@ -2846,17 +2838,12 @@ def test_fetch_run_status_main(raw_data_with_two_runs, [sample_watcher]) assert new_run is None - assert expected_insertion_path.exists() - assert expected_nuc_path.exists() - assert [call(ANY, '/datasets/110/download/'), - call(ANY, '/datasets/111/download/') - ] == mock_session.download_file.call_args_list + assert mock_run.get('datasets') == expected_datasets def test_fetch_run_status_main_and_resistance(raw_data_with_two_runs, mock_open_kive, pipelines_config): - mock_session = mock_open_kive.return_value base_calls = (raw_data_with_two_runs / "MiSeq/runs/140101_M01234/Data/Intensities/BaseCalls") folder_watcher = FolderWatcher(base_calls, None) @@ -2867,18 +2854,19 @@ def test_fetch_run_status_main_and_resistance(raw_data_with_two_runs, ('V3LOOP', None))) main_run = dict(id=123) resistance_run = dict(id=124) - mock_session.endpoints.containerruns.get.side_effect = [ + sample_watcher.runs[PipelineType.MAIN] = main_run + sample_watcher.runs[PipelineType.RESISTANCE] = resistance_run + main_datasets = [dict(argument_name='nuc_csv', + argument_type='O', + dataset='/datasets/110/')] + resistance_datasets = [dict(argument_name='resistance_csv', + argument_type='O', + dataset='/datasets/112/')] + mock_open_kive.return_value.endpoints.containerruns.get.side_effect = [ dict(state='C'), # main run refresh - [dict(argument_name='nuc_csv', - argument_type='O', - dataset='/datasets/110/')], # main run datasets + main_datasets, # main run datasets dict(state='C'), # resistance run refresh - [dict(argument_name='resistance_csv', - argument_type='O', - dataset='/datasets/112/')]] # resistance run datasets - expected_scratch = base_calls / "../../../Results/version_0-dev/scratch" - expected_nuc_path = expected_scratch / "2000A-V3LOOP_S2/nuc.csv" - expected_resistance_path = expected_scratch / "2000A-V3LOOP_S2/resistance.csv" + resistance_datasets] # resistance run datasets kive_watcher = KiveWatcher(pipelines_config) @@ -2895,14 +2883,13 @@ def test_fetch_run_status_main_and_resistance(raw_data_with_two_runs, assert new_main_run is None assert new_resistance_run is None - assert expected_nuc_path.exists() - assert expected_resistance_path.exists() + assert main_run.get('datasets') == main_datasets + assert resistance_run.get('datasets') == resistance_datasets def test_fetch_run_status_main_and_midi(raw_data_with_hcv_pair, mock_open_kive, pipelines_config): - mock_session = mock_open_kive.return_value base_calls = (raw_data_with_hcv_pair / "MiSeq/runs/140101_M01234/Data/Intensities/BaseCalls") folder_watcher = FolderWatcher(base_calls, None) @@ -2913,18 +2900,19 @@ def test_fetch_run_status_main_and_midi(raw_data_with_hcv_pair, ('HCV', 'MidHCV'))) main_run = dict(id=123) midi_run = dict(id=124) - mock_session.endpoints.containerruns.get.side_effect = [ + sample_watcher.runs[PipelineType.MAIN] = main_run + sample_watcher.runs[PipelineType.MIDI] = midi_run + main_datasets = [dict(argument_name='nuc_csv', + argument_type='O', + dataset='/datasets/110/')] + midi_datasets = [dict(argument_name='nuc_csv', + argument_type='O', + dataset='/datasets/111/')] + mock_open_kive.return_value.endpoints.containerruns.get.side_effect = [ dict(state='C'), # main run refresh - [dict(argument_name='nuc_csv', - argument_type='O', - dataset='/datasets/110/')], # main outputs + main_datasets, # main outputs dict(state='C'), # midi run refresh - [dict(argument_name='nuc_csv', - argument_type='O', - dataset='/datasets/111/')]] # midi outputs - expected_scratch = base_calls / "../../../Results/version_0-dev/scratch" - expected_main_nuc_path = expected_scratch / "2130A-HCV_S15/nuc.csv" - expected_midi_nuc_path = expected_scratch / "2130AMIDI-MidHCV_S16/nuc.csv" + midi_datasets] # midi outputs kive_watcher = KiveWatcher(pipelines_config) @@ -2939,8 +2927,8 @@ def test_fetch_run_status_main_and_midi(raw_data_with_hcv_pair, assert new_main_run is None assert new_midi_run is None - assert expected_main_nuc_path.exists() - assert expected_midi_nuc_path.exists() + assert main_run.get('datasets') == main_datasets + assert midi_run.get('datasets') == midi_datasets def test_fetch_run_status_session_expired(raw_data_with_two_runs, @@ -2999,7 +2987,7 @@ def test_fetch_run_status_user_cancelled(raw_data_with_two_runs, assert new_run is not original_run -def test_folder_completed(raw_data_with_two_samples, mock_open_kive, default_config): +def test_folder_completed_triggers_collation(raw_data_with_two_samples, mock_open_kive, default_config): base_calls = (raw_data_with_two_samples / "MiSeq/runs/140101_M01234/Data/Intensities/BaseCalls") kive_watcher = create_kive_watcher_with_main_run( @@ -3036,233 +3024,12 @@ def test_folder_completed(raw_data_with_two_samples, mock_open_kive, default_con [dict(dataset='/datasets/162/', argument_type='O', argument_name='resistance_csv')]] # run datasets - results_path = base_calls / "../../../Results/version_0-dev" - scratch_path = results_path / "scratch" - expected_coverage_map_content = b'This is a coverage map.' - sample_scratch_path = scratch_path / "2110A-V3LOOP_S13" - sample_scratch_path.mkdir(parents=True) - coverage_maps_path = sample_scratch_path / "coverage_maps.tar" - with tarfile.open(coverage_maps_path, 'w') as coverage_maps_tar: - content = BytesIO(expected_coverage_map_content) - tar_info = TarInfo('coverage_maps/R1_coverage.txt') - tar_info.size = len(expected_coverage_map_content) - coverage_maps_tar.addfile(tar_info, content) - expected_coverage_map_path = ( - results_path / "coverage_maps/2110A-V3LOOP_S13.R1_coverage.txt") - expected_mutations_path = results_path / "mutations.csv" - expected_done_path = results_path / "doneprocessing" - expected_all_done_path = results_path / "done_all_processing" - expected_resistance_path = results_path / "resistance.csv" - expected_resistance_content = """\ -sample,url,n -2110A-V3LOOP_S13,/datasets/161/download/,0 -2110A-V3LOOP_S13,/datasets/161/download/,1 -2110A-V3LOOP_S13,/datasets/161/download/,2 -2120A-PR_S14,/datasets/162/download/,0 -2120A-PR_S14,/datasets/162/download/,1 -2120A-PR_S14,/datasets/162/download/,2 -""" - - kive_watcher.poll_runs() - - assert not scratch_path.exists() - assert not expected_mutations_path.exists() - assert expected_resistance_content == expected_resistance_path.read_text() - assert expected_coverage_map_content == expected_coverage_map_path.read_bytes() - assert expected_done_path.exists() - assert expected_all_done_path.exists() - assert kive_watcher.is_idle() - - -def test_folder_completed_except_denovo(raw_data_with_two_samples, mock_open_kive, default_config): - base_calls = (raw_data_with_two_samples / - "MiSeq/runs/140101_M01234/Data/Intensities/BaseCalls") - default_config.denovo_main_pipeline_id = 495 - kive_watcher = create_kive_watcher_with_main_run( - default_config, - base_calls, - SampleGroup('2110A', - ('2110A-V3LOOP_S13_L001_R1_001.fastq.gz', None), - ('V3LOOP', None)), - is_complete=True) - kive_watcher.app_urls[default_config.denovo_main_pipeline_id] = '/containerapps/105' - kive_watcher.app_args[default_config.denovo_main_pipeline_id] = dict( - bad_cycles_csv='/containerargs/113', - fastq1='/containerargs/114', - fastq2='/containerargs/115') - folder_watcher = kive_watcher.folder_watchers[base_calls] - sample1_watcher, = folder_watcher.sample_watchers - sample2_watcher = kive_watcher.add_sample_group( - base_calls=base_calls, - sample_group=SampleGroup('2120A', - ('2120A-PR_S14_L001_R1_001.fastq.gz', None), - ('PR', None))) - kive_watcher.finish_folder(base_calls) - folder_watcher.add_run(dict(id=150), - PipelineType.MAIN, - sample2_watcher, - is_complete=True) - folder_watcher.add_run(dict(id=151), - PipelineType.DENOVO_MAIN, - sample1_watcher) - folder_watcher.add_run(dict(id=152), - PipelineType.RESISTANCE, - sample1_watcher) - folder_watcher.add_run(dict(id=153), - PipelineType.RESISTANCE, - sample2_watcher) - kive_watcher.session.endpoints.containerruns.get.side_effect = [ - dict(id=151, state='C'), # refresh run state for denovo main - [dict(dataset='/datasets/161/', - argument_type='O', - argument_name='amino_csv'), - dict(dataset='/datasets/171/', - argument_type='O', - argument_name='nuc_csv')], # run datasets - dict(id=152, state='C'), # refresh run state for 2110 - [dict(dataset='/datasets/162/', - argument_type='O', - argument_name='resistance_csv')], # run datasets - dict(id=153, state='C'), # refresh run state for 2120 - [dict(dataset='/datasets/163/', - argument_type='O', - argument_name='resistance_csv')]] # run datasets - results_path = base_calls / "../../../Results/version_0-dev" - scratch_path = results_path / "scratch" - sample_scratch_path = scratch_path / "2110A-V3LOOP_S13" - sample_scratch_path.mkdir(parents=True) - denovo_scratch_path = results_path / "scratch_denovo" / "2110A-V3LOOP_S13" - expected_done_path = results_path / "doneprocessing" - expected_all_done_path = results_path / "done_all_processing" - expected_mutations_path = results_path / "mutations.csv" - expected_resistance_path = results_path / "resistance.csv" - expected_amino_path = denovo_scratch_path / "amino.csv" - - kive_watcher.poll_runs() - - assert not scratch_path.exists() - assert not expected_mutations_path.exists() - assert expected_resistance_path.exists() - assert expected_done_path.exists() - assert not expected_all_done_path.exists() - assert expected_amino_path.exists() - assert not kive_watcher.is_idle() + with patch.object(kive_watcher, 'run_collation_pipeline', + return_value=dict(id='702', state='N')) as mock_collate: + kive_watcher.poll_runs() -def test_folder_completed_with_fasta(raw_data_with_two_samples, mock_open_kive, default_config): - base_calls = (raw_data_with_two_samples / - "MiSeq/runs/140101_M01234/Data/Intensities/BaseCalls") - kive_watcher = create_kive_watcher_with_main_run( - default_config, - base_calls, - SampleGroup('2110A', - ('2110A-V3LOOP_S13_L001_R1_001.fastq.gz', None), - ('V3LOOP', None)), - is_complete=True) - mock_session = mock_open_kive.return_value - mock_session.download_file.side_effect = mock_session_download_fasta - folder_watcher = kive_watcher.folder_watchers[base_calls] - sample1_watcher, = folder_watcher.sample_watchers - sample2_watcher = kive_watcher.add_sample_group( - base_calls=base_calls, - sample_group=SampleGroup('2120A', - ('2120A-PR_S14_L001_R1_001.fastq.gz', None), - ('PR', None))) - kive_watcher.finish_folder(base_calls) - folder_watcher.add_run(dict(id=150), - PipelineType.MAIN, - sample2_watcher, - is_complete=True) - folder_watcher.add_run(dict(id=151), - PipelineType.RESISTANCE, - sample1_watcher) - folder_watcher.add_run(dict(id=152), - PipelineType.RESISTANCE, - sample2_watcher) - kive_watcher.session.endpoints.containerruns.get.side_effect = [ - dict(id=151, state='C'), # refresh run state for 2110 - [dict(dataset='/datasets/161/', - argument_type='O', - argument_name='wg_fasta')], # run datasets - dict(id=152, state='C'), # refresh run state for 2120 - [dict(dataset='/datasets/162/', - argument_type='O', - argument_name='wg_fasta')]] # run datasets - results_path = base_calls / "../../../Results/version_0-dev" - scratch_path = results_path / "scratch" - sample_scratch_path = scratch_path / "2110A-V3LOOP_S13" - sample_scratch_path.mkdir(parents=True) - expected_fasta_path = results_path / "wg.fasta" - expected_fasta_content = """\ ->2110A-V3LOOP_S13,/datasets/161/download/,0 -ACTGTCA -CTGTCA -TGTCA ->2110A-V3LOOP_S13,/datasets/161/download/,1 -CTGTCA -TGTCA -GTCA ->2120A-PR_S14,/datasets/162/download/,0 -ACTGTCA -CTGTCA -TGTCA ->2120A-PR_S14,/datasets/162/download/,1 -CTGTCA -TGTCA -GTCA -""" - - kive_watcher.poll_runs() - - assert not scratch_path.exists() - assert expected_fasta_content == expected_fasta_path.read_text() - - -def test_folder_completed_with_svg(raw_data_with_two_samples, mock_open_kive, default_config): - base_calls = (raw_data_with_two_samples / - "MiSeq/runs/140101_M01234/Data/Intensities/BaseCalls") - kive_watcher = create_kive_watcher_with_main_run( - default_config, - base_calls, - SampleGroup('2110A', - ('2110A-V3LOOP_S13_L001_R1_001.fastq.gz', None), - ('V3LOOP', None)), - is_complete=True) - folder_watcher = kive_watcher.folder_watchers[base_calls] - sample1_watcher, = folder_watcher.sample_watchers - sample2_watcher = kive_watcher.add_sample_group( - base_calls=base_calls, - sample_group=SampleGroup('2120A', - ('2120A-PR_S14_L001_R1_001.fastq.gz', None), - ('PR', None))) - kive_watcher.finish_folder(base_calls) - folder_watcher.add_run(dict(id=150), - PipelineType.MAIN, - sample2_watcher, - is_complete=True) - folder_watcher.add_run(dict(id=151), - PipelineType.RESISTANCE, - sample1_watcher) - folder_watcher.add_run(dict(id=152), - PipelineType.RESISTANCE, - sample2_watcher) - kive_watcher.session.endpoints.containerruns.get.side_effect = [ - dict(id=151, state='C'), # refresh run state for 2110 - [dict(dataset='/datasets/161/', - argument_type='O', - argument_name='alignment_svg')], # run datasets - dict(id=152, state='C'), # refresh run state for 2120 - [dict(dataset='/datasets/162/', - argument_type='O', - argument_name='alignment_svg')]] # run datasets - results_path = base_calls / "../../../Results/version_0-dev" - expected_alignment1_path = results_path / "alignment" / "2110A-V3LOOP_S13_alignment.svg" - expected_alignment2_path = results_path / "alignment" / "2120A-PR_S14_alignment.svg" - - kive_watcher.poll_runs() - - assert expected_alignment1_path.exists() - assert expected_alignment2_path.exists() + mock_collate.assert_called_once() def test_folder_not_finished(raw_data_with_two_samples, mock_open_kive, default_config): @@ -3303,12 +3070,10 @@ def test_folder_not_finished(raw_data_with_two_samples, mock_open_kive, default_ argument_type='O', argument_name='resistance_csv')]] # run datasets results_path = base_calls / "../../../Results/version_0-dev" - scratch_path = results_path / "scratch" expected_resistance_path = results_path / "resistance.csv" kive_watcher.poll_runs() - assert scratch_path.exists() assert not expected_resistance_path.exists() @@ -3342,13 +3107,11 @@ def test_folder_not_finished_before_new_start(raw_data_with_two_runs, sample1_watcher) folder_watcher2.quality_dataset = dict(url='/datasets/127/', id=127) results_path = base_calls1 / "../../../Results/version_0-dev" - scratch_path = results_path / "scratch" expected_resistance_path = results_path / "resistance.csv" kive_watcher.poll_runs() assert not expected_resistance_path.exists() - assert scratch_path.exists() def test_folder_failed_quality(raw_data_with_two_samples, mock_open_kive, default_config): @@ -3446,6 +3209,34 @@ def test_folder_failed_sample(raw_data_with_two_samples, mock_open_kive, default assert expected_error_message == expected_error_path.read_text() +def test_folder_collation_failed_does_not_mark_done_all_processing(raw_data_with_two_samples, + mock_open_kive, + default_config): + mock_session = mock_open_kive.return_value + base_calls = (raw_data_with_two_samples / + "MiSeq/runs/140101_M01234/Data/Intensities/BaseCalls") + kive_watcher = KiveWatcher(default_config) + folder_watcher = kive_watcher.add_folder(base_calls) + + folder_watcher.active_pipeline_groups.add(PipelineType.MAIN) + kive_watcher.loaded_folders.add(base_calls) + collation_key = (base_calls, PipelineType.MAIN) + kive_watcher.collation_runs[collation_key] = dict(id='702', state='N') + mock_session.endpoints.containerruns.get.return_value = dict(id='702', state='F') + + run_path = base_calls / '../../..' + expected_done_all_path = run_path / 'Results/version_0-dev/done_all_processing' + expected_error_path = run_path / 'errorprocessing' + + kive_watcher.check_completed_folders() + + assert not expected_done_all_path.exists() + assert expected_error_path.read_text() == 'Kive collation failed for MAIN (state: F).\n' + assert collation_key not in kive_watcher.collation_runs + assert base_calls not in kive_watcher.loaded_folders + assert base_calls not in kive_watcher.folder_watchers + + def test_add_duplicate_sample(raw_data_with_two_samples, mock_open_kive, default_config): @@ -3546,219 +3337,6 @@ def test_calculate_retry_wait(): attempt_count=10000) -def test_collate_main_results(raw_data_with_two_samples, default_config, mock_open_kive): - run_folder = raw_data_with_two_samples / "MiSeq/runs/140101_M01234" - base_calls = run_folder / "Data/Intensities/BaseCalls" - results_path = run_folder / "Results" - results_path.mkdir(parents=True) - version_folder: Path = results_path / 'version_0-dev' - version_folder.mkdir() - - sample1_scratch = version_folder / "scratch" / "2120A-PR_S14" - sample1_scratch.mkdir(parents=True) - (sample1_scratch / "cascade.csv").write_text("col1,col2\nval1.1,val2.1\n") - sample2_scratch = version_folder / "scratch" / "2110A-V3LOOP_S13" - sample2_scratch.mkdir(parents=True) - (sample2_scratch / "cascade.csv").write_text("col1,col2\nval1.2,val2.2\n") - - expected_cascade_path = version_folder / "cascade.csv" - expected_cascade_text = "sample,col1,col2\n2120A-PR_S14,val1.1,val2.1\n2110A-V3LOOP_S13,val1.2,val2.2\n" - expected_done_path = version_folder / "doneprocessing" - - denovo_scratch_path = version_folder / "scratch_denovo" - denovo_scratch_path.mkdir() - - kive_watcher = KiveWatcher(default_config) - folder_watcher = kive_watcher.add_folder(base_calls) - kive_watcher.add_sample_group( - base_calls, - SampleGroup('2120A', - ('2120A-PR_S14_L001_R1_001.fastq.gz', None), - ('PR', None))) - kive_watcher.add_sample_group( - base_calls, - SampleGroup('2110A', - ('2110A-V3LOOP_S13_L001_R1_001.fastq.gz', None), - ('V3LOOP', None))) - - kive_watcher.collate_folder(folder_watcher, PipelineType.MAIN) - - cascade_text = expected_cascade_path.read_text() - assert cascade_text == expected_cascade_text - assert expected_done_path.exists() - assert denovo_scratch_path.exists() - - -def test_collate_denovo_results(raw_data_with_two_samples, default_config, mock_open_kive): - run_folder = raw_data_with_two_samples / "MiSeq/runs/140101_M01234" - base_calls = run_folder / "Data/Intensities/BaseCalls" - results_path = run_folder / "Results" - results_path.mkdir(parents=True) - version_folder: Path = results_path / 'version_0-dev' - version_folder.mkdir() - - sample1_scratch = version_folder / "scratch_denovo" / "2120A-PR_S14" - sample1_scratch.mkdir(parents=True) - (sample1_scratch / "cascade.csv").write_text("col1,col2\n") - sample2_scratch = version_folder / "scratch_denovo" / "2110A-V3LOOP_S13" - sample2_scratch.mkdir(parents=True) - (sample2_scratch / "cascade.csv").write_text("col1,col2\n") - - expected_cascade_path = version_folder / "denovo" / "cascade.csv" - expected_done_path = version_folder / "denovo" / "doneprocessing" - proviral_path = version_folder / "denovo" / "detailed_results" - - main_scratch_path = version_folder / "scratch" - main_scratch_path.mkdir() - - kive_watcher = KiveWatcher(default_config) - folder_watcher = kive_watcher.add_folder(base_calls) - kive_watcher.add_sample_group( - base_calls, - SampleGroup('2120A', - ('2120A-PR_S14_L001_R1_001.fastq.gz', None), - ('PR', None))) - kive_watcher.add_sample_group( - base_calls, - SampleGroup('2110A', - ('2110A-V3LOOP_S13_L001_R1_001.fastq.gz', None), - ('V3LOOP', None))) - - kive_watcher.collate_folder(folder_watcher, PipelineType.DENOVO_MAIN) - - assert expected_cascade_path.exists() - assert expected_done_path.exists() - assert main_scratch_path.exists() - assert not proviral_path.exists() - - -def test_collate_proviral_results(raw_data_with_two_samples, default_config, mock_open_kive): - run_folder = raw_data_with_two_samples / "MiSeq/runs/140101_M01234" - base_calls = run_folder / "Data/Intensities/BaseCalls" - results_path = run_folder / "Results" - results_path.mkdir(parents=True) - version_folder: Path = results_path / 'version_0-dev' - version_folder.mkdir() - - sample1_scratch = version_folder / "scratch_proviral" / "2120A-PR_S14" - sample1_scratch.mkdir(parents=True) - (sample1_scratch / "outcome_summary.csv").write_text("col1,col2\nvalue1,value2\n") - (sample1_scratch / "table_precursor.csv").write_text("col1,col2\n") - sample2_scratch = version_folder / "scratch_proviral" / "2110A-V3LOOP_S13" - sample2_scratch.mkdir(parents=True) - (sample2_scratch / "outcome_summary.csv").write_text("col1,col2\nvalue3,value4\n") - (sample2_scratch / "table_precursor.csv").write_text("col1,col2\n") - - expected_outcome_path = version_folder / "proviral" / "outcome_summary.csv" - expected_precursor_path = version_folder / "proviral" / "table_precursor.csv" - expected_outcome_text = "sample,col1,col2\n2120A-PR_S14,value1,value2\n2110A-V3LOOP_S13,value3,value4\n" - - main_scratch_path = version_folder / "scratch" - main_scratch_path.mkdir() - - kive_watcher = KiveWatcher(default_config) - folder_watcher = kive_watcher.add_folder(base_calls) - kive_watcher.add_sample_group( - base_calls, - SampleGroup('2120A', - ('2120A-PR_S14_L001_R1_001.fastq.gz', None), - ('PR', None))) - kive_watcher.add_sample_group( - base_calls, - SampleGroup('2110A', - ('2110A-V3LOOP_S13_L001_R1_001.fastq.gz', None), - ('V3LOOP', None))) - - kive_watcher.collate_folder(folder_watcher, PipelineType.PROVIRAL) - - assert expected_outcome_path.exists() - assert expected_precursor_path.exists() - outcome_text = expected_outcome_path.read_text() - assert outcome_text == expected_outcome_text - -def test_collate_mixed_hcv_results(raw_data_with_two_samples, default_config, mock_open_kive): - run_folder = raw_data_with_two_samples / "MiSeq/runs/140101_M01234" - base_calls = run_folder / "Data/Intensities/BaseCalls" - results_path = run_folder / "Results" - results_path.mkdir(parents=True) - version_folder: Path = results_path / 'version_0-dev' - version_folder.mkdir() - - sample1_scratch = version_folder / "scratch_mixed_hcv" / "2120A-PR_S14" - sample1_scratch.mkdir(parents=True) - (sample1_scratch / "mixed_counts.csv").write_text("col1,col2\n") - sample2_scratch = version_folder / "scratch_mixed_hcv" / "2110A-V3LOOP_S13" - sample2_scratch.mkdir(parents=True) - (sample2_scratch / "mixed_counts.csv").write_text("col1,col2\n") - - expected_cascade_path = version_folder / "mixed_hcv" / "mixed_counts.csv" - expected_done_path = version_folder / "mixed_hcv" / "doneprocessing" - - kive_watcher = KiveWatcher(default_config) - folder_watcher = kive_watcher.add_folder(base_calls) - kive_watcher.add_sample_group( - base_calls, - SampleGroup('2120A', - ('2120A-PR_S14_L001_R1_001.fastq.gz', None), - ('PR', None))) - kive_watcher.add_sample_group( - base_calls, - SampleGroup('2110A', - ('2110A-V3LOOP_S13_L001_R1_001.fastq.gz', None), - ('V3LOOP', None))) - - kive_watcher.collate_folder(folder_watcher, PipelineType.MIXED_HCV_MAIN) - - assert expected_cascade_path.exists() - assert expected_done_path.exists() - - -def test_collate_csv(): - source1 = StringIO("""\ -a,b,c -1,2,3 -_,-,= -""") - source2 = StringIO("""\ -a,b,c -10,20,30 -""") - expected_target = """\ -sample,a,b,c -E12345,1,2,3 -E12345,_,-,= -E22222,10,20,30 -""" - target = StringIO() - - KiveWatcher.extract_csv(source1, target, 'E12345', source_count=0) - KiveWatcher.extract_csv(source2, target, 'E22222', source_count=1) - - assert target.getvalue() == expected_target - - -def test_collate_csv_with_sample_already_filled(): - source1 = StringIO("""\ -sample,a,b,c -E12345,1,2,3 -E12345,_,-,= -""") - source2 = StringIO("""\ -sample,a,b,c -E22222,10,20,30 -""") - expected_target = """\ -sample,a,b,c -E12345,1,2,3 -E12345,_,-,= -E22222,10,20,30 -""" - target = StringIO() - - KiveWatcher.extract_csv(source1, target, 'ignored', source_count=0) - KiveWatcher.extract_csv(source2, target, 'ignored', source_count=1) - - assert target.getvalue() == expected_target def test_launch_main_good_pipeline_id(mock_open_kive, default_config): _mock_session = mock_open_kive.return_value @@ -3991,3 +3569,66 @@ def test_rmtree_with_ignore_errors(self): # Test rmtree with ignore_errors=True disk_operations.rmtree(results_path, ignore_errors=True) assert not results_path.exists() + + +def test_run_collation_pipeline_submits_multiple_optional_inputs(raw_data_with_two_samples, + mock_open_kive, + default_config): + default_config.micall_collation_pipeline_id = 700 + base_calls = (raw_data_with_two_samples / + "MiSeq/runs/140101_M01234/Data/Intensities/BaseCalls") + kive_watcher = KiveWatcher(default_config) + folder_watcher = kive_watcher.add_folder(base_calls) + folder_watcher.batch = dict(url='/batches/701') + + kive_watcher.add_sample_group( + base_calls, + SampleGroup('2120A', + ('2120A-PR_S14_L001_R1_001.fastq.gz', None), + ('PR', None))) + kive_watcher.add_sample_group( + base_calls, + SampleGroup('2110A', + ('2110A-V3LOOP_S13_L001_R1_001.fastq.gz', None), + ('V3LOOP', None))) + + sample1_run = dict(id='801', datasets=[dict(argument_type='O', + argument_name='cascade_csv', + dataset='/datasets/710/')]) + sample2_run = dict(id='802', datasets=[dict(argument_type='O', + argument_name='cascade_csv', + dataset='/datasets/711/')]) + folder_watcher.add_run(sample1_run, + PipelineType.MAIN, + folder_watcher.sample_watchers[0], + is_complete=True) + folder_watcher.add_run(sample2_run, + PipelineType.MAIN, + folder_watcher.sample_watchers[1], + is_complete=True) + + kive_watcher.app_urls[default_config.micall_collation_pipeline_id] = '/containerapps/700' + kive_watcher.app_args[default_config.micall_collation_pipeline_id] = dict( + inputs='/containerargs/7001') + + mock_session = kive_watcher.session + mock_session.endpoints.containerruns.filter.return_value = [] + mock_session.endpoints.containerruns.post.return_value = dict(id='702', state='N') + mock_session.get.side_effect = [ + Mock(json=Mock(return_value=dict(url='/datasets/710/', id='710'))), + Mock(json=Mock(return_value=dict(url='/datasets/711/', id='711'))), + ] + + with patch.object(kive_watcher, + 'find_or_upload_dataset', + side_effect=[dict(url='/datasets/799/', id='799')]): + run = kive_watcher.run_collation_pipeline(folder_watcher, PipelineType.MAIN) + + assert run is not None + mock_session.endpoints.containerruns.post.assert_called_once() + submitted_datasets = mock_session.endpoints.containerruns.post.call_args.kwargs['json']['datasets'] + assert submitted_datasets == [ + dict(argument='/containerargs/7001', dataset='/datasets/710/'), + dict(argument='/containerargs/7001', dataset='/datasets/711/'), + dict(argument='/containerargs/7001', dataset='/datasets/799/'), + ] diff --git a/micall/utils/kive_collate.py b/micall/utils/kive_collate.py new file mode 100644 index 000000000..a8b89a8ff --- /dev/null +++ b/micall/utils/kive_collate.py @@ -0,0 +1,358 @@ +import argparse +import csv +import logging +import os +import shutil +import sys +import tarfile +import tempfile +from pathlib import Path +from typing import Sequence, TextIO + +logger = logging.getLogger(__name__) + +DOWNLOADED_RESULTS = [ + 'remap_counts_csv', + 'conseq_csv', + 'conseq_all_csv', + 'conseq_region_csv', + 'concordance_csv', + 'concordance_seed_csv', + 'insertions_csv', + 'failed_csv', + 'nuc_csv', + 'amino_csv', + 'failed_align_csv', + 'g2p_csv', + 'g2p_summary_csv', + 'coverage_scores_csv', + 'coverage_maps_tar', + 'cascade_csv', + 'mixed_counts_csv', + 'mixed_amino_csv', + 'mixed_amino_merged_csv', + 'resistance_csv', + 'mutations_csv', + 'nuc_mutations_csv', + 'resistance_fail_csv', + 'resistance_consensus_csv', + 'wg_fasta', + 'mid_fasta', + 'unstitched_cascade_csv', + 'unstitched_conseq_csv', + 'unstitched_contigs_csv', + 'contigs_csv', + 'stitcher_plot_svg', + 'alignment_svg', + 'alignment_png', + 'assembly_fasta', + 'genome_coverage_csv', + 'genome_coverage_svg', + 'genome_concordance_svg', + 'read_entropy_csv', + 'outcome_summary_csv', + 'conseqs_primers_csv', + 'contigs_primers_csv', + 'table_precursor_csv', + 'proviral_landscape_csv', + 'hivseqinr_results_tar', + 'detailed_results_tar', +] + +REQUIRED_MANIFEST_COLUMNS = {'index', 'sample', 'output_name'} + + +def get_output_filename(output_name: str) -> str: + return '.'.join(output_name.rsplit('_', 1)) + + +def extract_csv(source: TextIO, target: TextIO, sample_name: str, source_count: int) -> int: + reader = csv.DictReader(source) + fieldnames = reader.fieldnames + if fieldnames is None: + return 0 + fieldnames = list(fieldnames) + has_sample = 'sample' in fieldnames + if not has_sample: + fieldnames.insert(0, 'sample') + writer = csv.DictWriter(target, fieldnames, lineterminator='\n') + if source_count == 0: + writer.writeheader() + for row in reader: + if not has_sample: + row['sample'] = sample_name + writer.writerow(row) + return 1 + + +def extract_fasta(source: TextIO, target: TextIO, sample_name: str) -> int: + for line in source: + if line.startswith('>'): + target.write(f'>{sample_name},{line[1:]}') + else: + target.write(line) + return 1 + + +def remove_empty_directory(path: Path) -> None: + if path.exists() and not any(path.iterdir()): + path.rmdir() + + +def extract_coverage_maps(sample_names: list[str], scratch_path: Path, results_path: Path) -> None: + coverage_path = results_path / 'coverage_maps' + coverage_path.mkdir(parents=True, exist_ok=True) + logger.debug('Extracting coverage maps for %d samples into %s.', len(sample_names), coverage_path) + for sample_name in sample_names: + source_path = scratch_path / sample_name / 'coverage_maps.tar' + try: + with tarfile.open(source_path) as source_tar: + for source_info in source_tar: + filename = os.path.basename(source_info.name) + target_path = coverage_path / f'{sample_name}.{filename}' + source = source_tar.extractfile(source_info) + if source is None: + continue + with source, open(target_path, 'wb') as target: + shutil.copyfileobj(source, target) + except FileNotFoundError: + logger.debug('Missing coverage_maps.tar for sample %s at %s.', sample_name, source_path) + continue + remove_empty_directory(coverage_path) + + +def extract_archive(sample_names: list[str], scratch_path: Path, results_path: Path, output_name: str) -> None: + archive_name = output_name[:-4] + output_path = results_path / archive_name + output_path.mkdir(parents=True, exist_ok=True) + logger.debug('Extracting archive output %s for %d samples into %s.', output_name, len(sample_names), output_path) + for sample_name in sample_names: + source_path = scratch_path / sample_name / f'{archive_name}.tar' + try: + with tarfile.open(source_path) as source_tar: + sample_target_path = output_path / sample_name + sample_target_path.mkdir(parents=True, exist_ok=True) + for source_info in source_tar: + filename = os.path.basename(source_info.name) + target_path = sample_target_path / filename + source = source_tar.extractfile(source_info) + if source is None: + continue + with source, open(target_path, 'wb') as target: + shutil.copyfileobj(source, target) + remove_empty_directory(sample_target_path) + except FileNotFoundError: + logger.debug('Missing archive %s for sample %s at %s.', archive_name, sample_name, source_path) + continue + remove_empty_directory(output_path) + + +def move_alignment_plot(sample_names: list[str], extension: str, scratch_path: Path, results_path: Path) -> None: + alignment_path = results_path / 'alignment' + alignment_path.mkdir(parents=True, exist_ok=True) + for sample_name in sample_names: + source_path = scratch_path / sample_name / f'alignment{extension}' + target_path = alignment_path / f'{sample_name}_alignment{extension}' + if source_path.exists(): + source_path.rename(target_path) + remove_empty_directory(alignment_path) + + +def move_genome_coverage(sample_names: list[str], scratch_path: Path, results_path: Path) -> None: + plots_path = results_path / 'genome_coverage' + plots_path.mkdir(parents=True, exist_ok=True) + for sample_name in sample_names: + source_path = scratch_path / sample_name / 'genome_coverage.svg' + target_path = plots_path / f'{sample_name}_genome_coverage.svg' + if source_path.exists(): + source_path.rename(target_path) + concordance_path = scratch_path / sample_name / 'genome_concordance.svg' + target_concordance_path = plots_path / f'{sample_name}_genome_concordance.svg' + if concordance_path.exists(): + concordance_path.rename(target_concordance_path) + remove_empty_directory(plots_path) + + +def move_stitcher_plot(sample_names: list[str], scratch_path: Path, results_path: Path) -> None: + plots_path = results_path / 'stitcher_plots' + plots_path.mkdir(parents=True, exist_ok=True) + for sample_name in sample_names: + source_path = scratch_path / sample_name / 'stitcher_plot.svg' + target_path = plots_path / f'{sample_name}_stitcher_plot.svg' + if source_path.exists(): + source_path.rename(target_path) + remove_empty_directory(plots_path) + + +def copy_outputs(sample_names: list[str], scratch_path: Path, results_path: Path) -> None: + logger.info('Collating %d samples into %s.', len(sample_names), results_path) + results_path.mkdir(parents=True, exist_ok=True) + for output_name in DOWNLOADED_RESULTS: + logger.debug('Processing output %s.', output_name) + if output_name == 'coverage_maps_tar': + extract_coverage_maps(sample_names, scratch_path, results_path) + continue + if output_name.endswith('_tar'): + extract_archive(sample_names, scratch_path, results_path, output_name) + continue + if output_name == 'alignment_svg': + move_alignment_plot(sample_names, '.svg', scratch_path, results_path) + continue + if output_name == 'alignment_png': + move_alignment_plot(sample_names, '.png', scratch_path, results_path) + continue + if output_name == 'genome_coverage_svg': + move_genome_coverage(sample_names, scratch_path, results_path) + continue + if output_name == 'stitcher_plot_svg': + move_stitcher_plot(sample_names, scratch_path, results_path) + continue + + source_count = 0 + filename = get_output_filename(output_name) + target_path = results_path / filename + with target_path.open('w', encoding='utf-8', newline='') as target: + for sample_name in sample_names: + source_path = scratch_path / sample_name / filename + if not source_path.exists(): + continue + with source_path.open(encoding='utf-8', newline='') as source: + if output_name.endswith('_fasta'): + source_count += extract_fasta(source, target, sample_name) + else: + source_count += extract_csv(source, target, sample_name, source_count) + if source_count == 0 and target_path.exists(): + target_path.unlink() + logger.debug('Removed empty collated output %s.', target_path) + + +def configure_logging(args: argparse.Namespace) -> None: + if args.quiet: + level = logging.ERROR + elif args.verbose: + level = logging.INFO + elif args.debug: + level = logging.DEBUG + else: + level = logging.WARN + + logger.setLevel(level) + logging.basicConfig(level=level, + format='%(asctime)s [%(levelname)s] %(name)s: %(message)s') + + +def parse_args(args: Sequence[str] | None = None) -> argparse.Namespace: + parser = argparse.ArgumentParser(description='Collate per-sample MiCall outputs into run-level grouped files.') + verbosity_group = parser.add_mutually_exclusive_group() + verbosity_group.add_argument('--verbose', action='store_true', help='Increase output verbosity.') + verbosity_group.add_argument('--debug', action='store_true', help='Maximum output verbosity.') + verbosity_group.add_argument('--quiet', action='store_true', help='Minimize output verbosity.') + verbosity_group.add_argument('--no-verbose', action='store_true', help='Normal output verbosity.', default=True) + parser.add_argument('--inputs', nargs='*', type=Path, required=True, + help='Input run outputs followed by the metadata manifest.') + parser.add_argument('output', type=Path, + help='Output tar file path for collated results.') + return parser.parse_args(args) + + +def stage_inputs_by_sample(run_outputs: Sequence[Path], metadata_csv: Path, scratch_path: Path) -> list[str]: + logger.info('Reading metadata manifest from %s.', metadata_csv) + with metadata_csv.open(newline='', encoding='utf-8') as manifest_file: + reader = csv.DictReader(manifest_file) + if reader.fieldnames is None: + raise ValueError('Metadata manifest is missing a header row.') + missing_columns = REQUIRED_MANIFEST_COLUMNS - set(reader.fieldnames) + if missing_columns: + missing = ', '.join(sorted(missing_columns)) + raise ValueError(f'Metadata manifest is missing required columns: {missing}.') + rows = list(reader) + logger.info('Loaded %d metadata rows with %d run outputs.', len(rows), len(run_outputs)) + + sample_names: set[str] = set() + for row_number, row in enumerate(rows, start=2): + index_text = row.get('index', '') + try: + file_index = int(index_text) + except (TypeError, ValueError) as ex: + raise ValueError( + f'Metadata manifest row {row_number} has invalid index {index_text!r}.') from ex + + sample_name = (row.get('sample') or '').strip() + output_name = (row.get('output_name') or '').strip() + logger.debug('Manifest row %d: index=%s sample=%s output=%s', + row_number, + index_text, + sample_name, + output_name) + if file_index < 0 or file_index >= len(run_outputs): + raise ValueError( + f'Invalid run_outputs index {file_index} in metadata manifest row {row_number}.') + if not sample_name: + raise ValueError(f'Metadata manifest row {row_number} has empty sample value.') + if (Path(sample_name).name != sample_name or + sample_name in ('.', '..') or + '/' in sample_name or + '\\' in sample_name): + raise ValueError( + f'Metadata manifest row {row_number} has invalid sample name {sample_name!r}.') + if output_name not in DOWNLOADED_RESULTS: + raise ValueError( + f'Metadata manifest row {row_number} has unknown output_name {output_name!r}.') + + source_path = run_outputs[file_index] + if not source_path.is_file(): + raise FileNotFoundError( + f'Metadata manifest row {row_number} references missing run output {source_path}.') + sample_path = scratch_path / sample_name + sample_path.mkdir(parents=True, exist_ok=True) + target_path = sample_path / get_output_filename(output_name) + if target_path.exists(): + raise ValueError( + f'Metadata manifest row {row_number} duplicates output {output_name!r} ' + f'for sample {sample_name!r}.') + shutil.copyfile(source_path, target_path) + sample_names.add(sample_name) + + logger.info('Prepared scratch staging for %d unique samples.', len(sample_names)) + return sorted(sample_names) + + +def main(argv: Sequence[str] | None = None) -> None: + print(f'Starting kive_collate with arguments: {argv}', file=sys.stderr) + args = parse_args(argv) + configure_logging(args) + inputs: Sequence[Path] = args.inputs + + run_outputs: Sequence[Path] = inputs[:-1] + metadata_csv: Path = inputs[-1] + output: Path = args.output + + logger.info('Starting kive collation run.') + logger.debug('Arguments: inputs=%s run_outputs=%s metadata_csv=%s output=%s', + inputs, + run_outputs, + metadata_csv, + output) + + output.parent.mkdir(parents=True, exist_ok=True) + with tempfile.TemporaryDirectory() as tmp_text: + tmp_path = Path(tmp_text) + scratch_path = tmp_path / 'scratch' + collated_path = tmp_path / 'collated' + scratch_path.mkdir(parents=True, exist_ok=True) + collated_path.mkdir(parents=True, exist_ok=True) + + sample_names = stage_inputs_by_sample(run_outputs, + metadata_csv, + scratch_path) + copy_outputs(sample_names, scratch_path, collated_path) + + with tarfile.open(output, 'w') as output_tar: + for file_path in sorted(collated_path.rglob('*')): + if file_path.is_file(): + output_tar.add(file_path, file_path.relative_to(collated_path)) + logger.info('Finished kive collation run. Output: %s', output) + + +if __name__ == '__main__': + main(sys.argv[1:]) diff --git a/micall/utils/singularity_build.py b/micall/utils/singularity_build.py index 77d09dd48..b7fa90de1 100644 --- a/micall/utils/singularity_build.py +++ b/micall/utils/singularity_build.py @@ -85,6 +85,18 @@ %apphelp denovo Standard pipeline with de novo assembly instead of mapping to reference sequences. + +%apphelp collation + Collate per-sample MiCall outputs into run-level grouped files. + +%applabels collation + KIVE_INPUTS --inputs* + KIVE_OUTPUTS output + KIVE_THREADS 1 + KIVE_MEMORY 1000 + +%apprun collation + python -m micall.utils.kive_collate --debug "$@" """