-
Notifications
You must be signed in to change notification settings - Fork 486
New Metric: Soft Dynamic Time Warping (Soft-DTW)
#3287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
VijayVignesh1
wants to merge
17
commits into
Lightning-AI:master
Choose a base branch
from
VijayVignesh1:feature/soft_dtw_loss
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 7 commits
Commits
Show all changes
17 commits
Select commit
Hold shift + click to select a range
3a5444c
Initial Commit
VijayVignesh1 131c611
Modifying the implementation and adding initial testcases
VijayVignesh1 81be46d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] c34c862
Resolving pre-commit errors
VijayVignesh1 89aee8b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 80a63fa
Adding more tests and cleaning up the code
VijayVignesh1 f22fbca
Adding reduction parameter over batch dimension
VijayVignesh1 f6264e3
Removing float check for gamma
VijayVignesh1 4e66cad
Removing inline function
VijayVignesh1 6fe1555
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 8dc4fa5
Modifying docstring
VijayVignesh1 9ad7b81
Modifying docstring
VijayVignesh1 1649851
Mdifying _devel.txt
VijayVignesh1 1a9d93e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 8771562
Adding numba installation to lightning workflows
VijayVignesh1 41096dd
Removing cuda usage from softdtw reference
VijayVignesh1 0f34ea5
Merge branch 'master' into feature/soft_dtw_loss
justusschock File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| .. customcarditem:: | ||
| :header: Soft Dynamic Time Warping | ||
| :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg | ||
| :tags: timeseries | ||
|
|
||
| .. include:: ../links.rst | ||
|
|
||
| ######################### | ||
| Soft Dynamic Time Warping | ||
| ######################### | ||
|
|
||
| Module Interface | ||
| ________________ | ||
|
|
||
| .. autoclass:: torchmetrics.timeseries.SoftDTW | ||
| :exclude-members: update, compute | ||
|
|
||
|
|
||
| Functional Interface | ||
| ____________________ | ||
|
|
||
| .. autofunction:: torchmetrics.functional.timeseries.soft_dtw |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| pysdtw==0.0.5 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| # Copyright The Lightning team. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from torchmetrics.functional.timeseries.softdtw import soft_dtw | ||
|
|
||
| __all__ = ["soft_dtw"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,163 @@ | ||
| # Copyright The Lightning team. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| import math | ||
| from typing import Callable, Literal, Optional | ||
|
|
||
| import torch | ||
| from torch import Tensor | ||
|
|
||
|
|
||
| def _soft_dtw_validate_args( | ||
| preds: Tensor, target: Tensor, gamma: float, reduction: Literal["mean", "sum", "none"] | ||
| ) -> None: | ||
| """Validate the input arguments for the soft_dtw function.""" | ||
| valid_reduction = ("mean", "sum", "none") | ||
| if reduction not in valid_reduction: | ||
| raise ValueError(f"Argument `reduction` must be one of {valid_reduction}, but got {reduction}") | ||
| if preds.ndim != 3 or target.ndim != 3: | ||
| raise ValueError("Inputs preds and target must be 3-dimensional tensors of shape [B, N, D] and [B, M, D].") | ||
| if preds.shape[0] != target.shape[0]: | ||
| raise ValueError("Batch size of preds and target must be the same.") | ||
| if preds.shape[2] != target.shape[2]: | ||
| raise ValueError("Feature dimension of preds and target must be the same.") | ||
| if not isinstance(gamma, float) or gamma <= 0: | ||
| raise ValueError("Gamma must be a positive float.") | ||
|
|
||
|
|
||
| def _soft_dtw_update(preds: Tensor, target: Tensor, gamma: float, distance_fn: Optional[Callable] = None) -> Tensor: | ||
| """Compute the Soft-DTW distance between two batched sequences.""" | ||
| b, n, d = preds.shape | ||
| _, m, _ = target.shape | ||
| device, dtype = target.device, target.dtype | ||
| if preds.dtype != target.dtype: | ||
| target = target.to(preds.dtype) | ||
|
|
||
| if distance_fn is None: | ||
|
|
||
| def distance_fn(x: Tensor, y: Tensor) -> Tensor: | ||
| """Default to squared Euclidean distance.""" | ||
| return torch.cdist(x, y, p=2).pow(2) | ||
|
|
||
| distances = distance_fn(preds, target) # [B, N, M] | ||
|
|
||
| r = torch.ones((b, n + 2, m + 2), device=device, dtype=dtype) * math.inf | ||
| r[:, 0, 0] = 0.0 | ||
|
|
||
| def softmin(a: Tensor, b: Tensor, c: Tensor, gamma: float) -> Tensor: | ||
|
VijayVignesh1 marked this conversation as resolved.
Outdated
|
||
| """Compute the soft minimum of three tensors.""" | ||
| vals = torch.stack([a, b, c], dim=-1) | ||
| return -gamma * torch.logsumexp(-vals / gamma, dim=-1) | ||
|
|
||
| # Anti-diagonal approach inspired from https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=8400444 | ||
| for k in range(2, n + m + 1): | ||
| i_vals = torch.arange(1, n + 1, device=device) | ||
| j_vals = k - i_vals | ||
| mask = (j_vals >= 1) & (j_vals <= m) | ||
| i_vals = i_vals[mask] | ||
| j_vals = j_vals[mask] | ||
|
|
||
| if len(i_vals) == 0: | ||
| continue | ||
|
|
||
| r1 = r[:, i_vals - 1, j_vals - 1] | ||
| r2 = r[:, i_vals - 1, j_vals] | ||
| r3 = r[:, i_vals, j_vals - 1] | ||
| r[:, i_vals, j_vals] = distances[:, i_vals - 1, j_vals - 1] + softmin(r1, r2, r3, gamma) | ||
|
|
||
| return r[:, n, m] | ||
|
|
||
|
|
||
| def _soft_dtw_compute(scores: Tensor, reduction: Literal["sum", "mean", "none"] = "mean") -> Tensor: | ||
| """Aggregate the computed Soft-DTW distances based on the specified reduction method.""" | ||
| if reduction == "none": | ||
| return scores | ||
| if reduction == "mean": | ||
| return scores.mean() | ||
| return scores.sum() | ||
|
|
||
|
|
||
| def soft_dtw( | ||
| preds: Tensor, | ||
| target: Tensor, | ||
| gamma: float = 1.0, | ||
| distance_fn: Optional[Callable] = None, | ||
| reduction: Literal["sum", "mean", "none"] = "mean", | ||
| ) -> Tensor: | ||
| r"""Compute the Soft Dynamic Time Warping (Soft-DTW) distance between two batched sequences. | ||
|
|
||
| This is a differentiable relaxation of the classic Dynamic Time Warping (DTW) algorithm, introduced by | ||
| Marco Cuturi and Mathieu Blondel (2017). | ||
| It replaces the hard minimum in DTW recursion with a soft-minimum using a log-sum-exp formulation: | ||
|
|
||
| .. math:: | ||
| \text{softmin}_\gamma(a,b,c) = -\gamma \log \left( e^{-a/\gamma} + e^{-b/\gamma} + e^{-c/\gamma} \right) | ||
|
|
||
| The Soft-DTW recurrence is then defined as: | ||
|
|
||
| .. math:: | ||
| R_{i,j} = D_{i,j} + \text{softmin}_\gamma(R_{i-1,j}, R_{i,j-1}, R_{i-1,j-1}) | ||
|
|
||
| where :math:`D_{i,j}` is the pairwise distance between sequence elements :math:`x_i` and :math:`y_j`. It could be | ||
| computed using any differentiable distance function, such as squared Euclidean distance or cosine distance. | ||
|
|
||
| The final Soft-DTW distance is :math:`R_{N,M}`. | ||
|
|
||
| Args: | ||
| preds: Tensor of shape ``[B, N, D]`` — batch of input sequences. | ||
| target: Tensor of shape ``[B, M, D]`` — batch of target sequences. | ||
| gamma: Smoothing parameter (:math:`\gamma > 0`). | ||
| Smaller values make the loss closer to standard DTW (hard minimum), | ||
| while larger values produce a smoother and more differentiable surface. | ||
| distance_fn: Optional callable ``(x, y) -> [B, N, M]`` defining the pairwise distance matrix. | ||
| If ``None``, defaults to squared Euclidean distance. | ||
| reduction: indicates how to reduce over the batch dimension. Choose between [``sum``, ``mean``, ``none``]. | ||
| Defaults to ``mean``. | ||
|
|
||
| Returns: | ||
| A tensor of shape ``[B]`` containing the Soft-DTW distance for each sequence pair in the batch. | ||
|
|
||
| Raises: | ||
| ValueError: | ||
| If ``reduction`` is not one of [``sum``, ``mean``, ``none``]. | ||
| ValueError: | ||
| If ``gamma`` is not a positive float. | ||
| ValueError: | ||
| If input tensors to ``preds`` and ``target`` are not 3-dimensional | ||
| with the same batch size and feature dimension. | ||
|
|
||
| Example:: | ||
| >>> import torch | ||
| >>> from torchmetrics.functional.timeseries import soft_dtw | ||
| >>> | ||
| >>> x = torch.tensor([[[0.0], [1.0], [2.0]]]) # [B, N, D] | ||
| >>> y = torch.tensor([[[0.0], [2.0], [3.0]]]) # [B, M, D] | ||
| >>> soft_dtw(x, y, gamma=0.1) | ||
| tensor([0.4003]) | ||
|
|
||
|
|
||
| Example (custom distance function):: | ||
| >>> def cosine_dist(a, b): | ||
| ... a = torch.nn.functional.normalize(a, dim=-1) | ||
| ... b = torch.nn.functional.normalize(b, dim=-1) | ||
| ... return 1 - torch.bmm(a, b.transpose(1, 2)) | ||
| >>> | ||
| >>> x = torch.randn(2, 5, 3) | ||
| >>> y = torch.randn(2, 6, 3) | ||
| >>> soft_dtw(x, y, gamma=0.5, distance_fn=cosine_dist) | ||
| tensor([2.8301, 3.0128]) | ||
|
|
||
| """ | ||
| _soft_dtw_validate_args(preds, target, gamma, reduction) | ||
| scores = _soft_dtw_update(preds, target, gamma, distance_fn) | ||
| return _soft_dtw_compute(scores, reduction) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| # Copyright The Lightning team. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from torchmetrics.timeseries.softdtw import SoftDTW | ||
|
|
||
| __all__ = ["SoftDTW"] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.