Skip to content

PyTorch convenience features#246

Open
chhwang wants to merge 86 commits intomainfrom
merge-test-pr215
Open

PyTorch convenience features#246
chhwang wants to merge 86 commits intomainfrom
merge-test-pr215

Conversation

@chhwang
Copy link
Copy Markdown
Contributor

@chhwang chhwang commented Apr 2, 2026

Python API

  • Tensor.eval(stream=None) — evaluate graph and return torch tensor
  • All ops accept torch.Tensor directly (auto-conversion)
  • Model API: set_model(), current_model(), use_model()
  • Planner(model=...) parameter

C++ fixes

  • Reduce ops: Tile config for correct grid alignment when fused
  • WwiseReduce: per-row reduction fix for multi-row tiles

Tests

  • 70+ Python op tests replacing C++ tests
  • test_eval.py: caching, recompilation, stream interleaving, chained ops
  • Deleted 9 C++ test files

Examples

  • Tutorials use eval() + torch tensors
  • MHA module matching FlashAttention-2 performance

chhwang and others added 30 commits May 27, 2024 04:35
- Introduced support for multiple Runtime instances
- Added utility functions for multi-runtime management
- Ensured backward compatibility with existing usage patterns of Runtime
- Added unit tests for multi-runtime functionality

---------

Co-authored-by: noli <t-ngerawork@microsoft.com>
- Adds Torch to ARK tensor conversion support
- New ModelBufferManager class handles external buffer registration and
simplifies buffer access during kernel initialization
- Adds test cases for ARK to Torch conversion support

---------

Co-authored-by: Changho Hwang <changhohwang@microsoft.com>
chhwang added 22 commits July 31, 2024 16:01
Merge main (including PRs #222, #235, #245) into the PR #215 branch.

Resolution strategy:
- Take main's version for all C++ core, Python API, and C++ binding files
  (features from PR #215 were reworked and landed via PRs #222 and #235)
- Remove superseded PR #215 files:
  - torch_mock.py (replaced by python/ark/torch/mock.py)
  - unittest_common.py (replaced by python/unittest/common.py)
  - model_buffer_manager.hpp (replaced by buffer_registry)
  - arkprof.py (Profiler class in main suffices)
  - model_7b_b1_s2048.py / plan_llama2_7b_b1_s2048.json (superseded by current llama examples)
- Rewrite test_conversion.py for current API:
  - get_torch_view() -> to_torch()
  - Remove delete_all_runtimes() / reset() calls
  - Use pytest_ark decorator and with-block pattern
- Rewrite torch_tutorial.py to use placeholder API instead of RuntimeModule
- Add test_conversion to test runner
@chhwang chhwang changed the title Enhance multi-runtime support and optimize tensor conversions PyTorch convenience features Apr 2, 2026
@chhwang chhwang mentioned this pull request Apr 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants