diff --git a/pisa/analysis/analysis.py b/pisa/analysis/analysis.py index fe9ab18fd..10a18cdea 100644 --- a/pisa/analysis/analysis.py +++ b/pisa/analysis/analysis.py @@ -294,28 +294,35 @@ def detailed_metric_info(self, new_info): [d[self.metric[0 if not is_detectors else i]]["maps"]["total"] for i, d in enumerate(self._detailed_metric_info)] ) + # Set off by assuming we're unable to obtain the total prior contribution + metric_val_from_priors = np.nan if is_detectors: # 1) # TODO: We don't have access to the Detectors instance itself here # -> no straightforward way to correctly determine the total prior # contribution (cf. Detectors.init_params()) - metric_val_from_priors = np.nan + pass + elif not "priors_penalty_total" in self._detailed_metric_info[0][self.metric[0]]: + # 2) Don't attempt to manually compute total prior contribution: + # there may be correlations! + pass else: - # 2) - # can obtain the total prior contribution from any of the list entries - metric_val_from_priors = np.sum( - self._detailed_metric_info[0][self.metric[0]]["priors"] - ) - total_metric_val_from_detailed = metric_val_from_maps + metric_val_from_priors + # 3) Can read out the pre-computed total prior contribution from + # any of the list entries + metric_val_from_priors = self._detailed_metric_info[0][self.metric[0]]["priors_penalty_total"] else: self._detailed_metric_info = self.deserialize_detailed_metric_info(new_info) if self.metric_val is None: return # sanity check on metric value - total_metric_val_from_detailed = ( - self._detailed_metric_info[self.metric[0]]["maps"]["total"] + - np.sum(self._detailed_metric_info[self.metric[0]]["priors"]) - ) + metric_val_from_priors = np.nan + if "priors_penalty_total" in self._detailed_metric_info[self.metric[0]]: + metric_val_from_priors = self._detailed_metric_info[self.metric[0]]["priors_penalty_total"] + metric_val_from_maps = self._detailed_metric_info[self.metric[0]]["maps"]["total"] + + # Get the total and compare if possible + total_metric_val_from_detailed = metric_val_from_maps + metric_val_from_priors + if not np.isnan(total_metric_val_from_detailed): if not recursiveEquality(total_metric_val_from_detailed, self.metric_val): logging.warning( @@ -455,6 +462,7 @@ def get_detailed_metric_info(data_dist, hypo_maker, hypo_asimov_dist, params, if include_maps_binned: name_vals_d['maps_binned'] = MapSet(maps_binned) name_vals_d['priors'] = params.priors_penalties(metric=metric) + name_vals_d['priors_penalty_total'] = params.priors_penalty(metric=metric) detailed_metric_info[m] = name_vals_d return detailed_metric_info @@ -479,6 +487,10 @@ def deserialize_detailed_metric_info(info_dict): # Deserialize if necessary name_vals_d["maps_binned"] = MapSet(**info_dict[m]["maps_binned"]) name_vals_d["priors"] = info_dict[m]["priors"] + if "priors_penalty_total" in info_dict[m]: + # Don't assume that this entry exists either, so that old fit + # results can still be deserialized + name_vals_d["priors_penalty_total"] = info_dict[m]["priors_penalty_total"] detailed_metric_info[m] = name_vals_d return detailed_metric_info