diff --git a/sandboxes/__init__.py b/sandboxes/__init__.py index 52b3e10..97adf60 100644 --- a/sandboxes/__init__.py +++ b/sandboxes/__init__.py @@ -13,6 +13,7 @@ SandboxState, ) from .base import Sandbox as BaseSandbox +from .constants import VALID_PROVIDERS, validate_provider, validate_providers from .exceptions import ( ProviderError, SandboxAuthenticationError, @@ -52,6 +53,10 @@ "RetryConfig", "with_retry", "CircuitBreaker", + # Constants and validation + "VALID_PROVIDERS", + "validate_provider", + "validate_providers", # Exceptions "SandboxError", "SandboxNotFoundError", diff --git a/sandboxes/cli.py b/sandboxes/cli.py index affc11e..1698e1e 100644 --- a/sandboxes/cli.py +++ b/sandboxes/cli.py @@ -10,6 +10,8 @@ from tabulate import tabulate from sandboxes import SandboxConfig +from sandboxes.constants import validate_provider +from sandboxes.exceptions import ProviderError from sandboxes.providers.cloudflare import CloudflareProvider from sandboxes.providers.daytona import DaytonaProvider from sandboxes.providers.e2b import E2BProvider @@ -18,6 +20,13 @@ def get_provider(name: str): """Get a provider instance by name.""" + # Validate provider name + try: + validate_provider(name, allow_none=False) + except ProviderError as e: + click.echo(f"❌ {e}", err=True) + sys.exit(1) + providers = { "e2b": E2BProvider, "modal": ModalProvider, @@ -25,11 +34,6 @@ def get_provider(name: str): "cloudflare": CloudflareProvider, } - if name not in providers: - click.echo(f"❌ Unknown provider: {name}", err=True) - click.echo(f"Available providers: {', '.join(providers.keys())}", err=True) - sys.exit(1) - try: return providers[name]() except Exception as e: diff --git a/sandboxes/constants.py b/sandboxes/constants.py new file mode 100644 index 0000000..7c66b5d --- /dev/null +++ b/sandboxes/constants.py @@ -0,0 +1,47 @@ +"""Constants and validation for sandboxes library.""" + +from .exceptions import ProviderError + +# Valid provider names +VALID_PROVIDERS = frozenset(["e2b", "modal", "daytona", "cloudflare"]) + + +def validate_provider(provider: str | None, allow_none: bool = True) -> None: + """ + Validate a provider name. + + Args: + provider: Provider name to validate + allow_none: Whether to allow None as a valid value + + Raises: + ProviderError: If provider is invalid + """ + if provider is None: + if allow_none: + return + raise ProviderError("Provider cannot be None") + + if provider not in VALID_PROVIDERS: + raise ProviderError( + f"Invalid provider: '{provider}'. " + f"Valid providers are: {', '.join(sorted(VALID_PROVIDERS))}" + ) + + +def validate_providers(providers: list[str] | None, allow_none: bool = True) -> None: + """ + Validate a list of provider names. + + Args: + providers: List of provider names to validate + allow_none: Whether to allow None values in the list + + Raises: + ProviderError: If any provider is invalid + """ + if providers is None: + return + + for provider in providers: + validate_provider(provider, allow_none=allow_none) diff --git a/sandboxes/manager.py b/sandboxes/manager.py index 8dac718..0c422db 100644 --- a/sandboxes/manager.py +++ b/sandboxes/manager.py @@ -4,6 +4,7 @@ from typing import Any from .base import ExecutionResult, Sandbox, SandboxConfig, SandboxProvider +from .constants import validate_provider from .exceptions import ProviderError logger = logging.getLogger(__name__) @@ -51,7 +52,9 @@ def get_provider(self, name: str | None = None) -> SandboxProvider: if not name: raise ProviderError("No provider specified and no default provider set") + # Validate provider name (only if not already registered, to allow custom/test providers) if name not in self.providers: + validate_provider(name, allow_none=False) raise ProviderError(f"Provider '{name}' not registered") return self.providers[name] @@ -70,12 +73,20 @@ async def create_sandbox( provider: Preferred provider name fallback_providers: List of providers to try if primary fails """ + # Validate provider names (skip if already registered, to allow custom/test providers) + if provider and provider not in self.providers: + validate_provider(provider, allow_none=False) + if fallback_providers: + for fallback in fallback_providers: + if fallback not in self.providers: + validate_provider(fallback, allow_none=False) + providers_to_try = [provider] if provider else [] if fallback_providers: providers_to_try.extend(fallback_providers) - if not providers_to_try: + if not providers_to_try and self.default_provider: providers_to_try = [self.default_provider] last_error = None @@ -84,8 +95,8 @@ async def create_sandbox( continue try: - provider = self.get_provider(provider_name) - sandbox = await provider.create_sandbox(config) + provider_obj = self.get_provider(provider_name) + sandbox = await provider_obj.create_sandbox(config) logger.info(f"Created sandbox {sandbox.id} with provider {provider_name}") return sandbox except Exception as e: diff --git a/sandboxes/sandbox.py b/sandboxes/sandbox.py index 05209b3..58f74fa 100644 --- a/sandboxes/sandbox.py +++ b/sandboxes/sandbox.py @@ -8,6 +8,7 @@ from .base import ExecutionResult, SandboxConfig from .base import Sandbox as BaseSandbox +from .constants import validate_provider, validate_providers from .manager import SandboxManager from .providers import CloudflareProvider, DaytonaProvider, E2BProvider, ModalProvider @@ -139,6 +140,10 @@ def configure( default_provider="e2b" ) """ + # Validate default_provider if specified + if default_provider: + validate_provider(default_provider, allow_none=False) + manager = cls._ensure_manager() if e2b_api_key: @@ -169,6 +174,10 @@ async def _create_impl( **kwargs: Any, ) -> Sandbox: """Internal implementation of sandbox creation.""" + # Validate provider names + validate_provider(provider, allow_none=True) + validate_providers(fallback, allow_none=False) + manager = cls._ensure_manager() # Build config @@ -272,6 +281,9 @@ async def find( Returns: Sandbox instance if found, None otherwise """ + # Validate provider if specified + validate_provider(provider, allow_none=True) + manager = cls._ensure_manager() providers_to_check = [provider] if provider else list(manager.providers.keys()) diff --git a/tests/test_manager.py b/tests/test_manager.py index 7954f32..f945ef0 100644 --- a/tests/test_manager.py +++ b/tests/test_manager.py @@ -177,4 +177,4 @@ async def test_invalid_provider(self, manager): with pytest.raises(ProviderError) as exc_info: await manager.create_sandbox(config, provider="nonexistent") - assert "Failed to create sandbox" in str(exc_info.value) + assert "Invalid provider" in str(exc_info.value) diff --git a/tests/test_validation.py b/tests/test_validation.py new file mode 100644 index 0000000..c7a2533 --- /dev/null +++ b/tests/test_validation.py @@ -0,0 +1,166 @@ +"""Tests for provider validation.""" + +import pytest + +from sandboxes.constants import VALID_PROVIDERS, validate_provider, validate_providers +from sandboxes.exceptions import ProviderError + + +class TestValidateProvider: + """Test validate_provider function.""" + + def test_valid_providers(self): + """Test validation passes for valid providers.""" + for provider in VALID_PROVIDERS: + validate_provider(provider, allow_none=False) + + def test_none_allowed(self): + """Test None is allowed when allow_none=True.""" + validate_provider(None, allow_none=True) + + def test_none_not_allowed(self): + """Test None raises error when allow_none=False.""" + with pytest.raises(ProviderError, match="Provider cannot be None"): + validate_provider(None, allow_none=False) + + def test_invalid_provider(self): + """Test invalid provider raises error.""" + with pytest.raises(ProviderError, match="Invalid provider: 'invalid'"): + validate_provider("invalid", allow_none=False) + + def test_invalid_provider_shows_valid_options(self): + """Test error message shows valid provider options.""" + with pytest.raises(ProviderError) as exc_info: + validate_provider("invalid", allow_none=False) + + error_msg = str(exc_info.value) + assert "cloudflare" in error_msg + assert "daytona" in error_msg + assert "e2b" in error_msg + assert "modal" in error_msg + + def test_case_sensitive(self): + """Test provider names are case-sensitive.""" + with pytest.raises(ProviderError): + validate_provider("E2B", allow_none=False) + + with pytest.raises(ProviderError): + validate_provider("Modal", allow_none=False) + + +class TestValidateProviders: + """Test validate_providers function.""" + + def test_valid_providers_list(self): + """Test validation passes for list of valid providers.""" + validate_providers(["e2b", "modal", "daytona"], allow_none=False) + + def test_empty_list(self): + """Test empty list is valid.""" + validate_providers([], allow_none=False) + + def test_none_list(self): + """Test None list is valid.""" + validate_providers(None, allow_none=False) + + def test_list_with_none_allowed(self): + """Test list with None values when allow_none=True.""" + validate_providers(["e2b", None, "modal"], allow_none=True) + + def test_list_with_none_not_allowed(self): + """Test list with None values raises error when allow_none=False.""" + with pytest.raises(ProviderError, match="Provider cannot be None"): + validate_providers(["e2b", None, "modal"], allow_none=False) + + def test_invalid_provider_in_list(self): + """Test invalid provider in list raises error.""" + with pytest.raises(ProviderError, match="Invalid provider"): + validate_providers(["e2b", "invalid", "modal"], allow_none=False) + + def test_all_invalid(self): + """Test all invalid providers raises error.""" + with pytest.raises(ProviderError): + validate_providers(["bad1", "bad2"], allow_none=False) + + +class TestIntegrationWithManager: + """Test validation integration with Manager.""" + + def test_manager_get_provider_validates(self): + """Test Manager.get_provider validates provider names.""" + from sandboxes.manager import SandboxManager + + manager = SandboxManager() + + with pytest.raises(ProviderError, match="Invalid provider"): + manager.get_provider("invalid") + + @pytest.mark.asyncio + async def test_manager_create_sandbox_validates(self): + """Test Manager.create_sandbox validates provider names.""" + from sandboxes import SandboxConfig + from sandboxes.manager import SandboxManager + + manager = SandboxManager() + + with pytest.raises(ProviderError, match="Invalid provider"): + await manager.create_sandbox(SandboxConfig(), provider="invalid") + + @pytest.mark.asyncio + async def test_manager_create_sandbox_validates_fallback(self): + """Test Manager.create_sandbox validates fallback provider names.""" + from sandboxes import SandboxConfig + from sandboxes.manager import SandboxManager + + manager = SandboxManager() + + with pytest.raises(ProviderError, match="Invalid provider"): + await manager.create_sandbox( + SandboxConfig(), provider="e2b", fallback_providers=["modal", "invalid"] + ) + + +class TestIntegrationWithSandbox: + """Test validation integration with Sandbox.""" + + def test_sandbox_configure_validates_default_provider(self): + """Test Sandbox.configure validates default_provider.""" + from sandboxes import Sandbox + + with pytest.raises(ProviderError, match="Invalid provider"): + Sandbox.configure(default_provider="invalid") + + @pytest.mark.asyncio + async def test_sandbox_create_validates_provider(self): + """Test Sandbox.create validates provider.""" + from sandboxes import Sandbox + + with pytest.raises(ProviderError, match="Invalid provider"): + await Sandbox.create(provider="invalid") + + @pytest.mark.asyncio + async def test_sandbox_create_validates_fallback(self): + """Test Sandbox.create validates fallback providers.""" + from sandboxes import Sandbox + + with pytest.raises(ProviderError, match="Invalid provider"): + await Sandbox.create(fallback=["e2b", "invalid"]) + + @pytest.mark.asyncio + async def test_sandbox_find_validates_provider(self): + """Test Sandbox.find validates provider.""" + from sandboxes import Sandbox + + with pytest.raises(ProviderError, match="Invalid provider"): + await Sandbox.find(labels={"test": "true"}, provider="invalid") + + +class TestIntegrationWithCLI: + """Test validation integration with CLI.""" + + def test_cli_get_provider_validates(self): + """Test CLI get_provider validates provider names.""" + from sandboxes.cli import get_provider + + with pytest.raises(SystemExit): + get_provider("invalid")