Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ Complete list of metrics
SSIM
TopKCategoricalAccuracy
Bleu
Perplexity
Rouge
RougeL
RougeN
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ignite.metrics.mutual_information import MutualInformation
from ignite.metrics.nlp.bleu import Bleu
from ignite.metrics.nlp.rouge import Rouge, RougeL, RougeN
from ignite.metrics.nlp.perplexity import Perplexity
from ignite.metrics.precision import Precision
from ignite.metrics.precision_recall_curve import PrecisionRecallCurve
from ignite.metrics.psnr import PSNR
Expand Down Expand Up @@ -93,6 +94,7 @@
"Rouge",
"RougeN",
"RougeL",
"Perplexity",
"regression",
"clustering",
"fairness",
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/nlp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from ignite.metrics.nlp.bleu import Bleu
from ignite.metrics.nlp.perplexity import Perplexity
from ignite.metrics.nlp.rouge import Rouge, RougeL, RougeN

__all__ = [
"Bleu",
"Perplexity",
Comment thread
steaphenai marked this conversation as resolved.
"Rouge",
"RougeN",
"RougeL",
Expand Down
107 changes: 107 additions & 0 deletions ignite/metrics/nlp/perplexity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from collections.abc import Callable

import torch
import torch.nn.functional as F

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce

__all__ = ["Perplexity"]


class Perplexity(Metric):
r"""Calculates the `Perplexity <https://en.wikipedia.org/wiki/Perplexity>`_ of a language model.

.. math::
\text{PPL}(W) = \exp \left( -\frac{1}{N} \sum_{i=1}^{N} \log P(w_i | w_1, \ldots, w_{i-1}) \right)

where :math:`N` is the total number of tokens and :math:`P(w_i | w_1, \ldots, w_{i-1})` is the
conditional probability of token :math:`w_i` given the preceding tokens.

Perplexity is computed as :math:`\exp(\text{NLL})` where NLL is the mean negative log-likelihood
over all tokens. Lower perplexity indicates a better language model.

- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
- `y_pred` must be a floating-point tensor of shape ``(batch_size, vocab_size, seq_len)``
containing the unnormalized log-probabilities (logits).
- `y` must be a long tensor of shape ``(batch_size, seq_len)`` containing the target token indices.

Note:
Perplexity uses token-weighted accumulation rather than batch-average to avoid bias
towards shorter sequences. The total NLL and total token count are accumulated across
all batches, and the final perplexity is computed as ``exp(total_nll / total_tokens)``.

Args:
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.

Examples:

For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.

.. testcode::

from ignite.metrics.nlp import Perplexity
import torch

ppl = Perplexity()

# batch_size=2, vocab_size=5, seq_len=3
y_pred = torch.log_softmax(torch.randn(2, 5, 3), dim=1)
y = torch.randint(0, 5, (2, 3))

ppl.update((y_pred, y))

print(type(ppl.compute()))

.. testoutput::

<class 'float'>

.. versionadded:: 0.5.2
"""

_state_dict_all_req_keys = ("_sum_of_nll", "_num_tokens")

def __init__(
Comment thread
steaphenai marked this conversation as resolved.
self,
output_transform: Callable = lambda x: x,
device: str | torch.device = torch.device("cpu"),
ignore_index: int = -100,
):
self._ignore_index = ignore_index
super().__init__(output_transform=output_transform, device=device)

@reinit__is_reduced
def reset(self) -> None:
self._sum_of_nll = torch.tensor(0.0, device=self._device)
self._num_tokens = torch.tensor(0, device=self._device)

@reinit__is_reduced
def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None:
y_pred, y = output
Comment thread
steaphenai marked this conversation as resolved.
y_pred = y_pred.detach()
y = y.detach()

if y_pred.ndim < 2:
raise ValueError(f"y_pred must be at least 2-dimensional (got shape: {y_pred.shape})")

if y.ndim < 1:
raise ValueError(f"y must be at least 1-dimensional (got shape: {y.shape})")

nll = F.cross_entropy(y_pred, y, reduction="sum", ignore_index=self._ignore_index)
self._sum_of_nll += nll.to(self._device)
self._num_tokens += (y != self._ignore_index).sum()

@sync_all_reduce("_sum_of_nll", "_num_tokens")
def compute(self) -> float:
if self._num_tokens == 0:
raise NotComputableError("Perplexity must have at least one example before it can be computed.")

return torch.exp(self._sum_of_nll / self._num_tokens).item()
147 changes: 147 additions & 0 deletions tests/ignite/metrics/nlp/test_perplexity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import pytest
import torch
import torch.nn.functional as F

import ignite.distributed as idist
from ignite.engine import Engine
from ignite.exceptions import NotComputableError
from ignite.metrics.nlp import Perplexity

torch.manual_seed(12)


def test_zero_sample():
ppl = Perplexity()
with pytest.raises(
NotComputableError, match=r"Perplexity must have at least one example before it can be computed"
):
ppl.compute()


def test_invalid_y_pred_shape():
ppl = Perplexity()
with pytest.raises(ValueError, match=r"y_pred must be at least 2-dimensional"):
ppl.update((torch.tensor([1.0, 2.0]), torch.tensor([0])))


def test_reset_clears_state():
torch.manual_seed(2)
ppl = Perplexity()

y_pred = torch.randn(2, 5, 3)
y = torch.randint(0, 5, (2, 3))
ppl.update((y_pred, y))
ppl.reset()

with pytest.raises(NotComputableError):
ppl.compute()


def _reference_perplexity(y_pred, y):
"""Reference implementation: token-weighted NLL."""
nll = F.cross_entropy(y_pred, y, reduction="sum")
return torch.exp(nll / y.numel()).item()


@pytest.mark.parametrize("n_times", range(3))
def test_compute_matches_reference(n_times, available_device):
ppl = Perplexity(device=available_device)
assert ppl._device == torch.device(available_device)

torch.manual_seed(n_times)
y_pred = torch.randn(4, 10, 5)
y = torch.randint(0, 10, (4, 5))

ppl.reset()
ppl.update((y_pred, y))

ref = _reference_perplexity(y_pred, y)
assert pytest.approx(ppl.compute(), abs=1e-4) == ref


@pytest.mark.parametrize("n_times", range(3))
def test_token_weighted_accumulation(n_times, available_device):
"""Token-weighted accumulation across multiple batches."""
ppl = Perplexity(device=available_device)
assert ppl._device == torch.device(available_device)

torch.manual_seed(n_times)

b1_pred = torch.randn(2, 5, 4)
b1_y = torch.randint(0, 5, (2, 4))
b2_pred = torch.randn(3, 5, 4)
b2_y = torch.randint(0, 5, (3, 4))

ppl.reset()
ppl.update((b1_pred, b1_y))
ppl.update((b2_pred, b2_y))

combined_pred = torch.cat([b1_pred, b2_pred], dim=0)
combined_y = torch.cat([b1_y, b2_y], dim=0)
ppl_ref = _reference_perplexity(combined_pred, combined_y)

assert pytest.approx(ppl.compute(), abs=1e-4) == ppl_ref


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.usefixtures("distributed")
class TestDistributed:
def test_accumulator_device(self):
metric_devices = [torch.device("cpu")]
device = idist.device()
if device.type != "xla":
metric_devices.append(device)

for metric_device in metric_devices:
ppl = Perplexity(device=metric_device)
assert ppl._device == metric_device
assert ppl._sum_of_nll.device == metric_device, f"{ppl._sum_of_nll.device} vs {metric_device}"

y_pred = torch.randn(2, 5, 3, device=device)
y = torch.randint(0, 5, (2, 3), device=device)
ppl.update((y_pred, y))

assert ppl._sum_of_nll.device == metric_device, f"{ppl._sum_of_nll.device} vs {metric_device}"

@pytest.mark.parametrize("n_epochs", [1, 2])
def test_integration(self, n_epochs):
rank = idist.get_rank()
torch.manual_seed(10 + rank)

n_iters = 20
batch_size = 4
vocab_size = 10
seq_len = 5

metric_devices = [torch.device("cpu")]
device = idist.device()
if device.type != "xla":
metric_devices.append(device)

for metric_device in metric_devices:
y_true = torch.randint(0, vocab_size, size=(n_iters * batch_size, seq_len)).to(device)
y_preds = torch.randn(n_iters * batch_size, vocab_size, seq_len).to(device)

def update(engine, i):
return (
y_preds[i * batch_size : (i + 1) * batch_size],
y_true[i * batch_size : (i + 1) * batch_size],
)

engine = Engine(update)
ppl = Perplexity(device=metric_device)
ppl.attach(engine, "ppl")

data = list(range(n_iters))
engine.run(data=data, max_epochs=n_epochs)

y_true_gathered = idist.all_gather(y_true)
y_preds_gathered = idist.all_gather(y_preds)

assert "ppl" in engine.state.metrics
res = engine.state.metrics["ppl"]

ref = _reference_perplexity(y_preds_gathered, y_true_gathered)

assert pytest.approx(res, abs=1e-4) == ref