Skip to content

Add mps support when loading model from checkpoint#181

Open
etherealsunshine wants to merge 7 commits intoprescient-design:mainfrom
etherealsunshine:add_mps_support
Open

Add mps support when loading model from checkpoint#181
etherealsunshine wants to merge 7 commits intoprescient-design:mainfrom
etherealsunshine:add_mps_support

Conversation

@etherealsunshine
Copy link
Copy Markdown
Contributor

@etherealsunshine etherealsunshine commented Aug 6, 2025

Description

Adds support for Apple Metal Perfomance shaders when loading models from checkpoint in UME. Addresses #180

Type of Change

  • Bug fix
  • New feature
  • Documentation update
  • Performance improvement
  • Code refactoring

@etherealsunshine
Copy link
Copy Markdown
Contributor Author

etherealsunshine commented Aug 6, 2025

On the lookout for a python implementation of Flash Attention for MPS architecture, as of right now, I can't seem to find any.

@ncfrey
Copy link
Copy Markdown
Contributor

ncfrey commented Aug 13, 2025

@etherealsunshine have you tried this out with MPS?

@etherealsunshine
Copy link
Copy Markdown
Contributor Author

Hi @ncfrey, so MPS gets auto-detected correctly, Flash Attention properly falls back to SDPA on MPS, and basic tensor operations work. I just tested the device detection and configuration logic directly since I don't have access to the pretrained models - but the core MPS handling works correctly. Dont have immediate plans to for training a custom model yet, so if you think thats sufficient I can add these validation tests to the tests. Let me know what you think!

Copy link
Copy Markdown

@LiudengZhang LiudengZhang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me — clean and well-scoped change.

The auto-detection fallback order (CUDA → MPS → CPU) makes sense and is consistent with what PyTorch Lightning does internally. The torch.backends.mps.is_available() guard is correct.

One minor thought: if a user passes device="mps" with use_flash_attn=True, the existing guard on line 1104 (if use_flash_attn and device != "cuda") will already handle it correctly by disabling flash attention. So no edge case there.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants