Skip to content

Commit ae3d191

Browse files
authored
Save memory in PointingProvider.get_pointings (#488)
* Save memory in PointingProvider.get_pointings PointingProvider.get_pointings used to run the slerp operation over the whole timeline, thus allocating a significant amount of memory. Now the class' constructor accepts a new argument that limits the maximum amount of memory for the temporary buffer used to store the fully-sampled quaternions (default is 256 MB). This should slow a bit the generation of pointings, but it has the advantage to limit the amount of memory needed for a simulation run. * [skipci] Update CHANGELOG
1 parent f327427 commit ae3d191

4 files changed

Lines changed: 138 additions & 21 deletions

File tree

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# HEAD
22

3+
- Save memory in pointing generation [#488](https://github.com/litebird/litebird_sim/pull/488)
4+
35
-   **Breaking change**: Major reworking of the interfaces and handling of inputs across the framework [#479](https://github.com/litebird/litebird_sim/pull/479), in detail:
46

57
1. Rework the handling of spherical harmonics by integrating ducc0 as the primary engine for SHT operations, including interpolation.

litebird_sim/observations.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from .input_sky import SkyGenerationParams
1515
from .maps_and_harmonics import HealpixMap, SphericalHarmonics
1616
from .mpi import MPI_COMM_GRID, _SerialMpiCommunicator
17-
from .pointings import PointingProvider
17+
from .pointings import PointingProvider, DEFAULT_INTERNAL_BUFFER_SIZE_FOR_POINTINGS_MB
1818
from .scanning import RotQuaternion
1919
from .units import Units
2020

@@ -910,6 +910,7 @@ def prepare_pointings(
910910
instrument: InstrumentInfo,
911911
spin2ecliptic_quats: RotQuaternion,
912912
hwp: HWP | None = None,
913+
maximum_internal_buffer_mem_mb: float = DEFAULT_INTERNAL_BUFFER_SIZE_FOR_POINTINGS_MB,
913914
) -> None:
914915
"""Prepare quaternion-based pointing and HWP information for this observation.
915916
@@ -936,6 +937,10 @@ def prepare_pointings(
936937
Optional HWP model. If provided, it is stored and its Mueller matrix
937938
applied to all detectors lacking one.
938939
940+
maximum_internal_buffer_mem_mb (float):
941+
Maximum number of megabytes (MB) to allocate for internal buffers during
942+
the computation of pointings. Set to -1 to remove any limit.
943+
939944
Raises:
940945
AssertionError:
941946
If `hwp` is not provided and one or more detectors do not have a
@@ -947,10 +952,19 @@ def prepare_pointings(
947952
internal :class:`.PointingProvider`.
948953
"""
949954

955+
assert (maximum_internal_buffer_mem_mb > 0) or (
956+
maximum_internal_buffer_mem_mb == -1
957+
), (
958+
"Invalid value for maximum_internal_buffer_mem_mb ({val}), it must either be -1 or a positive number".format(
959+
val=maximum_internal_buffer_mem_mb
960+
)
961+
)
962+
950963
bore2ecliptic_quats = spin2ecliptic_quats * instrument.bore2spin_quat
951964
pointing_provider = PointingProvider(
952965
bore2ecliptic_quats=bore2ecliptic_quats,
953966
hwp=hwp,
967+
maximum_internal_buffer_mem_mb=maximum_internal_buffer_mem_mb,
954968
)
955969

956970
self.pointing_provider = pointing_provider

litebird_sim/pointings.py

Lines changed: 70 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
RotQuaternion,
99
)
1010

11+
DEFAULT_INTERNAL_BUFFER_SIZE_FOR_POINTINGS_MB = 256.0
12+
1113

1214
class PointingProvider:
1315
"""Provides detector pointing angles and HWP angles based on scanning geometry.
@@ -51,10 +53,42 @@ def __init__(
5153
# Note that we require here *boresight*→Ecliptic instead of *spin*→Ecliptic
5254
bore2ecliptic_quats: RotQuaternion,
5355
hwp: HWP | None = None,
56+
maximum_internal_buffer_mem_mb: float = DEFAULT_INTERNAL_BUFFER_SIZE_FOR_POINTINGS_MB,
5457
):
5558
self.bore2ecliptic_quats = bore2ecliptic_quats
59+
self.maximum_internal_buffer_mem_mb = maximum_internal_buffer_mem_mb
5660
self.hwp = hwp
5761

62+
def _optimal_block_lengths(self, total_nsamples: int) -> list[int]:
63+
# Size of one quaternion, in bytes
64+
quaternion_size_bytes = 4 * self.bore2ecliptic_quats.quats.itemsize
65+
66+
# Average number of quaternions in each block, to make sure that no more than
67+
# a fixed number of MB is ever needed to store them
68+
quaternions_per_block = int(
69+
self.maximum_internal_buffer_mem_mb * 1024 * 1024 / quaternion_size_bytes
70+
)
71+
72+
# How many blocks of quaternions will need to be processed
73+
number_of_blocks = total_nsamples // quaternions_per_block
74+
75+
# If something was left out of the previous calculation, include an additional block
76+
if total_nsamples % quaternions_per_block != 0:
77+
number_of_blocks += 1
78+
79+
# Instead of making the first N−1 blocks of the same size and add any leftover to the
80+
# last block, which might even have 1 sample, try to create the blocks so that they all
81+
# have roughly the same number of elements
82+
result = []
83+
quaternions_left = total_nsamples
84+
while quaternions_left > 0:
85+
current_block_length = quaternions_left // (number_of_blocks - len(result))
86+
result.append(current_block_length)
87+
quaternions_left -= current_block_length
88+
89+
assert sum(result) == total_nsamples
90+
return result
91+
5892
def has_hwp(self):
5993
"""Return ``True`` if a HWP has been set.
6094
@@ -122,27 +156,21 @@ def get_pointings(
122156
one is a float and the other is an `astropy.time.Time`).
123157
"""
124158

125-
full_quaternions = (self.bore2ecliptic_quats * detector_quat).slerp(
126-
start_time=start_time,
127-
sampling_rate_hz=sampling_rate_hz,
128-
nsamples=nsamples,
129-
)
159+
if isinstance(start_time, astropy.time.Time):
160+
assert isinstance(start_time_global, astropy.time.Time), (
161+
"The start_time is a astropy.time.Time object, so start_time_global must also be an astropy.time.Time object."
162+
)
163+
start_time_s = (start_time - start_time_global).to("s").value
164+
else:
165+
assert isinstance(start_time_global, (int, float)), (
166+
"The start_time is a float, so start_time_global must also be a float."
167+
)
168+
start_time_s = start_time - start_time_global
130169

131170
if self.hwp is not None:
132171
if hwp_buffer is None:
133172
hwp_buffer = np.empty(nsamples, dtype=pointings_dtype)
134173

135-
if isinstance(start_time, astropy.time.Time):
136-
assert isinstance(start_time_global, astropy.time.Time), (
137-
"The start_time is a astropy.time.Time object, so start_time_global must also be an astropy.time.Time object."
138-
)
139-
start_time_s = (start_time - start_time_global).to("s").value
140-
else:
141-
assert isinstance(start_time_global, (int, float)), (
142-
"The start_time is a float, so start_time_global must also be a float."
143-
)
144-
start_time_s = start_time - start_time_global
145-
146174
self.hwp.get_hwp_angle(
147175
output_buffer=hwp_buffer,
148176
start_time_s=start_time_s,
@@ -154,10 +182,32 @@ def get_pointings(
154182
if pointing_buffer is None:
155183
pointing_buffer = np.empty(shape=(nsamples, 3), dtype=pointings_dtype)
156184

157-
all_compute_pointing_and_orientation(
158-
result_matrix=pointing_buffer,
159-
quat_matrix=full_quaternions,
160-
)
185+
block_lengths = self._optimal_block_lengths(total_nsamples=nsamples)
186+
187+
det_to_ecliptic_quats = self.bore2ecliptic_quats * detector_quat
188+
cur_time = start_time
189+
start_sample = 0
190+
for cur_block_length in block_lengths:
191+
cur_quaternions = det_to_ecliptic_quats.slerp(
192+
start_time=cur_time,
193+
sampling_rate_hz=sampling_rate_hz,
194+
nsamples=cur_block_length,
195+
)
196+
all_compute_pointing_and_orientation(
197+
result_matrix=pointing_buffer[
198+
start_sample : (start_sample + cur_block_length), :
199+
],
200+
quat_matrix=cur_quaternions,
201+
)
202+
203+
if isinstance(cur_time, astropy.time.Time):
204+
cur_time += astropy.time.TimeDelta(
205+
cur_block_length / sampling_rate_hz, format="sec"
206+
)
207+
else:
208+
cur_time += cur_block_length / sampling_rate_hz
209+
210+
start_sample += cur_block_length
161211

162212
return pointing_buffer, hwp_buffer
163213

test/test_scanning.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,3 +538,54 @@ def test_time_dependent_quaternions_operations():
538538
expected[0, :] = qconst1.quats[0, :]
539539
lbs.quat_right_multiply(expected[0, :], *qconst2.quats[0, :])
540540
np.testing.assert_allclose(actual=result.quats, desired=expected)
541+
542+
543+
def test_chunked_pointing_generation():
544+
quat_array = lbs.RotQuaternion(
545+
quats=np.array(
546+
[
547+
# This is not really a “rotating” quaternion: we repeat
548+
# the same quaternion (90° rotation around x) thrice
549+
# just for testing
550+
[1.0, 0.0, 0.0, 1.0],
551+
[1.0, 0.0, 0.0, 1.0],
552+
[1.0, 0.0, 0.0, 1.0],
553+
]
554+
/ np.sqrt(2)
555+
),
556+
start_time=0.0,
557+
sampling_rate_hz=0.25, # Four seconds per quaternion
558+
)
559+
560+
# Make room for 5 quaternions at most
561+
quaternion_size_in_bytes = 32
562+
pp = lbs.PointingProvider(
563+
bore2ecliptic_quats=quat_array,
564+
maximum_internal_buffer_mem_mb=(quaternion_size_in_bytes * 5) / (1024 * 1024),
565+
)
566+
567+
num_of_samples = 12
568+
block_lengths = pp._optimal_block_lengths(total_nsamples=num_of_samples)
569+
assert len(block_lengths) == 3
570+
assert sum(block_lengths) == num_of_samples
571+
572+
pointing_buf, hwp_buf = pp.get_pointings(
573+
detector_quat=lbs.RotQuaternion(quats=np.array([[0.0, 0.0, 0.0, 1.0]])),
574+
start_time=0.0,
575+
start_time_global=0.0,
576+
sampling_rate_hz=1.0,
577+
nsamples=num_of_samples,
578+
)
579+
assert pointing_buf.shape == (num_of_samples, 3)
580+
581+
# We expect the +z axis of the detector to be rotated by 90° around the x axis,
582+
# so that it should point towards −y. This implies that ϑ = π/2 and φ = −π/2
583+
584+
# ϑ
585+
np.testing.assert_allclose(pointing_buf[:, 0], np.pi / 2)
586+
587+
# φ
588+
np.testing.assert_allclose(pointing_buf[:, 1], -np.pi / 2)
589+
590+
# ψ
591+
np.testing.assert_allclose(pointing_buf[:, 2], 0, atol=1e-15)

0 commit comments

Comments
 (0)