diff --git a/sotodlib/tod_ops/azss.py b/sotodlib/tod_ops/azss.py index 5cb94267f..18481d4b4 100644 --- a/sotodlib/tod_ops/azss.py +++ b/sotodlib/tod_ops/azss.py @@ -43,6 +43,19 @@ def _check_azcoverage(aman, flags, az=None, coverage_threshold=0.95, return coverages < coverage_threshold, coverages +def _valid_dets_mask(azss_stats, min_valid_bins=1): + """ + Boolean mask of detectors usable for fitting / interpolation. + + A detector is considered valid if it has at least ``min_valid_bins`` + non-NaN bins and is not flagged in ``azss_stats['bad_dets']``, if present. + """ + valid = (~np.isnan(azss_stats.binned_signal)).sum(axis=1) >= min_valid_bins + if 'bad_dets' in azss_stats: + valid &= ~azss_stats['bad_dets'] + return valid + + def bin_by_az(aman, signal=None, az=None, azrange=None, bins=100, flags=None, apodize_edges=True, apodize_edges_samps=1600, apodize_flags=True, apodize_flags_samps=200): @@ -120,6 +133,45 @@ def bin_by_az(aman, signal=None, az=None, azrange=None, bins=100, flags=None, range=azrange, bins=bins, flags=flags, weight_for_signal=weight_for_signal) return binning_dict + +def _legendre_x(az, binned_az, m, fit_range, bin_width): + """ + Build Legendre x-coordinates (rescaled to [-1, 1]) for both the sample + azimuth array and the bin centers, given a 1D mask m of valid bins. + + If fit_range is None the range is inferred from the bin centers + selected by m (extended by half a bin on each side). + """ + if fit_range is None: + az_min = binned_az[m].min() - bin_width / 2 + az_max = binned_az[m].max() + bin_width / 2 + else: + az_min, az_max = fit_range[0], fit_range[1] + span = az_max - az_min + mid = az_min + az_max + x_samp = (2 * az - mid) / span + x_bin = np.where(m, (2 * binned_az - mid) / span, np.nan) + return x_samp, x_bin + + +def _fit_one_det(binned_signal_i, x_bin, m, sigma_i, max_mode): + """ + Fit a Legendre polynomial to one detector's binned signal. + + Returns + ------- + coeffs : ndarray, shape (max_mode + 1,) + binned_model : ndarray, shape (nbins,) (NaN at masked bins) + sum_of_squares : float + redchi2 : float + """ + coeffs = L.legfit(x_bin[m], binned_signal_i[m], max_mode) + binned_model = np.where(m, L.legval(x_bin, coeffs), np.nan) + ssq = np.sum((binned_signal_i[m] - binned_model[m])**2) + redchi2 = ssq / sigma_i**2 / (m.sum() - max_mode - 1) + return coeffs, binned_model, ssq, redchi2 + + def fit_azss(az, azss_stats, max_mode, fit_range=None, overwrite=False): """ Function for fitting Legendre polynomials to signal binned in azimuth. @@ -148,28 +200,71 @@ def fit_azss(az, azss_stats, max_mode, fit_range=None, overwrite=False): Model fit for each detector size ndets x n_samps """ bin_width = azss_stats.binned_az[1] - azss_stats.binned_az[0] - m = ~np.isnan(azss_stats.binned_signal[0]) # masks bins without counts - if np.count_nonzero(m) < max_mode + 1: - raise ValueError('Number of valid bins is smaller than mode of Legendre function') - - if fit_range is None: - az_min = np.min(azss_stats.binned_az[m]) - bin_width / 2 - az_max = np.max(azss_stats.binned_az[m]) + bin_width / 2 + ndets = azss_stats.dets.count + nbins = azss_stats.bin_az_samps.count + + m_2d = ~np.isnan(azss_stats.binned_signal) + + # Number of valid bins cannot be smaller than mode of Legendre function. + valid_dets = _valid_dets_mask(azss_stats, min_valid_bins=max_mode + 1) + + model = np.zeros((ndets, len(az))) + coeffs = np.zeros((ndets, max_mode + 1)) + binned_model = np.full((ndets, nbins), np.nan) + sum_of_squares = np.full(ndets, np.nan) + redchi2s = np.full(ndets, np.nan) + + if not valid_dets.any(): + logger.warning('All the detectors have low az coverage and cannot make model') + return model + + use_cached = ('coeffs' in azss_stats) and not overwrite + + is_uniform = np.array_equal(m_2d[valid_dets].any(axis=0), + m_2d[valid_dets].all(axis=0)) + + if is_uniform: + m = m_2d[valid_dets][0] + x_samp, x_legendre_bin_centers = _legendre_x( + az, azss_stats.binned_az, m, fit_range, bin_width) + + if use_cached: + return L.legval(x_samp, azss_stats.coeffs.T) + + # Vectorized Legendre fit over all valid detectors at once. + c = L.legfit(x_legendre_bin_centers[m], + azss_stats.binned_signal[valid_dets][:, m].T, + max_mode).T # shape: (n_valid, max_mode + 1) + coeffs[valid_dets] = c + bm = L.legval(x_legendre_bin_centers, c.T) + binned_model[valid_dets] = np.where(m, bm, np.nan) + model[valid_dets] = L.legval(x_samp, c.T) + + diff = azss_stats.binned_signal[valid_dets][:, m] - bm[:, m] + sum_of_squares[valid_dets] = np.sum(diff**2, axis=-1) + redchi2s[valid_dets] = ( + sum_of_squares[valid_dets] + / azss_stats.uniform_binned_signal_sigma[valid_dets]**2 + / (m.sum() - max_mode - 1) + ) else: - az_min, az_max = fit_range[0], fit_range[1] - - x_legendre = (2 * az - (az_min+az_max)) / (az_max - az_min) - x_legendre_bin_centers = (2 * azss_stats.binned_az - (az_min+az_max)) / (az_max - az_min) - x_legendre_bin_centers = np.where(~m, np.nan, x_legendre_bin_centers) - if ('coeffs' in azss_stats) and not overwrite: - return L.legval(x_legendre, azss_stats.coeffs.T) - - coeffs = L.legfit(x_legendre_bin_centers[m], azss_stats.binned_signal[:, m].T, max_mode) - coeffs = coeffs.T - binned_model = L.legval(x_legendre_bin_centers, coeffs.T) - binned_model = np.where(~m, np.nan, binned_model) - sum_of_squares = np.sum(((azss_stats.binned_signal[:, m] - binned_model[:, m])**2), axis=-1) - redchi2s = sum_of_squares/azss_stats.uniform_binned_signal_sigma**2 / (len(x_legendre_bin_centers[m]) - max_mode - 1) + for i in np.where(valid_dets)[0]: + m_i = m_2d[i] + x_samp, x_legendre_bin_centers = _legendre_x( + az, azss_stats.binned_az, m_i, fit_range, bin_width) + if use_cached: + model[i] = L.legval(x_samp, azss_stats.coeffs[i]) + continue + (coeffs[i], + binned_model[i], + sum_of_squares[i], + redchi2s[i]) = _fit_one_det( + azss_stats.binned_signal[i], x_legendre_bin_centers, m_i, + azss_stats.uniform_binned_signal_sigma[i], max_mode) + model[i] = L.legval(x_samp, coeffs[i]) + + if use_cached: + return model mode_names = np.array([f'legendre{mode}' for mode in range(max_mode + 1)], dtype=' np.std(tod.signal - model_sig_tod)) + model = azss.fit_azss(tod.boresight.az, azss_stats, max_mode) + self.assertTrue(np.all(model_sig_tod == model)) + + def test_fit_with_flags(self): + """ + Test the Azimuth Synchronous Signal fitting functions with flags + """ + max_mode = 10 + tod = make_fake_azss_tod(noise_amp=0, n_scans=50, max_mode=max_mode) + + mask = np.zeros((tod.dets.count, tod.samps.count), dtype=bool) + # one detector has partial az coverage, one detector has zero az coverage + mask[0, tod.boresight.az > np.percentile(tod.boresight.az, 95)] = True + mask[-1, :] = True + flags = RangesMatrix.from_mask(mask) + + azss_stats, model_sig_tod = azss.get_azss( + tod, + method='fit', + max_mode=max_mode, + azrange=None, + bins=100, + flags=flags, + ) + self.assertTrue(~np.any(np.isnan(tod.signal))) + self.assertTrue(np.std(tod.signal[~mask]) > np.std(tod.signal[~mask] - model_sig_tod[~mask])) + + # check consistency of model made with cached legendre coeffs + model = azss.fit_azss(tod.boresight.az, azss_stats, max_mode) + self.assertTrue(np.all(model_sig_tod == model)) + def test_interpolate(self): """ Test the interpolation method of Azimuth Synchronous Signal subtraction. @@ -144,7 +177,9 @@ def test_interpolate_with_flags(self): tod = make_fake_azss_tod(noise_amp=0, n_scans=50, max_mode=max_mode) mask = np.zeros((tod.dets.count, tod.samps.count), dtype=bool) - mask[-1, :] = True # one detector has low az coverage + # one detector has partial az coverage, one detector has zero az coverage + mask[0, tod.boresight.az > np.percentile(tod.boresight.az, 95)] = True + mask[-1, :] = True flags = RangesMatrix.from_mask(mask) azss_stats, model_sig_tod = azss.get_azss(