-
-
Notifications
You must be signed in to change notification settings - Fork 695
feat(metrics): add Perplexity metric to ignite.metrics.nlp #3743
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
steaphenai
wants to merge
13
commits into
pytorch:master
Choose a base branch
from
steaphenai:feat/perplexity-metric-pr
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
fa394ca
feat(metrics): add Perplexity metric to NLP metrics
steaphenai e1e1e5f
fix(metrics): detach Perplexity accumulators and refine tests
steaphenai 4335359
test(metrics): align Perplexity tests with metric patterns
steaphenai d5ab433
fix(metrics): address Perplexity review follow-ups
steaphenai dae37e9
test(metrics): use _reference_perplexity in token-weighted accumulati…
steaphenai f9ecaa1
Update tests/ignite/metrics/nlp/test_perplexity.py
steaphenai e5c0cfd
feat(metrics): add ignore_index to Perplexity, expose in docs, remove…
steaphenai 46a3b8d
Merge branch 'master' into feat/perplexity-metric-pr
steaphenai fa8fb7f
style: fix ruff formatting in test_perplexity.py
steaphenai 49ccdc4
Merge branch 'feat/perplexity-metric-pr' of https://github.com/steaph…
steaphenai c7a3720
fix(tests): fix token weighted accumulation test with different seq l…
steaphenai 3143650
fix(tests): use _reference_perplexity and matching seq lengths in acc…
steaphenai b514e12
fix(metrics): remove explicit double dtype from Perplexity accumulato…
steaphenai File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -351,6 +351,7 @@ Complete list of metrics | |
| SSIM | ||
| TopKCategoricalAccuracy | ||
| Bleu | ||
| Perplexity | ||
| Rouge | ||
| RougeL | ||
| RougeN | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,107 @@ | ||
| from collections.abc import Callable | ||
|
|
||
| import torch | ||
| import torch.nn.functional as F | ||
|
|
||
| from ignite.exceptions import NotComputableError | ||
| from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce | ||
|
|
||
| __all__ = ["Perplexity"] | ||
|
|
||
|
|
||
| class Perplexity(Metric): | ||
| r"""Calculates the `Perplexity <https://en.wikipedia.org/wiki/Perplexity>`_ of a language model. | ||
|
|
||
| .. math:: | ||
| \text{PPL}(W) = \exp \left( -\frac{1}{N} \sum_{i=1}^{N} \log P(w_i | w_1, \ldots, w_{i-1}) \right) | ||
|
|
||
| where :math:`N` is the total number of tokens and :math:`P(w_i | w_1, \ldots, w_{i-1})` is the | ||
| conditional probability of token :math:`w_i` given the preceding tokens. | ||
|
|
||
| Perplexity is computed as :math:`\exp(\text{NLL})` where NLL is the mean negative log-likelihood | ||
| over all tokens. Lower perplexity indicates a better language model. | ||
|
|
||
| - ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. | ||
| - `y_pred` must be a floating-point tensor of shape ``(batch_size, vocab_size, seq_len)`` | ||
| containing the unnormalized log-probabilities (logits). | ||
| - `y` must be a long tensor of shape ``(batch_size, seq_len)`` containing the target token indices. | ||
|
|
||
| Note: | ||
| Perplexity uses token-weighted accumulation rather than batch-average to avoid bias | ||
| towards shorter sequences. The total NLL and total token count are accumulated across | ||
| all batches, and the final perplexity is computed as ``exp(total_nll / total_tokens)``. | ||
|
|
||
| Args: | ||
| output_transform: a callable that is used to transform the | ||
| :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the | ||
| form expected by the metric. This can be useful if, for example, you have a multi-output model and | ||
| you want to compute the metric with respect to one of the outputs. | ||
| By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. | ||
| device: specifies which device updates are accumulated on. Setting the | ||
| metric's device to be the same as your ``update`` arguments ensures the ``update`` method is | ||
| non-blocking. By default, CPU. | ||
|
|
||
| Examples: | ||
|
|
||
| For more information on how metric works with :class:`~ignite.engine.engine.Engine`, visit :ref:`attach-engine`. | ||
|
|
||
| .. testcode:: | ||
|
|
||
| from ignite.metrics.nlp import Perplexity | ||
| import torch | ||
|
|
||
| ppl = Perplexity() | ||
|
|
||
| # batch_size=2, vocab_size=5, seq_len=3 | ||
| y_pred = torch.log_softmax(torch.randn(2, 5, 3), dim=1) | ||
| y = torch.randint(0, 5, (2, 3)) | ||
|
|
||
| ppl.update((y_pred, y)) | ||
|
|
||
| print(type(ppl.compute())) | ||
|
|
||
| .. testoutput:: | ||
|
|
||
| <class 'float'> | ||
|
|
||
| .. versionadded:: 0.5.2 | ||
| """ | ||
|
|
||
| _state_dict_all_req_keys = ("_sum_of_nll", "_num_tokens") | ||
|
|
||
| def __init__( | ||
|
steaphenai marked this conversation as resolved.
|
||
| self, | ||
| output_transform: Callable = lambda x: x, | ||
| device: str | torch.device = torch.device("cpu"), | ||
| ignore_index: int = -100, | ||
| ): | ||
| self._ignore_index = ignore_index | ||
| super().__init__(output_transform=output_transform, device=device) | ||
|
|
||
| @reinit__is_reduced | ||
| def reset(self) -> None: | ||
| self._sum_of_nll = torch.tensor(0.0, device=self._device) | ||
| self._num_tokens = torch.tensor(0, device=self._device) | ||
|
|
||
| @reinit__is_reduced | ||
| def update(self, output: tuple[torch.Tensor, torch.Tensor]) -> None: | ||
| y_pred, y = output | ||
|
steaphenai marked this conversation as resolved.
|
||
| y_pred = y_pred.detach() | ||
| y = y.detach() | ||
|
|
||
| if y_pred.ndim < 2: | ||
| raise ValueError(f"y_pred must be at least 2-dimensional (got shape: {y_pred.shape})") | ||
|
|
||
| if y.ndim < 1: | ||
| raise ValueError(f"y must be at least 1-dimensional (got shape: {y.shape})") | ||
|
|
||
| nll = F.cross_entropy(y_pred, y, reduction="sum", ignore_index=self._ignore_index) | ||
| self._sum_of_nll += nll.to(self._device) | ||
| self._num_tokens += (y != self._ignore_index).sum() | ||
|
|
||
| @sync_all_reduce("_sum_of_nll", "_num_tokens") | ||
| def compute(self) -> float: | ||
| if self._num_tokens == 0: | ||
| raise NotComputableError("Perplexity must have at least one example before it can be computed.") | ||
|
|
||
| return torch.exp(self._sum_of_nll / self._num_tokens).item() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,147 @@ | ||
| import pytest | ||
| import torch | ||
| import torch.nn.functional as F | ||
|
|
||
| import ignite.distributed as idist | ||
| from ignite.engine import Engine | ||
| from ignite.exceptions import NotComputableError | ||
| from ignite.metrics.nlp import Perplexity | ||
|
|
||
| torch.manual_seed(12) | ||
|
|
||
|
|
||
| def test_zero_sample(): | ||
| ppl = Perplexity() | ||
| with pytest.raises( | ||
| NotComputableError, match=r"Perplexity must have at least one example before it can be computed" | ||
| ): | ||
| ppl.compute() | ||
|
|
||
|
|
||
| def test_invalid_y_pred_shape(): | ||
| ppl = Perplexity() | ||
| with pytest.raises(ValueError, match=r"y_pred must be at least 2-dimensional"): | ||
| ppl.update((torch.tensor([1.0, 2.0]), torch.tensor([0]))) | ||
|
|
||
|
|
||
| def test_reset_clears_state(): | ||
| torch.manual_seed(2) | ||
| ppl = Perplexity() | ||
|
|
||
| y_pred = torch.randn(2, 5, 3) | ||
| y = torch.randint(0, 5, (2, 3)) | ||
| ppl.update((y_pred, y)) | ||
| ppl.reset() | ||
|
|
||
| with pytest.raises(NotComputableError): | ||
| ppl.compute() | ||
|
|
||
|
|
||
| def _reference_perplexity(y_pred, y): | ||
| """Reference implementation: token-weighted NLL.""" | ||
| nll = F.cross_entropy(y_pred, y, reduction="sum") | ||
| return torch.exp(nll / y.numel()).item() | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("n_times", range(3)) | ||
| def test_compute_matches_reference(n_times, available_device): | ||
| ppl = Perplexity(device=available_device) | ||
| assert ppl._device == torch.device(available_device) | ||
|
|
||
| torch.manual_seed(n_times) | ||
| y_pred = torch.randn(4, 10, 5) | ||
| y = torch.randint(0, 10, (4, 5)) | ||
|
|
||
| ppl.reset() | ||
| ppl.update((y_pred, y)) | ||
|
|
||
| ref = _reference_perplexity(y_pred, y) | ||
| assert pytest.approx(ppl.compute(), abs=1e-4) == ref | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("n_times", range(3)) | ||
| def test_token_weighted_accumulation(n_times, available_device): | ||
| """Token-weighted accumulation across multiple batches.""" | ||
| ppl = Perplexity(device=available_device) | ||
| assert ppl._device == torch.device(available_device) | ||
|
|
||
| torch.manual_seed(n_times) | ||
|
|
||
| b1_pred = torch.randn(2, 5, 4) | ||
| b1_y = torch.randint(0, 5, (2, 4)) | ||
| b2_pred = torch.randn(3, 5, 4) | ||
| b2_y = torch.randint(0, 5, (3, 4)) | ||
|
|
||
| ppl.reset() | ||
| ppl.update((b1_pred, b1_y)) | ||
| ppl.update((b2_pred, b2_y)) | ||
|
|
||
| combined_pred = torch.cat([b1_pred, b2_pred], dim=0) | ||
| combined_y = torch.cat([b1_y, b2_y], dim=0) | ||
| ppl_ref = _reference_perplexity(combined_pred, combined_y) | ||
|
|
||
| assert pytest.approx(ppl.compute(), abs=1e-4) == ppl_ref | ||
|
|
||
|
|
||
| @pytest.mark.distributed | ||
| @pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") | ||
| @pytest.mark.usefixtures("distributed") | ||
| class TestDistributed: | ||
| def test_accumulator_device(self): | ||
| metric_devices = [torch.device("cpu")] | ||
| device = idist.device() | ||
| if device.type != "xla": | ||
| metric_devices.append(device) | ||
|
|
||
| for metric_device in metric_devices: | ||
| ppl = Perplexity(device=metric_device) | ||
| assert ppl._device == metric_device | ||
| assert ppl._sum_of_nll.device == metric_device, f"{ppl._sum_of_nll.device} vs {metric_device}" | ||
|
|
||
| y_pred = torch.randn(2, 5, 3, device=device) | ||
| y = torch.randint(0, 5, (2, 3), device=device) | ||
| ppl.update((y_pred, y)) | ||
|
|
||
| assert ppl._sum_of_nll.device == metric_device, f"{ppl._sum_of_nll.device} vs {metric_device}" | ||
|
|
||
| @pytest.mark.parametrize("n_epochs", [1, 2]) | ||
| def test_integration(self, n_epochs): | ||
| rank = idist.get_rank() | ||
| torch.manual_seed(10 + rank) | ||
|
|
||
| n_iters = 20 | ||
| batch_size = 4 | ||
| vocab_size = 10 | ||
| seq_len = 5 | ||
|
|
||
| metric_devices = [torch.device("cpu")] | ||
| device = idist.device() | ||
| if device.type != "xla": | ||
| metric_devices.append(device) | ||
|
|
||
| for metric_device in metric_devices: | ||
| y_true = torch.randint(0, vocab_size, size=(n_iters * batch_size, seq_len)).to(device) | ||
| y_preds = torch.randn(n_iters * batch_size, vocab_size, seq_len).to(device) | ||
|
|
||
| def update(engine, i): | ||
| return ( | ||
| y_preds[i * batch_size : (i + 1) * batch_size], | ||
| y_true[i * batch_size : (i + 1) * batch_size], | ||
| ) | ||
|
|
||
| engine = Engine(update) | ||
| ppl = Perplexity(device=metric_device) | ||
| ppl.attach(engine, "ppl") | ||
|
|
||
| data = list(range(n_iters)) | ||
| engine.run(data=data, max_epochs=n_epochs) | ||
|
|
||
| y_true_gathered = idist.all_gather(y_true) | ||
| y_preds_gathered = idist.all_gather(y_preds) | ||
|
|
||
| assert "ppl" in engine.state.metrics | ||
| res = engine.state.metrics["ppl"] | ||
|
|
||
| ref = _reference_perplexity(y_preds_gathered, y_true_gathered) | ||
|
|
||
| assert pytest.approx(res, abs=1e-4) == ref |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.