Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions sandboxes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
SandboxState,
)
from .base import Sandbox as BaseSandbox
from .constants import VALID_PROVIDERS, validate_provider, validate_providers
from .exceptions import (
ProviderError,
SandboxAuthenticationError,
Expand Down Expand Up @@ -52,6 +53,10 @@
"RetryConfig",
"with_retry",
"CircuitBreaker",
# Constants and validation
"VALID_PROVIDERS",
"validate_provider",
"validate_providers",
# Exceptions
"SandboxError",
"SandboxNotFoundError",
Expand Down
14 changes: 9 additions & 5 deletions sandboxes/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from tabulate import tabulate

from sandboxes import SandboxConfig
from sandboxes.constants import validate_provider
from sandboxes.exceptions import ProviderError
from sandboxes.providers.cloudflare import CloudflareProvider
from sandboxes.providers.daytona import DaytonaProvider
from sandboxes.providers.e2b import E2BProvider
Expand All @@ -18,18 +20,20 @@

def get_provider(name: str):
"""Get a provider instance by name."""
# Validate provider name
try:
validate_provider(name, allow_none=False)
except ProviderError as e:
click.echo(f"❌ {e}", err=True)
sys.exit(1)

providers = {
"e2b": E2BProvider,
"modal": ModalProvider,
"daytona": DaytonaProvider,
"cloudflare": CloudflareProvider,
}

if name not in providers:
click.echo(f"❌ Unknown provider: {name}", err=True)
click.echo(f"Available providers: {', '.join(providers.keys())}", err=True)
sys.exit(1)

try:
return providers[name]()
except Exception as e:
Expand Down
47 changes: 47 additions & 0 deletions sandboxes/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Constants and validation for sandboxes library."""

from .exceptions import ProviderError

# Valid provider names
VALID_PROVIDERS = frozenset(["e2b", "modal", "daytona", "cloudflare"])


def validate_provider(provider: str | None, allow_none: bool = True) -> None:
"""
Validate a provider name.

Args:
provider: Provider name to validate
allow_none: Whether to allow None as a valid value

Raises:
ProviderError: If provider is invalid
"""
if provider is None:
if allow_none:
return
raise ProviderError("Provider cannot be None")

if provider not in VALID_PROVIDERS:
raise ProviderError(
f"Invalid provider: '{provider}'. "
f"Valid providers are: {', '.join(sorted(VALID_PROVIDERS))}"
)


def validate_providers(providers: list[str] | None, allow_none: bool = True) -> None:
"""
Validate a list of provider names.

Args:
providers: List of provider names to validate
allow_none: Whether to allow None values in the list

Raises:
ProviderError: If any provider is invalid
"""
if providers is None:
return

for provider in providers:
validate_provider(provider, allow_none=allow_none)
17 changes: 14 additions & 3 deletions sandboxes/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any

from .base import ExecutionResult, Sandbox, SandboxConfig, SandboxProvider
from .constants import validate_provider
from .exceptions import ProviderError

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -51,7 +52,9 @@ def get_provider(self, name: str | None = None) -> SandboxProvider:
if not name:
raise ProviderError("No provider specified and no default provider set")

# Validate provider name (only if not already registered, to allow custom/test providers)
if name not in self.providers:
validate_provider(name, allow_none=False)
raise ProviderError(f"Provider '{name}' not registered")

return self.providers[name]
Expand All @@ -70,12 +73,20 @@ async def create_sandbox(
provider: Preferred provider name
fallback_providers: List of providers to try if primary fails
"""
# Validate provider names (skip if already registered, to allow custom/test providers)
if provider and provider not in self.providers:
validate_provider(provider, allow_none=False)
if fallback_providers:
for fallback in fallback_providers:
if fallback not in self.providers:
validate_provider(fallback, allow_none=False)

providers_to_try = [provider] if provider else []

if fallback_providers:
providers_to_try.extend(fallback_providers)

if not providers_to_try:
if not providers_to_try and self.default_provider:
providers_to_try = [self.default_provider]

last_error = None
Expand All @@ -84,8 +95,8 @@ async def create_sandbox(
continue

try:
provider = self.get_provider(provider_name)
sandbox = await provider.create_sandbox(config)
provider_obj = self.get_provider(provider_name)
sandbox = await provider_obj.create_sandbox(config)
logger.info(f"Created sandbox {sandbox.id} with provider {provider_name}")
return sandbox
except Exception as e:
Expand Down
12 changes: 12 additions & 0 deletions sandboxes/sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from .base import ExecutionResult, SandboxConfig
from .base import Sandbox as BaseSandbox
from .constants import validate_provider, validate_providers
from .manager import SandboxManager
from .providers import CloudflareProvider, DaytonaProvider, E2BProvider, ModalProvider

Expand Down Expand Up @@ -139,6 +140,10 @@ def configure(
default_provider="e2b"
)
"""
# Validate default_provider if specified
if default_provider:
validate_provider(default_provider, allow_none=False)

manager = cls._ensure_manager()

if e2b_api_key:
Expand Down Expand Up @@ -169,6 +174,10 @@ async def _create_impl(
**kwargs: Any,
) -> Sandbox:
"""Internal implementation of sandbox creation."""
# Validate provider names
validate_provider(provider, allow_none=True)
validate_providers(fallback, allow_none=False)

manager = cls._ensure_manager()

# Build config
Expand Down Expand Up @@ -272,6 +281,9 @@ async def find(
Returns:
Sandbox instance if found, None otherwise
"""
# Validate provider if specified
validate_provider(provider, allow_none=True)

manager = cls._ensure_manager()

providers_to_check = [provider] if provider else list(manager.providers.keys())
Expand Down
2 changes: 1 addition & 1 deletion tests/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,4 +177,4 @@ async def test_invalid_provider(self, manager):
with pytest.raises(ProviderError) as exc_info:
await manager.create_sandbox(config, provider="nonexistent")

assert "Failed to create sandbox" in str(exc_info.value)
assert "Invalid provider" in str(exc_info.value)
166 changes: 166 additions & 0 deletions tests/test_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
"""Tests for provider validation."""

import pytest

from sandboxes.constants import VALID_PROVIDERS, validate_provider, validate_providers
from sandboxes.exceptions import ProviderError


class TestValidateProvider:
"""Test validate_provider function."""

def test_valid_providers(self):
"""Test validation passes for valid providers."""
for provider in VALID_PROVIDERS:
validate_provider(provider, allow_none=False)

def test_none_allowed(self):
"""Test None is allowed when allow_none=True."""
validate_provider(None, allow_none=True)

def test_none_not_allowed(self):
"""Test None raises error when allow_none=False."""
with pytest.raises(ProviderError, match="Provider cannot be None"):
validate_provider(None, allow_none=False)

def test_invalid_provider(self):
"""Test invalid provider raises error."""
with pytest.raises(ProviderError, match="Invalid provider: 'invalid'"):
validate_provider("invalid", allow_none=False)

def test_invalid_provider_shows_valid_options(self):
"""Test error message shows valid provider options."""
with pytest.raises(ProviderError) as exc_info:
validate_provider("invalid", allow_none=False)

error_msg = str(exc_info.value)
assert "cloudflare" in error_msg
assert "daytona" in error_msg
assert "e2b" in error_msg
assert "modal" in error_msg

def test_case_sensitive(self):
"""Test provider names are case-sensitive."""
with pytest.raises(ProviderError):
validate_provider("E2B", allow_none=False)

with pytest.raises(ProviderError):
validate_provider("Modal", allow_none=False)


class TestValidateProviders:
"""Test validate_providers function."""

def test_valid_providers_list(self):
"""Test validation passes for list of valid providers."""
validate_providers(["e2b", "modal", "daytona"], allow_none=False)

def test_empty_list(self):
"""Test empty list is valid."""
validate_providers([], allow_none=False)

def test_none_list(self):
"""Test None list is valid."""
validate_providers(None, allow_none=False)

def test_list_with_none_allowed(self):
"""Test list with None values when allow_none=True."""
validate_providers(["e2b", None, "modal"], allow_none=True)

def test_list_with_none_not_allowed(self):
"""Test list with None values raises error when allow_none=False."""
with pytest.raises(ProviderError, match="Provider cannot be None"):
validate_providers(["e2b", None, "modal"], allow_none=False)

def test_invalid_provider_in_list(self):
"""Test invalid provider in list raises error."""
with pytest.raises(ProviderError, match="Invalid provider"):
validate_providers(["e2b", "invalid", "modal"], allow_none=False)

def test_all_invalid(self):
"""Test all invalid providers raises error."""
with pytest.raises(ProviderError):
validate_providers(["bad1", "bad2"], allow_none=False)


class TestIntegrationWithManager:
"""Test validation integration with Manager."""

def test_manager_get_provider_validates(self):
"""Test Manager.get_provider validates provider names."""
from sandboxes.manager import SandboxManager

manager = SandboxManager()

with pytest.raises(ProviderError, match="Invalid provider"):
manager.get_provider("invalid")

@pytest.mark.asyncio
async def test_manager_create_sandbox_validates(self):
"""Test Manager.create_sandbox validates provider names."""
from sandboxes import SandboxConfig
from sandboxes.manager import SandboxManager

manager = SandboxManager()

with pytest.raises(ProviderError, match="Invalid provider"):
await manager.create_sandbox(SandboxConfig(), provider="invalid")

@pytest.mark.asyncio
async def test_manager_create_sandbox_validates_fallback(self):
"""Test Manager.create_sandbox validates fallback provider names."""
from sandboxes import SandboxConfig
from sandboxes.manager import SandboxManager

manager = SandboxManager()

with pytest.raises(ProviderError, match="Invalid provider"):
await manager.create_sandbox(
SandboxConfig(), provider="e2b", fallback_providers=["modal", "invalid"]
)


class TestIntegrationWithSandbox:
"""Test validation integration with Sandbox."""

def test_sandbox_configure_validates_default_provider(self):
"""Test Sandbox.configure validates default_provider."""
from sandboxes import Sandbox

with pytest.raises(ProviderError, match="Invalid provider"):
Sandbox.configure(default_provider="invalid")

@pytest.mark.asyncio
async def test_sandbox_create_validates_provider(self):
"""Test Sandbox.create validates provider."""
from sandboxes import Sandbox

with pytest.raises(ProviderError, match="Invalid provider"):
await Sandbox.create(provider="invalid")

@pytest.mark.asyncio
async def test_sandbox_create_validates_fallback(self):
"""Test Sandbox.create validates fallback providers."""
from sandboxes import Sandbox

with pytest.raises(ProviderError, match="Invalid provider"):
await Sandbox.create(fallback=["e2b", "invalid"])

@pytest.mark.asyncio
async def test_sandbox_find_validates_provider(self):
"""Test Sandbox.find validates provider."""
from sandboxes import Sandbox

with pytest.raises(ProviderError, match="Invalid provider"):
await Sandbox.find(labels={"test": "true"}, provider="invalid")


class TestIntegrationWithCLI:
"""Test validation integration with CLI."""

def test_cli_get_provider_validates(self):
"""Test CLI get_provider validates provider names."""
from sandboxes.cli import get_provider

with pytest.raises(SystemExit):
get_provider("invalid")