diff --git a/sotodlib/preprocess/pcore.py b/sotodlib/preprocess/pcore.py index 0149ffbc2..b820b13cc 100644 --- a/sotodlib/preprocess/pcore.py +++ b/sotodlib/preprocess/pcore.py @@ -446,7 +446,9 @@ def __getitem__(self, index): else: return result - def run(self, aman, proc_aman=None, select=True, sim=False, update_plot=False, data_amans=None): + + def run(self, aman, proc_aman=None, full_aman=None, select=True, + sim=False, update_plot=False, data_amans=None): """ The main workhorse function for the pipeline class. This function takes an AxisManager TOD and successively runs the pipeline of preprocessing @@ -470,6 +472,12 @@ def run(self, aman, proc_aman=None, select=True, sim=False, update_plot=False, d returned this preprocess axismanager. In this case, calls to ``process.calc_and_save()`` are skipped as the information is expected to be present in this AxisManager. + full_aman: AxisManager (Optional) + A preprocess axismanager. This axis manager stores the outputs of + preprocessing functions (proc_aman) but without any of the detector + or samps restrictions applied, thus maintaining its original shape. + This is returned at the end of the pipeline. If not passed it is + instantiated with the same number of dets and samps as aman. select: boolean (Optional) if True, the aman detector axis is restricted as described in each preprocess module. Most pipelines are developed with @@ -490,18 +498,22 @@ def run(self, aman, proc_aman=None, select=True, sim=False, update_plot=False, d Returns ------- - proc_aman: AxisManager + full_aman: AxisManager A preprocess axismanager that contains all data products calculated - throughout the running of the pipeline - + throughout the running of the pipeline. + success: str + A string that stores the name of the last process step that the pipeline + completed. If the pipeline successfully finishes all steps, success = 'end'. """ if proc_aman is None: if 'preprocess' in aman: proc_aman = aman.preprocess.copy() - full = aman.preprocess.copy() + if full_aman is None: + full_aman = aman.preprocess.copy() else: proc_aman = core.AxisManager(aman.dets, aman.samps) - full = core.AxisManager( aman.dets, aman.samps) + if full_aman is None: + full_aman = core.AxisManager( aman.dets, aman.samps) run_calc = True update_plot = False else: @@ -510,7 +522,8 @@ def run(self, aman, proc_aman=None, select=True, sim=False, update_plot=False, d det_list = [det for det in proc_aman.dets.vals if det in aman.dets.vals] aman.restrict('dets', det_list) proc_aman.restrict('dets', det_list) - full = proc_aman.copy() + if full_aman is None: + full_aman = proc_aman.copy() run_calc = False if 'frequency_cutoffs' not in proc_aman: @@ -553,7 +566,7 @@ def run(self, aman, proc_aman=None, select=True, sim=False, update_plot=False, d if run_calc: aman, proc_aman = process.calc_and_save(aman, proc_aman) process.plot(aman, proc_aman, filename=os.path.join(self.plot_dir, '{ctime}/{obsid}', f'{step+1}_{{name}}.png')) - update_full_aman( proc_aman, full, self.wrap_valid) + update_full_aman( proc_aman, full_aman, self.wrap_valid) if update_plot: process.plot(aman, proc_aman, filename=os.path.join(self.plot_dir, '{ctime}/{obsid}', f'{step+1}_{{name}}.png')) plt.close() @@ -561,22 +574,22 @@ def run(self, aman, proc_aman=None, select=True, sim=False, update_plot=False, d process.select(aman, proc_aman) proc_aman.restrict('dets', aman.dets.vals) self.logger.debug(f"{proc_aman.dets.count} detectors remaining") - + if aman.dets.count == 0: success = process.name break if run_calc: - _wrap_valid_ranges(proc_aman, full, valid_name='valid_data', + _wrap_valid_ranges(proc_aman, full_aman, valid_name='valid_data', wrap_name='valid_data') - # copy updated frequency cutoffs to full - if "frequency_cutoffs" in full: - full.move("frequency_cutoffs", None) - full.wrap("frequency_cutoffs", proc_aman["frequency_cutoffs"]) + # copy updated frequency cutoffs to full_aman + if "frequency_cutoffs" in full_aman: + full_aman.move("frequency_cutoffs", None) + full_aman.wrap("frequency_cutoffs", proc_aman["frequency_cutoffs"]) + + return full_aman, success - return full, success - class _FracFlaggedMixIn(object): diff --git a/sotodlib/preprocess/preprocess_util.py b/sotodlib/preprocess/preprocess_util.py index 3323c80f5..1dfd4e6d0 100644 --- a/sotodlib/preprocess/preprocess_util.py +++ b/sotodlib/preprocess/preprocess_util.py @@ -516,7 +516,7 @@ def load_preprocess_det_select(obs_id, configs, context=None, def load_and_preprocess(obs_id, configs, context=None, dets=None, meta=None, - no_signal=None, logger=None): + no_signal=None, logger=None, return_full_aman=False): """Loads the saved information from the preprocessing pipeline and runs the processing section of the pipeline. @@ -543,12 +543,19 @@ def load_and_preprocess(obs_id, configs, context=None, dets=None, meta=None, logger : PythonLogger Optional. Logger object. If None, a new logger is created. + return_full_aman : bool + Optional. Return unrestricted axis manager alongside restricted aman + if True, otherwise return None. Returns ------- aman : core.AxisManager or None Loaded and restricted axis manager with preprocessing metadata. Returns ``None`` if all detectors cut. + full_aman : core.AxisManager or None + Unrestricted preprocessing axis manager. Used when running multilayer + pipeline to ensure saved detector axis has the full size when saving + metadata. """ if logger is None: @@ -556,6 +563,12 @@ def load_and_preprocess(obs_id, configs, context=None, dets=None, meta=None, configs, context = get_preprocess_context(configs, context) meta = context.get_meta(obs_id, dets=dets, meta=meta) + + if return_full_aman: + full_aman = meta.preprocess.copy() + else: + full_aman = None + if ( 'valid_data' in meta.preprocess and isinstance(meta.preprocess.valid_data, core.AxisManager) @@ -574,7 +587,7 @@ def load_and_preprocess(obs_id, configs, context=None, dets=None, meta=None, pipe = Pipeline(configs["process_pipe"], logger=logger) aman = context.get_obs(meta, no_signal=no_signal) pipe.run(aman, aman.preprocess, select=False) - return aman + return aman, full_aman def multilayer_load_and_preprocess(obs_id, configs_init, configs_proc, @@ -674,6 +687,21 @@ def multilayer_load_and_preprocess(obs_id, configs_init, configs_proc, ): pipe_proc = Pipeline(configs_proc["process_pipe"], logger=logger) + logger.info("Restricting detectors on all init pipeline processes") + if ( + 'valid_data' in meta_init.preprocess and + isinstance(meta_init.preprocess.valid_data, core.AxisManager) + ): + keep_all = has_any_cuts(meta_init.preprocess.valid_data.valid_data) + else: + keep_all = np.ones(meta_init.dets.count,dtype=bool) + for process in pipe_init[:]: + keep = process.select(meta_init, in_place=False) + if isinstance(keep, np.ndarray): + keep_all &= keep + meta_init.restrict("dets", meta_init.dets.vals[keep_all]) + meta_proc.restrict("dets", meta_init.dets.vals) + logger.info("Restricting detectors on all proc pipeline processes") if ( 'valid_data' in meta_proc.preprocess and @@ -686,6 +714,7 @@ def multilayer_load_and_preprocess(obs_id, configs_init, configs_proc, keep = process.select(meta_proc, in_place=False) if isinstance(keep, np.ndarray): keep_all &= keep + meta_proc.restrict("dets", meta_proc.dets.vals[keep_all]) meta_init.restrict('dets', meta_proc.dets.vals) @@ -705,6 +734,7 @@ def multilayer_load_and_preprocess(obs_id, configs_init, configs_proc, return aman logger.info("Running dependent pipeline") + if stop_for_sims: aman = out_amans_init[(len(pipe_init), 'last')] proc_aman = context_proc.get_meta(obs_id, meta=aman) @@ -723,7 +753,6 @@ def multilayer_load_and_preprocess(obs_id, configs_init, configs_proc, if name != 'last' }) return out_amans - else: pipe_proc.run(aman, aman.preprocess, select=False) return aman @@ -821,7 +850,7 @@ def multilayer_load_and_preprocess_sim(obs_id, configs_init, configs_proc, filled with AxisManager processed up to step-1. This is used to pre-load all data AxisManager which could be required when processing simulations (e.g. to provide a T2P template) - + Returns ------- aman : core.AxisManager or None @@ -858,14 +887,37 @@ def multilayer_load_and_preprocess_sim(obs_id, configs_init, configs_proc, logger=logger): pipe_proc = Pipeline(configs_proc["process_pipe"], logger=logger) + logger.info("Restricting detectors on all init pipeline processes") + if ( + 'valid_data' in meta_init.preprocess and + isinstance(meta_init.preprocess.valid_data, core.AxisManager) + ): + keep_all = has_any_cuts(meta_init.preprocess.valid_data.valid_data) + else: + keep_all = np.ones(meta_init.dets.count,dtype=bool) + for process in pipe_init[:]: + keep = process.select(meta_init, in_place=False) + if isinstance(keep, np.ndarray): + keep_all &= keep + meta_init.restrict("dets", meta_init.dets.vals[keep_all]) + meta_proc.restrict("dets", meta_init.dets.vals) + logger.info("Restricting detectors on all proc pipeline processes") - keep_all = np.ones(meta_proc.dets.count, dtype=bool) - for process in pipe_proc[:]: - keep = process.select(meta_proc, in_place=False) - if isinstance(keep, np.ndarray): - keep_all &= keep + if ( + 'valid_data' in meta_proc.preprocess and + isinstance(meta_proc.preprocess.valid_data, core.AxisManager) + ): + keep_all = has_any_cuts(meta_proc.preprocess.valid_data.valid_data) + else: + keep_all = np.ones(meta_proc.dets.count, dtype=bool) + for process in pipe_proc[:]: + keep = process.select(meta_proc, in_place=False) + if isinstance(keep, np.ndarray): + keep_all &= keep + meta_proc.restrict("dets", meta_proc.dets.vals[keep_all]) meta_init.restrict('dets', meta_proc.dets.vals) + aman = context_init.get_obs(meta_proc, no_signal=True) # One needs to correct HWP model and gamma @@ -1327,9 +1379,14 @@ def preproc_or_load_group(obs_id, configs_init, dets, configs_proc=None, if db_init_exist and not db_proc_exist: out_dict_init = None try: + # need unrestricted proc aman for second layer + if configs_proc is not None: + return_full_aman = True + else: + return_full_aman = False logger.info(f"Loading and applying preprocessing for initial layer db on {obs_id}:{group}") - aman = load_and_preprocess(obs_id=obs_id, dets=dets, configs=configs_init, - logger=logger) + aman, proc_aman = load_and_preprocess(obs_id=obs_id, dets=dets, configs=configs_init, + logger=logger, return_full_aman=return_full_aman) except Exception as e: errmsg, tb = PreprocessErrors.get_errors(e) logger.error(f"Initial layer Pipeline Load Error for {obs_id}: {group}\n{errmsg}\n{tb}") @@ -1434,7 +1491,7 @@ def preproc_or_load_group(obs_id, configs_init, dets, configs_proc=None, pipe_proc = Pipeline(configs_proc["process_pipe"], plot_dir=configs_proc["plot_dir"], logger=logger) - proc_aman, success = pipe_proc.run(aman) + proc_aman, success = pipe_proc.run(aman, full_aman=proc_aman) pipe_init = Pipeline(configs_init["process_pipe"], plot_dir=configs_init["plot_dir"], logger=logger) diff --git a/sotodlib/site_pipeline/make_depth1_map.py b/sotodlib/site_pipeline/make_depth1_map.py index 2aa27d759..92e026233 100644 --- a/sotodlib/site_pipeline/make_depth1_map.py +++ b/sotodlib/site_pipeline/make_depth1_map.py @@ -154,7 +154,7 @@ def make_depth1_map( # which is pointless, but shouldn't cost that much time # obs = context.get_obs(obs_id, dets={"wafer_slot":detset, "band":band}) try: - obs = pp_util.load_and_preprocess( + obs, _ = pp_util.load_and_preprocess( obs_id, preproc, dets={"wafer_slot": detset, "wafer.bandpass": band}, diff --git a/sotodlib/site_pipeline/multilayer_preprocess_tod.py b/sotodlib/site_pipeline/multilayer_preprocess_tod.py index 5a069d2a3..ff82a2e86 100644 --- a/sotodlib/site_pipeline/multilayer_preprocess_tod.py +++ b/sotodlib/site_pipeline/multilayer_preprocess_tod.py @@ -474,7 +474,7 @@ def _main(executor: Union["MPICommExecutor", "ProcessPoolExecutor"], n_groups_fail += 1 if n_groups_fail > 0: - raise RuntimeError(f"preprocess_tod ended with {n_obs_fail}/{len(obs_errors)} " + raise RuntimeError(f"multilayer_preprocess_tod ended with {n_obs_fail}/{len(obs_errors)} " f"failed obsids and {n_groups_fail}/{len(run_list)} failed groups") logger.info("multilayer_preprocess_tod is done")