Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions pisa/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,28 +294,35 @@ def detailed_metric_info(self, new_info):
[d[self.metric[0 if not is_detectors else i]]["maps"]["total"] for
i, d in enumerate(self._detailed_metric_info)]
)
# Set off by assuming we're unable to obtain the total prior contribution
metric_val_from_priors = np.nan
if is_detectors:
# 1)
# TODO: We don't have access to the Detectors instance itself here
# -> no straightforward way to correctly determine the total prior
# contribution (cf. Detectors.init_params())
metric_val_from_priors = np.nan
pass
elif not "priors_penalty_total" in self._detailed_metric_info[0][self.metric[0]]:
# 2) Don't attempt to manually compute total prior contribution:
# there may be correlations!
pass
else:
# 2)
# can obtain the total prior contribution from any of the list entries
metric_val_from_priors = np.sum(
self._detailed_metric_info[0][self.metric[0]]["priors"]
)
total_metric_val_from_detailed = metric_val_from_maps + metric_val_from_priors
# 3) Can read out the pre-computed total prior contribution from
# any of the list entries
metric_val_from_priors = self._detailed_metric_info[0][self.metric[0]]["priors_penalty_total"]
else:
self._detailed_metric_info = self.deserialize_detailed_metric_info(new_info)
if self.metric_val is None:
return
# sanity check on metric value
total_metric_val_from_detailed = (
self._detailed_metric_info[self.metric[0]]["maps"]["total"] +
np.sum(self._detailed_metric_info[self.metric[0]]["priors"])
)
metric_val_from_priors = np.nan
if "priors_penalty_total" in self._detailed_metric_info[self.metric[0]]:
metric_val_from_priors = self._detailed_metric_info[self.metric[0]]["priors_penalty_total"]
metric_val_from_maps = self._detailed_metric_info[self.metric[0]]["maps"]["total"]

# Get the total and compare if possible
total_metric_val_from_detailed = metric_val_from_maps + metric_val_from_priors

if not np.isnan(total_metric_val_from_detailed):
if not recursiveEquality(total_metric_val_from_detailed, self.metric_val):
logging.warning(
Expand Down Expand Up @@ -455,6 +462,7 @@ def get_detailed_metric_info(data_dist, hypo_maker, hypo_asimov_dist, params,
if include_maps_binned:
name_vals_d['maps_binned'] = MapSet(maps_binned)
name_vals_d['priors'] = params.priors_penalties(metric=metric)
name_vals_d['priors_penalty_total'] = params.priors_penalty(metric=metric)
detailed_metric_info[m] = name_vals_d
return detailed_metric_info

Expand All @@ -479,6 +487,10 @@ def deserialize_detailed_metric_info(info_dict):
# Deserialize if necessary
name_vals_d["maps_binned"] = MapSet(**info_dict[m]["maps_binned"])
name_vals_d["priors"] = info_dict[m]["priors"]
if "priors_penalty_total" in info_dict[m]:
# Don't assume that this entry exists either, so that old fit
# results can still be deserialized
name_vals_d["priors_penalty_total"] = info_dict[m]["priors_penalty_total"]
detailed_metric_info[m] = name_vals_d
return detailed_metric_info

Expand Down