diff --git a/docs/det_match.rst b/docs/det_match.rst index 84f42c0b2..73085b39e 100644 --- a/docs/det_match.rst +++ b/docs/det_match.rst @@ -6,7 +6,7 @@ DetMatch The ``sotodlib.coords.det_match`` module allows us to map resonators from one source to another, using information such as resonator frequency, bias-line assignments, and pointing information. This is particularly useful to create -a map from resonators in a SMuRF tune-file, to real detectors from either a +a map from resonators in a SMuRF tune-file to real detectors from either a design-file, or a handmade solutions file based on a separate tune-file. This works by translating the detector matching problem into an instance @@ -90,10 +90,43 @@ lenient with the frequency penalty: ) match = dm.Match(src, dst, match_params=mpars) +Det Match Solutions +````````````````````````` + +The ``det_match_solutions`` script can be used to generate "handmade" detector match +solutions sets for all wafers. A solution set is defined as a mapping from a tune file +with the addition of pointing information from fits to a point source in that observation to +the design wafer information (i.e. matching tune readout IDs to design detector IDs with +xi and eta constraints). Solutions are useful due to the potentially alrge frequency shifts +between the tunesets and design frequencies.It is performs multiple matches +sequentially while correcting for frequency and pointing offsets between them. + +The major steps in this script are: + +- Load pointing xi and eta information from fits to observations of point sources. + These are derived by fitting TODs or maps of observations targeting point + sources (planets or the Moon) and are stored as a structured array in an + ``hdf5`` file under a group named ``focal_plane`` and should include entries + for all det_ids from the matching tune (NaNs are allowed). It should also include + an estimate of the coefficient of determination, R\ :sup:`2` for excluding bad fits. + Multiple pointing files may be input in which case they will a match will be + performed and the median xi and eta values will be used from all matched resonators. +- Do the first match for the wafer using pointing, frequency, and bias line information. +- Subtract the median xi and eta offset from matched detectors. Also remove frequency + offsets through box median interpolation. +- Run a second match after offset correction. +- Perform a grid based pointing offset given a selection radius in the config file. +- Run the third match after second pointing offset correction. + API ------- .. automodule:: sotodlib.coords.det_match + :members: + :undoc-members: + :show-inheritance: + +.. automodule:: sotodlib.coords.det_match_solutions :members: :undoc-members: :show-inheritance: \ No newline at end of file diff --git a/sotodlib/coords/det_match.py b/sotodlib/coords/det_match.py index 90107e92b..ea9bc7806 100644 --- a/sotodlib/coords/det_match.py +++ b/sotodlib/coords/det_match.py @@ -1,6 +1,6 @@ import warnings from dataclasses import dataclass, fields, asdict, field -from typing import List, Optional, Tuple, Iterator, Union +from typing import List, Optional, Tuple, Iterator, Union, get_origin, get_args from copy import deepcopy import h5py @@ -147,10 +147,10 @@ class Resonator: bg: int = -1 det_x: float = np.nan det_y: float = np.nan - det_row: int = 0 - det_col: int = 0 + det_row: Optional[int] = None + det_col: Optional[int] = None pixel_num: int = 0 - det_rhomb: str = '' + det_rhomb: Optional[str] = None det_pol: str = '' det_freq: int = 0 det_bandpass: str = '' @@ -190,12 +190,18 @@ def apply_design_properties(smurf_res, design_res, in_place=False, apply_pointin r = deepcopy(smurf_res) design_props = [ - 'bg', 'det_x', 'det_y', 'det_row', 'det_col', 'pixel_num', 'det_rhomb', + 'bg', 'det_x', 'det_y', 'pixel_num', 'det_pol', 'det_freq', 'det_bandpass', 'det_angle_raw_deg', 'det_angle_actual_deg', 'det_type', 'det_id', 'is_optical', 'mux_bondpad', 'mux_subband', 'mux_band', 'mux_channel', 'mux_layout_pos' ] + + # for LF + for attr in ('det_row', 'det_col', 'det_rhomb'): + if getattr(design_res, attr, None) is not None: + design_props.append(attr) + if apply_pointing: design_props += ['xi', 'eta', 'gamma'] @@ -267,6 +273,8 @@ def as_array(self): dtype = [] data = [] for field in fields(Resonator): + if get_origin(field.type) is Union and type(None) in get_args(field.type): + continue if field.type == str: typ = ' np.ndarray: - src_arr = self.src.as_array() - dst_arr = self.dst.as_array() + src_arr = self.src.as_array() # pointing or tuneset + dst_arr = self.dst.as_array() # design or pointing mat = np.zeros((len(self.src), len(self.dst)), dtype=float) @@ -706,6 +749,11 @@ def _get_biadjacency_matrix(self) -> np.ndarray: m = src_arr['is_north'][:, None] != dst_arr['is_north'][None, :] mat[m] = np.inf + src_no_match = np.isin(src_arr['det_type'], ['NC']) + dst_no_match = np.isin(dst_arr['det_type'], ['NC']) + mat[src_no_match, :] = np.inf + mat[:, dst_no_match] = np.inf + if self.match_pars.enforce_pointing_reqs: # Det types of DARK and SLOT are allowed to may or may not have # pointing data. For these types of detectors, the cost is left @@ -714,11 +762,7 @@ def _get_biadjacency_matrix(self) -> np.ndarray: src_has_pointing = np.isfinite(src_arr['xi']) & np.isfinite(src_arr['eta']) dst_has_pointing = np.isfinite(dst_arr['xi']) & np.isfinite(dst_arr['eta']) - src_no_match = np.isin(src_arr['det_type'], ['NC']) - dst_no_match = np.isin(dst_arr['det_type'], ['NC']) - mat[src_no_match, :] = np.inf - mat[:, dst_no_match] = np.inf - + # UNRT, SQID, and BARE must not have pointing src_pointing_forbidden = np.isin(src_arr['det_type'], ['UNRT', 'SQID', 'BARE']) dst_pointing_forbidden = np.isin(dst_arr['det_type'], ['UNRT', 'SQID', 'BARE']) m = src_pointing_forbidden[:, None] & dst_has_pointing[None, :] @@ -726,6 +770,7 @@ def _get_biadjacency_matrix(self) -> np.ndarray: m = src_has_pointing[:, None] & dst_pointing_forbidden[None, :] mat[m] = np.inf + # OPTC must have pointing src_pointing_required = np.isin(src_arr['det_type'], ['OPTC']) dst_pointing_required = np.isin(dst_arr['det_type'], ['OPTC']) m = src_pointing_required[:, None] & (~dst_has_pointing[None, :]) @@ -733,19 +778,44 @@ def _get_biadjacency_matrix(self) -> np.ndarray: m = (~src_has_pointing[:, None]) & dst_pointing_required[None, :] mat[m] = np.inf + # These should always have BG = -1 + src_unassigned_type = np.isin(src_arr['det_type'], ['UNRT', 'SQID', 'BARE']) + dst_unassigned_type = np.isin(dst_arr['det_type'], ['UNRT', 'SQID', 'BARE']) + # pointing or tuneset have unassigned type (shouldn't ever happen since + # they won't have det_types yet) + m = src_unassigned_type[:, None] & (dst_arr['bg'][None, :] !=-1) + mat[m] = np.inf + # Design or pointing have unassigned type (should happen since dst + # will have det_types) + m = dst_unassigned_type[None, :] & (src_arr['bg'][:, None] !=-1) + mat[m] = np.inf + # Frequency offset df = src_arr['res_freq'][:, None] - dst_arr['res_freq'][None, :] df -= self.match_pars.freq_offset_mhz mat += np.exp((np.abs(df / self.match_pars.freq_width)) ** 2) - # BG mismatch - bgs_mismatch = src_arr['bg'][:, None] != dst_arr['bg'][None, :] - bgs_unassigned = (src_arr['bg'][:, None] == -1) | (dst_arr['bg'][None, :] == -1) + # Design or pointing unassigned + dst_unassigned = (dst_arr['bg'][None, :] == -1) + # pointing or tune unassigned + src_unassigned = (src_arr['bg'][:, None] == -1) - m = bgs_mismatch & bgs_unassigned + # Whether or not to match unassigned bg to assigned bgs + # don't want matches when matching from design to pointing + m = dst_unassigned & (~src_unassigned) + if not self.match_pars.allow_unassigned_to_assigned: + mat[m] = np.inf + else: + mat[m] += self.match_pars.unassigned_bg_mismatch_pen + + # Match assigned bg to unassigned bg + m = (~dst_unassigned) & src_unassigned mat[m] += self.match_pars.unassigned_bg_mismatch_pen - m = bgs_mismatch & (~bgs_unassigned) - mat[m] += self.match_pars.assigned_bg_mismatch_pen + + # Assigned bgs must not be mismatched + bgs_mismatch = src_arr['bg'][:, None] != dst_arr['bg'][None, :] + m = bgs_mismatch & (~dst_unassigned) & (~src_unassigned) + mat[m] = np.inf # If pointing, add cost if assigned too far dd = np.sqrt( @@ -801,9 +871,11 @@ def _match(self): for r1, r2 in self.get_match_iter(include_unmatched=True): if r1 is None: r2.matched = 0 + r2.match_idx = -1 continue if r2 is None: r1.matched = 0 + r1.match_idx = -1 continue r1.matched = 1 diff --git a/sotodlib/coords/det_match_solutions.py b/sotodlib/coords/det_match_solutions.py new file mode 100644 index 000000000..2399e04f4 --- /dev/null +++ b/sotodlib/coords/det_match_solutions.py @@ -0,0 +1,589 @@ +from sotodlib.core import Context, AxisManager, LabelAxis +from sotodlib.coords import det_match as dm +from dataclasses import dataclass, field +from typing import Optional, List, Dict, Tuple, Callable, Any +from collections import defaultdict +import os +import numpy as np +import yaml +from copy import deepcopy +from scipy import interpolate +from tqdm.auto import tqdm, trange +from copy import deepcopy + +import h5py +import sys + + +@dataclass +class SolutionsCfg: + """ + Args + ------ + ctx_path: str + Path to context file to use to pull tod metadata. + pointing_results_dir: str + Results to directory that contains pointing results. Files in + directory should look like: + ``` + focal_plane__.hdf + ``` + results_dir: str + Directory where results should be stored. + wafer_info_path: str + Path to the wafer_info h5 file. + tel_type: str + Tel type for the optics model. Either "SAT" or "LAT" + base_obs_id: str + Obs_id to use as a base for matching when merging multiple pointing + obs_ids for a wafer. Will default to the pointing obs_id with the + greatest number of detectors above the min_R2 threshold. + zemax_path: str + If running for a "LAT" tel_type, the path to the zemax file must be specified. + apply_roll: bool + Whether or not to apply the obs_id roll angle. Some pointing sets + may already be corrected for roll angle. + pointing_field: str + Name of sub axis manager in pointing tune axis maanger containng the + pointing information. + site_pipeline_cfg_dir: str + Path to site-pipeline-config dir. Defaults to the env var + ``$SITE_PIPELINE_CONFIG_DIR``. + finite_xi_thresh: int + Minimum number of dets a pointing result must have to add it to the + analysis. + min_r2: float + Minimum R-squared for det pointing to be considered. + sel_rad: float + Selection radius for grid-based interpolation pointing offset subtraction. + unassigned_slots: int: + Number of additional "unassigned" node to use per-side + wafer_map_path: str + Path to the wafer map file. Defaults to ``/shared/detmatpping/wafer_map.yaml``. + match_pars: dict + Dictionary of match parameters to use for pointing obs_id merging and + each match iteration. Should have the form:: + + match_pars: + pointing: + freq_width: 0.4 + dist_width: 2.0 + match0: + freq_width: 200 + dist_width: 0.4 + match1: + freq_width: 50 + dist_width: 0.8 + match2: + freq_width: 5 + dist_width: 0.1 + + Initial pointing offset: Tuple[float, float] + Estimated pointing offset for the boresight. This should be + (xi_offset, eta_offset) where both are in radians. + ufm_to_fp_path: str + Path to file that maps wafer_slot to position on focal plane. + freq_correct_by_muxband: bool + If true, apply the same freq offset correction to all resonators in a mux-band. + """ + + ctx_path: str + pointing_results_dir: str + results_dir: str + wafer_info_path: str + tel_type: str + base_obs_id: Optional[str] = None + zemax_path: Optional[str] = None + apply_roll: bool = True + + pointing_field: str = "tod_pointing" + site_pipeline_cfg_dir: str = "$SITE_PIPELINE_CONFIG_DIR" + finite_xi_thresh: int = ( + 500 # Min number of dets with finite xi to consider a pointing input + ) + min_r2: float = 0.9 + sel_rad: float = 2.0 + unassigned_slots: int = 1200 + wafer_map_path: Optional[str] = None + match_pars: Dict[str, dict] = field(default_factory=lambda: defaultdict(dict)) + + initial_pointing_offset: Tuple[float, float] = (0, 0) + ufm_to_fp_path: Optional[str] = None + freq_correct_by_muxband: bool = True + + ctx: Context = field(init=False) + wafer_map: Dict[str, dict] = field(init=False) + + @classmethod + def from_dict(cls, d: Dict[str, Any]) -> "SolutionsCfg": + return cls(**d) + + @classmethod + def from_yaml(cls, path: str) -> "SolutionsCfg": + with open(path, "r") as f: + return cls.from_dict(yaml.safe_load(f)) + + def __post_init__(self): + self.ctx = Context(self.ctx_path) + + if not os.path.exists(self.results_dir): + if os.path.exists(os.path.split(self.results_dir)[0]): + os.makedirs(self.results_dir) + else: + raise FileNotFoundError( + f"Could not find results dir or basedir: {self.results_dir}" + ) + + self.site_pipeline_cfg_dir = os.path.expandvars(self.site_pipeline_cfg_dir) + if self.wafer_map_path is None: + self.wafer_map_path = os.path.join( + self.site_pipeline_cfg_dir, "shared/detmapping/wafer_map.yaml" + ) + with open(self.wafer_map_path, "r") as f: + self.wafer_map = yaml.safe_load(f) + + if self.ufm_to_fp_path is None: + self.ufm_to_fp_path = os.path.join( + self.site_pipeline_cfg_dir, "shared/focalplane/ufm_to_fp.yaml" + ) + + +@dataclass +class PointingInfo: + pointing: np.ndarray + obs_id: str + obs: dict + meta: AxisManager + preprocessed: bool = False + + +_meta_cache = {} +def get_meta(cfg: SolutionsCfg, obs_id: str, wafer_slot: Optional[str] = None): + if obs_id in _meta_cache: + meta = _meta_cache[obs_id] + else: + meta = cfg.ctx.get_meta(obs_id) + _meta_cache[obs_id] = meta + meta = deepcopy(meta) + if wafer_slot is not None: + meta.restrict("dets", meta.det_info.wafer_slot == wafer_slot) + return meta + + +def load_good_pointing_info(cfg: SolutionsCfg, wafer_slot: str) -> List[PointingInfo]: + """ + Load pointing data for each pointing measurement on disk + """ + files = [] + for f in os.listdir(cfg.pointing_results_dir): + if os.path.splitext(f)[0].split("_")[-1] == wafer_slot: + files.append(os.path.join(cfg.pointing_results_dir, f)) + + pointing_info = [] + for f in tqdm(files): + d = h5py.File(f)["focal_plane"] + if np.sum(np.isfinite(d["xi"])) < cfg.finite_xi_thresh: + continue + obs_id = "_".join(os.path.basename(f).split("_")[2:-1]) + obs = cfg.ctx.obsdb.get(obs_id) + pinfo = PointingInfo( + pointing=d, + obs_id=obs_id, + obs=obs, + meta=get_meta(cfg, obs_id, wafer_slot=wafer_slot), + ) + pointing_info.append(pinfo) + + return pointing_info + + +def pointing_preprocess(cfg: SolutionsCfg, pinfo: PointingInfo): + """ + Add tod_pointing to PointingInfo metadata, adjusting for boresight angle and + pointing offset. + """ + meta: AxisManager = pinfo.meta + assert (meta.det_info.readout_id == pinfo.pointing["dets:readout_id"].astype(str)).all() + + tod_pointing = AxisManager(meta.dets) + + _xi = deepcopy(pinfo.pointing["xi"]) + _eta = deepcopy(pinfo.pointing["eta"]) + theta = 0 + offset = cfg.initial_pointing_offset + + _xi += offset[0] + _eta += offset[1] + xi = _xi * np.cos(theta) - _eta * np.sin(theta) + eta = _xi * np.sin(theta) + _eta * np.cos(theta) + + tod_pointing.wrap("xi", xi) + tod_pointing.wrap("eta", eta) + tod_pointing.wrap("r2", pinfo.pointing["R2"]) + tod_pointing.xi[tod_pointing.r2 < cfg.min_r2] = np.nan + tod_pointing.eta[tod_pointing.r2 < cfg.min_r2] = np.nan + + if 'tod_pointing' in meta: + meta.move('tod_pointing', None) + + meta.wrap("tod_pointing", tod_pointing) + return meta + + +def merge_pointing_info(cfg: SolutionsCfg, pinfos: List[PointingInfo], base_idx=0): + """ + Combine all pointing measurements into a single resonator set, with the + median pointing info from all. This requires a base_idx to be specified, + which will be the index of the PointingInfo to use to create the ResSet + template. For all other PointingInfo objects, resonators will be matched to + the base resset based on resonance frequency, to compile all pointing + measurements for a given detector. The median of all measurements will be + used as the real value. + """ + for pinfo in pinfos: + pointing_preprocess(cfg, pinfo) + + meta = pinfos[base_idx].meta + stream_id = meta.det_info.stream_id[0] + wafer_slot = meta.det_info.wafer_slot[0] + base_resset = dm.ResSet.from_aman(meta, stream_id=stream_id, pointing=meta.tod_pointing) + + pointing_map = { + r.idx: [(r.xi, r.eta)] for r in base_resset + } + + match_pars = dm.MatchParams( + freq_width=cfg.match_pars["pointing"]["freq_width"], + dist_width=np.deg2rad(cfg.match_pars["pointing"]["dist_width"]) + ) + + for i in range(len(pinfos)): + if i == base_idx: + continue + meta = pinfos[i].meta + src = dm.ResSet.from_aman(meta, stream_id=stream_id, pointing=meta.tod_pointing) + dst = base_resset + match = dm.Match(src, dst, match_pars=match_pars) + for rsrc, rdst in match.get_match_iter(include_unmatched=False): + pointing_map[rdst.idx].append((rsrc.xi, rsrc.eta)) + + for r in base_resset: + r.xi, r.eta = np.nanmedian(np.array(pointing_map[r.idx]).T, axis=1) + + return base_resset, pointing_map + + +def get_best_tod_pointing(cfg: SolutionsCfg, pinfos: List[PointingInfo]) -> AxisManager: + _readout_ids = pinfos[0].pointing["dets:readout_id"] + + readout_ids = list(map(lambda bs: bs.decode(), _readout_ids)) + dets = LabelAxis("dets", readout_ids) + ndets = dets.count + + for pinfo in pinfos: # Shift and rotate xi/eta per pointing observation + _xi = pinfo.pointing["xi"] + cfg.initial_pointing_offset[0] + _eta = pinfo.pointing["eta"] + cfg.initial_pointing_offset[1] + theta = np.deg2rad(pinfo.obs["roll_center"]) if cfg.apply_roll else 0. + pinfo.xi = _xi * np.cos(theta) - _eta * np.sin(theta) + pinfo.eta = _xi * np.sin(theta) + _eta * np.cos(theta) + + xis = np.full(ndets, np.nan) + etas = np.full(ndets, np.nan) + + for i in trange(len(readout_ids)): # Find optimal xi/eta per readout channel + _xis = np.full(len(pinfos), np.nan) + _etas = np.full(len(pinfos), np.nan) + _r2s = np.full(len(pinfos), np.nan) + for j, pi in enumerate(pinfos): + rc = np.where(pi.pointing["dets:readout_id"] == readout_ids[i].encode())[0] + if not rc: + continue + rc = rc[0] + _xis[j] = pi.xi[rc] + _etas[j] = pi.eta[rc] + _r2s[j] = pi.pointing["R2"][rc] + xis[i] = np.nanmean(_xis[_r2s > cfg.min_r2]) + etas[i] = np.nanmean(_etas[_r2s > cfg.min_r2]) + + tod_pointing = AxisManager(dets) + tod_pointing.wrap("xi", xis, [(0, "dets")]) + tod_pointing.wrap("eta", etas, [(0, "dets")]) + + return tod_pointing + + +@dataclass +class MatchSolution: + match: dm.Match + am: AxisManager + match_iterations: List[dm.Match] = field(default_factory=list) + + +def get_pt_offset_interp(match, sel_rad=np.deg2rad(2)) -> Tuple[Any, Any]: + _xis, _etas, _dxis, _detas = [], [], [], [] + for r1, r2 in match.get_match_iter(include_unmatched=False): + _xis.append(r1.xi) + _etas.append(r1.eta) + _dxis.append(r1.xi - r2.xi) + _detas.append(r1.eta - r2.eta) + xis = np.array(_xis) + etas = np.array(_etas) + dxis = np.array(_dxis) + detas = np.array(_detas) + + xi_list = np.arange(np.nanmin(xis), np.nanmax(xis), sel_rad / 2) + eta_list = np.arange(np.nanmin(etas), np.nanmax(etas), sel_rad / 2) + xi_grid, eta_grid = np.meshgrid(xi_list, eta_list) + dxi_data = np.full_like(xi_grid, np.nan) + deta_data = np.full_like(eta_grid, np.nan) + for i, j in np.ndindex(xi_grid.shape): + sel = ( + np.sqrt((xis - xi_grid[i, j]) ** 2 + (etas - eta_grid[i, j]) ** 2) < sel_rad + ) + sel &= np.isfinite(dxis) & np.isfinite(detas) + dxi_data[i, j] = np.nanmedian(dxis[sel]) + deta_data[i, j] = np.nanmedian(detas[sel]) + + dxi_interp = interpolate.RegularGridInterpolator( + (xi_list, eta_list), dxi_data.T, bounds_error=False, fill_value=None + ) + deta_interp = interpolate.RegularGridInterpolator( + (xi_list, eta_list), deta_data.T, bounds_error=False, fill_value=None + ) + return dxi_interp, deta_interp + + +def get_foffset_interp( + match, is_north, box_size=50, box_step=25 +) -> Callable[[float], float]: + df, f, is_norths = [], [], [] + for r1, r2 in match.get_match_iter(include_unmatched=False): + df.append(r1.res_freq - r2.res_freq) + f.append(r1.res_freq) + is_norths.append(r1.is_north) + df_arr = np.array(df) + f_arr = np.array(f) + is_north_arr = np.array(is_norths, dtype=bool) + + f0, f1 = np.min(f), np.max(f) + df_meds = [] + fcs = [] + for fc in np.arange(f0, f1, box_step): + sel = (f > fc - box_size / 2) & (f < fc + box_size / 2) + sel &= is_north_arr == is_north + df_meds.append(np.nanmedian(df_arr[sel])) + fcs.append(fc) + + # Create interpolation + f_func = interpolate.interp1d(fcs, df_meds, fill_value="extrapolate") + return f_func + + +@dataclass +class MatchSolutionResult: + results: Dict[str, Optional[MatchSolution]] + am: Optional[AxisManager] = None + traceback: Optional[str] = None + + +def match_wafer( + cfg: SolutionsCfg, + am: AxisManager, + stream_id: str, + meas_rset: Optional[dm.ResSet] +) -> MatchSolution: + """ + Create a match solution for a given wafer slot. + + Args + ------ + cfg: SolutionsCfg + Configuration object + am: AxisManager + Axis manager containing detector info about relevant wafer slot, along + with measured pointing data. + stream_id: str + Stream Id of the wafer + """ + match_iterations = [] + + m = am.det_info.stream_id == stream_id + wafer_slot = am.det_info.wafer_slot[m][0] + + if meas_rset is None: + src = dm.ResSet.from_aman(am, stream_id, pointing=am[cfg.pointing_field]) + else: + src = meas_rset + + pt_cfg = dm.PointingConfig( + fp_file=cfg.ufm_to_fp_path, wafer_slot=wafer_slot, tel_type=cfg.tel_type, + zemax_path=cfg.zemax_path, + roll=np.deg2rad(am.obs_info.roll_center) if cfg.apply_roll else 0.0, + tube_slot = am.obs_info.tube_slot + ) + dst = dm.ResSet.from_wafer_info_file(cfg.wafer_info_path, stream_id, pt_cfg=pt_cfg) + + # first match + match_pars = dm.MatchParams( + freq_width=cfg.match_pars["match0"]["freq_width"], + dist_width=np.deg2rad(cfg.match_pars["match0"]["dist_width"]), + enforce_pointing_reqs=True, + allow_unassigned_to_assigned=False, + unassigned_slots=cfg.unassigned_slots + ) + match = dm.Match(src, dst, match_pars=match_pars, apply_dst_pointing=False) + + match_iterations.append(deepcopy(match)) + dxis, detas = [], [] + dfs = [] + is_north = [] + for r1, r2 in match.get_match_iter(include_unmatched=False): + dxis.append(r2.xi - r1.xi) + detas.append(r2.eta - r1.eta) + dfs.append(r2.res_freq - r1.res_freq) + is_north.append(r2.is_north) + dxi = np.nanmedian(dxis) + deta = np.nanmedian(detas) + + for r in match.src: + r.xi += dxi + r.eta += deta + + match._match() + match_iterations.append(deepcopy(match)) + + if cfg.freq_correct_by_muxband: + da = match.dst.as_array() + df_arr = np.full(len(da), np.nan) + for i, r in enumerate(match.dst): + if r.matched: + df_arr[i] = r.res_freq - match.src[r.match_idx].res_freq + + for is_north in [0, 1]: + for mb in np.unique(da["mux_band"]): + mask = (da["mux_band"] == mb) & (da["is_north"] == is_north) + df_med = np.nanmedian(df_arr[mask]) + for res_idx in np.where(mask)[0]: + match.dst[res_idx].res_freq -= df_med + else: + # Correct freq offset by box median interpolation + foffset_north = get_foffset_interp(match, True) + foffset_south = get_foffset_interp(match, False) + for r in match.dst: + if r.is_north: + r.res_freq += foffset_north(r.res_freq) + else: + r.res_freq += foffset_south(r.res_freq) + + # Second match + match.match_pars = dm.MatchParams( + freq_width=cfg.match_pars["match1"]["freq_width"], + dist_width=np.deg2rad(cfg.match_pars["match1"]["dist_width"]), + enforce_pointing_reqs=True, + allow_unassigned_to_assigned=False, + unassigned_slots=cfg.unassigned_slots + ) + match._match() + + dxi_interp, deta_interp = get_pt_offset_interp(match, sel_rad=np.deg2rad(cfg.sel_rad)) + for r in match.src: + if np.isnan(r.xi): + continue + r.xi -= dxi_interp((r.xi, r.eta)).item() + r.eta -= deta_interp((r.xi, r.eta)).item() + + # Third match + match.match_pars.freq_width = cfg.match_pars["match2"]["freq_width"] + match.match_pars.dist_width = np.deg2rad(cfg.match_pars["match2"]["dist_width"]) + + match._match() + + match_iterations.append(deepcopy(match)) + + return MatchSolution( + match=match, + match_iterations=match_iterations, + am=am, + ) + + +@dataclass +class FullWaferSolution: + match_solution: MatchSolution + pointing_results: List[PointingInfo] + meta: AxisManager + stream_id: str + + +def create_empty_match(cfg, am, wafer_slot, save=False): + m = am.det_info.wafer_slot == wafer_slot + stream_id = am.det_info.stream_id[m][0] + src = dm.ResSet.from_aman(am, stream_id) + + if save: + resset_file = os.path.join(cfg.results_dir, f"{stream_id}.npy") + np.save(resset_file, src.as_array()) + + return src + +def save_wafer_solution(cfg: SolutionsCfg, solution: FullWaferSolution): + solution.stream_id + resset_file = os.path.join(cfg.results_dir, f"{solution.stream_id}.npy") + match_file = os.path.join(cfg.results_dir, "matches", f"{solution.stream_id}.h5") + if not os.path.exists(os.path.dirname(match_file)): + os.makedirs(os.path.dirname(match_file)) + + match = solution.match_solution.match + np.save(resset_file, match.merged.as_array()) + match.save(match_file) + + +def get_wafer_solution( + cfg: SolutionsCfg, wafer_slot: str, save=False +) -> Optional[FullWaferSolution]: + pointing_results = load_good_pointing_info(cfg, wafer_slot) + if len(pointing_results) == 0: + return None + + if cfg.base_obs_id is not None: + for i, pi in enumerate(pointing_results): + if pi.obs_id == cfg.base_obs_id: + base_idx = i + break + else: + raise ValueError(f"Pointing info for base obs_id not found: {cfg.base_obs_id}") + else: + base_idx = 0 + + meas_rset, pointing_map = merge_pointing_info(cfg, pointing_results, base_idx=base_idx) + # tod_pointing = get_best_tod_pointing(cfg, pointing_results) + + meta = pointing_results[0].meta + stream_id = meta.det_info.stream_id[0] + + match_solution = match_wafer(cfg, meta, stream_id, meas_rset=meas_rset) + + solution = FullWaferSolution( + pointing_results=pointing_results, + match_solution=match_solution, + meta=meta, + stream_id=stream_id, + ) + + if save: + save_wafer_solution(cfg, solution) + + return solution + + +def solve_all(cfg) -> Dict[str, Optional[FullWaferSolution]]: + wafer_slots = ["ws0", "ws1", "ws2", "ws3", "ws4", "ws5", "ws6"] + results = {ws: get_wafer_solution(cfg, ws, save=True) for ws in wafer_slots} + return results + + +if __name__ == "__main__": + cfg_file = sys.argv[1] + with open(cfg_file, "r") as f: + cfg = SolutionsCfg(**yaml.safe_load(f)) + solve_all(cfg) diff --git a/sotodlib/site_pipeline/cli.py b/sotodlib/site_pipeline/cli.py index 399f64c37..b6257f985 100644 --- a/sotodlib/site_pipeline/cli.py +++ b/sotodlib/site_pipeline/cli.py @@ -59,6 +59,7 @@ def main(obs_id=None, config_file=None, logger=None): update_obsdb_ancil, make_coadd_atomic_map, make_cosamp_hk, + update_det_match, ) # Dictionary matching element name to a submodule (which must have @@ -81,7 +82,8 @@ def main(obs_id=None, config_file=None, logger=None): 'update-obsdb': update_obsdb, 'update-obsdb-ancil': update_obsdb_ancil, 'make-cosamp-hk': make_cosamp_hk, - 'make-coadd-atomic-map': make_coadd_atomic_map + 'make-coadd-atomic-map': make_coadd_atomic_map, + 'update-det-match': update_det_match, } CLI_NAME = 'so-site-pipeline' diff --git a/sotodlib/site_pipeline/update_det_match.py b/sotodlib/site_pipeline/update_det_match.py index f24fbf000..3cf74ca2c 100644 --- a/sotodlib/site_pipeline/update_det_match.py +++ b/sotodlib/site_pipeline/update_det_match.py @@ -11,6 +11,7 @@ import time from sotodlib.coords import det_match, optics +from sotodlib.coords.det_match_solutions import SolutionsCfg, solve_all from sotodlib import core from sotodlib.core.metadata import ManifestDb from sotodlib.io.metadata import write_dataset @@ -25,6 +26,14 @@ logger.setLevel(logging.INFO) +def get_detset_time(detset: str) -> float: + """ + Gets timestamp associated with a detset. Will parse this from the detset + name, assuming it is of the form _