Skip to content

Fix MetricCollection state comparison for nested sequence states#3337

Open
omkar-334 wants to merge 5 commits intoLightning-AI:masterfrom
omkar-334:fix-metric
Open

Fix MetricCollection state comparison for nested sequence states#3337
omkar-334 wants to merge 5 commits intoLightning-AI:masterfrom
omkar-334:fix-metric

Conversation

@omkar-334
Copy link
Copy Markdown

@omkar-334 omkar-334 commented Mar 17, 2026

What does this PR do?

Fixes #3335

Before submitting
  • Was this discussed/agreed via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?
PR review

Fixes a bug where MetricCollection crashes while auto-merging compute groups for metrics with nested sequence state, such as
MeanAveragePrecision.

Previously, _equal_metric_states assumed that list-valued state always contained tensors and accessed .shape directly on each element. This
breaks for metrics whose state contains tuples or other nested structures, leading to:

AttributeError: 'tuple' object has no attribute 'shape'

Fix - i replaced the tensor-only list comparison with a recursive state comparision function


📚 Documentation preview 📚: https://torchmetrics--3337.org.readthedocs.build/en/3337/

Signed-off-by: Omkar Kabde <omkarkabde@gmail.com>
@omkar-334
Copy link
Copy Markdown
Author

omkar-334 commented Mar 17, 2026

Before -
Screenshot 2026-03-18 at 1 55 44 AM

After -
Screenshot 2026-03-18 at 1 53 58 AM

@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 18, 2026

Codecov Report

❌ Patch coverage is 9.09091% with 10 lines in your changes missing coverage. Please review.
✅ Project coverage is 36%. Comparing base (d184220) to head (1f374cd).

❌ Your project check has failed because the head coverage (36%) is below the target coverage (95%). You can increase the head coverage or adjust the target coverage.

Additional details and impacted files
@@          Coverage Diff           @@
##           master   #3337   +/-   ##
======================================
- Coverage      37%     36%   -0%     
======================================
  Files         349     349           
  Lines       19901   19907    +6     
======================================
+ Hits         7264    7265    +1     
- Misses      12637   12642    +5     
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@Borda Borda added the bug / fix Something isn't working label Mar 18, 2026
@Borda Borda requested a review from Copilot March 18, 2026 15:09
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes a crash in MetricCollection when auto-merging compute groups for metrics whose states contain nested sequences (e.g., list of tuples), as reported in #3335.

Changes:

  • Introduced a recursive state-value comparison helper to correctly compare nested structures in metric states.
  • Updated MetricCollection._equal_metric_states to use the new recursive comparison.
  • Added a unit test covering nested sequence state values to prevent regressions.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.

File Description
src/torchmetrics/collections.py Replaces tensor-only list state comparison with a recursive comparator that supports nested sequences/mappings.
tests/unittests/bases/test_collections.py Adds a regression test ensuring compute group merging works with nested sequence state values.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

@omkar-334
Copy link
Copy Markdown
Author

@justusschock all tests are passing.... can you review this?

@m-matthias
Copy link
Copy Markdown

Is there any update on this? It would be great if metric collections containing e.g. MeanAveragePrecision can be used with compute_groups enabled.

Copy link
Copy Markdown
Collaborator

@Borda Borda left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for tracking down this issue and the PR is well-targeted. A few things stand out:

The root cause was clear: _equal_metric_states used isinstance(state1, list) which misses tuple states that MeanAveragePrecision stores. The fix generalizes this to all Sequence types via a recursive helper.

The new _equal_state_value function is clean -- it handles Tensor, Mapping, Sequence (with string exclusion), and falls back to direct == for primitives. No new imports needed, no API changes.

The test case is well-constructed: DummyNestedListMetric faithfully reproduces the MeanAveragePrecision state shape (list[tuple[tensor, tensor]]), and the assertions verify both the compute-group merging and the state comparison.

Thematic areas to consider:

The helper function is internal (not exported in init.py) which is correct, but the docstring could benefit from mentioning what types it supports.
For future robustness, the MeanAveragePrecision reproduction in the linked issue would be a valuable addition (or a comment noting the fix covers it).
Consider whether dict states with tensor values are explicitly tested somewhere -- the recursive Mapping branch is new and worth confirming.
I've left a couple of minor observations below.

return string[: -len(suffix)] if string.endswith(suffix) else string


def _equal_state_value(state1: Any, state2: Any) -> bool:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _equal_state_value docstring is one line. Consider listing supported types for callers: e.g., "Recursively compare metric state values. Supports: Tensor (shape+value), Mapping (key+recursive value), Sequence/str-excluded (length+recursive element), and primitives (direct ==)."


def _equal_state_value(state1: Any, state2: Any) -> bool:
"""Recursively compare metric state values while preserving structure checks."""
if type(state1) is not type(state2):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type(state1) is not type(state2) is intentionally strict (exact type match), which is the right call here -- but worth a comment since most Python code uses != or isinstance. Consider # noqa: E721 or a short note.

Comment on lines +67 to +71
if isinstance(state1, Mapping):
return state1.keys() == state2.keys() and all(_equal_state_value(state1[k], state2[k]) for k in state1)

if isinstance(state1, Sequence) and not isinstance(state1, str):
return len(state1) == len(state2) and all(_equal_state_value(s1, s2) for s1, s2 in zip(state1, state2))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Mapping and Sequence branches use generator expressions with all(). For very deep nesting (e.g., list of list of list of ...) this could hit Python recursion limits. In practice metric states are rarely >3 levels deep, so this is theoretical -- but note the limit if it matters for your use case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug / fix Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

MetricCollection not working with MeanAveragePrecision

4 participants