-
Notifications
You must be signed in to change notification settings - Fork 95
fix: Allow loading CUDA-saved models on CPU-only machines #296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
fix: Allow loading CUDA-saved models on CPU-only machines #296
Conversation
|
Thank you for your pull request and welcome to our community. We could not parse the GitHub identity of the following contributors: Simon Openclaw.
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR aims to fix #295 by making CEBRA.load(...) gracefully fall back to CPU when a checkpoint indicates it was saved on a CUDA device but CUDA is unavailable at load time.
Changes:
- Add
_resolve_checkpoint_device()to convert checkpointdevice_values (string ortorch.device) into a runtime-valid device, with CUDA→CPU fallback. - Update
_load_cebra_with_sklearn_backend()to use the resolved device for model/criterion/solver.to(...)calls and to updatecebra_.device_(andcebra_.deviceon fallback). - Add tests that monkeypatch CUDA availability and validate successful load + inference.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
cebra/integrations/sklearn/cebra.py |
Introduces checkpoint-device resolution and applies it during sklearn-backend loading. |
tests/test_sklearn.py |
Adds regression tests intended to cover CUDA-saved checkpoint loading on CPU-only environments. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
cebra/integrations/sklearn/cebra.py
Outdated
| # Resolve device: use CPU when checkpoint was saved on CUDA but CUDA is not available | ||
| saved_device = state["device_"] | ||
| load_device = _resolve_checkpoint_device(saved_device) | ||
|
|
Copilot
AI
Feb 11, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new CPU-fallback logic only changes subsequent .to(load_device) calls, but loading a truly CUDA-saved checkpoint can still fail earlier in torch.load when the checkpoint contains CUDA tensors and CUDA isn’t available. Consider adding a retry/automatic fallback in CEBRA.load / _safe_torch_load that catches the CUDA deserialization RuntimeError and re-loads with map_location='cpu' (when the caller didn’t already pass map_location).
| # Train a model on CPU | ||
| cebra_model = cebra_sklearn_cebra.CEBRA( | ||
| model_architecture=model_architecture, | ||
| max_iterations=5, | ||
| device="cpu" | ||
| ).fit(X) | ||
|
|
||
| with _windows_compatible_tempfile(mode="w+b") as tempname: | ||
| # Save the model | ||
| cebra_model.save(tempname) | ||
|
|
||
| # Modify the checkpoint to have a CUDA device | ||
| checkpoint = cebra_sklearn_cebra._safe_torch_load(tempname) | ||
| checkpoint["state"]["device_"] = saved_device | ||
| torch.save(checkpoint, tempname) | ||
|
|
Copilot
AI
Feb 11, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test is described as loading a “CUDA-saved checkpoint”, but it trains/saves the model on CPU and only edits checkpoint['state']['device_']. That doesn’t exercise the common failure mode where the checkpoint’s state_dict tensors are actually on CUDA and torch.load fails unless map_location is used. Consider either generating a real CUDA checkpoint when available, or monkeypatching torch.load/_safe_torch_load to simulate the CUDA deserialization error and assert the loader retries/falls back correctly.
tests/test_sklearn.py
Outdated
| monkeypatch.setattr(torch.cuda, "is_available", lambda: False) | ||
|
|
||
| # Load with explicit map_location | ||
| loaded_model = cebra_sklearn_cebra.CEBRA.load(tempname) |
Copilot
AI
Feb 11, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test name/doc/comment mention an explicit map_location='cpu' override, but CEBRA.load(tempname) is called without passing map_location (or any kwargs). Either pass map_location via CEBRA.load(tempname, map_location='cpu') (since **kwargs are forwarded to torch.load) or rename/update the test/docstring to reflect what’s actually being tested.
| loaded_model = cebra_sklearn_cebra.CEBRA.load(tempname) | |
| loaded_model = cebra_sklearn_cebra.CEBRA.load(tempname, map_location='cpu') |
Add comprehensive tests for the CUDA-to-CPU fallback fix: - test_load_cuda_checkpoint_falls_back_to_cpu: parametrized with 4 CUDA device variants and 2 model architectures - test_load_cuda_checkpoint_with_device_override: tests map_location behavior These tests verify: - Models saved with device='cuda' can load on CPU-only machines - Device attributes are correctly updated to 'cpu' - Model parameters are on CPU - Model can perform inference (transform) after loading Expected to FAIL before the fix is applied.
b952d02 to
6fae503
Compare
|
Thank you for your contribution. We require contributors to sign our Contributor License Agreement (CLA). We do not have a signed CLA on file for you. In order for us to review and merge your code, please sign our CLA here. After you signed, you can comment on this PR with |
6fae503 to
efe8b95
Compare
|
Thanks for the review! I've addressed all the comments:
All tests pass (including the new ones). |
When a CEBRA checkpoint was saved on a CUDA device but is loaded on a machine without CUDA available, it now gracefully falls back to CPU instead of crashing with RuntimeError. Changes: - Add _resolve_checkpoint_device() helper to handle device resolution - Update _load_cebra_with_sklearn_backend() to use resolved device - Handle both string and torch.device types, including cuda:0 variants - Update model device attributes after resolution Fixes: Loading model saved with device='cuda' on CPU-only machine
efe8b95 to
97d5b90
Compare
|
Added the real CUDA-saved checkpoint to This checkpoint was saved with CUDA tensors and serves as:
The checkpoint is stored in PyTorch's newer directory format (version 3). While the current test environment can't directly load it due to format limitations, the The actual fix is still fully tested via the mock-based tests that simulate the CUDA deserialization error and verify the retry logic works correctly. |
|
See comment in #295 (comment); I think we need to discuss if there is an issue in the first place: Using So the implementation could be:
|
stes
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
left some minor comments.
In addition:
- code for generating the test data (on gpu) should be added (see the tests/ folder, this can be a separate utility
- the logic should be robust to other devices (e.g. mps)
- the binary files should not go to the repo -- but I can handle this once the PR is ready (and code for generating the checkpoint is available)
If we go for this change to the logic, I would recommend to change the high level signature of the CEBRA.load function; e.g. a default to map_storage could be added.
The value of the current logic is that when a user attempts to use a CEBRA model on a GPU machine and for some reasaon (e.g. failure to use docker with Nvidia, outdated driver, etc) the GPU is not available, we will see an error during load.
On the other hand, an automated mapping to cpu might be more user friendly and should be minimally documented. If we go with the auto-remapping, we should minimally add a warning when this happens though
cebra/integrations/sklearn/cebra.py
Outdated
| if legacy_mode: | ||
| checkpoint = torch.load(filename, weights_only=False, | ||
| **kwargs) | ||
| else: | ||
| with torch.serialization.safe_globals( | ||
| CEBRA_LOAD_SAFE_GLOBALS): | ||
| checkpoint = torch.load(filename, | ||
| weights_only=weights_only, | ||
| **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
duplicate; I recommend to use cebra.load again but adapt the map storage parameter instead, and fail on second attempt
cebra/integrations/sklearn/cebra.py
Outdated
| else: | ||
| raise | ||
| else: | ||
| raise |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raise a meaningful error here
…lLab#296 - Refactored _safe_torch_load() to use recursion instead of duplicate logic - Added meaningful error messages when CPU fallback fails - Added UserWarning when auto-remapping CUDA/MPS to CPU - Extended _resolve_checkpoint_device() to handle MPS fallback - Added test for MPS checkpoint fallback - Added test for meaningful error on retry failure - Added test for error with explicit map_location - Created tests/generate_cuda_checkpoint.py utility for GPU test data - Removed binary checkpoint files from repo - Updated .gitignore to exclude test checkpoint binaries All 53 tests pass (14 CUDA/MPS tests + 39 regression tests)
1e6749f to
55f7589
Compare
|
@stes Thanks for the review! All comments addressed:
Re: map_location='cpu' default: Demo branch robosimon#3 proves this alone fails (8/8 tests fail). The issue is state['device_'] still says 'cuda' after loading, and model.to('cuda') fails afterwards. Test results: All 53 tests pass (14 CUDA/MPS + 39 regression). Let me know if any adjustments needed! |
cebra/integrations/sklearn/cebra.py
Outdated
| Args: | ||
| device: The device from the checkpoint (str or torch.device). | ||
|
|
||
| Returns: | ||
| str: The resolved device string ('cpu' or validated device). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pls dont mention types in args/returns. type annotate instead
cebra/integrations/sklearn/cebra.py
Outdated
| if isinstance(device, torch.device): | ||
| device = str(device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not robust. use torch.device type instead of string parsing
https://docs.pytorch.org/docs/stable/tensor_attributes.html#torch.device
| f"got {type(device)}.") | ||
|
|
||
| fallback_to_cpu = False | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cebra/integrations/sklearn/cebra.py
Outdated
| args, state, state_dict = cebra_info['args'], cebra_info[ | ||
| 'state'], cebra_info['state_dict'] | ||
|
|
||
| # Resolve device: use CPU when checkpoint was saved on CUDA but CUDA is not available |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # Resolve device: use CPU when checkpoint was saved on CUDA but CUDA is not available |
remove comments that are obvious from context
cebra/integrations/sklearn/cebra.py
Outdated
| for key, value in state.items(): | ||
| setattr(cebra_, key, value) | ||
|
|
||
| # Update device attributes to the resolved device for the current runtime |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see above
| # Update device attributes to the resolved device for the current runtime |
cebra/integrations/sklearn/cebra.py
Outdated
| if isinstance(saved_device_str, | ||
| str) and saved_device_str.startswith("cuda") and load_device == "cpu": | ||
| cebra_.device = "cpu" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see above; lets use torch.device instead of string operations. e.g. instead of startswith you can check the .type of the device
- Add type annotations to _resolve_checkpoint_device: Union[str, torch.device] -> str - Remove type mentions from docstring Args/Returns - Use torch.device.type instead of string startswith checks - Remove obvious comments from _load_cebra_with_sklearn_backend - Use torch.device for device type checking in load backend All 10 related tests pass.
|
@stes All review comments addressed: _resolve_checkpoint_device:
_load_cebra_with_sklearn_backend:
All 10 related tests pass. Let me know if there's anything else! |
Fix #295
This PR adds automatic CPU fallback when loading CUDA-saved checkpoints on CPU-only machines.
Changes:
_resolve_checkpoint_device()helper to detect CUDA → CPU fallback scenario_load_cebra_with_sklearn_backend()to use resolved device for all.to()callsstrandtorch.devicetypes (including"cuda:0"variants)cebra_.device_andcebra_.deviceattributes after resolutionTest coverage:
"cuda","cuda:0",torch.device("cuda"),torch.device("cuda", 0)"offset1-model","parametrized-model-5"Verification: