Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ignite.metrics.mutual_information import MutualInformation
from ignite.metrics.nlp.bleu import Bleu
from ignite.metrics.nlp.rouge import Rouge, RougeL, RougeN
from ignite.metrics.nlp.perplexity import Perplexity
from ignite.metrics.precision import Precision
from ignite.metrics.precision_recall_curve import PrecisionRecallCurve
from ignite.metrics.psnr import PSNR
Expand Down Expand Up @@ -93,6 +94,7 @@
"Rouge",
"RougeN",
"RougeL",
"Perplexity",
"regression",
"clustering",
"fairness",
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/nlp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from ignite.metrics.nlp.bleu import Bleu
from ignite.metrics.nlp.perplexity import Perplexity
from ignite.metrics.nlp.rouge import Rouge, RougeL, RougeN

__all__ = [
"Bleu",
"Perplexity",
Comment thread
steaphenai marked this conversation as resolved.
"Rouge",
"RougeN",
"RougeL",
Expand Down
103 changes: 103 additions & 0 deletions ignite/metrics/nlp/perplexity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from collections.abc import Callable

import torch
import torch.nn.functional as F

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce

__all__ = ["Perplexity"]


class Perplexity(Metric):
r"""Calculates the `Perplexity <https://en.wikipedia.org/wiki/Perplexity>`_ of a language model.

.. math::
\text{PPL}(W) = \exp \left( -\frac{1}{N} \sum_{i=1}^{N} \log P(w_i | w_1, \ldots, w_{i-1}) \right)

where :math:`N` is the total number of tokens and :math:`P(w_i | w_1, \ldots, w_{i-1})` is the
conditional probability of token :math:`w_i` given the preceding tokens.

Perplexity is computed as :math:`\exp(\text{NLL})` where NLL is the mean negative log-likelihood
over all tokens. Lower perplexity indicates a better language model.

- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
- `y_pred` must be a floating-point tensor of shape ``(batch_size, vocab_size, seq_len)``
containing the unnormalized log-probabilities (logits).
- `y` must be a long tensor of shape ``(batch_size, seq_len)`` containing the target token indices.

Note:
Perplexity uses token-weighted accumulation rather than batch-average to avoid bias
towards shorter sequences. The total NLL and total token count are accumulated across
all batches, and the final perplexity is computed as ``exp(total_nll / total_tokens)``.

Args:
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.

Examples:

For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.

.. testcode::

from ignite.metrics.nlp import Perplexity
import torch

ppl = Perplexity()

# batch_size=2, vocab_size=5, seq_len=3
y_pred = torch.log_softmax(torch.randn(2, 5, 3), dim=1)
y = torch.randint(0, 5, (2, 3))

ppl.update((y_pred, y))

print(type(ppl.compute()))

.. testoutput::

<class 'float'>

.. versionadded:: 0.5.2
"""

_state_dict_all_req_keys = ("_sum_of_nll", "_num_tokens")

def __init__(
Comment thread
steaphenai marked this conversation as resolved.
self,
output_transform: Callable = lambda x: x,
device: str | torch.device = torch.device("cpu"),
):
super().__init__(output_transform=output_transform, device=device)

@reinit__is_reduced
def reset(self) -> None:
self._sum_of_nll = torch.tensor(0.0, dtype=torch.double, device=self._device)
self._num_tokens = torch.tensor(0, dtype=torch.long, device=self._device)

@reinit__is_reduced
def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None:
y_pred, y = output
Comment thread
steaphenai marked this conversation as resolved.

if y_pred.ndim < 2:
raise ValueError(f"y_pred must be at least 2-dimensional (got shape: {y_pred.shape})")

if y.ndim < 1:
raise ValueError(f"y must be at least 1-dimensional (got shape: {y.shape})")

nll = F.cross_entropy(y_pred, y, reduction="sum")
self._sum_of_nll += nll.to(self._device, dtype=torch.double)
self._num_tokens += y.numel()

@sync_all_reduce("_sum_of_nll", "_num_tokens")
def compute(self) -> float:
if self._num_tokens == 0:
raise NotComputableError("Perplexity must have at least one example before it can be computed.")

return torch.exp(self._sum_of_nll / self._num_tokens).item()
99 changes: 99 additions & 0 deletions tests/ignite/metrics/nlp/test_perplexity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import pytest
import torch
import torch.nn.functional as F

from ignite.exceptions import NotComputableError
from ignite.metrics.nlp import Perplexity


def test_zero_sample():
ppl = Perplexity()
ppl.reset()
with pytest.raises(NotComputableError):
ppl.compute()


def test_compute_matches_manual():
torch.manual_seed(42)
ppl = Perplexity()
ppl.reset()

y_pred = torch.randn(4, 10, 5)
y = torch.randint(0, 10, (4, 5))

ppl.update((y_pred, y))

nll_manual = F.cross_entropy(y_pred, y, reduction="sum").item()
ppl_manual = torch.exp(torch.tensor(nll_manual / y.numel())).item()

assert abs(ppl.compute() - ppl_manual) < 1e-4


def test_token_weighted_accumulation():
"""Token-weighted accumulation must differ from naive batch average."""
torch.manual_seed(0)
ppl = Perplexity()
ppl.reset()

# Two batches with different sequence lengths
b1_pred = torch.randn(2, 5, 4)
b1_y = torch.randint(0, 5, (2, 4))
b2_pred = torch.randn(3, 5, 10)
b2_y = torch.randint(0, 5, (3, 10))

ppl.update((b1_pred, b1_y))
ppl.update((b2_pred, b2_y))

nll1 = F.cross_entropy(b1_pred, b1_y, reduction="sum").item()
Comment thread
steaphenai marked this conversation as resolved.
Outdated
nll2 = F.cross_entropy(b2_pred, b2_y, reduction="sum").item()
total_tokens = b1_y.numel() + b2_y.numel()
ppl_ref = torch.exp(torch.tensor((nll1 + nll2) / total_tokens)).item()

assert abs(ppl.compute() - ppl_ref) < 1e-4


def test_returns_float():
Comment thread
steaphenai marked this conversation as resolved.
Outdated
torch.manual_seed(1)
ppl = Perplexity()
ppl.reset()

y_pred = torch.randn(2, 5, 3)
y = torch.randint(0, 5, (2, 3))
ppl.update((y_pred, y))

result = ppl.compute()
assert isinstance(result, float)


def test_invalid_y_pred_shape():
ppl = Perplexity()
ppl.reset()

with pytest.raises(ValueError, match="y_pred must be at least 2-dimensional"):
ppl.update((torch.tensor([1.0, 2.0]), torch.tensor([0])))


def test_reset_clears_state():
torch.manual_seed(2)
ppl = Perplexity()

y_pred = torch.randn(2, 5, 3)
y = torch.randint(0, 5, (2, 3))
ppl.update((y_pred, y))

ppl.reset()
with pytest.raises(NotComputableError):
ppl.compute()


def test_single_token():
ppl = Perplexity()
ppl.reset()

y_pred = torch.randn(1, 5, 1)
y = torch.randint(0, 5, (1, 1))
ppl.update((y_pred, y))

result = ppl.compute()
assert result > 0
assert isinstance(result, float)