Skip to content

fix(pytorch): place unsqueeze axis at the correct position#1496

Open
areporeporepo wants to merge 2 commits into
fastmachinelearning:mainfrom
areporeporepo:fix-pytorch-unsqueeze-axis
Open

fix(pytorch): place unsqueeze axis at the correct position#1496
areporeporepo wants to merge 2 commits into
fastmachinelearning:mainfrom
areporeporepo:fix-pytorch-unsqueeze-axis

Conversation

@areporeporepo

Copy link
Copy Markdown

Description

The PyTorch unsqueeze parser placed the new size-1 axis using
output_shape.index(output_shape[squeeze_dim]) — i.e. it searched for the
value of the indexed dimension. list.index() returns the first dimension
that shares that size, so the axis was inserted in the wrong position whenever:

  • the dim is negative (e.g. the very common torch.unsqueeze(x, dim=-1)), or
  • the target dimension's size matches an earlier dimension.

Negative dims were also never normalized to torch's accepted range.

Example (on main)

For an input of shape (N, 4):

call expected target_shape parsed on main
torch.unsqueeze(x, dim=-1) [4, 1] [1, 4]

And for (N, 4, 4), torch.unsqueeze(x, dim=2) should give [4, 1, 4] but
main produces [1, 4, 4], because list.index(4) returns the first 4.

Fix

Insert the size-1 axis directly at squeeze_dim, normalizing negative dims to
the range torch accepts ([-(D+1), D]). This is a small, self-contained change
in hls4ml/converters/pytorch/reshape.py.

Testing

Added test_unsqueeze to test/pytest/test_pytorch_api.py, parametrized over
backends and io types and mirroring the existing test_squeeze. It exercises
torch.unsqueeze(x, dim=-1) and asserts the resulting reshape reports
target_shape == [4, 1]. The test fails on main ([1, 4]) and passes with
this change. The existing test_squeeze (which covers unsqueeze(x, dim=1))
continues to pass, and pre-commit passes on both edited files.

🤖 Generated with Claude Code

https://claude.ai/code/session_01SVVaAFkshNYfkJLo6GLukb

The PyTorch `unsqueeze` parser located the new size-1 axis with
`output_shape.index(output_shape[squeeze_dim])`, i.e. by the *value* of the
indexed dimension. `list.index` returns the first dimension that happens to
share that size, so `torch.unsqueeze(x, dim=-1)` (and any unsqueeze whose
target dimension collides in size with an earlier one) inserted the axis at
the wrong position. Negative dims were also never normalized.

Insert the size-1 axis directly at `squeeze_dim`, normalizing negative dims
to the range torch accepts ([-(D+1), D]). Adds a regression test for the
common `dim=-1` case across backends and io types.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01SVVaAFkshNYfkJLo6GLukb
Copilot AI review requested due to automatic review settings June 29, 2026 18:59

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes the PyTorch unsqueeze shape parser so the new size-1 axis is inserted at the correct position (especially for negative dim and duplicate dimension sizes), and adds a regression test to ensure dim=-1 behaves as expected.

Changes:

  • Fix parse_unsqueeze_layer to insert the new axis at squeeze_dim and normalize negative dims.
  • Add a PyTorch API regression test validating torch.unsqueeze(x, dim=-1) results in target_shape == [4, 1].

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
hls4ml/converters/pytorch/reshape.py Corrects unsqueeze axis insertion logic and negative-dim handling.
test/pytest/test_pytorch_api.py Adds a regression test covering unsqueeze(dim=-1) shape parsing and numerical parity.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +84 to +86
if squeeze_dim < 0:
squeeze_dim += len(output_shape) + 1
output_shape.insert(squeeze_dim, 1)
Reject `dim` outside torch's accepted [-(D+1), D] range with a clear error
instead of letting `list.insert` silently clamp an invalid value to a wrong
shape. Addresses the Copilot review comment on the PR.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01SVVaAFkshNYfkJLo6GLukb
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants