fix(pytorch): place unsqueeze axis at the correct position#1496
Open
areporeporepo wants to merge 2 commits into
Open
fix(pytorch): place unsqueeze axis at the correct position#1496areporeporepo wants to merge 2 commits into
areporeporepo wants to merge 2 commits into
Conversation
The PyTorch `unsqueeze` parser located the new size-1 axis with `output_shape.index(output_shape[squeeze_dim])`, i.e. by the *value* of the indexed dimension. `list.index` returns the first dimension that happens to share that size, so `torch.unsqueeze(x, dim=-1)` (and any unsqueeze whose target dimension collides in size with an earlier one) inserted the axis at the wrong position. Negative dims were also never normalized. Insert the size-1 axis directly at `squeeze_dim`, normalizing negative dims to the range torch accepts ([-(D+1), D]). Adds a regression test for the common `dim=-1` case across backends and io types. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01SVVaAFkshNYfkJLo6GLukb
There was a problem hiding this comment.
Pull request overview
This PR fixes the PyTorch unsqueeze shape parser so the new size-1 axis is inserted at the correct position (especially for negative dim and duplicate dimension sizes), and adds a regression test to ensure dim=-1 behaves as expected.
Changes:
- Fix
parse_unsqueeze_layerto insert the new axis atsqueeze_dimand normalize negative dims. - Add a PyTorch API regression test validating
torch.unsqueeze(x, dim=-1)results intarget_shape == [4, 1].
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
hls4ml/converters/pytorch/reshape.py |
Corrects unsqueeze axis insertion logic and negative-dim handling. |
test/pytest/test_pytorch_api.py |
Adds a regression test covering unsqueeze(dim=-1) shape parsing and numerical parity. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+84
to
+86
| if squeeze_dim < 0: | ||
| squeeze_dim += len(output_shape) + 1 | ||
| output_shape.insert(squeeze_dim, 1) |
Reject `dim` outside torch's accepted [-(D+1), D] range with a clear error instead of letting `list.insert` silently clamp an invalid value to a wrong shape. Addresses the Copilot review comment on the PR. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_01SVVaAFkshNYfkJLo6GLukb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
The PyTorch
unsqueezeparser placed the new size-1 axis usingoutput_shape.index(output_shape[squeeze_dim])— i.e. it searched for thevalue of the indexed dimension.
list.index()returns the first dimensionthat shares that size, so the axis was inserted in the wrong position whenever:
torch.unsqueeze(x, dim=-1)), orNegative dims were also never normalized to torch's accepted range.
Example (on
main)For an input of shape
(N, 4):target_shapemaintorch.unsqueeze(x, dim=-1)[4, 1][1, 4]❌And for
(N, 4, 4),torch.unsqueeze(x, dim=2)should give[4, 1, 4]butmainproduces[1, 4, 4], becauselist.index(4)returns the first4.Fix
Insert the size-1 axis directly at
squeeze_dim, normalizing negative dims tothe range torch accepts (
[-(D+1), D]). This is a small, self-contained changein
hls4ml/converters/pytorch/reshape.py.Testing
Added
test_unsqueezetotest/pytest/test_pytorch_api.py, parametrized overbackends and io types and mirroring the existing
test_squeeze. It exercisestorch.unsqueeze(x, dim=-1)and asserts the resulting reshape reportstarget_shape == [4, 1]. The test fails onmain([1, 4]) and passes withthis change. The existing
test_squeeze(which coversunsqueeze(x, dim=1))continues to pass, and
pre-commitpasses on both edited files.🤖 Generated with Claude Code
https://claude.ai/code/session_01SVVaAFkshNYfkJLo6GLukb