diff --git a/exegol/utils/DockerUtils.py b/exegol/utils/DockerUtils.py index b6e4eb4e..bb8cf901 100644 --- a/exegol/utils/DockerUtils.py +++ b/exegol/utils/DockerUtils.py @@ -134,6 +134,11 @@ def createContainer(self, model: ExegolContainerTemplate, temporary: bool = Fals "tty": model.config.tty, "mounts": model.config.getVolumes(), "working_dir": model.config.getWorkingDir()} + gpu_flag = self.isGPUAvailable() + if gpu_flag: + docker_args["device_requests"] = [ + docker.types.DeviceRequest(count=-1 if gpu_flag == "all" else 1,capabilities=[["gpu"]]) + ] if temporary: # Only the 'run' function support the "remove" parameter docker_create_function = self.__client.containers.run @@ -198,6 +203,21 @@ def getContainer(self, tag: str) -> ExegolContainer: # In this case, ObjectNotFound is raised raise ObjectNotFound + def isGPUAvailable(self) -> Optional[str]: + """Check if the GPU is available and return the appropriate device ('cuda', 'mps', or 'cpu'). + Return the appropriate value for Docker's --gpus flag.""" + try: + import torch, numpy + if os.name == "nt" or os.name == "posix": + if torch.cuda.is_available(): + num_gpus = torch.cuda.device_count() + return f'"device=0"' if num_gpus == 1 else '"all"' + if "darwin" in os.uname().sysname.lower() and torch.backends.mps.is_available(): + return None + except ImportError: + pass + return None + # # # Volumes Section # # # def __loadDockerVolume(self, volume_path: str, volume_name: str) -> Volume: diff --git a/requirements.txt b/requirements.txt index cf72eb36..c778ef31 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,5 @@ GitPython~=3.1.43 PyYAML>=6.0.2 argcomplete~=3.5.0 tzlocal~=5.2; platform_system != 'Linux' +torch~=2.6.0 +numpy~=2.2.3