diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index 2261498c8be0..6517db169e19 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -351,6 +351,7 @@ Complete list of metrics SSIM TopKCategoricalAccuracy Bleu + Perplexity Rouge RougeL RougeN diff --git a/ignite/metrics/__init__.py b/ignite/metrics/__init__.py index b1813cc92935..dd81e6e41f09 100644 --- a/ignite/metrics/__init__.py +++ b/ignite/metrics/__init__.py @@ -33,6 +33,7 @@ from ignite.metrics.mutual_information import MutualInformation from ignite.metrics.nlp.bleu import Bleu from ignite.metrics.nlp.rouge import Rouge, RougeL, RougeN +from ignite.metrics.nlp.perplexity import Perplexity from ignite.metrics.precision import Precision from ignite.metrics.precision_recall_curve import PrecisionRecallCurve from ignite.metrics.psnr import PSNR @@ -93,6 +94,7 @@ "Rouge", "RougeN", "RougeL", + "Perplexity", "regression", "clustering", "fairness", diff --git a/ignite/metrics/nlp/__init__.py b/ignite/metrics/nlp/__init__.py index 506f0bab51e1..d0212882b78b 100644 --- a/ignite/metrics/nlp/__init__.py +++ b/ignite/metrics/nlp/__init__.py @@ -1,8 +1,10 @@ from ignite.metrics.nlp.bleu import Bleu +from ignite.metrics.nlp.perplexity import Perplexity from ignite.metrics.nlp.rouge import Rouge, RougeL, RougeN __all__ = [ "Bleu", + "Perplexity", "Rouge", "RougeN", "RougeL", diff --git a/ignite/metrics/nlp/perplexity.py b/ignite/metrics/nlp/perplexity.py new file mode 100644 index 000000000000..45c4ddf98138 --- /dev/null +++ b/ignite/metrics/nlp/perplexity.py @@ -0,0 +1,107 @@ +from collections.abc import Callable + +import torch +import torch.nn.functional as F + +from ignite.exceptions import NotComputableError +from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce + +__all__ = ["Perplexity"] + + +class Perplexity(Metric): + r"""Calculates the `Perplexity `_ of a language model. + + .. math:: + \text{PPL}(W) = \exp \left( -\frac{1}{N} \sum_{i=1}^{N} \log P(w_i | w_1, \ldots, w_{i-1}) \right) + + where :math:`N` is the total number of tokens and :math:`P(w_i | w_1, \ldots, w_{i-1})` is the + conditional probability of token :math:`w_i` given the preceding tokens. + + Perplexity is computed as :math:`\exp(\text{NLL})` where NLL is the mean negative log-likelihood + over all tokens. Lower perplexity indicates a better language model. + + - ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. + - `y_pred` must be a floating-point tensor of shape ``(batch_size, vocab_size, seq_len)`` + containing the unnormalized log-probabilities (logits). + - `y` must be a long tensor of shape ``(batch_size, seq_len)`` containing the target token indices. + + Note: + Perplexity uses token-weighted accumulation rather than batch-average to avoid bias + towards shorter sequences. The total NLL and total token count are accumulated across + all batches, and the final perplexity is computed as ``exp(total_nll / total_tokens)``. + + Args: + output_transform: a callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. This can be useful if, for example, you have a multi-output model and + you want to compute the metric with respect to one of the outputs. + By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. + device: specifies which device updates are accumulated on. Setting the + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is + non-blocking. By default, CPU. + + Examples: + + For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. + + .. testcode:: + + from ignite.metrics.nlp import Perplexity + import torch + + ppl = Perplexity() + + # batch_size=2, vocab_size=5, seq_len=3 + y_pred = torch.log_softmax(torch.randn(2, 5, 3), dim=1) + y = torch.randint(0, 5, (2, 3)) + + ppl.update((y_pred, y)) + + print(type(ppl.compute())) + + .. testoutput:: + + + + .. versionadded:: 0.5.2 + """ + + _state_dict_all_req_keys = ("_sum_of_nll", "_num_tokens") + + def __init__( + self, + output_transform: Callable = lambda x: x, + device: str | torch.device = torch.device("cpu"), + ignore_index: int = -100, + ): + self._ignore_index = ignore_index + super().__init__(output_transform=output_transform, device=device) + + @reinit__is_reduced + def reset(self) -> None: + self._sum_of_nll = torch.tensor(0.0, device=self._device) + self._num_tokens = torch.tensor(0, device=self._device) + + @reinit__is_reduced + def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: + y_pred, y = output + y_pred = y_pred.detach() + y = y.detach() + + if y_pred.ndim < 2: + raise ValueError(f"y_pred must be at least 2-dimensional (got shape: {y_pred.shape})") + + if y.ndim < 1: + raise ValueError(f"y must be at least 1-dimensional (got shape: {y.shape})") + + nll = F.cross_entropy(y_pred, y, reduction="sum", ignore_index=self._ignore_index) + self._sum_of_nll += nll.to(self._device) + self._num_tokens += (y != self._ignore_index).sum() + + @sync_all_reduce("_sum_of_nll", "_num_tokens") + def compute(self) -> float: + if self._num_tokens == 0: + raise NotComputableError("Perplexity must have at least one example before it can be computed.") + + return torch.exp(self._sum_of_nll / self._num_tokens).item() diff --git a/tests/ignite/metrics/nlp/test_perplexity.py b/tests/ignite/metrics/nlp/test_perplexity.py new file mode 100644 index 000000000000..26c207c03815 --- /dev/null +++ b/tests/ignite/metrics/nlp/test_perplexity.py @@ -0,0 +1,147 @@ +import pytest +import torch +import torch.nn.functional as F + +import ignite.distributed as idist +from ignite.engine import Engine +from ignite.exceptions import NotComputableError +from ignite.metrics.nlp import Perplexity + +torch.manual_seed(12) + + +def test_zero_sample(): + ppl = Perplexity() + with pytest.raises( + NotComputableError, match=r"Perplexity must have at least one example before it can be computed" + ): + ppl.compute() + + +def test_invalid_y_pred_shape(): + ppl = Perplexity() + with pytest.raises(ValueError, match=r"y_pred must be at least 2-dimensional"): + ppl.update((torch.tensor([1.0, 2.0]), torch.tensor([0]))) + + +def test_reset_clears_state(): + torch.manual_seed(2) + ppl = Perplexity() + + y_pred = torch.randn(2, 5, 3) + y = torch.randint(0, 5, (2, 3)) + ppl.update((y_pred, y)) + ppl.reset() + + with pytest.raises(NotComputableError): + ppl.compute() + + +def _reference_perplexity(y_pred, y): + """Reference implementation: token-weighted NLL.""" + nll = F.cross_entropy(y_pred, y, reduction="sum") + return torch.exp(nll / y.numel()).item() + + +@pytest.mark.parametrize("n_times", range(3)) +def test_compute_matches_reference(n_times, available_device): + ppl = Perplexity(device=available_device) + assert ppl._device == torch.device(available_device) + + torch.manual_seed(n_times) + y_pred = torch.randn(4, 10, 5) + y = torch.randint(0, 10, (4, 5)) + + ppl.reset() + ppl.update((y_pred, y)) + + ref = _reference_perplexity(y_pred, y) + assert pytest.approx(ppl.compute(), abs=1e-4) == ref + + +@pytest.mark.parametrize("n_times", range(3)) +def test_token_weighted_accumulation(n_times, available_device): + """Token-weighted accumulation across multiple batches.""" + ppl = Perplexity(device=available_device) + assert ppl._device == torch.device(available_device) + + torch.manual_seed(n_times) + + b1_pred = torch.randn(2, 5, 4) + b1_y = torch.randint(0, 5, (2, 4)) + b2_pred = torch.randn(3, 5, 4) + b2_y = torch.randint(0, 5, (3, 4)) + + ppl.reset() + ppl.update((b1_pred, b1_y)) + ppl.update((b2_pred, b2_y)) + + combined_pred = torch.cat([b1_pred, b2_pred], dim=0) + combined_y = torch.cat([b1_y, b2_y], dim=0) + ppl_ref = _reference_perplexity(combined_pred, combined_y) + + assert pytest.approx(ppl.compute(), abs=1e-4) == ppl_ref + + +@pytest.mark.distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.usefixtures("distributed") +class TestDistributed: + def test_accumulator_device(self): + metric_devices = [torch.device("cpu")] + device = idist.device() + if device.type != "xla": + metric_devices.append(device) + + for metric_device in metric_devices: + ppl = Perplexity(device=metric_device) + assert ppl._device == metric_device + assert ppl._sum_of_nll.device == metric_device, f"{ppl._sum_of_nll.device} vs {metric_device}" + + y_pred = torch.randn(2, 5, 3, device=device) + y = torch.randint(0, 5, (2, 3), device=device) + ppl.update((y_pred, y)) + + assert ppl._sum_of_nll.device == metric_device, f"{ppl._sum_of_nll.device} vs {metric_device}" + + @pytest.mark.parametrize("n_epochs", [1, 2]) + def test_integration(self, n_epochs): + rank = idist.get_rank() + torch.manual_seed(10 + rank) + + n_iters = 20 + batch_size = 4 + vocab_size = 10 + seq_len = 5 + + metric_devices = [torch.device("cpu")] + device = idist.device() + if device.type != "xla": + metric_devices.append(device) + + for metric_device in metric_devices: + y_true = torch.randint(0, vocab_size, size=(n_iters * batch_size, seq_len)).to(device) + y_preds = torch.randn(n_iters * batch_size, vocab_size, seq_len).to(device) + + def update(engine, i): + return ( + y_preds[i * batch_size : (i + 1) * batch_size], + y_true[i * batch_size : (i + 1) * batch_size], + ) + + engine = Engine(update) + ppl = Perplexity(device=metric_device) + ppl.attach(engine, "ppl") + + data = list(range(n_iters)) + engine.run(data=data, max_epochs=n_epochs) + + y_true_gathered = idist.all_gather(y_true) + y_preds_gathered = idist.all_gather(y_preds) + + assert "ppl" in engine.state.metrics + res = engine.state.metrics["ppl"] + + ref = _reference_perplexity(y_preds_gathered, y_true_gathered) + + assert pytest.approx(res, abs=1e-4) == ref