diff --git a/CHANGELOG.md b/CHANGELOG.md index fd7498d..89f043a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,15 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.12.2] - 2026-06-03 + +### Fixed + +* Models with static metrics breaking when adding to XArray dataset + + + + ## [0.12.1] - 2026-04-20 Code cleaning and documentation updates. diff --git a/reno/__init__.py b/reno/__init__.py index 50eb83c..275f505 100644 --- a/reno/__init__.py +++ b/reno/__init__.py @@ -98,7 +98,7 @@ warnings.simplefilter("always", RuntimeWarning) -__version__ = "0.12.1" +__version__ = "0.12.2" __all__ = [ "Bernoulli", diff --git a/reno/model.py b/reno/model.py index 397f743..947e5e0 100644 --- a/reno/model.py +++ b/reno/model.py @@ -944,8 +944,18 @@ def dataset(self) -> xr.Dataset: # noqa: C901 # be 1 per step new_vars = {} for metric in self.metrics: - coords = ["sample"] if len(metric.value.shape) == 1 else ["sample", "step"] - new_vars[metric.qual_name()] = (coords, metric.value) + val = metric.value + if metric.is_static(): # noqa: SIM102 + # bleh, see note above + if ( + isinstance(val, (int, float)) + or (isinstance(val, np.ndarray) and len(val.shape) == 0) + or (len(val.shape) > 0 and val.shape[0] != self.last_n) + or len(val.shape) == 0 + ): + val = np.broadcast_to(val, (self.last_n,)) + coords = ["sample"] if len(val.shape) == 1 else ["sample", "step"] + new_vars[metric.qual_name()] = (coords, val) ds = ds.assign(new_vars) # merge in any sub datasets diff --git a/reno/ops.py b/reno/ops.py index 97e5c3d..cd6d338 100644 --- a/reno/ops.py +++ b/reno/ops.py @@ -1433,10 +1433,13 @@ def op_latex(self, **kwargs: dict) -> str: return f"\\text{{interpolate}}({self.sub_equation_parts[0].latex(**kwargs)}, {self.sub_equation_parts[1].latex(**kwargs)}, {self.sub_equation_parts[2].latex(**kwargs)})" def op_eval(self, **kwargs: dict) -> np.ndarray: + input_eval = self.sub_equation_parts[0].eval(**kwargs) + x_eval = self.sub_equation_parts[1].eval(**kwargs) + y_eval = self.sub_equation_parts[2].eval(**kwargs) return np.interp( - self.sub_equation_parts[0].eval(**kwargs), - self.sub_equation_parts[1].eval(**kwargs), - self.sub_equation_parts[2].eval(**kwargs), + input_eval, + x_eval, + y_eval, ) def pt(self, **refs: dict[str, pt.TensorVariable]) -> pt.TensorVariable: diff --git a/reno/third_party/__init__.py b/reno/third_party/__init__.py new file mode 100644 index 0000000..007a777 --- /dev/null +++ b/reno/third_party/__init__.py @@ -0,0 +1,9 @@ +"""Public API for the mmdfuse function, re-exported for convenient import. + +Example: + from reno.third_party.mmdfuse import mmdfuse +""" + +from .mmdfuse import mmdfuse + +__all__ = ["mmdfuse"] diff --git a/reno/third_party/mmdfuse/LICENSE.md b/reno/third_party/mmdfuse/LICENSE.md new file mode 100644 index 0000000..4a19d1a --- /dev/null +++ b/reno/third_party/mmdfuse/LICENSE.md @@ -0,0 +1,21 @@ +# MIT License + +Copyright (c) 2023 Antonin Schrab + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/reno/third_party/mmdfuse/__init__.py b/reno/third_party/mmdfuse/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/reno/third_party/mmdfuse/mmd_gaussian_fixed.py b/reno/third_party/mmdfuse/mmd_gaussian_fixed.py new file mode 100644 index 0000000..f90d9a1 --- /dev/null +++ b/reno/third_party/mmdfuse/mmd_gaussian_fixed.py @@ -0,0 +1,169 @@ +# ============================================================================= +# Gaussian-kernel MMD distance with fixed bandwidths. +# +# This file modifies the MMD-FUSE implementation by removing two-sample +# testing, permutations, softmax/logsumexp fusion, and data-dependent +# bandwidth selection. It uses only Gaussian kernels with fixed bandwidths +# averages the kernel matrices into a single multi-scale Gaussian kernel. +# +# The modification allows using MMD as a metric to pairwise compare +# multiple distributions in a consistent manner. +# ============================================================================= + +import numpy as np + + +# Fixed, data-independent Gaussian bandwidths. Ideally, features are already +# on a comparable scale, e.g. standardized/z-scored or otherwise normalized. +DEFAULT_FIXED_BANDWIDTHS = np.array([0.1, 0.3, 1.0, 3.0, 10.0], dtype=float) + + +def np_distances(X, Y, l="l2", max_samples=None, matrix=False): + """ + Compute pairwise l1 or l2 distances using NumPy broadcasting. + + Parameters + ---------- + X, Y : ndarray, shape (n_samples, n_features) + l : {"l1", "l2"} + max_samples : int or None + Optional cap applied to both X and Y before computing distances. + matrix : bool + If True, return the full distance matrix. If False, return the + upper-triangular entries of the distance matrix. + """ + X = np.asarray(X, dtype=float) + Y = np.asarray(Y, dtype=float) + + Xs = X[:max_samples] + Ys = Y[:max_samples] + + diff = Xs[:, None, :] - Ys[None, :, :] + + if l == "l1": + output = np.sum(np.abs(diff), axis=-1) + elif l == "l2": + output = np.sqrt(np.sum(diff**2, axis=-1)) + else: + raise ValueError("Value of 'l' must be either 'l1' or 'l2'.") + + if matrix: + return output + return output[np.triu_indices(output.shape[0])] + + +def gaussian_kernel_matrix(pairwise_l2_distances, bandwidth): + """Gaussian/RBF kernel matrix from an l2 distance matrix.""" + d = pairwise_l2_distances / bandwidth + return np.exp(-(d**2) / 2) + + +def average_gaussian_kernel_matrix(pairwise_l2_distances, bandwidths=DEFAULT_FIXED_BANDWIDTHS): + """ + Build one multi-scale Gaussian kernel by averaging Gaussian kernels over + fixed bandwidths. No softmax, no learned weights, no data-dependent tuning. + """ + bandwidths = np.asarray(bandwidths, dtype=float) + if bandwidths.ndim != 1 or len(bandwidths) == 0: + raise ValueError("bandwidths must be a non-empty 1D array.") + if np.any(bandwidths <= 0): + raise ValueError("All bandwidths must be positive.") + + K = np.zeros_like(pairwise_l2_distances, dtype=float) + for bandwidth in bandwidths: + K += gaussian_kernel_matrix(pairwise_l2_distances, bandwidth) + K /= len(bandwidths) + return K + + +def mmd2_average_gaussian( + X, + Y, + bandwidths=DEFAULT_FIXED_BANDWIDTHS, + unbiased=False, + return_kernel=False, +): + """ + Compute MMD^2 between X and Y using a single kernel formed by averaging + multiple fixed-bandwidth Gaussian kernels for measuring distribution + dissimilarity. + + Parameters + ---------- + X, Y : ndarray, shape (n_samples, n_features) + Samples from the two windows/distributions to compare. + bandwidths : ndarray + Fixed Gaussian bandwidths. The default is + DEFAULT_FIXED_BANDWIDTHS = [0.1, 0.3, 1.0, 3.0, 10.0]. + unbiased : bool + If True, use the unbiased U-statistic estimator and exclude diagonals + in Kxx and Kyy. This estimator can be negative for finite samples. + If False, use the biased V-statistic estimator, which is nonnegative + up to numerical precision and is better for a metric-like distance. + return_kernel : bool + If True, also return the averaged full kernel matrix and bandwidths. + + Returns + ------- + mmd2 : float + Estimated squared MMD. Small means similar distributions; large means + different distributions. + bandwidths : ndarray + The fixed bandwidths used. + K : ndarray, optional + Full averaged kernel matrix over concat(X, Y), returned only when + return_kernel=True. + """ + X = np.asarray(X, dtype=float) + Y = np.asarray(Y, dtype=float) + + if X.ndim != 2 or Y.ndim != 2: + raise ValueError("X and Y must both be 2D arrays: samples x features.") + if X.shape[1] != Y.shape[1]: + raise ValueError("X and Y must have the same number of features.") + + m = X.shape[0] + n = Y.shape[0] + + if unbiased and (m < 2 or n < 2): + raise ValueError("Unbiased MMD requires at least 2 samples in X and Y.") + + bandwidths = np.asarray(bandwidths, dtype=float) + + Z = np.concatenate((X, Y), axis=0) + pairwise_l2 = np_distances(Z, Z, l="l2", matrix=True) + K = average_gaussian_kernel_matrix(pairwise_l2, bandwidths) + + Kxx = K[:m, :m] + Kyy = K[m:, m:] + Kxy = K[:m, m:] + + if unbiased: + mmd2 = ( + (np.sum(Kxx) - np.trace(Kxx)) / (m * (m - 1)) + + (np.sum(Kyy) - np.trace(Kyy)) / (n * (n - 1)) + - 2 * np.mean(Kxy) + ) + else: + mmd2 = np.mean(Kxx) + np.mean(Kyy) - 2 * np.mean(Kxy) + + # Guard tiny negative values caused by floating-point roundoff. + if not unbiased and mmd2 < 0 and np.isclose(mmd2, 0.0): + mmd2 = 0.0 + + if return_kernel: + return float(mmd2), bandwidths.copy(), K + return float(mmd2), bandwidths.copy() + + +def mmd_average_gaussian(*args, **kwargs): + """ + Return sqrt(max(MMD^2, 0)) using fixed-bandwidth averaged Gaussian kernels. + + With the default biased estimator, this is a nonnegative empirical MMD + distance. The kernel/bandwidth choice is fixed and not fit to the data. + """ + result = mmd2_average_gaussian(*args, **kwargs) + mmd2 = result[0] + mmd = np.sqrt(max(mmd2, 0.0)) + return (mmd, *result[1:]) diff --git a/reno/third_party/mmdfuse/mmdfuse.py b/reno/third_party/mmdfuse/mmdfuse.py new file mode 100644 index 0000000..eb65801 --- /dev/null +++ b/reno/third_party/mmdfuse/mmdfuse.py @@ -0,0 +1,404 @@ +"""Third party library for MMD-FUSE, a value of information algorithm. + +See https://github.com/antoninschrab/mmdfuse for the original. +""" + +# ============================================================================= +# mmdfuse.py - copy of the original MMD-FUSE implementation +# ============================================================================= +# Copyright (c) 2023 Biggs, Schrab & Gretton +# Licensed under the MIT License (see LICENSE file in this directory). +# +# ------------------------------------------------------------------------- +# Modifications made for the Reno repository +# ------------------------------------------------------------------------- +# * Removed JAX dependency - all JAX-specific imports and `jax.numpy` +# calls have been replaced with NumPy equivalents. The public API +# (`mmdfuse(...)`) and the statistical behaviour remain unchanged. +# * Updated type hints to use `numpy.ndarray` instead of `jax.numpy.ndarray`. +# * Minor refactoring to avoid JAX random-key handling. +# * Update docstring style. +# * Add typehints and some minor variable renames. +# +# ------------------------------------------------------------------------- +# Citation +# ------------------------------------------------------------------------- +# If you use this code in a publication, please cite the original work: +# @article{biggs2023mmdfuse, +# author = {Biggs, Felix and Schrab, Antonin and Gretton, Arthur}, +# title = {{MMD-FUSE}: {L}earning and Combining Kernels for Two-Sample Testing Without Data Splitting}, +# year = {2023}, +# journal = {Advances in Neural Information Processing Systems}, +# volume = {36} +# } +# +# Repo link: https://github.com/antoninschrab/mmdfuse +# ============================================================================= + +from typing import Literal + +import numpy as np +from numpy.typing import ArrayLike + + +def _logsumexp(a: ArrayLike, axis: int = None, b: float = 1.0) -> np.ndarray: + """NumPy-only stable logsumexp. + + Equivalent to ``scipy.special.logsumexp(a, axis=axis, b=b)`` for positive scalar b. + """ + a = np.asarray(a) + a_max = np.max(a, axis=axis, keepdims=True) + + # Avoid invalid operations if all values are -inf + shifted = np.exp(a - a_max) + s = np.sum(b * shifted, axis=axis, keepdims=True) + out = a_max + np.log(s) + + if axis is not None: + out = np.squeeze(out, axis=axis) + + return out + + +def kernel_matrix( + pairwise_matrix: np.ndarray, + dist_metric: Literal["l1", "l2"], + kernel: str, + bandwidth: float, + rq_kernel_exponent: float = 0.5, +) -> np.ndarray: + """Compute kernel matrix for a given kernel and bandwidth. + + Args: + pairwise_matrix (ndarray): Matrix of pairwise distances. + dist_metric (Literal["l1", "l2"]): Distance type. + kernel (str): Kernel name. + bandwidth (float): Kernel bandwidth. + rq_kernel_exponent (float): Exponent for rational quadratic kernel. + """ + d = pairwise_matrix / bandwidth + + if kernel == "gaussian" and dist_metric == "l2": + return np.exp(-(d**2) / 2) + + elif kernel == "laplace" and dist_metric == "l1": + return np.exp(-d * np.sqrt(2)) + + elif kernel == "rq" and dist_metric == "l2": + return (1 + d**2 / (2 * rq_kernel_exponent)) ** (-rq_kernel_exponent) + + elif kernel == "imq" and dist_metric == "l2": + return (1 + d**2) ** (-0.5) + + elif (kernel == "matern_0.5_l1" and dist_metric == "l1") or ( + kernel == "matern_0.5_l2" and dist_metric == "l2" + ): + return np.exp(-d) + + elif (kernel == "matern_1.5_l1" and dist_metric == "l1") or ( + kernel == "matern_1.5_l2" and dist_metric == "l2" + ): + return (1 + np.sqrt(3) * d) * np.exp(-np.sqrt(3) * d) + + elif (kernel == "matern_2.5_l1" and dist_metric == "l1") or ( + kernel == "matern_2.5_l2" and dist_metric == "l2" + ): + return (1 + np.sqrt(5) * d + 5 / 3 * d**2) * np.exp(-np.sqrt(5) * d) + + elif (kernel == "matern_3.5_l1" and dist_metric == "l1") or ( + kernel == "matern_3.5_l2" and dist_metric == "l2" + ): + return ( + 1 + np.sqrt(7) * d + 2 * 7 / 5 * d**2 + 7 * np.sqrt(7) / 3 / 5 * d**3 + ) * np.exp(-np.sqrt(7) * d) + + elif (kernel == "matern_4.5_l1" and dist_metric == "l1") or ( + kernel == "matern_4.5_l2" and dist_metric == "l2" + ): + return ( + 1 + + 3 * d + + 3 * (6**2) / 28 * d**2 + + (6**3) / 84 * d**3 + + (6**4) / 1680 * d**4 + ) * np.exp(-3 * d) + + else: + raise ValueError('The values of "l" and "kernel" are not valid.') + + +def np_distances( + X: ArrayLike, + Y: ArrayLike, + dist_metric: Literal["l1", "l2"], + max_samples: int = None, + matrix: bool = False, +) -> np.ndarray: + """NumPy replacement for jax_distances. + + Computes pairwise l1 or l2 distances using broadcasting. + + Args: + X (ndarray): shape (m, d) + Y (ndarray): shape (n, d) + dist_metric (Literal["l1", "l2"]): Distance type. + max_samples (int): Maximum number of pairs to draw for computing distances. + matrix (bool): Returns the full distance matrix if ``True``, otherwise just the + upper-triangular entries. + """ + X = np.asarray(X) + Y = np.asarray(Y) + + Xs = X[:max_samples] + Ys = Y[:max_samples] + + diff = Xs[:, None, :] - Ys[None, :, :] + + if dist_metric == "l1": + output = np.sum(np.abs(diff), axis=-1) + elif dist_metric == "l2": + output = np.sqrt(np.sum(diff**2, axis=-1)) + else: + raise ValueError("Value of 'l' must be either 'l1' or 'l2'.") + + if matrix: + return output + else: + return output[np.triu_indices(output.shape[0])] + + +def compute_bandwidths( + X: np.ndarray, + Y: np.ndarray, + dist_metric: Literal["l1", "l2"], + number_bandwidths: int, + only_median: bool = False, +) -> np.ndarray: + """NumPy replacement for the JAX/JIT compute_bandwidths function.""" + Z = np.concatenate((X, Y), axis=0) + distances = np_distances(Z, Z, dist_metric, matrix=False) + + median = np.median(distances) + + if only_median: + return median + + distances = distances + (distances == 0) * median + dd = np.sort(distances) + + lambda_min = dd[int(np.floor(len(dd) * 0.05))] / 2 + lambda_max = dd[int(np.floor(len(dd) * 0.95))] * 2 + + bandwidths = np.linspace(lambda_min, lambda_max, number_bandwidths) + return bandwidths + + +def _make_rng(key: int | np.random.Generator = None) -> np.random.Generator: + """Convert a key-like input into a NumPy random Generator, if not one already. + + Args: + key (int | np.random.Generator): The key to use or an existing generator. + """ + if isinstance(key, np.random.Generator): + return key + return np.random.default_rng(key) + + +def mmdfuse( + X: ArrayLike, + Y: ArrayLike, + key: int | np.random.Generator = None, + alpha: float = 0.05, + kernels: str | tuple[str] | list[str] = ("laplace", "gaussian"), + lambda_multiplier: float = 1.0, + number_bandwidths: int = 10, + number_permutations: int = 2000, + return_p_val: bool = False, +) -> int | tuple[int, float, np.ndarray]: + """Two-Sample MMD-FUSE test, NumPy-only version. + + Args: + X (ArrayLike): shape (m, d) + Y (ArrayLike): shape (n, d) + key (int | np.random.Generator): Random seed or NumPy Generator. + Example: ``key=0`` + alpha (float): Test level. + kernels (str | tuple[str] | list[str]): Kernel names. + lambda_multiplier (float): ??? + number_bandwidths (int): ??? + number_permutations (int): ??? + return_p_val (bool): ??? + + Returns: + 0 if the test fails to reject the null. 1 if the test rejects the null. + Or, if return_p_val=True, returns a tuple of the int output, float p_val, and + the numpy array containing all statistics. + """ + X = np.asarray(X, dtype=float) + Y = np.asarray(Y, dtype=float) + + rng = _make_rng(key) + + # Match original behavior: ensure n <= m + if Y.shape[0] > X.shape[0]: + X, Y = Y, X + + m = X.shape[0] + n = Y.shape[0] + + assert n <= m + assert n >= 2 and m >= 2 + assert 0 < alpha < 1 + assert lambda_multiplier > 0 + assert number_bandwidths > 1 and type(number_bandwidths) is int + assert number_permutations > 0 and type(number_permutations) is int + + if type(kernels) is str: + kernels = (kernels,) + + valid_kernels = ( + "imq", + "rq", + "gaussian", + "matern_0.5_l2", + "matern_1.5_l2", + "matern_2.5_l2", + "matern_3.5_l2", + "matern_4.5_l2", + "laplace", + "matern_0.5_l1", + "matern_1.5_l1", + "matern_2.5_l1", + "matern_3.5_l1", + "matern_4.5_l1", + ) + + for kernel in kernels: + assert kernel in valid_kernels + + all_kernels_l1 = ( + "laplace", + "matern_0.5_l1", + "matern_1.5_l1", + "matern_2.5_l1", + "matern_3.5_l1", + "matern_4.5_l1", + ) + + all_kernels_l2 = ( + "imq", + "rq", + "gaussian", + "matern_0.5_l2", + "matern_1.5_l2", + "matern_2.5_l2", + "matern_3.5_l2", + "matern_4.5_l2", + ) + + number_kernels = len(kernels) + kernels_l1 = [k for k in kernels if k in all_kernels_l1] + kernels_l2 = [k for k in kernels if k in all_kernels_l2] + + # Setup for permutations + B = number_permutations + total = m + n + + # Shape: (B + 1, m + n) + idx = np.empty((B + 1, total), dtype=int) + for b in range(B + 1): + idx[b] = rng.permutation(total) + + # 11 + v11 = np.concatenate((np.ones(m), -np.ones(n))) + V11i = np.tile(v11, (B + 1, 1)) + V11 = np.take_along_axis(V11i, idx, axis=1) + V11[B] = v11 + V11 = V11.T + + # 10 + v10 = np.concatenate((np.ones(m), np.zeros(n))) + V10i = np.tile(v10, (B + 1, 1)) + V10 = np.take_along_axis(V10i, idx, axis=1) + V10[B] = v10 + V10 = V10.T + + # 01 + v01 = np.concatenate((np.zeros(m), -np.ones(n))) + V01i = np.tile(v01, (B + 1, 1)) + V01 = np.take_along_axis(V01i, idx, axis=1) + V01[B] = v01 + V01 = V01.T + + # Compute all permuted MMD estimates + N = number_bandwidths * number_kernels + M = np.zeros((N, B + 1)) + + kernel_count = -1 + + Z = np.concatenate((X, Y), axis=0) + + for r in range(2): + kernels_l = (kernels_l1, kernels_l2)[r] + dist_metric = ("l1", "l2")[r] + + if len(kernels_l) > 0: + # Pairwise distance matrix + pairwise_matrix = np_distances(Z, Z, dist_metric, matrix=True) + + # Collection of bandwidths + distances = pairwise_matrix[np.triu_indices(pairwise_matrix.shape[0])] + + median = np.median(distances) + distances = distances + (distances == 0) * median + + dd = np.sort(distances) + lambda_min = dd[int(np.floor(len(dd) * 0.05))] / 2 + lambda_max = dd[int(np.floor(len(dd) * 0.95))] * 2 + + bandwidths = np.linspace(lambda_min, lambda_max, number_bandwidths) + + # Compute all permuted MMD estimates for either l1 or l2 + for kernel in kernels_l: + kernel_count += 1 + + for i in range(number_bandwidths): + bandwidth = bandwidths[i] + + # Compute kernel matrix and set diagonal to zero + K = kernel_matrix(pairwise_matrix, dist_metric, kernel, bandwidth) + np.fill_diagonal(K, 0) + + # Compute standard deviation + unscaled_std = np.sqrt(np.sum(K**2)) + + # Matrix products + KV10 = K @ V10 + KV01 = K @ V01 + KV11 = K @ V11 + + values = ( + np.sum(V10 * KV10, axis=0) + * (n - m + 1) + * (n - 1) + / (m * (m - 1)) + + np.sum(V01 * KV01, axis=0) * (m - n + 1) / m + + np.sum(V11 * KV11, axis=0) * (n - 1) / m + ) + + values = values / unscaled_std * np.sqrt(n * (n - 1)) + + M[kernel_count * number_bandwidths + i] = values + + # Compute permuted and original statistics + all_statistics = _logsumexp(lambda_multiplier * M, axis=0, b=1 / N) + original_statistic = all_statistics[-1] + + # Compute p-value and test output + p_val = np.mean(all_statistics >= original_statistic) + output = int(p_val <= alpha) + + if return_p_val: + return output, p_val, all_statistics + else: + return output diff --git a/tests/test_components.py b/tests/test_components.py index 1146e2f..823fd6a 100644 --- a/tests/test_components.py +++ b/tests/test_components.py @@ -735,3 +735,13 @@ def test_variable_boolean(): ds1 = m() ds2 = m(v0=False) + + +def test_static_metrics(): + """A metric that is static shouldn't break when constructing the dataset.""" + m = Model() + with m: + met = Metric(Scalar(5) + 4) + + ds = m() + assert ds.met.values[0] == 9