diff --git a/exegol/console/cli/actions/ExegolParameters.py b/exegol/console/cli/actions/ExegolParameters.py index 05947830..dc188212 100644 --- a/exegol/console/cli/actions/ExegolParameters.py +++ b/exegol/console/cli/actions/ExegolParameters.py @@ -26,6 +26,7 @@ def __init__(self) -> None: "Get a [blue]tmux[/blue] shell": "exegol start --shell [blue]tmux[/blue]", "Share a specific [blue]hardware device[/blue] [bright_black](e.g. Proxmark)[/bright_black]": "exegol start -d [bright_magenta]/dev/ttyACM0[/bright_magenta]", "Share every [blue]USB device[/blue] connected to the host": "exegol start -d [magenta]/dev/bus/usb/[/magenta]", + "Enable [blue]NVIDIA GPU[/blue] passthrough": "exegol start [blue]gpu[/blue] [bright_blue]free[/bright_blue] --gpu [magenta]nvidia[/magenta]", } def __call__(self, *args, **kwargs): diff --git a/exegol/console/cli/actions/GenericParameters.py b/exegol/console/cli/actions/GenericParameters.py index 91c9e268..8d14096f 100644 --- a/exegol/console/cli/actions/GenericParameters.py +++ b/exegol/console/cli/actions/GenericParameters.py @@ -240,7 +240,13 @@ def __init__(self, groupArgs: List[GroupArg]): dest="devices", default=[], action="append", - help="Add host [default not bold]device(s)[/default not bold] at the container creation (example: -d /dev/ttyACM0 -d /dev/bus/usb/)") + help="Add host [default not bold]device(s)[/default not bold] at the container creation (example: -d /dev/ttyACM0 -d /dev/bus/usb/ -d nvidia.com/gpu=all)") + self.gpu = Option("--gpu", + dest="gpu", + choices=["nvidia"], + default=None, + action="store", + help="Enable GPU passthrough using Docker CDI on Linux hosts (example: --gpu nvidia)") self.hosts_file = Option("--hosts-file", dest="hosts_file", @@ -263,6 +269,7 @@ def __init__(self, groupArgs: List[GroupArg]): {"arg": self.hostname, "required": False}, {"arg": self.privileged, "required": False}, {"arg": self.devices, "required": False}, + {"arg": self.gpu, "required": False}, {"arg": self.X11, "required": False}, {"arg": self.my_resources, "required": False}, {"arg": self.exegol_resources, "required": False}, diff --git a/exegol/model/ContainerConfig.py b/exegol/model/ContainerConfig.py index 0012da7a..c6c5d85c 100644 --- a/exegol/model/ContainerConfig.py +++ b/exegol/model/ContainerConfig.py @@ -113,6 +113,7 @@ def __init__(self, container: Optional[Container] = None, container_name: Option self.__wrapper_start_enabled: bool = False self.__mounts: List[Mount] = [] self.__devices: List[str] = [] + self.__device_requests: List[Dict[str, Union[str, int, List[str]]]] = [] self.__capabilities: List[str] = [] self.__sysctls: Dict[str, str] = {} self.__envs: Dict[str, str] = {} @@ -190,6 +191,19 @@ def __parseContainerConfig(self, container: Container) -> None: self.__devices.append( f"{device.get('PathOnHost', '?')}:{device.get('PathInContainer', '?')}:{device.get('CgroupPermissions', '?')}") logger.debug(f"└── Load devices : {self.__devices}") + device_requests = host_config.get("DeviceRequests", []) + if device_requests is not None: + for request in device_requests: + if request is None: + continue + driver = request.get("Driver") + device_ids = request.get("DeviceIDs") + if driver == "cdi" and isinstance(device_ids, list): + for device_id in device_ids: + if isinstance(device_id, str): + self.__addCdiDevice(device_id) + else: + logger.debug(f"Unsupported device request: {request}") extra_hosts = host_config.get("ExtraHosts", []) for entry in extra_hosts: hostname, ip = entry.rsplit(":", 1) @@ -344,6 +358,18 @@ async def configFromUser(self) -> "ContainerConfig": if ParametersManager().volumes is not None: for volume in ParametersManager().volumes: await self.addRawVolume(volume) + gpu_vendor = ParametersManager().gpu + if gpu_vendor: + if EnvInfo.isMacHost() or EnvInfo.isWindowsHost(): + logger.critical("The --gpu option is currently supported only on Linux hosts.") + gpu_cdi_selectors = { + "nvidia": "nvidia.com/gpu=all", + } + cdi_selector = gpu_cdi_selectors.get(gpu_vendor) + if cdi_selector is None: + logger.critical(f"Unsupported GPU vendor for --gpu: {gpu_vendor}") + if cdi_selector not in ParametersManager().devices: + self.addUserDevice(cdi_selector) if ParametersManager().devices is not None: for device in ParametersManager().devices: self.addUserDevice(device) @@ -1330,6 +1356,10 @@ def getDevices(self) -> List[str]: """Devices config getter""" return self.__devices + def getDeviceRequests(self) -> List[Dict[str, Union[str, int, List[str]]]]: + """Device requests config getter (used for CDI selectors).""" + return self.__device_requests + def addEnv(self, key: str, value: str) -> None: """Add or update an environment variable to the container configuration""" self.__envs[key] = value @@ -1566,8 +1596,20 @@ def addUserDevice(self, user_device_config: str) -> None: logger.warning("Orbstack does not support (yet) USB device passthrough.") logger.verbose("Official doc: https://docs.orbstack.dev/machines/#usb-devices") logger.critical("Device configuration cannot be applied, aborting operation.") + if self.__isCdiDevice(user_device_config): + self.__addCdiDevice(user_device_config) + return self.__addDevice(user_device_config) + def __addCdiDevice(self, device_selector: str) -> None: + """Add a CDI selector as a Docker device request.""" + self.__device_requests.append({"Driver": "cdi", "DeviceIDs": [device_selector]}) + + @staticmethod + def __isCdiDevice(device: str) -> bool: + """Return True when user input looks like a CDI selector.""" + return re.match(r"^[^/:]+/[^:=]+=[^:]+$", device) is not None + async def addRawPort(self, user_test_port: str) -> None: """Add port config or range of ports from user input. Format must be [:][-][:[-]][:] @@ -1711,11 +1753,21 @@ def getTextMounts(self, verbose: bool = False) -> str: def getTextDevices(self, verbose: bool = False) -> str: """Text formatter for Devices configuration. The verbose mode show full device configuration.""" result = '' - for device in self.__devices: + text_devices = list(self.__devices) + for request in self.__device_requests: + driver = request.get("Driver") + device_ids = request.get("DeviceIDs") + if driver == "cdi" and isinstance(device_ids, list): + text_devices.extend([device for device in device_ids if isinstance(device, str)]) + for device in text_devices: if verbose: result += f"{device}{os.linesep}" else: - src, dest = device.split(':')[:2] + split_device = device.split(':') + if len(split_device) < 2: + result += f"{device}{os.linesep}" + continue + src, dest = split_device[:2] if src == dest: result += f"{src}{os.linesep}" else: diff --git a/exegol/model/ExegolContainer.py b/exegol/model/ExegolContainer.py index d3cb91db..e3061a95 100644 --- a/exegol/model/ExegolContainer.py +++ b/exegol/model/ExegolContainer.py @@ -150,7 +150,12 @@ async def __start_container(self) -> None: self.__container.start() except APIError as e: logger.debug(e) - logger.critical(f"Docker raised a critical error when starting the container [green]{self.name}[/green], error message is: {e.explanation}") + explanation = str(e.explanation if e.explanation is not None else "") + if "cdi device injection failed" in explanation.lower() and "nvidia.com/gpu=all" in explanation.lower(): + logger.warning("Hint: verify that nvidia-container-toolkit is installed. See https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html") + explanation = explanation.replace('[', '\\[') + logger.error(f"Docker raised a critical error when starting the container [green]{self.name}[/green], error message is: {explanation}") + logger.critical("Error while starting exegol container. Exiting.") if not self.config.legacy_entrypoint: # TODO improve startup compatibility check try: # Try to find log / startup messages. Will time out after 2 seconds if the image don't support status update through container logs. diff --git a/exegol/utils/DockerUtils.py b/exegol/utils/DockerUtils.py index eb899e1e..6c92825c 100644 --- a/exegol/utils/DockerUtils.py +++ b/exegol/utils/DockerUtils.py @@ -129,6 +129,7 @@ def createContainer(self, model: ExegolContainerTemplate, temporary: bool = Fals "hostname": model.config.hostname, "extra_hosts": model.config.getExtraHost(), "devices": model.config.getDevices(), + "device_requests": model.config.getDeviceRequests(), "environment": model.config.getEnvs(), "labels": model.config.getLabels(), "ports": model.config.getPorts(),