diff --git a/ignite/metrics/nlp/__init__.py b/ignite/metrics/nlp/__init__.py index 506f0bab51e1..2cfc991d1f66 100644 --- a/ignite/metrics/nlp/__init__.py +++ b/ignite/metrics/nlp/__init__.py @@ -1,9 +1,11 @@ from ignite.metrics.nlp.bleu import Bleu from ignite.metrics.nlp.rouge import Rouge, RougeL, RougeN +from ignite.metrics.nlp.word_error_rate import WordErrorRate __all__ = [ "Bleu", "Rouge", "RougeN", "RougeL", + "WordErrorRate", ] diff --git a/ignite/metrics/nlp/word_error_rate.py b/ignite/metrics/nlp/word_error_rate.py new file mode 100644 index 000000000000..77bbd2dea6c6 --- /dev/null +++ b/ignite/metrics/nlp/word_error_rate.py @@ -0,0 +1,134 @@ +from typing import Any, Callable, Sequence + +import torch +from torch.types import Number + +from ignite.exceptions import NotComputableError +from ignite.metrics.metric import Metric, reinit__is_reduced, sync_all_reduce + +__all__ = ["WordErrorRate"] + + +def _edit_distance(ref: Sequence[Any], pred: Sequence[Any]) -> int: + """Computes the Levenshtein distance between two sequences.""" + n = len(ref) + m = len(pred) + + if n == 0: + return m + if m == 0: + return n + + dp = list(range(m + 1)) + + for i in range(1, n + 1): + prev_diag = dp[0] + dp[0] = i + for j in range(1, m + 1): + temp = dp[j] + if ref[i - 1] == pred[j - 1]: + dp[j] = prev_diag + else: + dp[j] = min(dp[j - 1], dp[j], prev_diag) + 1 + prev_diag = temp + + return dp[m] + + +class _BaseErrorRate(Metric): + """ + Base class for error rate metrics based on Levenshtein distance (edit distance). + """ + + def __init__( + self, + output_transform: Callable = lambda x: x, + device: str | torch.device = torch.device("cpu"), + skip_unrolling: bool = False, + ): + super().__init__(output_transform=output_transform, device=device, skip_unrolling=skip_unrolling) + + @reinit__is_reduced + def reset(self) -> None: + self._num_errors = torch.tensor(0.0, device=self._device) + self._num_refs = torch.tensor(0.0, device=self._device) + super().reset() + + def _tokenize(self, text: str) -> Sequence[Any]: + raise NotImplementedError + + @reinit__is_reduced + def update(self, output: Sequence[str]) -> None: + y_pred, y = output[0], output[1] + + if isinstance(y_pred, str) and isinstance(y, str): + y_pred = [y_pred] + y = [y] + + if len(y_pred) != len(y): + raise ValueError( + f"y_pred and y must have the same length. Got y_pred of length {len(y_pred)} and y of length {len(y)}." + ) + + errors = 0.0 + refs = 0.0 + for p, r in zip(y_pred, y): + p_tokens = self._tokenize(p) + r_tokens = self._tokenize(r) + + errors += _edit_distance(r_tokens, p_tokens) + refs += len(r_tokens) + + self._num_errors += torch.tensor(errors, device=self._device) + self._num_refs += torch.tensor(refs, device=self._device) + + @sync_all_reduce("_num_errors", "_num_refs") + def compute(self) -> Number: + if self._num_refs == 0: + raise NotComputableError("Error rate must have at least one valid reference sequence to be computed.") + return (self._num_errors / self._num_refs).item() + + +class WordErrorRate(_BaseErrorRate): + r"""Calculates the Word Error Rate (WER). + + WER is defined as the total number of errors (substitutions, deletions, and insertions) + at the word level divided by the total number of words in the reference sequence. + + .. math:: + \text{WER} = \frac{S + D + I}{N} = \frac{S + D + I}{S + D + C} + + where :math:`S` is the number of substitutions, :math:`D` is the number of deletions, + :math:`I` is the number of insertions, :math:`C` is the number of correct words, + and :math:`N` is the total number of words in the reference (:math:`N = S + D + C`). + + - ``update`` must receive output of the form ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. + - `y_pred` must be a list of strings (predicted sentences). + - `y` must be a list of strings (reference sentences). + + Args: + output_transform: a callable that is used to transform the + :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the + form expected by the metric. This can be useful if, for example, you have a multi-output model and + you want to compute the metric with respect to one of the outputs. + device: specifies which device updates are accumulated on. Setting the metric's + device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By + default, CPU. + skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be + true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)`` + Alternatively, ``output_transform`` can be used to handle this. + + Examples: + .. code-block:: python + + from ignite.metrics.nlp import WordErrorRate + + wer = WordErrorRate() + y_pred = ["the cat sat on the mat", "hello world"] + y = ["the cat sat on mat", "hello world"] + wer.update((y_pred, y)) + print(wer.compute()) # Output: 0.2 (1 insertion / 5 reference words) + """ + + def _tokenize(self, text: str) -> Sequence[str]: + return text.split() diff --git a/tests/ignite/metrics/nlp/test_word_error_rate.py b/tests/ignite/metrics/nlp/test_word_error_rate.py new file mode 100644 index 000000000000..fdc86a07b3d7 --- /dev/null +++ b/tests/ignite/metrics/nlp/test_word_error_rate.py @@ -0,0 +1,67 @@ +import pytest +import torch + +import ignite.distributed as idist +from ignite.exceptions import NotComputableError +from ignite.metrics.nlp import WordErrorRate + + +def test_wer_wrong_inputs(): + wer = WordErrorRate() + + with pytest.raises(NotComputableError, match=r"Error rate must have at least one valid reference sequence"): + wer.compute() + + with pytest.raises(ValueError, match=r"y_pred and y must have the same length"): + wer.update((["a", "b"], ["a"])) + + with pytest.raises(ValueError, match=r"y_pred and y must have the same length"): + wer.update((["a"], ["a", "b"])) + + +def test_wer_compute(): + wer = WordErrorRate() + + # Exact match + wer.update((["hello world", "test sequence"], ["hello world", "test sequence"])) + assert pytest.approx(wer.compute()) == 0.0 + + # 1 Substitution + wer.reset() + wer.update((["hello word"], ["hello world"])) + # 1 error / 2 words = 0.5 + assert pytest.approx(wer.compute()) == 0.5 + + # 1 Deletion + wer.reset() + wer.update((["hello"], ["hello world"])) + # 1 error / 2 words = 0.5 + assert pytest.approx(wer.compute()) == 0.5 + + # 1 Insertion + wer.reset() + wer.update((["hello world test"], ["hello world"])) + # 1 error / 2 words = 0.5 + assert pytest.approx(wer.compute()) == 0.5 + + # Completely different + wer.reset() + wer.update((["completely different string"], ["hello world test sequence"])) + # 'completely', 'different', 'string' vs 'hello', 'world', 'test', 'sequence' + # 4 references. 3 predicted. It will be 4 errors (3 substitutions, 1 deletion). + assert pytest.approx(wer.compute()) == 1.0 + + +def test_wer_batching(): + wer = WordErrorRate() + # Batch 1 + wer.update((["the cat sat", "hello world"], ["the bat sat", "hello"])) + # Batch 2 + wer.update((["test string"], ["test string again"])) + + # 1 sub (the bat sat) = 1_e / 3_ref + # 1 ins (hello world) = 1_e / 1_ref + # 1 del (test string again) = 1_e / 3_ref + # Total errors = 3 + # Total refs = 3 + 1 + 3 = 7 + assert pytest.approx(wer.compute()) == 3 / 7