Add mps support when loading model from checkpoint#181
Add mps support when loading model from checkpoint#181etherealsunshine wants to merge 7 commits intoprescient-design:mainfrom
Conversation
…/lobster into add_mps_support
|
On the lookout for a python implementation of Flash Attention for MPS architecture, as of right now, I can't seem to find any. |
|
@etherealsunshine have you tried this out with MPS? |
|
Hi @ncfrey, so MPS gets auto-detected correctly, Flash Attention properly falls back to SDPA on MPS, and basic tensor operations work. I just tested the device detection and configuration logic directly since I don't have access to the pretrained models - but the core MPS handling works correctly. Dont have immediate plans to for training a custom model yet, so if you think thats sufficient I can add these validation tests to the tests. Let me know what you think! |
LiudengZhang
left a comment
There was a problem hiding this comment.
Looks good to me — clean and well-scoped change.
The auto-detection fallback order (CUDA → MPS → CPU) makes sense and is consistent with what PyTorch Lightning does internally. The torch.backends.mps.is_available() guard is correct.
One minor thought: if a user passes device="mps" with use_flash_attn=True, the existing guard on line 1104 (if use_flash_attn and device != "cuda") will already handle it correctly by disabling flash attention. So no edge case there.
Description
Adds support for Apple Metal Perfomance shaders when loading models from checkpoint in
UME. Addresses #180Type of Change