Skip to content

Fix incorrect dtype default in aten_full causing ONNX shape-inference errors#2926

Merged
titaiwangms merged 5 commits into
mainfrom
copilot/bugfix-onnx-export-aten-full
Jun 3, 2026
Merged

Fix incorrect dtype default in aten_full causing ONNX shape-inference errors#2926
titaiwangms merged 5 commits into
mainfrom
copilot/bugfix-onnx-export-aten-full

Conversation

Copy link
Copy Markdown
Contributor

Copilot AI commented Jun 1, 2026

aten_full defaulted dtype to FLOAT.dtype. When fill_value is integral and no dtype is supplied, the value was cast to float, yielding type mismatches downstream (e.g. Expand: Inferred elem type differs from existing elem type: (1) vs (7)) and failing strict shape inference.

Change

  • onnxscript/function_libs/torch_lib/ops/core.py: default dtype changed from FLOAT.dtype to -1, matching aten_full_like. The existing if dtype != -1: guard now skips the cast when dtype is unspecified, preserving fill_value's original type.
def aten_full(
    size: Union[INT64, INT32],
    fill_value: TensorType,
    dtype: int = -1,  # was FLOAT.dtype
    ...
):
    if dtype != -1:
        fill_value = op.Cast(fill_value, to=dtype)
    size = op.Cast(size, to=INT64.dtype)
    return op.Expand(fill_value, size)

With the prior default, an integral fill_value with no dtype reproduced the Expand TypeInferenceError; the change resolves it while leaving the explicit-dtype path unchanged.

Copilot AI changed the title [WIP] Fix bug in ONNX export of aten_full due to incorrect dtype default Fix incorrect dtype default in aten_full causing ONNX shape-inference errors Jun 1, 2026
Copilot AI requested a review from titaiwangms June 1, 2026 16:23
@justinchuby justinchuby requested a review from Copilot June 1, 2026 16:24
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes aten_full’s default dtype handling in the Torch lib so that, when dtype is not specified, the fill_value’s original element type is preserved (avoiding downstream ONNX type/shape-inference mismatches such as Expand type inference errors).

Changes:

  • Update aten_full’s dtype default from FLOAT.dtype to -1 (unspecified), aligning it with aten_full_like.
  • Rely on the existing if dtype != -1: guard to skip Cast when dtype is unspecified.

@codecov
Copy link
Copy Markdown

codecov Bot commented Jun 1, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 72.65%. Comparing base (b16ef16) to head (f85e90b).
⚠️ Report is 2 commits behind head on main.
✅ All tests successful. No failed tests found.

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #2926   +/-   ##
=======================================
  Coverage   72.64%   72.65%           
=======================================
  Files         259      259           
  Lines       31655    31655           
  Branches     2981     2981           
=======================================
+ Hits        22997    22998    +1     
  Misses       7649     7649           
+ Partials     1009     1008    -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@titaiwangms titaiwangms marked this pull request as ready for review June 1, 2026 16:56
@titaiwangms titaiwangms enabled auto-merge (squash) June 1, 2026 16:56
@titaiwangms titaiwangms disabled auto-merge June 1, 2026 18:32
@titaiwangms titaiwangms enabled auto-merge (squash) June 1, 2026 18:33
@titaiwangms titaiwangms merged commit 02a0030 into main Jun 3, 2026
33 checks passed
@titaiwangms titaiwangms deleted the copilot/bugfix-onnx-export-aten-full branch June 3, 2026 17:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Development

Successfully merging this pull request may close these issues.

bug in ONNX export of aten_full due to incorrect dtype default argument

5 participants