Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 129 additions & 34 deletions sotodlib/tod_ops/azss.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,19 @@ def _check_azcoverage(aman, flags, az=None, coverage_threshold=0.95,
return coverages < coverage_threshold, coverages


def _valid_dets_mask(azss_stats, min_valid_bins=1):
"""
Boolean mask of detectors usable for fitting / interpolation.

A detector is considered valid if it has at least ``min_valid_bins``
non-NaN bins and is not flagged in ``azss_stats['bad_dets']``, if present.
"""
valid = (~np.isnan(azss_stats.binned_signal)).sum(axis=1) >= min_valid_bins
if 'bad_dets' in azss_stats:
valid &= ~azss_stats['bad_dets']
return valid


def bin_by_az(aman, signal=None, az=None, azrange=None, bins=100, flags=None,
apodize_edges=True, apodize_edges_samps=1600,
apodize_flags=True, apodize_flags_samps=200):
Expand Down Expand Up @@ -120,6 +133,45 @@ def bin_by_az(aman, signal=None, az=None, azrange=None, bins=100, flags=None,
range=azrange, bins=bins, flags=flags, weight_for_signal=weight_for_signal)
return binning_dict


def _legendre_x(az, binned_az, m, fit_range, bin_width):
"""
Build Legendre x-coordinates (rescaled to [-1, 1]) for both the sample
azimuth array and the bin centers, given a 1D mask m of valid bins.

If fit_range is None the range is inferred from the bin centers
selected by m (extended by half a bin on each side).
"""
if fit_range is None:
az_min = binned_az[m].min() - bin_width / 2
az_max = binned_az[m].max() + bin_width / 2
else:
az_min, az_max = fit_range[0], fit_range[1]
span = az_max - az_min
mid = az_min + az_max
x_samp = (2 * az - mid) / span
x_bin = np.where(m, (2 * binned_az - mid) / span, np.nan)
return x_samp, x_bin


def _fit_one_det(binned_signal_i, x_bin, m, sigma_i, max_mode):
"""
Fit a Legendre polynomial to one detector's binned signal.

Returns
-------
coeffs : ndarray, shape (max_mode + 1,)
binned_model : ndarray, shape (nbins,) (NaN at masked bins)
sum_of_squares : float
redchi2 : float
"""
coeffs = L.legfit(x_bin[m], binned_signal_i[m], max_mode)
binned_model = np.where(m, L.legval(x_bin, coeffs), np.nan)
ssq = np.sum((binned_signal_i[m] - binned_model[m])**2)
redchi2 = ssq / sigma_i**2 / (m.sum() - max_mode - 1)
return coeffs, binned_model, ssq, redchi2


def fit_azss(az, azss_stats, max_mode, fit_range=None, overwrite=False):
"""
Function for fitting Legendre polynomials to signal binned in azimuth.
Expand Down Expand Up @@ -148,28 +200,71 @@ def fit_azss(az, azss_stats, max_mode, fit_range=None, overwrite=False):
Model fit for each detector size ndets x n_samps
"""
bin_width = azss_stats.binned_az[1] - azss_stats.binned_az[0]
m = ~np.isnan(azss_stats.binned_signal[0]) # masks bins without counts
if np.count_nonzero(m) < max_mode + 1:
raise ValueError('Number of valid bins is smaller than mode of Legendre function')

if fit_range is None:
az_min = np.min(azss_stats.binned_az[m]) - bin_width / 2
az_max = np.max(azss_stats.binned_az[m]) + bin_width / 2
ndets = azss_stats.dets.count
nbins = azss_stats.bin_az_samps.count

m_2d = ~np.isnan(azss_stats.binned_signal)

# Number of valid bins cannot be smaller than mode of Legendre function.
valid_dets = _valid_dets_mask(azss_stats, min_valid_bins=max_mode + 1)

model = np.zeros((ndets, len(az)))
coeffs = np.zeros((ndets, max_mode + 1))
binned_model = np.full((ndets, nbins), np.nan)
sum_of_squares = np.full(ndets, np.nan)
redchi2s = np.full(ndets, np.nan)

if not valid_dets.any():
logger.warning('All the detectors have low az coverage and cannot make model')
return model

use_cached = ('coeffs' in azss_stats) and not overwrite

is_uniform = np.array_equal(m_2d[valid_dets].any(axis=0),
m_2d[valid_dets].all(axis=0))

if is_uniform:
m = m_2d[valid_dets][0]
x_samp, x_legendre_bin_centers = _legendre_x(
az, azss_stats.binned_az, m, fit_range, bin_width)

if use_cached:
return L.legval(x_samp, azss_stats.coeffs.T)

# Vectorized Legendre fit over all valid detectors at once.
c = L.legfit(x_legendre_bin_centers[m],
azss_stats.binned_signal[valid_dets][:, m].T,
max_mode).T # shape: (n_valid, max_mode + 1)
coeffs[valid_dets] = c
bm = L.legval(x_legendre_bin_centers, c.T)
binned_model[valid_dets] = np.where(m, bm, np.nan)
model[valid_dets] = L.legval(x_samp, c.T)

diff = azss_stats.binned_signal[valid_dets][:, m] - bm[:, m]
sum_of_squares[valid_dets] = np.sum(diff**2, axis=-1)
redchi2s[valid_dets] = (
sum_of_squares[valid_dets]
/ azss_stats.uniform_binned_signal_sigma[valid_dets]**2
/ (m.sum() - max_mode - 1)
)
else:
az_min, az_max = fit_range[0], fit_range[1]

x_legendre = (2 * az - (az_min+az_max)) / (az_max - az_min)
x_legendre_bin_centers = (2 * azss_stats.binned_az - (az_min+az_max)) / (az_max - az_min)
x_legendre_bin_centers = np.where(~m, np.nan, x_legendre_bin_centers)
if ('coeffs' in azss_stats) and not overwrite:
return L.legval(x_legendre, azss_stats.coeffs.T)

coeffs = L.legfit(x_legendre_bin_centers[m], azss_stats.binned_signal[:, m].T, max_mode)
coeffs = coeffs.T
binned_model = L.legval(x_legendre_bin_centers, coeffs.T)
binned_model = np.where(~m, np.nan, binned_model)
sum_of_squares = np.sum(((azss_stats.binned_signal[:, m] - binned_model[:, m])**2), axis=-1)
redchi2s = sum_of_squares/azss_stats.uniform_binned_signal_sigma**2 / (len(x_legendre_bin_centers[m]) - max_mode - 1)
for i in np.where(valid_dets)[0]:
m_i = m_2d[i]
x_samp, x_legendre_bin_centers = _legendre_x(
az, azss_stats.binned_az, m_i, fit_range, bin_width)
if use_cached:
model[i] = L.legval(x_samp, azss_stats.coeffs[i])
continue
(coeffs[i],
binned_model[i],
sum_of_squares[i],
redchi2s[i]) = _fit_one_det(
azss_stats.binned_signal[i], x_legendre_bin_centers, m_i,
azss_stats.uniform_binned_signal_sigma[i], max_mode)
model[i] = L.legval(x_samp, coeffs[i])

if use_cached:
return model

mode_names = np.array([f'legendre{mode}' for mode in range(max_mode + 1)], dtype='<U10')
modes_axis = core.LabelAxis(name='azss_modes', vals=mode_names)
Expand All @@ -179,7 +274,7 @@ def fit_azss(az, azss_stats, max_mode, fit_range=None, overwrite=False):
azss_stats.wrap('coeffs', coeffs, [(0, 'dets'), (1, 'azss_modes')])
azss_stats.wrap('redchi2s', redchi2s, [(0, 'dets')])

return L.legval(x_legendre, coeffs.T)
return model


def get_azss(aman, signal='signal', az=None, azrange=None, bins=100, flags=None, scan_flags=None,
Expand Down Expand Up @@ -365,24 +460,24 @@ def get_azss_model(aman, azss_stats, az=None, method='interpolate',

if method == 'interpolate':
model = np.zeros((aman.dets.count, aman.samps.count))
if 'bad_dets' in azss_stats:
valid_dets = ~azss_stats['bad_dets']
else:
valid_dets = np.ones(aman.dets.count, dtype=bool)
if sum(valid_dets) == 0:
logger.info('All the detectors have low az coverage and cannot make model')

valid_dets = _valid_dets_mask(azss_stats, min_valid_bins=1)
if not valid_dets.any():
logger.warning('All the detectors have low az coverage and cannot make model')
return model

mask = ~np.isnan(azss_stats.binned_signal[valid_dets, :])
is_uniform = np.all(mask == mask[0, :])
m_2d = ~np.isnan(azss_stats.binned_signal)
is_uniform = np.array_equal(m_2d[valid_dets].any(axis=0),
m_2d[valid_dets].all(axis=0))
if is_uniform:
m = mask[0, :]
m = m_2d[valid_dets][0]
f_template = interp1d(azss_stats.binned_az[m], azss_stats.binned_signal[:, m][valid_dets, :], fill_value='extrapolate')
model[valid_dets, :] = f_template(az)
else:
for i, (m, binned_signal) in enumerate(zip(mask, azss_stats.binned_signal[valid_dets, :])):
f_template = interp1d(azss_stats.binned_az[m], binned_signal[m], fill_value='extrapolate')
model[valid_dets, :][i] = f_template(az)
for i in np.where(valid_dets)[0]:
m = m_2d[i]
f_template = interp1d(azss_stats.binned_az[m], azss_stats.binned_signal[i][m], fill_value='extrapolate')
model[i, :] = f_template(az)

if np.any(~np.isfinite(model)):
logger.warning('azss model has nan. set zero to nan but this may make glitch')
Expand Down
45 changes: 40 additions & 5 deletions tests/test_azss.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_scan(n_scans=33, scan_accel=0.025, scanrate=0.025,


def make_fake_azss_tod(max_mode=20, noise_amp=1, n_scans=10,
ndets=2, input_coeffs=None):
ndets=3, input_coeffs=None):
"""
Makes an axis manager with azimuth synchronous signal
in it, populated via legendre polynomials plus gaussian noise.
Expand Down Expand Up @@ -94,16 +94,18 @@ def get_coeff_metric(tod):
"""
Evaluates fit is working by comparing coefficients in to out.
"""
print(tod.input_coeffs[0])
print(tod.azss_stats.coeffs[0])
print(tod.input_coeffs)
print(tod.azss_stats.coeffs)
outmetric_num = (tod.azss_stats.coeffs - tod.input_coeffs)**2
outmetric_denom = (tod.input_coeffs)**2
return np.median(100*(outmetric_num/outmetric_denom))


class AzssTest(unittest.TestCase):
"Test the Azimuth Synchronous Signal fitting functions"
def test_fit(self):
"""
Test the Azimuth Synchronous Signal fitting functions
"""
max_mode = 10
tod = make_fake_azss_tod(noise_amp=0, n_scans=50, max_mode=max_mode)
azss_stats, model_sig_tod = azss.get_azss(
Expand All @@ -119,6 +121,37 @@ def test_fit(self):
self.assertTrue(~np.any(np.isnan(tod.signal)))
self.assertTrue(np.std(tod.signal) > np.std(tod.signal - model_sig_tod))

model = azss.fit_azss(tod.boresight.az, azss_stats, max_mode)
self.assertTrue(np.all(model_sig_tod == model))

def test_fit_with_flags(self):
"""
Test the Azimuth Synchronous Signal fitting functions with flags
"""
max_mode = 10
tod = make_fake_azss_tod(noise_amp=0, n_scans=50, max_mode=max_mode)

mask = np.zeros((tod.dets.count, tod.samps.count), dtype=bool)
# one detector has partial az coverage, one detector has zero az coverage
mask[0, tod.boresight.az > np.percentile(tod.boresight.az, 95)] = True
mask[-1, :] = True
flags = RangesMatrix.from_mask(mask)

azss_stats, model_sig_tod = azss.get_azss(
tod,
method='fit',
max_mode=max_mode,
azrange=None,
bins=100,
flags=flags,
)
self.assertTrue(~np.any(np.isnan(tod.signal)))
self.assertTrue(np.std(tod.signal[~mask]) > np.std(tod.signal[~mask] - model_sig_tod[~mask]))

# check consistency of model made with cached legendre coeffs
model = azss.fit_azss(tod.boresight.az, azss_stats, max_mode)
self.assertTrue(np.all(model_sig_tod == model))

def test_interpolate(self):
"""
Test the interpolation method of Azimuth Synchronous Signal subtraction.
Expand All @@ -144,7 +177,9 @@ def test_interpolate_with_flags(self):
tod = make_fake_azss_tod(noise_amp=0, n_scans=50, max_mode=max_mode)

mask = np.zeros((tod.dets.count, tod.samps.count), dtype=bool)
mask[-1, :] = True # one detector has low az coverage
# one detector has partial az coverage, one detector has zero az coverage
mask[0, tod.boresight.az > np.percentile(tod.boresight.az, 95)] = True
mask[-1, :] = True
flags = RangesMatrix.from_mask(mask)

azss_stats, model_sig_tod = azss.get_azss(
Expand Down