Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ignite/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from ignite.metrics.mutual_information import MutualInformation
from ignite.metrics.nlp.bleu import Bleu
from ignite.metrics.nlp.rouge import Rouge, RougeL, RougeN
from ignite.metrics.nlp.perplexity import Perplexity
from ignite.metrics.precision import Precision
from ignite.metrics.precision_recall_curve import PrecisionRecallCurve
from ignite.metrics.psnr import PSNR
Expand Down Expand Up @@ -93,6 +94,7 @@
"Rouge",
"RougeN",
"RougeL",
"Perplexity",
"regression",
"clustering",
"fairness",
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/nlp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from ignite.metrics.nlp.bleu import Bleu
from ignite.metrics.nlp.perplexity import Perplexity
from ignite.metrics.nlp.rouge import Rouge, RougeL, RougeN

__all__ = [
"Bleu",
"Perplexity",
Comment thread
steaphenai marked this conversation as resolved.
"Rouge",
"RougeN",
"RougeL",
Expand Down
105 changes: 105 additions & 0 deletions ignite/metrics/nlp/perplexity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from collections.abc import Callable

import torch
import torch.nn.functional as F

from ignite.exceptions import NotComputableError
from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce

__all__ = ["Perplexity"]


class Perplexity(Metric):
r"""Calculates the `Perplexity <https://en.wikipedia.org/wiki/Perplexity>`_ of a language model.

.. math::
\text{PPL}(W) = \exp \left( -\frac{1}{N} \sum_{i=1}^{N} \log P(w_i | w_1, \ldots, w_{i-1}) \right)

where :math:`N` is the total number of tokens and :math:`P(w_i | w_1, \ldots, w_{i-1})` is the
conditional probability of token :math:`w_i` given the preceding tokens.

Perplexity is computed as :math:`\exp(\text{NLL})` where NLL is the mean negative log-likelihood
over all tokens. Lower perplexity indicates a better language model.

- ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
- `y_pred` must be a floating-point tensor of shape ``(batch_size, vocab_size, seq_len)``
containing the unnormalized log-probabilities (logits).
- `y` must be a long tensor of shape ``(batch_size, seq_len)`` containing the target token indices.

Note:
Perplexity uses token-weighted accumulation rather than batch-average to avoid bias
towards shorter sequences. The total NLL and total token count are accumulated across
all batches, and the final perplexity is computed as ``exp(total_nll / total_tokens)``.

Args:
output_transform: a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.

Examples:

For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`.

.. testcode::

from ignite.metrics.nlp import Perplexity
import torch

ppl = Perplexity()

# batch_size=2, vocab_size=5, seq_len=3
y_pred = torch.log_softmax(torch.randn(2, 5, 3), dim=1)
y = torch.randint(0, 5, (2, 3))

ppl.update((y_pred, y))

print(type(ppl.compute()))

.. testoutput::

<class 'float'>

.. versionadded:: 0.5.2
"""

_state_dict_all_req_keys = ("_sum_of_nll", "_num_tokens")

def __init__(
Comment thread
steaphenai marked this conversation as resolved.
self,
output_transform: Callable = lambda x: x,
device: str | torch.device = torch.device("cpu"),
):
super().__init__(output_transform=output_transform, device=device)

@reinit__is_reduced
def reset(self) -> None:
self._sum_of_nll = torch.tensor(0.0, dtype=torch.double, device=self._device)
self._num_tokens = torch.tensor(0, dtype=torch.long, device=self._device)

@reinit__is_reduced
def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None:
y_pred, y = output
Comment thread
steaphenai marked this conversation as resolved.
y_pred = y_pred.detach()
y = y.detach()

if y_pred.ndim < 2:
raise ValueError(f"y_pred must be at least 2-dimensional (got shape: {y_pred.shape})")

if y.ndim < 1:
raise ValueError(f"y must be at least 1-dimensional (got shape: {y.shape})")

nll = F.cross_entropy(y_pred, y, reduction="sum")
self._sum_of_nll += nll.to(self._device, dtype=torch.double)
self._num_tokens += y.numel()

@sync_all_reduce("_sum_of_nll", "_num_tokens")
def compute(self) -> float:
if self._num_tokens == 0:
raise NotComputableError("Perplexity must have at least one example before it can be computed.")

return torch.exp(self._sum_of_nll / self._num_tokens).item()
162 changes: 162 additions & 0 deletions tests/ignite/metrics/nlp/test_perplexity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import pytest
import torch
import torch.nn.functional as F

import ignite.distributed as idist
from ignite.engine import Engine
from ignite.exceptions import NotComputableError
from ignite.metrics.nlp import Perplexity

torch.manual_seed(12)


def test_zero_sample():
ppl = Perplexity()
with pytest.raises(NotComputableError, match=r"Perplexity must have at least one example before it can be computed"):
ppl.compute()


def test_invalid_y_pred_shape():
ppl = Perplexity()
with pytest.raises(ValueError, match=r"y_pred must be at least 2-dimensional"):
ppl.update((torch.tensor([1.0, 2.0]), torch.tensor([0])))


def test_reset_clears_state():
torch.manual_seed(2)
ppl = Perplexity()

y_pred = torch.randn(2, 5, 3)
y = torch.randint(0, 5, (2, 3))
ppl.update((y_pred, y))
ppl.reset()

with pytest.raises(NotComputableError):
ppl.compute()


def _reference_perplexity(y_pred, y):
"""Reference implementation: token-weighted NLL."""
nll = F.cross_entropy(y_pred, y, reduction="sum")
return torch.exp(nll / y.numel()).item()


@pytest.mark.parametrize("n_times", range(3))
def test_compute_matches_reference(n_times, available_device):
ppl = Perplexity(device=available_device)
assert ppl._device == torch.device(available_device)

torch.manual_seed(n_times)
y_pred = torch.randn(4, 10, 5)
y = torch.randint(0, 10, (4, 5))

ppl.reset()
ppl.update((y_pred, y))

ref = _reference_perplexity(y_pred, y)
assert pytest.approx(ppl.compute(), abs=1e-4) == ref


@pytest.mark.parametrize("n_times", range(3))
def test_token_weighted_accumulation(n_times, available_device):
"""Token-weighted accumulation across multiple batches."""
ppl = Perplexity(device=available_device)
assert ppl._device == torch.device(available_device)

torch.manual_seed(n_times)

b1_pred = torch.randn(2, 5, 4)
b1_y = torch.randint(0, 5, (2, 4))
b2_pred = torch.randn(3, 5, 10)
b2_y = torch.randint(0, 5, (3, 10))

ppl.reset()
ppl.update((b1_pred, b1_y))
ppl.update((b2_pred, b2_y))

combined_pred = torch.cat([b1_pred, b2_pred], dim=0)
combined_y = torch.cat([b1_y, b2_y], dim=0)
ppl_ref = _reference_perplexity(combined_pred, combined_y)

assert pytest.approx(ppl.compute(), abs=1e-4) == ppl_ref


def test_accumulator_detached():
Comment thread
steaphenai marked this conversation as resolved.
Outdated
"""Metric state tensors must be detached from the computation graph."""
ppl = Perplexity()
ppl.reset()

y_pred = torch.randn(2, 5, 3, requires_grad=True)
y = torch.randint(0, 5, (2, 3))
ppl.update((y_pred, y))

assert not ppl._sum_of_nll.requires_grad
assert not ppl._num_tokens.requires_grad


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.usefixtures("distributed")
class TestDistributed:
def test_accumulator_device(self):
metric_devices = [torch.device("cpu")]
device = idist.device()
if device.type != "xla":
metric_devices.append(device)

for metric_device in metric_devices:
ppl = Perplexity(device=metric_device)
assert ppl._device == metric_device
assert ppl._sum_of_nll.device == metric_device, (
f"{ppl._sum_of_nll.device} vs {metric_device}"
)

y_pred = torch.randn(2, 5, 3, device=device)
y = torch.randint(0, 5, (2, 3), device=device)
ppl.update((y_pred, y))

assert ppl._sum_of_nll.device == metric_device, (
f"{ppl._sum_of_nll.device} vs {metric_device}"
)

@pytest.mark.parametrize("n_epochs", [1, 2])
def test_integration(self, n_epochs):
rank = idist.get_rank()
torch.manual_seed(10 + rank)

n_iters = 20
batch_size = 4
vocab_size = 10
seq_len = 5

metric_devices = [torch.device("cpu")]
device = idist.device()
if device.type != "xla":
metric_devices.append(device)

for metric_device in metric_devices:
y_true = torch.randint(0, vocab_size, size=(n_iters * batch_size, seq_len)).to(device)
y_preds = torch.randn(n_iters * batch_size, vocab_size, seq_len).to(device)

def update(engine, i):
return (
y_preds[i * batch_size: (i + 1) * batch_size],
y_true[i * batch_size: (i + 1) * batch_size],
)

engine = Engine(update)
ppl = Perplexity(device=metric_device)
ppl.attach(engine, "ppl")

data = list(range(n_iters))
engine.run(data=data, max_epochs=n_epochs)

y_true_gathered = idist.all_gather(y_true)
y_preds_gathered = idist.all_gather(y_preds)

assert "ppl" in engine.state.metrics
res = engine.state.metrics["ppl"]

ref = _reference_perplexity(y_preds_gathered, y_true_gathered)

assert pytest.approx(res, abs=1e-4) == ref
Loading