Fix MetricCollection state comparison for nested sequence states#3337
Fix MetricCollection state comparison for nested sequence states#3337omkar-334 wants to merge 5 commits intoLightning-AI:masterfrom
Conversation
Signed-off-by: Omkar Kabde <omkarkabde@gmail.com>
Codecov Report❌ Patch coverage is ❌ Your project check has failed because the head coverage (36%) is below the target coverage (95%). You can increase the head coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## master #3337 +/- ##
======================================
- Coverage 37% 36% -0%
======================================
Files 349 349
Lines 19901 19907 +6
======================================
+ Hits 7264 7265 +1
- Misses 12637 12642 +5 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Pull request overview
Fixes a crash in MetricCollection when auto-merging compute groups for metrics whose states contain nested sequences (e.g., list of tuples), as reported in #3335.
Changes:
- Introduced a recursive state-value comparison helper to correctly compare nested structures in metric states.
- Updated
MetricCollection._equal_metric_statesto use the new recursive comparison. - Added a unit test covering nested sequence state values to prevent regressions.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
src/torchmetrics/collections.py |
Replaces tensor-only list state comparison with a recursive comparator that supports nested sequences/mappings. |
tests/unittests/bases/test_collections.py |
Adds a regression test ensuring compute group merging works with nested sequence state values. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
|
@justusschock all tests are passing.... can you review this? |
|
Is there any update on this? It would be great if metric collections containing e.g. MeanAveragePrecision can be used with compute_groups enabled. |
Borda
left a comment
There was a problem hiding this comment.
Thanks for tracking down this issue and the PR is well-targeted. A few things stand out:
The root cause was clear: _equal_metric_states used isinstance(state1, list) which misses tuple states that MeanAveragePrecision stores. The fix generalizes this to all Sequence types via a recursive helper.
The new _equal_state_value function is clean -- it handles Tensor, Mapping, Sequence (with string exclusion), and falls back to direct == for primitives. No new imports needed, no API changes.
The test case is well-constructed: DummyNestedListMetric faithfully reproduces the MeanAveragePrecision state shape (list[tuple[tensor, tensor]]), and the assertions verify both the compute-group merging and the state comparison.
Thematic areas to consider:
The helper function is internal (not exported in init.py) which is correct, but the docstring could benefit from mentioning what types it supports.
For future robustness, the MeanAveragePrecision reproduction in the linked issue would be a valuable addition (or a comment noting the fix covers it).
Consider whether dict states with tensor values are explicitly tested somewhere -- the recursive Mapping branch is new and worth confirming.
I've left a couple of minor observations below.
| return string[: -len(suffix)] if string.endswith(suffix) else string | ||
|
|
||
|
|
||
| def _equal_state_value(state1: Any, state2: Any) -> bool: |
There was a problem hiding this comment.
The _equal_state_value docstring is one line. Consider listing supported types for callers: e.g., "Recursively compare metric state values. Supports: Tensor (shape+value), Mapping (key+recursive value), Sequence/str-excluded (length+recursive element), and primitives (direct ==)."
|
|
||
| def _equal_state_value(state1: Any, state2: Any) -> bool: | ||
| """Recursively compare metric state values while preserving structure checks.""" | ||
| if type(state1) is not type(state2): |
There was a problem hiding this comment.
type(state1) is not type(state2) is intentionally strict (exact type match), which is the right call here -- but worth a comment since most Python code uses != or isinstance. Consider # noqa: E721 or a short note.
| if isinstance(state1, Mapping): | ||
| return state1.keys() == state2.keys() and all(_equal_state_value(state1[k], state2[k]) for k in state1) | ||
|
|
||
| if isinstance(state1, Sequence) and not isinstance(state1, str): | ||
| return len(state1) == len(state2) and all(_equal_state_value(s1, s2) for s1, s2 in zip(state1, state2)) |
There was a problem hiding this comment.
The Mapping and Sequence branches use generator expressions with all(). For very deep nesting (e.g., list of list of list of ...) this could hit Python recursion limits. In practice metric states are rarely >3 levels deep, so this is theoretical -- but note the limit if it matters for your use case.


What does this PR do?
Fixes #3335
Before submitting
PR review
Fixes a bug where
MetricCollectioncrashes while auto-merging compute groups for metrics with nested sequence state, such asMeanAveragePrecision.Previously,
_equal_metric_statesassumed that list-valued state always contained tensors and accessed.shapedirectly on each element. Thisbreaks for metrics whose state contains tuples or other nested structures, leading to:
Fix - i replaced the tensor-only list comparison with a recursive state comparision function
📚 Documentation preview 📚: https://torchmetrics--3337.org.readthedocs.build/en/3337/