diff --git a/README.md b/README.md index 8e78eac1..9ea51b7c 100644 --- a/README.md +++ b/README.md @@ -57,8 +57,38 @@ appreciated. ## Getting Started -To run simulations locally (Docker Compose, single machine), see the [Tutorial](docs/TUTORIAL.md). -For cluster or SLURM deployment, see `src/tools/run-on-slurm`. +### Quick Start + +1. **Setup Environment** + Follow the [Onboarding Guide](docs/ONBOARDING.md) to install dependencies and configure your + environment. + +2. **Prepare Scene Data** + AlpaSim uses [trajdata](https://github.com/NVlabs/trajdata/tree/alpasim) (custom `alpasim` + branch) for unified data loading. This dependency is automatically installed via `uv`. The wizard + automatically prepares scene caches, but you can also do it manually: + + ```bash + # For USDZ scenes (after downloading from Hugging Face) + uv run python -m alpasim_runtime.prepare_data \ + --desired-data=usdz \ + --data-dir=./data/nre-artifacts/all-usdzs \ + --cache-location=./cache/trajdata_usdz + ``` + + See the [Tutorial](docs/TUTORIAL.md#data-preparation) for more details on data preparation + options. + +3. **Run Your First Simulation** + ```bash + source setup_local_env.sh + uv run alpasim_wizard +deploy=local wizard.log_dir=$PWD/my_first_run + ``` + + Results will be in `my_first_run/` including videos and metrics. + +For detailed instructions, see the [Tutorial](docs/TUTORIAL.md). For cluster or SLURM deployment, see +`src/tools/run-on-slurm`. ## Documentation & Resources diff --git a/docs/TUTORIAL.md b/docs/TUTORIAL.md index e5efebbd..8d7d909a 100644 --- a/docs/TUTORIAL.md +++ b/docs/TUTORIAL.md @@ -336,6 +336,91 @@ the predictions of a policy, you can set `runtime.simulation_config.physics_update_mode: NONE` and `runtime.simulation_config.force_gt_duration_us` to a very high value (20s+). +## Data Preparation + +AlpaSim uses [trajdata](https://github.com/NVlabs/trajdata/tree/alpasim) (custom `alpasim` branch) +for unified data loading across different autonomous driving datasets. The trajdata library is +automatically installed via `uv` when you run `setup_local_env.sh`. + +The wizard automatically prepares the trajdata cache when needed. However, you can manually +prepare or rebuild the cache for optimization or debugging purposes. + +### Preparing USDZ Scene Cache + +The wizard automatically handles data preparation for downloaded USDZ scenes. However, if you need to +manually prepare or rebuild the cache, use the `prepare_data` tool: + +```bash +# Prepare cache from USDZ scenes +uv run python -m alpasim_runtime.prepare_data \ + --desired-data=usdz \ + --data-dir=./data/nre-artifacts/all-usdzs \ + --cache-location=./cache/trajdata_usdz \ + --smooth_trajectories=true \ + --log-level=INFO +``` + +This command: +- Scans USDZ files in the specified directory +- Extracts trajectory and map data +- Creates a trajdata cache for fast scene loading +- The cache is reused across runs to improve startup time + +### Using Configuration Files + +For more complex setups or repeated use, create a configuration file (e.g., +`user_config/config_prepare_usdz.yaml`): + +```yaml +data_source: + desired_data: ["usdz"] + data_dirs: + usdz: "./data/nre-artifacts/all-usdzs" + cache_location: "./cache/trajdata_usdz" + incl_vector_map: true + rebuild_cache: false # Set to true to force rebuild + rebuild_maps: false # Set to true to force rebuild + desired_dt: 0.1 # 10 Hz sampling rate + num_workers: 4 # Parallel workers for cache creation + +smooth_trajectories: true +``` + +Then run: + +```bash +uv run python -m alpasim_runtime.prepare_data \ + --user-config=user_config/config_prepare_usdz.yaml +``` + +### Cache Location + +The trajdata cache contains: +- Preprocessed scene metadata and indices +- Trajectory data in a unified format +- Vector map data (when enabled) +- Scene lookup tables for fast access + +You can share the cache directory across machines to avoid redundant preprocessing. + +### Rebuilding the Cache + +If you encounter data inconsistencies or add new scenes, rebuild the cache: + +```bash +# Command line +uv run python -m alpasim_runtime.prepare_data \ + --desired-data=usdz \ + --data-dir=./data/nre-artifacts/all-usdzs \ + --cache-location=./cache/trajdata_usdz \ + --rebuild-cache + +# Or in config file, set: rebuild_cache: true +``` + +> :green_book: The wizard uses configuration files from `user_config/` that include data source +> settings. These configs are automatically used during simulation runs. + ## Scenes The scene in AlpaSim is a NuRec reconstruction of a real-world driving log. @@ -510,7 +595,17 @@ from shutting down the docker containers after each simulation by setting 1. (Terminal 2) `cd` into the the runtime src directory (`/src/runtime/`) and prepare to start the runtime. The exact command paths will vary, but, to use the configuration generated from the earlier steps, an example command would be: - `bash cd /src/runtime/ # Following command is based on the docker-compose.yaml generated by the wizard uv run python -m alpasim_runtime.simulate \ --usdz-glob=../../data/nre-artifacts/all-usdzs/**/*.usdz \ --user-config=../../tutorial_dbg_runtime/generated-user-config-0.yaml \ --network-config=../../tutorial_dbg_runtime/generated-network-config.yaml \ --log-dir=../../tutorial_dbg_runtime \ --log-level=INFO ` + ```bash + cd /src/runtime/ + # Following command is based on the docker-compose.yaml generated by the wizard + # Ensure the user config contains the data_source configuration + uv run python -m alpasim_runtime.simulate \ + --user-config=../../tutorial_dbg_runtime/generated-user-config-0.yaml \ + --network-config=../../tutorial_dbg_runtime/generated-network-config.yaml \ + --log-dir=../../tutorial_dbg_runtime \ + --eval-config=../../tutorial_dbg_runtime/eval-config.yaml \ + --log-level=INFO + ``` ### Using VSCode Debugger (Optional) @@ -528,13 +623,16 @@ built-in debugger: "justMyCode": false, "cwd": "${workspaceFolder}/src/runtime", "args": [ - "--usdz-glob=../../data/nre-artifacts/all-usdzs/**/*.usdz", "--user-config=../../tutorial_dbg_runtime/generated-user-config-0.yaml", "--network-config=../../tutorial_dbg_runtime/generated-network-config.yaml", + "--eval-config=../../tutorial_dbg_runtime/eval-config.yaml", "--log-dir=../../tutorial_dbg_runtime", "--log-level=INFO" ], - "console": "integratedTerminal" + "console": "integratedTerminal", + "env": { + "PYTHONPATH": "${workspaceFolder}/src/grpc:${workspaceFolder}/src/eval/src:${workspaceFolder}/src/utils:${workspaceFolder}/src/runtime:${env:PYTHONPATH}" + } } ``` diff --git a/src/runtime/alpasim_runtime/config.py b/src/runtime/alpasim_runtime/config.py index e6f883c9..81fce70b 100644 --- a/src/runtime/alpasim_runtime/config.py +++ b/src/runtime/alpasim_runtime/config.py @@ -8,7 +8,7 @@ from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from typing import Optional, Type, TypeVar, cast +from typing import Any, Dict, Optional, Type, TypeVar, cast from alpasim_utils.scenario import VehicleConfig from alpasim_utils.yaml_utils import load_yaml_dict @@ -17,6 +17,105 @@ C = TypeVar("C") +@dataclass +class GenericSourceConfig: + """Generic configuration for any trajdata-supported dataset. + + This unified config supports all trajdata datasets (USDZ, NuPlan, Waymo, + nuScenes, Lyft, Argoverse, etc.) with a flexible extra_params field for + dataset-specific options. + + Attributes: + enabled: Whether this data source is enabled + data_dir: Path to dataset directory + desired_dt: Desired time delta between trajectory frames in seconds + incl_vector_map: Whether to load vector maps (roads, lanes, etc.) + extra_params: Dataset-specific parameters (e.g., NuPlan's config_dir, + USDZ's asset_base_path, etc.) + + Example extra_params: + - NuPlan: {"config_dir": "/path", "num_timesteps_before": 30, "num_timesteps_after": 80} + - USDZ: {"asset_base_path": "/assets"} + - Waymo: {} (no extra params needed) + """ + + enabled: bool = True + data_dir: Optional[str] = None + extra_params: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class DataSourceConfig: + """Configuration for unified data loading through trajdata. + + Supports dynamic registration of any trajdata dataset (USDZ, NuPlan, Waymo, + nuScenes, Lyft, Argoverse, etc.) through the 'sources' dictionary. + + Attributes: + cache_location: Path to shared trajdata cache directory + rebuild_cache: Whether to force rebuild the cache for all sources + rebuild_maps: Whether to force rebuild maps for all sources + num_workers: Number of parallel workers for cache creation + sources: Dictionary mapping dataset names to their configurations + (e.g., {"usdz": GenericSourceConfig(...), "waymo": ...}) + """ + + # Common configuration (applies to all data sources) + cache_location: str = MISSING + desired_dt: float = 0.1 # 10 Hz sampling + incl_vector_map: bool = True + rebuild_cache: bool = False + rebuild_maps: bool = False + num_workers: int = 1 # Conservative default for stability; increase for production + + # New extensible source configuration (preferred) + sources: Dict[str, GenericSourceConfig] = field(default_factory=dict) + + def to_trajdata_params(self) -> dict: + """Convert hierarchical config to flat parameters for trajdata's UnifiedDataset. + + Returns: + Dictionary with keys expected by UnifiedDataset constructor + + Raises: + ValueError: If no data sources are enabled + """ + desired_data = [] + data_dirs = {} + dataset_kwargs = {} + + for dataset_name, source in self.sources.items(): + if source.enabled: + if source.data_dir is None: + raise ValueError( + f"data_source.sources.{dataset_name}.data_dir is required when enabled" + ) + desired_data.append(dataset_name) + data_dirs[dataset_name] = source.data_dir + if source.extra_params: + dataset_kwargs[dataset_name] = source.extra_params + + if not desired_data: + raise ValueError("No data sources enabled in configuration") + + params = { + "desired_data": desired_data, + "data_dirs": data_dirs, + "cache_location": self.cache_location, + "incl_vector_map": self.incl_vector_map, + "rebuild_cache": self.rebuild_cache, + "rebuild_maps": self.rebuild_maps, + "num_workers": self.num_workers, + "desired_dt": self.desired_dt, + } + + # Add dataset-specific kwargs if any source has extra_params + if dataset_kwargs: + params["dataset_kwargs"] = dataset_kwargs + + return params + + def typed_parse_config(path: str | Path, config_type: Type[C]) -> C: """Reads a yaml file at `path` and parses it into a provided type using omegaconf.""" yaml_config = OmegaConf.create(load_yaml_dict(path)) @@ -243,9 +342,6 @@ class UserSimulatorConfig: endpoints: UserEndpointConfig = MISSING smooth_trajectories: bool = True # whether to smooth trajectories with cubic spline - # Max worker-local artifact cache size. - # None = unlimited, 0 = disable cache and always reload artifacts. - artifact_cache_size: Optional[int] = None extra_cameras: list[CameraDefinitionConfig] = field(default_factory=list) # Number of worker processes for parallel rollout execution. @@ -253,6 +349,10 @@ class UserSimulatorConfig: # >1 = multi-worker mode with subprocess-based parallelism nr_workers: int = MISSING + # Unified data source configuration (required) + # Data loading goes through trajdata's UnifiedDataset + data_source: DataSourceConfig = MISSING + @dataclass class SimulatorConfig: diff --git a/src/runtime/alpasim_runtime/daemon/__init__.py b/src/runtime/alpasim_runtime/daemon/__init__.py index 5d7ab7b6..40324f86 100644 --- a/src/runtime/alpasim_runtime/daemon/__init__.py +++ b/src/runtime/alpasim_runtime/daemon/__init__.py @@ -2,6 +2,7 @@ # Copyright (c) 2026 NVIDIA Corporation from alpasim_runtime.daemon.engine import DaemonEngine +from alpasim_runtime.daemon.exceptions import InvalidRequestError, UnknownSceneError from alpasim_runtime.daemon.request_store import RequestStore from alpasim_runtime.daemon.scheduler import DaemonScheduler, DaemonUnavailableError @@ -10,4 +11,6 @@ "DaemonScheduler", "DaemonUnavailableError", "RequestStore", + "InvalidRequestError", + "UnknownSceneError", ] diff --git a/src/runtime/alpasim_runtime/daemon/engine.py b/src/runtime/alpasim_runtime/daemon/engine.py index 7985dc6d..ead01497 100644 --- a/src/runtime/alpasim_runtime/daemon/engine.py +++ b/src/runtime/alpasim_runtime/daemon/engine.py @@ -5,17 +5,21 @@ import logging from collections import defaultdict +from typing import Callable from uuid import uuid4 from alpasim_grpc.v0 import logging_pb2, runtime_pb2 from alpasim_runtime.address_pool import AddressPool +from alpasim_runtime.daemon.exceptions import InvalidRequestError from alpasim_runtime.daemon.scheduler import DaemonScheduler, DaemonUnavailableError from alpasim_runtime.runtime_context import ( build_runtime_context, compute_num_consumers_per_worker, ) +from alpasim_runtime.scene_loader import SceneLoader from alpasim_runtime.worker.ipc import JobResult, PendingRolloutJob from alpasim_runtime.worker.runtime import WorkerRuntime, start_worker_runtime +from alpasim_utils.scene_data_source import SceneDataSource from eval.data import AggregationType @@ -97,37 +101,27 @@ def build_simulation_return( ) -class InvalidRequestError(ValueError): - """Raised when a simulation request contains invalid parameters.""" - - pass - - -class UnknownSceneError(InvalidRequestError): - """Raised when a simulation request references a scene_id with no known artifact.""" - - def __init__(self, scene_id: str): - super().__init__(f"No artifact found for scene_id: {scene_id}") - self.scene_id = scene_id - - def build_pending_jobs_from_request( request: runtime_pb2.SimulationRequest, - scene_id_to_artifact_path: dict[str, str], + get_data_source: Callable[[str], SceneDataSource], ) -> list[PendingRolloutJob]: """Expand a SimulationRequest into individual PendingRolloutJob entries. Each RolloutSpec is expanded by its ``nr_rollouts`` count. Specs with ``nr_rollouts=0`` are silently dropped with a warning. + Args: + request: The simulation request to expand. + get_data_source: Callable that returns a SceneDataSource for a given scene_id. + Should raise UnknownSceneError if the scene_id is not found. + Raises: - UnknownSceneError: If a spec references a scene_id not present in - *scene_id_to_artifact_path*. + UnknownSceneError: If a spec references an unknown scene_id. """ jobs: list[PendingRolloutJob] = [] for spec_index, spec in enumerate(request.rollout_specs): - if spec.scenario_id not in scene_id_to_artifact_path: - raise UnknownSceneError(spec.scenario_id) + # This will raise UnknownSceneError if scene_id is not found + data_source = get_data_source(spec.scenario_id) if spec.nr_rollouts == 0: logger.warning( @@ -142,7 +136,7 @@ def build_pending_jobs_from_request( job_id=uuid4().hex, scene_id=spec.scenario_id, rollout_spec_index=spec_index, - artifact_path=scene_id_to_artifact_path[spec.scenario_id], + data_source=data_source, ) ) return jobs @@ -166,19 +160,18 @@ def __init__( user_config: str, network_config: str, eval_config: str, - usdz_glob: str, log_dir: str, validate_config_scenes: bool = True, ) -> None: self._user_config_path = user_config self._network_config_path = network_config self._eval_config_path = eval_config - self._usdz_glob = usdz_glob self._log_dir = log_dir self._validate_config_scenes = validate_config_scenes self._version_ids: logging_pb2.RolloutMetadata.VersionIds | None = None - self._scene_id_to_artifact_path: dict[str, str] = {} + self._config = None # Will be set during startup + self._scene_loader: SceneLoader | None = None self._scheduler: DaemonScheduler | None = None self._worker_runtime: WorkerRuntime | None = None self._started = False @@ -189,12 +182,30 @@ def version_ids(self) -> logging_pb2.RolloutMetadata.VersionIds: raise RuntimeError("daemon is not started") return self._version_ids + def _get_data_source(self, scene_id: str) -> SceneDataSource: + """Get or create a data source for the given scene_id. + + Delegates to SceneLoader for lazy loading and caching. + + Args: + scene_id: Scene identifier to load + + Returns: + SceneDataSource for the requested scene + + Raises: + RuntimeError: If SceneLoader not initialized + """ + if self._scene_loader is None: + raise RuntimeError("SceneLoader not initialized") + return self._scene_loader.get_data_source(scene_id) + async def startup(self) -> None: """Initialize the runtime context, start workers, and begin scheduling. Builds the RuntimeContext (parses configs, probes service versions, - validates scenarios, discovers scene artifacts), then creates the - worker runtime and daemon scheduler. Idempotent: subsequent calls + validates scenarios, creates scene mapping from trajdata), then creates + the worker runtime and daemon scheduler. Idempotent: subsequent calls after the first are no-ops. """ if self._started: @@ -204,10 +215,17 @@ async def startup(self) -> None: user_config_path=self._user_config_path, network_config_path=self._network_config_path, eval_config_path=self._eval_config_path, - usdz_glob=self._usdz_glob, validate_config_scenes=self._validate_config_scenes, ) + # Create SceneLoader from RuntimeContext + self._scene_loader = SceneLoader( + dataset=runtime_context.dataset, + scene_id_to_idx=runtime_context.scene_id_to_idx, + config=runtime_context.config, + ) + self._config = runtime_context.config + num_consumers_per_worker = compute_num_consumers_per_worker( max_in_flight=runtime_context.max_in_flight, nr_workers=runtime_context.config.user.nr_workers, @@ -228,7 +246,6 @@ async def startup(self) -> None: ) self._version_ids = runtime_context.version_ids - self._scene_id_to_artifact_path = runtime_context.scene_id_to_artifact_path self._worker_runtime = worker_runtime self._scheduler = scheduler self._started = True @@ -252,7 +269,9 @@ async def simulate( assert self._scheduler is not None request_id = uuid4().hex - jobs = build_pending_jobs_from_request(request, self._scene_id_to_artifact_path) + + # Use instance method for getting data sources + jobs = build_pending_jobs_from_request(request, self._get_data_source) driver_pool: AddressPool | None = None if request.available_drivers: diff --git a/src/runtime/alpasim_runtime/daemon/exceptions.py b/src/runtime/alpasim_runtime/daemon/exceptions.py new file mode 100644 index 00000000..2beaed16 --- /dev/null +++ b/src/runtime/alpasim_runtime/daemon/exceptions.py @@ -0,0 +1,18 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 NVIDIA Corporation + +"""Exception classes for runtime daemon.""" + + +class InvalidRequestError(ValueError): + """Raised when a simulation request contains invalid parameters.""" + + pass + + +class UnknownSceneError(InvalidRequestError): + """Raised when a simulation request references a scene_id with no known data source.""" + + def __init__(self, scene_id: str): + super().__init__(f"No data source found for scene_id: {scene_id}") + self.scene_id = scene_id diff --git a/src/runtime/alpasim_runtime/daemon/scheduler.py b/src/runtime/alpasim_runtime/daemon/scheduler.py index 35099758..7ef1b96f 100644 --- a/src/runtime/alpasim_runtime/daemon/scheduler.py +++ b/src/runtime/alpasim_runtime/daemon/scheduler.py @@ -146,7 +146,7 @@ async def dispatch_once(self) -> None: job_id=pending_job.job_id, scene_id=pending_job.scene_id, rollout_spec_index=pending_job.rollout_spec_index, - artifact_path=pending_job.artifact_path, + data_source=pending_job.data_source, endpoints=ServiceEndpoints( driver=acquired["driver"], sensorsim=acquired["sensorsim"], diff --git a/src/runtime/alpasim_runtime/prepare_data/__init__.py b/src/runtime/alpasim_runtime/prepare_data/__init__.py new file mode 100644 index 00000000..0ccbb123 --- /dev/null +++ b/src/runtime/alpasim_runtime/prepare_data/__init__.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 NVIDIA Corporation + +""" +Data preprocessing module for building trajdata cache. + +This module provides two approaches for preparing scene data before simulations: + +1. **User Config Path** (Recommended for complex scenarios): + - Load configuration from YAML file + - Supports multiple data sources with individual settings + - Automatic NuPlan YAML batch processing when config_dir is provided + - Full control over per-dataset parameters + +2. **CLI Path** (For simple, quick preprocessing): + - Specify parameters via command line + - Uniform parameters applied to all datasets + - Good for testing or simple caching tasks + +Main exports: + - preprocess_basic: Unified preprocessing function (handles both basic and NuPlan YAML modes) + - process_nuplan_yaml_configs: Process NuPlan YAML configs into central_tokens format + - load_yaml_configs: Load and parse YAML configuration files + - PrepareDataConfig: Configuration class for CLI mode + - main: CLI entry point + +Example usage (programmatic): + + from alpasim_runtime.prepare_data import preprocess_basic, PrepareDataConfig + + # Simple preprocessing with CLI config + config = PrepareDataConfig( + desired_data=["waymo"], + data_dirs={"waymo": "/path/to/waymo"}, + cache_location="/path/to/cache", + desired_dt=0.1, + ) + preprocess_basic(config, verbose=True) + + # Or use user config (recommended for production) + from alpasim_runtime.config import typed_parse_config, UserSimulatorConfig + user_config = typed_parse_config("user.yaml", UserSimulatorConfig) + preprocess_basic(user_config.data_source, verbose=True) + +CLI usage: + # Simple mode + python -m alpasim_runtime.prepare_data \\ + --desired-data waymo \\ + --data-dir /path/to/waymo \\ + --cache-location /path/to/cache + + # Complex mode with user config + python -m alpasim_runtime.prepare_data \\ + --user-config user.yaml \\ + --rebuild-cache +""" + +from alpasim_runtime.prepare_data.__main__ import ( + PrepareDataConfig, + load_yaml_configs, + main, + preprocess_basic, + process_nuplan_yaml_configs, +) + +__all__ = [ + "preprocess_basic", + "process_nuplan_yaml_configs", + "load_yaml_configs", + "PrepareDataConfig", + "main", +] diff --git a/src/runtime/alpasim_runtime/prepare_data/__main__.py b/src/runtime/alpasim_runtime/prepare_data/__main__.py new file mode 100644 index 00000000..fba88dc8 --- /dev/null +++ b/src/runtime/alpasim_runtime/prepare_data/__main__.py @@ -0,0 +1,547 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 NVIDIA Corporation + +""" +Data preprocessing CLI for building trajdata cache. + +This module provides two clear paths for data preprocessing: + +1. **User Config Path** (--user-config): For complex scenarios + - Load full configuration from YAML file + - Supports multiple data sources, hierarchical config + - Supports YAML batch mode (NuPlan central_tokens) + - CLI overrides limited to: --rebuild-cache, --rebuild-maps, --verbose + +2. **CLI Path**: For simple, quick preprocessing + - Specify all parameters via command line + - Single dataset preprocessing only + - Basic preprocessing mode only (no YAML batch mode) + - Good for testing or simple caching tasks + +Usage Examples: + + # Complex: Use user config with optional overrides + python -m alpasim_runtime.prepare_data --user-config user.yaml --rebuild-cache + + # Simple: Direct CLI parameters for basic preprocessing + python -m alpasim_runtime.prepare_data \\ + --desired-data nuplan_test \\ + --data-dir /path/to/nuplan \\ + --cache-location /path/to/cache +""" + +from __future__ import annotations + +import argparse +import logging +import sys +import time +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional + +import yaml +from alpasim_runtime.config import UserSimulatorConfig, typed_parse_config +from omegaconf import OmegaConf +from trajdata.dataset import UnifiedDataset + +logger = logging.getLogger(__name__) + + +@dataclass +class PrepareDataConfig: + """Configuration for CLI-based data preprocessing. + + This is used ONLY for CLI mode. User config mode uses DataSourceConfig directly. + + Note: CLI mode only supports basic preprocessing. For YAML batch mode (NuPlan), + use user-config files with config_dir in extra_params. + """ + + # Data source parameters (required) + desired_data: List[str] + data_dirs: Dict[str, str] + cache_location: str + + # Optional preprocessing parameters + rebuild_cache: bool = False + rebuild_maps: bool = False + incl_vector_map: bool = True + desired_dt: float = 0.1 + num_workers: int = 1 + + def to_trajdata_params(self) -> dict: + """Convert to flat parameters for trajdata's UnifiedDataset. + + Returns: + Dictionary with keys expected by UnifiedDataset constructor + """ + return { + "desired_data": self.desired_data, + "data_dirs": self.data_dirs, + "cache_location": self.cache_location, + "rebuild_cache": self.rebuild_cache, + "rebuild_maps": self.rebuild_maps, + "num_workers": self.num_workers, + "desired_dt": self.desired_dt, + "incl_vector_map": self.incl_vector_map, + } + + +def load_yaml_configs(config_dir: Path) -> Dict[str, List[Dict[str, str]]]: + """ + Load all yaml configuration files and group them by central_log. + + Supports both simple YAML files and NuPlan-generated files with Python object tags. + Uses a custom loader to handle Python objects without requiring module imports. + + Args: + config_dir: Directory containing yaml configuration files. + + Returns: + Dict where key is central_log and value is the list of central_tokens configs for that log. + """ + configs_by_log = defaultdict(list) + + yaml_files = list(config_dir.glob("*.yaml")) + logger.info(f"Found {len(yaml_files)} yaml configuration files.") + + # Custom YAML loader that converts unknown Python objects to dicts + class SafeLoaderWithObjects(yaml.SafeLoader): + """Custom YAML loader that treats Python objects as plain dicts.""" + + pass + + def python_object_constructor(loader, tag_suffix, node): + """Convert Python object tags to plain dicts. + + Args: + loader: YAML loader instance + tag_suffix: Tag suffix (for multi_constructor, ignored for single constructor) + node: YAML node to construct + """ + return loader.construct_mapping(node, deep=True) + + def python_tuple_constructor(loader, tag_suffix, node): + """Convert Python tuple tags to lists. + + Args: + loader: YAML loader instance + tag_suffix: Tag suffix (ignored) + node: YAML node to construct + """ + return loader.construct_sequence(node, deep=True) + + # Register constructors for Python objects and tuples + # Note: add_multi_constructor passes 3 args (loader, tag_suffix, node) + SafeLoaderWithObjects.add_multi_constructor( + "tag:yaml.org,2002:python/object", python_object_constructor + ) + SafeLoaderWithObjects.add_multi_constructor( + "tag:yaml.org,2002:python/tuple", python_tuple_constructor + ) + + for yaml_file in yaml_files: + try: + # Load YAML with custom loader that handles Python objects + config = yaml.load(yaml_file.read_text(), Loader=SafeLoaderWithObjects) + + # Support both attribute-style (config.central_log) and dict-style access + if hasattr(config, "central_log"): + central_log = config.central_log + central_tokens = config.central_tokens + else: + central_log = config.get("central_log", "") + central_tokens = config.get("central_tokens", []) + + if not central_log or not central_tokens: + logger.warning( + f"{yaml_file.name} is missing central_log or central_tokens, skipping." + ) + continue + + for token in central_tokens: + configs_by_log[central_log].append( + { + "central_token": token, + "logfile": central_log, + "yaml_file": str(yaml_file), + } + ) + + except Exception as e: + logger.error(f"Failed to load {yaml_file.name}: {e}") + continue + + logger.info( + f"\nAfter grouping by central_log, there are {len(configs_by_log)} different log files." + ) + for log, configs in configs_by_log.items(): + logger.info(f" {log}: {len(configs)} central tokens") + + return dict(configs_by_log) + + +def process_nuplan_yaml_configs( + dataset_name: str, extra_params: Dict[str, Any] +) -> Optional[Dict[str, Any]]: + """Process NuPlan YAML configuration files into central_tokens_config format. + + Args: + dataset_name: Name of the NuPlan dataset (e.g., 'nuplan_mini', 'nuplan_test') + extra_params: Dictionary containing 'config_dir' and optional timestep parameters + + Returns: + Processed dataset kwargs with central_tokens_config, or None if no configs found + """ + logger.info(f"Processing NuPlan YAML configs for {dataset_name}") + config_dir = Path(extra_params["config_dir"]) + + # Load YAML configs + configs_by_log = load_yaml_configs(config_dir) + + if not configs_by_log: + logger.warning(f"No valid YAML configs found in {config_dir}") + return None + + # Build central_tokens_config list + all_central_tokens_config: List[Dict[str, Any]] = [] + for _, configs in configs_by_log.items(): + for cfg in configs: + all_central_tokens_config.append( + { + "central_token": cfg["central_token"], + "logfile": cfg["logfile"], + } + ) + + logger.info(f" Found {len(all_central_tokens_config)} central tokens") + + # Return processed config + return { + "central_tokens_config": all_central_tokens_config, + "num_timesteps_before": extra_params.get("num_timesteps_before", 30), + "num_timesteps_after": extra_params.get("num_timesteps_after", 80), + } + + +def preprocess_basic(config: Any, verbose: bool = True) -> bool: + """Basic preprocessing - build trajdata cache for all scenes. + + Args: + config: Configuration (supports both PrepareDataConfig from CLI and + DataSourceConfig from user config). + verbose: Whether to show verbose output from trajdata. + + Returns: + True if successful, False otherwise. + """ + params = config.to_trajdata_params() + + logger.info("Data source configuration:") + logger.info(f" cache_location: {params['cache_location']}") + logger.info(f" desired_dt: {params['desired_dt']}") + logger.info(f" rebuild_cache: {params['rebuild_cache']}") + logger.info(f" rebuild_maps: {params['rebuild_maps']}") + logger.info(f" desired_data: {params['desired_data']}") + logger.info(f" data_dirs: {params['data_dirs']}") + + # Process NuPlan-specific YAML configs if present + dataset_kwargs = params.get("dataset_kwargs", {}) + if dataset_kwargs: + for dataset_name, extra_params in dataset_kwargs.items(): + # Check if this is a NuPlan dataset (nuplan, nuplan_mini, nuplan_test, etc.) + if "nuplan" in dataset_name.lower() and "config_dir" in extra_params: + processed_config = process_nuplan_yaml_configs( + dataset_name, extra_params + ) + if processed_config: + dataset_kwargs[dataset_name] = processed_config + + # Create cache directory + cache_path = Path(params["cache_location"]) + cache_path.mkdir(parents=True, exist_ok=True) + + # Build UnifiedDataset (this triggers cache building) + logger.info("Creating UnifiedDataset...") + start_time = time.perf_counter() + + try: + dataset = UnifiedDataset( + desired_data=params["desired_data"], + data_dirs=params["data_dirs"], + cache_location=params["cache_location"], + incl_vector_map=params["incl_vector_map"], + rebuild_cache=params["rebuild_cache"], + rebuild_maps=params["rebuild_maps"], + desired_dt=params["desired_dt"], + num_workers=params["num_workers"], + dataset_kwargs=dataset_kwargs, + verbose=verbose, + ) + except Exception as e: + logger.error(f"Failed to create UnifiedDataset: {e}") + import traceback + + traceback.print_exc() + return False + + elapsed = time.perf_counter() - start_time + logger.info(f"UnifiedDataset created in {elapsed:.2f} seconds") + + # Get scene count + num_scenes = dataset.num_scenes() + logger.info(f"Scene files (logs): {num_scenes}") + + logger.info("Data preparation complete!") + return True + + +def create_arg_parser() -> argparse.ArgumentParser: + """Create argument parser for prepare_data CLI.""" + parser = argparse.ArgumentParser( + description=( + "Prepare scene data and build trajdata cache for alpasim simulations.\n\n" + "Two modes:\n" + " 1. User config (--user-config): Complex scenarios with full YAML config\n" + " 2. CLI mode (--desired-data + --data-dir): Simple, direct preprocessing" + ), + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + + # Mode selection + mode_group = parser.add_argument_group("Mode Selection") + mode_group.add_argument( + "--user-config", + type=str, + help="Path to user config YAML file containing data_source configuration", + ) + + # Data source parameters + data_group = parser.add_argument_group("Data Source") + data_group.add_argument( + "--desired-data", + type=str, + nargs="+", + help="List of dataset names to prepare (e.g., nuplan_test, waymo_val, usdz)", + ) + data_group.add_argument( + "--data-dir", + type=str, + action="append", + dest="data_dirs", + help="Data directory (format: dataset_name=/path/to/data or just /path/to/data)", + ) + data_group.add_argument( + "--cache-location", + type=str, + help="Path to trajdata cache directory", + ) + + # Preprocessing options + preprocess_group = parser.add_argument_group("Preprocessing Options") + preprocess_group.add_argument( + "--rebuild-cache", + action="store_true", + help="Force rebuild cache even if it already exists", + ) + preprocess_group.add_argument( + "--rebuild-maps", + action="store_true", + help="Force rebuild map cache", + ) + preprocess_group.add_argument( + "--desired-dt", + type=float, + default=0.1, + help="Desired timestep duration in seconds (default: 0.1)", + ) + preprocess_group.add_argument( + "--num-workers", + type=int, + default=1, + help="Number of worker processes (default: 1)", + ) + preprocess_group.add_argument( + "--no-vector-map", + action="store_true", + help="Exclude vector maps (default: include)", + ) + + # Output options + output_group = parser.add_argument_group("Output Options") + output_group.add_argument( + "--log-level", + type=str, + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + help="Logging level (default: INFO)", + ) + output_group.add_argument( + "--verbose", + action=argparse.BooleanOptionalAction, + default=True, + help="Show verbose output (default: enabled)", + ) + + return parser + + +def parse_data_dirs( + data_dirs_args: Optional[List[str]], desired_data: List[str] +) -> Dict[str, str]: + """Parse data directory arguments into a dict. + + Supports two formats: + - "dataset_name=/path/to/data" - explicit mapping + - "/path/to/data" - auto-map to desired_data entries in order + + Args: + data_dirs_args: List of data directory arguments + desired_data: List of dataset names + + Returns: + Dictionary mapping dataset names to data directories + """ + if not data_dirs_args: + return {} + + result: Dict[str, str] = {} + + for i, arg in enumerate(data_dirs_args): + if "=" in arg: + # Explicit mapping: dataset_name=/path/to/data + parts = arg.split("=", 1) + result[parts[0]] = parts[1] + else: + # Implicit mapping: use desired_data order + if i < len(desired_data): + result[desired_data[i]] = arg + else: + # Use as default for remaining datasets (with warning) + for ds in desired_data[len(result) :]: + if ds not in result: + logger.warning( + f"Dataset '{ds}' has no explicit data_dir, using last provided: '{arg}'" + ) + result[ds] = arg + + return result + + +def run_from_user_config(config_path: str, args: argparse.Namespace) -> bool: + """Run preprocessing from user config file with minimal CLI overrides. + + Args: + config_path: Path to user config YAML file + args: Parsed command line arguments (used only for overrides) + + Returns: + True if successful, False otherwise + """ + logger.info(f"Loading configuration from: {config_path}") + user_config = typed_parse_config(config_path, UserSimulatorConfig) + user_config = OmegaConf.to_object(user_config) + + config = user_config.data_source + + # Apply minimal CLI overrides (only top-level flags) + if args.rebuild_cache: + config.rebuild_cache = True + logger.info("CLI override: rebuild_cache=True") + if args.rebuild_maps: + config.rebuild_maps = True + logger.info("CLI override: rebuild_maps=True") + + # Use unified preprocessing (handles both basic and YAML batch mode) + return preprocess_basic(config, verbose=args.verbose) + + +def run_from_cli(args: argparse.Namespace) -> bool: + """Run preprocessing from CLI arguments directly. + + CLI mode only supports basic preprocessing with uniform parameters applied + to all datasets. For dataset-specific parameters (e.g., different + smooth_trajectories per dataset) or YAML batch mode, use --user-config. + + Args: + args: Parsed command line arguments + + Returns: + True if successful, False otherwise + """ + # Validate required arguments + if not args.desired_data: + logger.error("--desired-data is required when not using --user-config") + return False + if not args.data_dirs: + logger.error("--data-dir is required when not using --user-config") + return False + if not args.cache_location: + logger.error("--cache-location is required when not using --user-config") + return False + + # Build simple configuration + data_dirs = parse_data_dirs(args.data_dirs, args.desired_data) + + # Validate that all datasets have data directories + missing = set(args.desired_data) - set(data_dirs.keys()) + if missing: + logger.error(f"Missing data directories for datasets: {missing}") + logger.error("Provide --data-dir for each dataset: dataset=/path or in order") + return False + + incl_vector_map = not args.no_vector_map + + config = PrepareDataConfig( + desired_data=args.desired_data, + data_dirs=data_dirs, + cache_location=args.cache_location, + incl_vector_map=incl_vector_map, + rebuild_cache=args.rebuild_cache, + rebuild_maps=args.rebuild_maps, + desired_dt=args.desired_dt, + num_workers=args.num_workers, + ) + + logger.info("Mode: CLI-based basic preprocessing") + return preprocess_basic(config, verbose=args.verbose) + + +def main(arg_list: Optional[List[str]] = None) -> int: + """Main entry point for prepare_data CLI.""" + parser = create_arg_parser() + args = parser.parse_args(arg_list) + + # Configure logging + log_level = getattr(logging, args.log_level.upper(), logging.INFO) + logging.basicConfig( + level=log_level, + format="%(asctime)s.%(msecs)03d %(levelname)s:\t%(message)s", + datefmt="%H:%M:%S", + ) + + logger.info("=" * 60) + logger.info("Alpasim Data Preparation Tool") + logger.info("=" * 60) + + # Route to appropriate mode + try: + if args.user_config: + success = run_from_user_config(args.user_config, args) + else: + success = run_from_cli(args) + except Exception as e: + logger.error(f"Error during preprocessing: {e}") + import traceback + + traceback.print_exc() + return 1 + + return 0 if success else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/runtime/alpasim_runtime/runtime_context.py b/src/runtime/alpasim_runtime/runtime_context.py index eb1f4c12..47244e6d 100644 --- a/src/runtime/alpasim_runtime/runtime_context.py +++ b/src/runtime/alpasim_runtime/runtime_context.py @@ -4,6 +4,7 @@ from __future__ import annotations import copy +import logging import math from dataclasses import dataclass @@ -19,10 +20,13 @@ gather_versions_from_addresses, validate_scenarios, ) -from alpasim_utils.artifact import Artifact +from omegaconf import OmegaConf +from trajdata.dataset import UnifiedDataset from eval.schema import EvalConfig +logger = logging.getLogger(__name__) + ALL_SKIP_PER_WORKER_CONCURRENCY = 16 @@ -121,13 +125,14 @@ class RuntimeContext: """Immutable snapshot of all runtime state needed to dispatch simulation jobs. Built once during startup by ``build_runtime_context`` after config parsing, - service version probing, scenario validation, and address pool creation. + service version probing, scenario validation, and scene mapping creation. """ config: SimulatorConfig eval_config: EvalConfig version_ids: RolloutMetadata.VersionIds - scene_id_to_artifact_path: dict[str, str] + scene_id_to_idx: dict[str, int] + dataset: UnifiedDataset pools: dict[str, AddressPool] max_in_flight: int @@ -147,7 +152,6 @@ async def build_runtime_context( user_config_path: str, network_config_path: str, eval_config_path: str, - usdz_glob: str, validate_config_scenes: bool = True, ) -> RuntimeContext: """Build the RuntimeContext by parsing configs, probing services, and validating scenarios. @@ -156,20 +160,20 @@ async def build_runtime_context( 1. Parse user and network configs. 2. Probe all service addresses for version IDs. 3. Validate scenario compatibility (unless *validate_config_scenes* is False). - 4. Discover scene artifacts from *usdz_glob*. + 4. Create UnifiedDataset from data_source config and build scene data sources. 5. Create address pools and compute max in-flight concurrency. Args: user_config_path: Path to user YAML config. network_config_path: Path to network YAML config. eval_config_path: Path to evaluation YAML config. - usdz_glob: Glob pattern for USDZ artifact discovery. validate_config_scenes: If False, skip scene compatibility checks (useful for daemon mode where scenes come from requests). """ config = parse_simulator_config(user_config_path, network_config_path) eval_config = typed_parse_config(eval_config_path, EvalConfig) + # Validate configuration version_ids = await gather_versions_from_addresses( config.network, config.user.endpoints, @@ -185,13 +189,31 @@ async def build_runtime_context( ) await validate_scenarios(config_for_validation) - scene_id_to_artifact_path = { - scene_id: artifact.source - for scene_id, artifact in Artifact.discover_from_glob( - usdz_glob, - smooth_trajectories=config.user.smooth_trajectories, - ).items() - } + # Create UnifiedDataset and build scene_id to data source mapping + logger.info("Creating UnifiedDataset from config") + data_source_config = OmegaConf.to_object(config.user.data_source) + trajdata_params = data_source_config.to_trajdata_params() + dataset_kwargs = trajdata_params.pop("dataset_kwargs", None) + if dataset_kwargs: + trajdata_params["dataset_kwargs"] = dataset_kwargs + dataset = UnifiedDataset(**trajdata_params) + logger.info( + f"Created UnifiedDataset with {dataset.num_scenes()} scenes, " + f"desired_data={trajdata_params['desired_data']}" + ) + + # Build scene_id to index mapping (once, in main process) + scene_id_to_idx = {} + num_scenes = dataset.num_scenes() + for idx in range(num_scenes): + try: + scene = dataset.get_scene(idx) + scene_id_to_idx[scene.name] = idx + except Exception as e: + logger.warning(f"Failed to get scene at index {idx}: {e}") + continue + logger.info(f"Built scene_id mapping for {len(scene_id_to_idx)} scenes") + pools = create_address_pools(config) max_in_flight = compute_max_in_flight(pools, config) @@ -199,7 +221,8 @@ async def build_runtime_context( config=config, eval_config=eval_config, version_ids=version_ids, - scene_id_to_artifact_path=scene_id_to_artifact_path, + scene_id_to_idx=scene_id_to_idx, + dataset=dataset, pools=pools, max_in_flight=max_in_flight, ) diff --git a/src/runtime/alpasim_runtime/scene_loader.py b/src/runtime/alpasim_runtime/scene_loader.py new file mode 100644 index 00000000..8164bb9c --- /dev/null +++ b/src/runtime/alpasim_runtime/scene_loader.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 NVIDIA Corporation + +"""SceneLoader: Manages scene data loading and caching.""" + +from __future__ import annotations + +import logging + +from alpasim_runtime.config import SimulatorConfig +from alpasim_runtime.daemon.exceptions import UnknownSceneError +from alpasim_utils.scene_data_source import SceneDataSource +from alpasim_utils.trajdata_data_source import TrajdataDataSource +from omegaconf import OmegaConf +from trajdata.dataset import UnifiedDataset + +logger = logging.getLogger(__name__) + + +class SceneLoader: + """Manages scene data loading and caching. + + Encapsulates UnifiedDataset, scene ID to index mapping, and lazy loading + of SceneDataSource objects. Provides a clean interface for on-demand scene + data access with automatic caching. + + Attributes: + _dataset: UnifiedDataset for accessing trajdata scenes + _scene_id_to_idx: Mapping from scene IDs to dataset indices + _config: Simulator configuration for scene parameters + _cache: Cache of loaded SceneDataSource objects + _asset_base_path_map: Mapping from dataset name to asset_base_path + """ + + def __init__( + self, + dataset: UnifiedDataset, + scene_id_to_idx: dict[str, int], + config: SimulatorConfig, + ): + """Initialize SceneLoader with dataset and configuration. + + Args: + dataset: UnifiedDataset for scene access + scene_id_to_idx: Mapping from scene ID to dataset index + config: Simulator configuration + """ + self._dataset = dataset + self._scene_id_to_idx = scene_id_to_idx + self._config = config + self._cache: dict[str, SceneDataSource] = {} + + # Build dataset_name -> asset_base_path mapping + self._asset_base_path_map: dict[str, str] = {} + data_source_config = OmegaConf.to_object(config.user.data_source) + + for dataset_name, source in data_source_config.sources.items(): + asset_base_path = source.extra_params.get("asset_base_path") + if asset_base_path is not None: + self._asset_base_path_map[dataset_name] = asset_base_path + logger.info( + f"Registered asset_base_path for {dataset_name}: {asset_base_path}" + ) + + def get_data_source(self, scene_id: str) -> SceneDataSource: + """Get or create a data source for the given scene_id. + + Implements lazy loading with caching. On first access, creates a + TrajdataDataSource from the UnifiedDataset scene. Subsequent accesses + return the cached instance. + + Args: + scene_id: Scene identifier to load + + Returns: + SceneDataSource for the requested scene + + Raises: + UnknownSceneError: If scene_id is not found in the dataset + RuntimeError: If scene loading fails + """ + # Check cache first + if scene_id in self._cache: + return self._cache[scene_id] + + # Validate scene exists + scene_idx = self._scene_id_to_idx.get(scene_id) + if scene_idx is None: + raise UnknownSceneError(scene_id) + + try: + # Load scene from dataset + scene = self._dataset.get_scene(scene_idx) + if scene is None: + raise UnknownSceneError(scene_id) + + # Get asset_base_path for this scene's dataset + # Use scene.env_name to lookup the correct asset_base_path + asset_base_path = self._asset_base_path_map.get(scene.env_name) + + # Get map_api for lazy loading (lightweight, can be passed to workers) + map_api = getattr(self._dataset, "_map_api", None) + + # Create scene_cache (pre-create to avoid pickle errors) + scene_cache = self._dataset.cache_class( + self._dataset.cache_path, scene, self._dataset.augmentations + ) + scene_cache.set_obs_format(self._dataset.obs_format) + + # Create TrajdataDataSource + data_source = TrajdataDataSource.from_trajdata_scene( + scene=scene, + dataset=None, # Don't pass dataset to avoid pickle errors + map_api=map_api, # Pass map_api + scene_cache=scene_cache, + smooth_trajectories=self._config.user.smooth_trajectories, + asset_base_path=asset_base_path, + ) + + # Cache for future use + self._cache[scene_id] = data_source + logger.debug(f"Loaded data source for scene {scene_id}") + return data_source + + except Exception as e: + logger.error(f"Failed to load scene {scene_id}: {e}") + raise RuntimeError(f"Scene loading failed for {scene_id}") from e + + @property + def num_scenes(self) -> int: + """Return total number of scenes available in the dataset.""" + return self._dataset.num_scenes() + + @property + def num_cached(self) -> int: + """Return number of scenes currently cached.""" + return len(self._cache) diff --git a/src/runtime/alpasim_runtime/simulate/__main__.py b/src/runtime/alpasim_runtime/simulate/__main__.py index ac8649b5..3cb802fc 100644 --- a/src/runtime/alpasim_runtime/simulate/__main__.py +++ b/src/runtime/alpasim_runtime/simulate/__main__.py @@ -54,7 +54,7 @@ def create_arg_parser() -> argparse.ArgumentParser: # we split user and network config files because the latter is commonly auto-generated by kubernetes parser.add_argument("--user-config", type=str, required=True) parser.add_argument("--network-config", type=str, required=True) - parser.add_argument("--usdz-glob", type=str, required=True) + parser.add_argument( "--log-dir", type=str, @@ -95,7 +95,6 @@ async def _serve(args: argparse.Namespace) -> None: user_config=args.user_config, network_config=args.network_config, eval_config=args.eval_config, - usdz_glob=args.usdz_glob, log_dir=args.log_dir, validate_config_scenes=False, ) @@ -265,7 +264,6 @@ async def _run_one_shot_request( user_config=args.user_config, network_config=args.network_config, eval_config=args.eval_config, - usdz_glob=args.usdz_glob, log_dir=args.log_dir, ) diff --git a/src/runtime/alpasim_runtime/unbound_rollout.py b/src/runtime/alpasim_runtime/unbound_rollout.py index 8166f7a9..929c0d6b 100644 --- a/src/runtime/alpasim_runtime/unbound_rollout.py +++ b/src/runtime/alpasim_runtime/unbound_rollout.py @@ -21,9 +21,9 @@ VehicleConfig, ) from alpasim_runtime.services.sensorsim_service import ImageFormat -from alpasim_utils.artifact import Artifact from alpasim_utils.geometry import Pose, Trajectory from alpasim_utils.scenario import AABB, TrafficObjects +from alpasim_utils.scene_data_source import SceneDataSource from trajdata.maps import VectorMap logger = logging.getLogger(__name__) @@ -108,15 +108,15 @@ def create( simulation_config: SimulationConfig, scene_id: str, version_ids: RolloutMetadata.VersionIds, - available_artifacts: dict[str, Artifact], + data_source: SceneDataSource, rollouts_dir: str, ) -> UnboundRollout: - artifact = available_artifacts[scene_id] + """Create UnboundRollout from SceneDataSource.""" camera_configs = list(simulation_config.cameras) control_timestamps_us_arr: np.ndarray = ( - artifact.rig.trajectory.time_range_us.start + data_source.rig.trajectory.time_range_us.start + simulation_config.time_start_offset_us + np.arange( simulation_config.n_sim_steps + 2 @@ -125,19 +125,19 @@ def create( ) control_timestamps_us = [ - int(min(t, artifact.rig.trajectory.time_range_us.stop - 1)) + int(min(t, data_source.rig.trajectory.time_range_us.stop - 1)) for t in control_timestamps_us_arr if t - < artifact.rig.trajectory.time_range_us.stop + < data_source.rig.trajectory.time_range_us.stop + ORIGINAL_TRAJECTORY_DURATION_EXTENSION_US ] start_us = control_timestamps_us[0] end_us = control_timestamps_us[-1] - gt_ego_trajectory = artifact.rig.trajectory + gt_ego_trajectory = data_source.rig.trajectory # Filter out objects that are not in the time window - all_objs_in_window = artifact.traffic_objects.clip_trajectories( + all_objs_in_window = data_source.traffic_objects.clip_trajectories( start_us, end_us + 1, exclude_empty=True ) @@ -191,8 +191,8 @@ def create( if simulation_config.vehicle is not None: vehicle = simulation_config.vehicle - elif artifact.rig.vehicle_config is not None: - vehicle = artifact.rig.vehicle_config + elif data_source.rig.vehicle_config is not None: + vehicle = data_source.rig.vehicle_config else: raise ValueError("No vehicle config provided/found.") @@ -208,7 +208,7 @@ def create( gt_ego_trajectory=gt_ego_trajectory, traffic_objs=traffic_objects, n_sim_steps=simulation_config.n_sim_steps, - start_timestamp_us=artifact.rig.trajectory.time_range_us.start, + start_timestamp_us=data_source.rig.trajectory.time_range_us.start, force_gt_duration_us=simulation_config.force_gt_duration_us, control_timestep_us=simulation_config.control_timestep_us, follow_log=None, @@ -228,15 +228,15 @@ def create( vehicle ), ego_aabb=ego_aabb, - nre_runid=str(artifact.metadata.logger.run_id), - nre_version=artifact.metadata.version_string, - nre_uuid=str(artifact.metadata.uuid), + nre_runid=str(data_source.metadata.logger.run_id), + nre_version=data_source.metadata.version_string, + nre_uuid=str(data_source.metadata.uuid), planner_delay_us=simulation_config.planner_delay_us, pose_reporting_interval_us=simulation_config.pose_reporting_interval_us, route_generator_type=simulation_config.route_generator_type, send_recording_ground_truth=simulation_config.send_recording_ground_truth, vehicle_config=vehicle, - vector_map=artifact.map, + vector_map=data_source.map, hidden_traffic_objs=hidden_traffic_objs, group_render_requests=simulation_config.group_render_requests, ) diff --git a/src/runtime/alpasim_runtime/worker/ipc.py b/src/runtime/alpasim_runtime/worker/ipc.py index 9568f937..0e937fe2 100644 --- a/src/runtime/alpasim_runtime/worker/ipc.py +++ b/src/runtime/alpasim_runtime/worker/ipc.py @@ -14,6 +14,7 @@ from alpasim_grpc.v0.logging_pb2 import RolloutMetadata from alpasim_runtime.address_pool import ServiceAddress from alpasim_runtime.telemetry.rpc_wrapper import SharedRpcTracking +from alpasim_utils.scene_data_source import SceneDataSource from eval.scenario_evaluator import ScenarioEvalResult from eval.schema import EvalConfig @@ -42,8 +43,8 @@ class PendingRolloutJob: scene_id: str # Index of rollout spec in SimulationRequest.rollout_specs rollout_spec_index: int - # Artifact source path for this job's scene. - artifact_path: str + # SceneDataSource for this job's scene + data_source: SceneDataSource @dataclass @@ -58,8 +59,8 @@ class AssignedRolloutJob: scene_id: str # Index of rollout spec in SimulationRequest.rollout_specs rollout_spec_index: int - # Artifact source path for this job's scene. - artifact_path: str + # SceneDataSource for this job's scene + data_source: SceneDataSource # Concrete service addresses assigned by the parent dispatch loop. endpoints: ServiceEndpoints diff --git a/src/runtime/alpasim_runtime/worker/main.py b/src/runtime/alpasim_runtime/worker/main.py index cd1b4b02..f527d69a 100644 --- a/src/runtime/alpasim_runtime/worker/main.py +++ b/src/runtime/alpasim_runtime/worker/main.py @@ -39,14 +39,13 @@ from alpasim_runtime.telemetry.rpc_wrapper import set_shared_rpc_tracking from alpasim_runtime.telemetry.telemetry_context import TelemetryContext from alpasim_runtime.unbound_rollout import UnboundRollout -from alpasim_runtime.worker.artifact_cache import make_artifact_loader from alpasim_runtime.worker.ipc import ( AssignedRolloutJob, JobResult, WorkerArgs, _ShutdownSentinel, ) -from alpasim_utils.artifact import Artifact +from alpasim_utils.scene_data_source import SceneDataSource from eval.schema import EvalConfig @@ -61,7 +60,7 @@ def _is_orphaned(parent_pid: int) -> bool: async def run_single_rollout( job: AssignedRolloutJob, user_config: UserSimulatorConfig, - artifacts: dict[str, Artifact], + data_source: SceneDataSource, camera_catalog: CameraCatalog, version_ids: RolloutMetadata.VersionIds, rollouts_dir: str, @@ -105,7 +104,7 @@ async def run_single_rollout( simulation_config=user_config.simulation_config, scene_id=job.scene_id, version_ids=version_ids, - available_artifacts=artifacts, + data_source=data_source, rollouts_dir=rollouts_dir, ), ) @@ -161,8 +160,6 @@ async def run_worker_loop( result_queue: Queue, num_consumers: int, user_config: UserSimulatorConfig, - smooth_trajectories: bool, - artifact_cache_size: int | None, camera_catalog: CameraCatalog, version_ids: RolloutMetadata.VersionIds, rollouts_dir: str, @@ -178,9 +175,6 @@ async def run_worker_loop( result_queue: Queue to push JobResult to. num_consumers: Number of concurrent consumer tasks. user_config: User simulator configuration. - smooth_trajectories: Whether to smooth trajectories when loading artifacts. - artifact_cache_size: Max worker-local artifact cache size. - None = unlimited cache, 0 = disable cache. camera_catalog: Camera catalog for sensorsim. version_ids: Canonical version IDs from the parent process. rollouts_dir: Directory for rollout outputs. @@ -205,11 +199,6 @@ async def run_worker_loop( # Install event loop idle profiler install_event_loop_idle_profiler(loop) - load_artifact = make_artifact_loader( - smooth_trajectories=smooth_trajectories, - max_cache_size=artifact_cache_size, - ) - # Create a process pool for offloading CPU-bound eval computation. # One slot per consumer so no consumer blocks waiting for a pool slot. eval_executor = ProcessPoolExecutor(max_workers=num_consumers) @@ -251,13 +240,11 @@ def _poll_job() -> AssignedRolloutJob | _ShutdownSentinel | None: shutdown_event.set() break - artifact = load_artifact(job.scene_id, job.artifact_path) - - # Process the job + # Process the job using data_source from the job result = await run_single_rollout( job=job, user_config=user_config, - artifacts={job.scene_id: artifact}, + data_source=job.data_source, camera_catalog=camera_catalog, version_ids=version_ids, rollouts_dir=rollouts_dir, @@ -352,8 +339,6 @@ async def worker_async_main(args: WorkerArgs) -> None: result_queue=args.result_queue, num_consumers=args.num_consumers, user_config=user_config, - smooth_trajectories=user_config.smooth_trajectories, - artifact_cache_size=user_config.artifact_cache_size, camera_catalog=camera_catalog, version_ids=args.version_ids, rollouts_dir=rollouts_dir, diff --git a/src/runtime/tests/test_config.py b/src/runtime/tests/test_config.py index 9b0d321c..9fe691b7 100644 --- a/src/runtime/tests/test_config.py +++ b/src/runtime/tests/test_config.py @@ -50,3 +50,82 @@ def test_typed_parse_config_invalid_yaml(tmp_path): # TODO(mwatson, mtyszkiewicz): What should happen when the config is empty? Currently, # no error is raised, and we return an empty config object. Is this the desired behavior? + + +def test_data_source_config_defaults(): + """Test that DataSourceConfig has sensible defaults.""" + cfg = config.DataSourceConfig( + cache_location="/tmp/cache", + sources={"usdz": config.GenericSourceConfig(data_dir="/data/usdz")}, + ) + assert cfg.cache_location == "/tmp/cache" + assert cfg.rebuild_cache is False + assert cfg.rebuild_maps is False + assert cfg.num_workers == 1 + assert "usdz" in cfg.sources + assert cfg.sources["usdz"].data_dir == "/data/usdz" + assert cfg.sources["usdz"].extra_params == {} + + +def test_data_source_config_to_trajdata_params(): + """Test conversion from hierarchical config to flat trajdata parameters.""" + cfg = config.DataSourceConfig( + cache_location="/tmp/cache", + rebuild_cache=True, + num_workers=8, + desired_dt=0.05, + incl_vector_map=True, + sources={ + "usdz": config.GenericSourceConfig( + data_dir="/data/usdz", + extra_params={"asset_base_path": "/assets"}, + ) + }, + ) + params = cfg.to_trajdata_params() + + assert params["desired_data"] == ["usdz"] + assert params["data_dirs"] == {"usdz": "/data/usdz"} + assert params["cache_location"] == "/tmp/cache" + assert params["rebuild_cache"] is True + assert params["num_workers"] == 8 + assert params["desired_dt"] == 0.05 + assert params["incl_vector_map"] is True + assert params["dataset_kwargs"] == {"usdz": {"asset_base_path": "/assets"}} + + +def test_data_source_config_multiple_sources(): + """Test configuration with multiple data sources enabled.""" + cfg = config.DataSourceConfig( + cache_location="/tmp/cache", + sources={ + "usdz": config.GenericSourceConfig(data_dir="/data/usdz"), + "nuplan": config.GenericSourceConfig( + data_dir="/data/nuplan", + extra_params={ + "config_dir": "/configs", + "num_timesteps_before": 30, + "num_timesteps_after": 80, + }, + ), + }, + ) + params = cfg.to_trajdata_params() + + assert set(params["desired_data"]) == {"usdz", "nuplan"} + assert params["data_dirs"]["usdz"] == "/data/usdz" + assert params["data_dirs"]["nuplan"] == "/data/nuplan" + assert params["dataset_kwargs"]["nuplan"]["config_dir"] == "/configs" + assert params["dataset_kwargs"]["nuplan"]["num_timesteps_before"] == 30 + assert params["dataset_kwargs"]["nuplan"]["num_timesteps_after"] == 80 + + +def test_data_source_config_no_sources_enabled(): + """Test that error is raised when no data sources are enabled.""" + cfg = config.DataSourceConfig( + cache_location="/tmp/cache", + # All sources are None, so no sources enabled + ) + + with pytest.raises(ValueError, match="No data sources enabled"): + cfg.to_trajdata_params() diff --git a/src/runtime/tests/test_daemon_engine.py b/src/runtime/tests/test_daemon_engine.py index 13c1522a..5f7929c8 100644 --- a/src/runtime/tests/test_daemon_engine.py +++ b/src/runtime/tests/test_daemon_engine.py @@ -65,7 +65,7 @@ async def _fake_build_runtime_context(*args, **kwargs): config=config, eval_config=eval_config, version_ids=version_ids, - scene_id_to_artifact_path={"clipgt-a": "/tmp/scene-a.usdz"}, + scene_id_to_idx={"clipgt-a": 0}, pools={"driver": MagicMock()}, max_in_flight=1, ) @@ -87,7 +87,6 @@ async def _fake_build_runtime_context(*args, **kwargs): user_config="u.yaml", network_config="n.yaml", eval_config="e.yaml", - usdz_glob="/tmp/*.usdz", log_dir="/tmp/log", ) @@ -130,7 +129,7 @@ async def _fake_build_runtime_context(*args, **kwargs): config=config, eval_config=eval_config, version_ids=version_ids, - scene_id_to_artifact_path={"clipgt-a": "/tmp/scene-a.usdz"}, + scene_id_to_idx={"clipgt-a": 0}, pools={"driver": MagicMock()}, max_in_flight=1, ) @@ -152,7 +151,6 @@ async def _fake_build_runtime_context(*args, **kwargs): user_config="u.yaml", network_config="n.yaml", eval_config="e.yaml", - usdz_glob="/tmp/*.usdz", log_dir="/tmp/log", validate_config_scenes=False, ) @@ -184,11 +182,13 @@ async def wait_request(self, request_id: str): user_config="u.yaml", network_config="n.yaml", eval_config="e.yaml", - usdz_glob="/tmp/*.usdz", log_dir="/tmp/log", ) engine._started = True - engine._scene_id_to_artifact_path = {"clipgt-a": "/tmp/scene-a.usdz"} + # Mock SceneLoader with a fake data source + mock_scene_loader = MagicMock() + mock_scene_loader.get_data_source.return_value = MagicMock() + engine._scene_loader = mock_scene_loader engine._version_ids = RolloutMetadata.VersionIds( runtime_version=VersionId(version_id="runtime", git_hash="a"), sensorsim_version=VersionId(version_id="sensorsim", git_hash="b"), @@ -232,11 +232,13 @@ async def wait_request(self, request_id): user_config="u.yaml", network_config="n.yaml", eval_config="e.yaml", - usdz_glob="/tmp/*.usdz", log_dir="/tmp/log", ) engine._started = True - engine._scene_id_to_artifact_path = {"clipgt-a": "/tmp/scene-a.usdz"} + # Mock SceneLoader with a fake data source + mock_scene_loader = MagicMock() + mock_scene_loader.get_data_source.return_value = MagicMock() + engine._scene_loader = mock_scene_loader engine._version_ids = RolloutMetadata.VersionIds( runtime_version=VersionId(version_id="runtime", git_hash="a"), sensorsim_version=VersionId(version_id="sensorsim", git_hash="b"), @@ -285,11 +287,13 @@ async def wait_request(self, request_id): user_config="u.yaml", network_config="n.yaml", eval_config="e.yaml", - usdz_glob="/tmp/*.usdz", log_dir="/tmp/log", ) engine._started = True - engine._scene_id_to_artifact_path = {"clipgt-a": "/tmp/scene-a.usdz"} + # Mock SceneLoader with a fake data source + mock_scene_loader = MagicMock() + mock_scene_loader.get_data_source.return_value = MagicMock() + engine._scene_loader = mock_scene_loader engine._version_ids = RolloutMetadata.VersionIds( runtime_version=VersionId(version_id="runtime", git_hash="a"), sensorsim_version=VersionId(version_id="sensorsim", git_hash="b"), diff --git a/src/runtime/tests/test_daemon_main.py b/src/runtime/tests/test_daemon_main.py index e31827bd..53139618 100644 --- a/src/runtime/tests/test_daemon_main.py +++ b/src/runtime/tests/test_daemon_main.py @@ -11,7 +11,7 @@ import pytest from alpasim_grpc.v0 import common_pb2, runtime_pb2 from alpasim_runtime.daemon.app import RuntimeDaemonApp -from alpasim_runtime.daemon.engine import UnknownSceneError +from alpasim_runtime.daemon.exceptions import UnknownSceneError from alpasim_runtime.daemon.servicer import RuntimeDaemonServicer from alpasim_runtime.simulate.__main__ import _serve, create_arg_parser, run_simulation @@ -143,7 +143,6 @@ async def run(self) -> None: user_config="u.yaml", network_config="n.yaml", eval_config="e.yaml", - usdz_glob="/tmp/*.usdz", log_dir="/tmp/log", listen_address="[::]:50051", ) @@ -194,7 +193,6 @@ async def run(self) -> None: user_config="u.yaml", network_config="n.yaml", eval_config="e.yaml", - usdz_glob="/tmp/*.usdz", log_dir="/tmp/log", listen_address="[::]:50051", ) @@ -418,7 +416,6 @@ def _make_one_shot_args() -> Namespace: user_config="u.yaml", network_config="n.yaml", eval_config="e.yaml", - usdz_glob="/tmp/*.usdz", log_dir="/tmp/log", array_job_dir=None, ) @@ -532,7 +529,6 @@ async def test_run_simulation_one_shot_uses_daemon_engine( user_config="u.yaml", network_config="n.yaml", eval_config="e.yaml", - usdz_glob="/tmp/*.usdz", log_dir="/tmp/log", ) fake_engine.startup.assert_awaited_once() diff --git a/src/runtime/tests/test_daemon_request_plumbing.py b/src/runtime/tests/test_daemon_request_plumbing.py index 88600196..88a0d10c 100644 --- a/src/runtime/tests/test_daemon_request_plumbing.py +++ b/src/runtime/tests/test_daemon_request_plumbing.py @@ -3,12 +3,12 @@ from __future__ import annotations +from unittest.mock import MagicMock + import pytest from alpasim_grpc.v0 import runtime_pb2 -from alpasim_runtime.daemon.engine import ( - UnknownSceneError, - build_pending_jobs_from_request, -) +from alpasim_runtime.daemon.engine import build_pending_jobs_from_request +from alpasim_runtime.daemon.exceptions import UnknownSceneError def test_adapter_expands_nr_rollouts() -> None: @@ -16,13 +16,17 @@ def test_adapter_expands_nr_rollouts() -> None: rollout_specs=[runtime_pb2.RolloutSpec(scenario_id="clipgt-a", nr_rollouts=3)] ) - jobs = build_pending_jobs_from_request( - req, - scene_id_to_artifact_path={"clipgt-a": "/tmp/clipgt-a.usdz"}, - ) + mock_data_source = MagicMock() + + def fake_get_data_source(scene_id: str): + if scene_id == "clipgt-a": + return mock_data_source + raise UnknownSceneError(scene_id) + + jobs = build_pending_jobs_from_request(req, fake_get_data_source) assert [job.scene_id for job in jobs] == ["clipgt-a", "clipgt-a", "clipgt-a"] assert [job.rollout_spec_index for job in jobs] == [0, 0, 0] - assert all(job.artifact_path == "/tmp/clipgt-a.usdz" for job in jobs) + assert all(job.data_source is mock_data_source for job in jobs) def test_adapter_drops_zero_nr_rollouts_with_warning( @@ -33,10 +37,12 @@ def test_adapter_drops_zero_nr_rollouts_with_warning( rollout_specs=[runtime_pb2.RolloutSpec(scenario_id="clipgt-a")] ) - jobs = build_pending_jobs_from_request( - req, - scene_id_to_artifact_path={"clipgt-a": "/tmp/clipgt-a.usdz"}, - ) + mock_data_source = MagicMock() + + def fake_get_data_source(scene_id: str): + return mock_data_source + + jobs = build_pending_jobs_from_request(req, fake_get_data_source) assert jobs == [] assert "Dropping rollout spec with nr_rollouts=0" in caplog.text @@ -48,11 +54,13 @@ def test_adapter_rejects_scene_without_artifact() -> None: ] ) + def fake_get_data_source(scene_id: str): + if scene_id == "clipgt-missing": + raise UnknownSceneError(scene_id) + return MagicMock() + with pytest.raises(UnknownSceneError): - build_pending_jobs_from_request( - req, - scene_id_to_artifact_path={"clipgt-a": "/tmp/clipgt-a.usdz"}, - ) + build_pending_jobs_from_request(req, fake_get_data_source) def test_adapter_assigns_rollout_spec_indexes_in_request_order() -> None: @@ -63,20 +71,24 @@ def test_adapter_assigns_rollout_spec_indexes_in_request_order() -> None: ] ) - jobs = build_pending_jobs_from_request( - req, - scene_id_to_artifact_path={ - "clipgt-a": "/tmp/clipgt-a.usdz", - "clipgt-b": "/tmp/clipgt-b.usdz", - }, - ) + mock_data_source_a = MagicMock() + mock_data_source_b = MagicMock() + + def fake_get_data_source(scene_id: str): + if scene_id == "clipgt-a": + return mock_data_source_a + elif scene_id == "clipgt-b": + return mock_data_source_b + raise UnknownSceneError(scene_id) + + jobs = build_pending_jobs_from_request(req, fake_get_data_source) assert len(jobs) == 3 assert [job.scene_id for job in jobs] == ["clipgt-a", "clipgt-b", "clipgt-b"] assert [job.rollout_spec_index for job in jobs] == [0, 1, 1] - assert [job.artifact_path for job in jobs] == [ - "/tmp/clipgt-a.usdz", - "/tmp/clipgt-b.usdz", - "/tmp/clipgt-b.usdz", + assert [job.data_source for job in jobs] == [ + mock_data_source_a, + mock_data_source_b, + mock_data_source_b, ] @@ -88,13 +100,17 @@ def test_adapter_ignores_zero_rollout_specs_when_indexing() -> None: ] ) - jobs = build_pending_jobs_from_request( - req, - scene_id_to_artifact_path={ - "clipgt-a": "/tmp/clipgt-a.usdz", - "clipgt-b": "/tmp/clipgt-b.usdz", - }, - ) + mock_data_source_a = MagicMock() + mock_data_source_b = MagicMock() + + def fake_get_data_source(scene_id: str): + if scene_id == "clipgt-a": + return mock_data_source_a + elif scene_id == "clipgt-b": + return mock_data_source_b + raise UnknownSceneError(scene_id) + + jobs = build_pending_jobs_from_request(req, fake_get_data_source) assert len(jobs) == 2 assert [job.scene_id for job in jobs] == ["clipgt-b", "clipgt-b"] assert [job.rollout_spec_index for job in jobs] == [1, 1] diff --git a/src/runtime/tests/test_runtime_integration_replay.py b/src/runtime/tests/test_runtime_integration_replay.py index c55b50c8..ad84f9e8 100644 --- a/src/runtime/tests/test_runtime_integration_replay.py +++ b/src/runtime/tests/test_runtime_integration_replay.py @@ -3,6 +3,11 @@ """Manual integration replay test for runtime request determinism. +TODO: This test needs to be updated for the new trajdata-based data flow. +The test currently references usdz_glob which has been removed in favor of +DataSourceConfig. This test is marked as @pytest.mark.manual and is not +run in regular CI, so it can be updated as part of a follow-up task. + This test boots one replay gRPC server per service and re-runs runtime against recorded artifacts. It asserts that runtime emits requests matching the ASL recording exactly (ignoring expected dynamic fields). diff --git a/src/runtime/tests/test_trajdata_integration.py b/src/runtime/tests/test_trajdata_integration.py new file mode 100644 index 00000000..3a547d50 --- /dev/null +++ b/src/runtime/tests/test_trajdata_integration.py @@ -0,0 +1,178 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 NVIDIA Corporation + +"""Integration tests for trajdata data source functionality.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest +from alpasim_grpc.v0 import runtime_pb2 +from alpasim_runtime.daemon.engine import DaemonEngine, build_pending_jobs_from_request +from alpasim_runtime.daemon.exceptions import UnknownSceneError + + +@pytest.fixture +def mock_trajdata_scene(): + """Create a mock trajdata Scene object.""" + scene = MagicMock() + scene.name = "test_scene_001" + scene.length_timesteps = 100 + scene.dt = 0.1 + return scene + + +@pytest.fixture +def mock_trajdata_dataset(mock_trajdata_scene): + """Create a mock UnifiedDataset.""" + dataset = MagicMock() + dataset.num_scenes.return_value = 1 + dataset.get_scene.return_value = mock_trajdata_scene + dataset.cache_class = MagicMock + dataset.cache_path = "/tmp/cache" + dataset.augmentations = None + dataset.obs_format = MagicMock() + return dataset + + +def test_build_pending_jobs_with_valid_scene(): + """Test building pending jobs from a simulation request with valid scene IDs.""" + request = runtime_pb2.SimulationRequest( + rollout_specs=[ + runtime_pb2.RolloutSpec(scenario_id="scene_a", nr_rollouts=2), + runtime_pb2.RolloutSpec(scenario_id="scene_b", nr_rollouts=1), + ] + ) + + mock_data_source_a = MagicMock() + mock_data_source_b = MagicMock() + + def fake_get_data_source(scene_id: str): + if scene_id == "scene_a": + return mock_data_source_a + elif scene_id == "scene_b": + return mock_data_source_b + raise UnknownSceneError(scene_id) + + jobs = build_pending_jobs_from_request(request, fake_get_data_source) + + # Should create 3 jobs total: 2 for scene_a, 1 for scene_b + assert len(jobs) == 3 + + # Check first two jobs are for scene_a + assert jobs[0].scene_id == "scene_a" + assert jobs[0].rollout_spec_index == 0 + assert jobs[0].data_source is mock_data_source_a + assert jobs[1].scene_id == "scene_a" + assert jobs[1].rollout_spec_index == 0 + assert jobs[1].data_source is mock_data_source_a + + # Check third job is for scene_b + assert jobs[2].scene_id == "scene_b" + assert jobs[2].rollout_spec_index == 1 + assert jobs[2].data_source is mock_data_source_b + + +def test_build_pending_jobs_with_unknown_scene(): + """Test that UnknownSceneError is raised for unknown scene IDs.""" + request = runtime_pb2.SimulationRequest( + rollout_specs=[ + runtime_pb2.RolloutSpec(scenario_id="unknown_scene", nr_rollouts=1), + ] + ) + + def fake_get_data_source(scene_id: str): + raise UnknownSceneError(scene_id) + + with pytest.raises(UnknownSceneError) as exc_info: + build_pending_jobs_from_request(request, fake_get_data_source) + + assert exc_info.value.scene_id == "unknown_scene" + + +def test_build_pending_jobs_drops_zero_rollouts(): + """Test that specs with nr_rollouts=0 are dropped with a warning.""" + request = runtime_pb2.SimulationRequest( + rollout_specs=[ + runtime_pb2.RolloutSpec(scenario_id="scene_a", nr_rollouts=1), + runtime_pb2.RolloutSpec(scenario_id="scene_b", nr_rollouts=0), + ] + ) + + mock_data_source = MagicMock() + + def fake_get_data_source(scene_id: str): + return mock_data_source + + jobs = build_pending_jobs_from_request(request, fake_get_data_source) + + # Should only create 1 job (scene_b with 0 rollouts is dropped) + assert len(jobs) == 1 + assert jobs[0].scene_id == "scene_a" + + +def test_daemon_engine_get_data_source_caching( + mock_trajdata_dataset, mock_trajdata_scene +): + """Test that DaemonEngine uses SceneLoader for caching data sources.""" + engine = DaemonEngine( + user_config="u.yaml", + network_config="n.yaml", + eval_config="e.yaml", + log_dir="/tmp/log", + ) + + # Set up engine with mock SceneLoader + engine._started = True + mock_scene_loader = MagicMock() + mock_data_source = MagicMock() + mock_scene_loader.get_data_source.return_value = mock_data_source + engine._scene_loader = mock_scene_loader + + # First call should delegate to SceneLoader + data_source_1 = engine._get_data_source("test_scene_001") + assert data_source_1 is mock_data_source + assert mock_scene_loader.get_data_source.call_count == 1 + + # Second call should also delegate (SceneLoader handles caching internally) + data_source_2 = engine._get_data_source("test_scene_001") + assert data_source_2 is mock_data_source + assert mock_scene_loader.get_data_source.call_count == 2 + + +def test_daemon_engine_get_data_source_unknown_scene(): + """Test that _get_data_source raises UnknownSceneError for unknown scenes.""" + engine = DaemonEngine( + user_config="u.yaml", + network_config="n.yaml", + eval_config="e.yaml", + log_dir="/tmp/log", + ) + + engine._started = True + # Mock SceneLoader that raises UnknownSceneError + mock_scene_loader = MagicMock() + mock_scene_loader.get_data_source.side_effect = UnknownSceneError("unknown_scene") + engine._scene_loader = mock_scene_loader + + with pytest.raises(UnknownSceneError) as exc_info: + engine._get_data_source("unknown_scene") + + assert exc_info.value.scene_id == "unknown_scene" + + +def test_daemon_engine_get_data_source_without_dataset(): + """Test that _get_data_source raises RuntimeError if SceneLoader is not initialized.""" + engine = DaemonEngine( + user_config="u.yaml", + network_config="n.yaml", + eval_config="e.yaml", + log_dir="/tmp/log", + ) + + engine._started = True + engine._scene_loader = None + + with pytest.raises(RuntimeError, match="SceneLoader not initialized"): + engine._get_data_source("test_scene") diff --git a/src/utils/alpasim_utils/scene_data_source.py b/src/utils/alpasim_utils/scene_data_source.py new file mode 100644 index 00000000..865754d6 --- /dev/null +++ b/src/utils/alpasim_utils/scene_data_source.py @@ -0,0 +1,78 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 NVIDIA Corporation + +""" +SceneDataSource Protocol for abstracting scene data loading. + +This Protocol allows Runtime to work with different data sources (USDZ, Nuplan, Waymo, etc.) +without being tied to a specific implementation. Any class that implements this Protocol +can be used as a data source for alpasim Runtime. +""" + +from __future__ import annotations + +from typing import Optional, Protocol, runtime_checkable + +try: + from trajdata.maps import VectorMap +except ImportError: + VectorMap = None # type: ignore + +from alpasim_utils.artifact import Metadata +from alpasim_utils.scenario import Rig, TrafficObjects + + +@runtime_checkable +class SceneDataSource(Protocol): + """ + Protocol defining the interface for scene data sources. + + Any class implementing this protocol can be used as a data source for alpasim Runtime. + This allows supporting multiple data formats (USDZ, Nuplan, Waymo, etc.) without + modifying Runtime code. + + Attributes: + scene_id: Unique identifier for the scene + """ + + scene_id: str + + @property + def rig(self) -> Rig: + """ + Get the rig (ego vehicle) trajectory and configuration. + + Returns: + Rig object containing trajectory, camera IDs, and vehicle config + """ + ... + + @property + def traffic_objects(self) -> TrafficObjects: + """ + Get traffic objects (vehicles, pedestrians, etc.) in the scene. + + Returns: + TrafficObjects dictionary mapping track_id to TrafficObject + """ + ... + + @property + def map(self) -> Optional[VectorMap]: + """ + Get the vector map for the scene. + + Returns: + VectorMap object or None if map data is not available + """ + ... + + @property + def metadata(self) -> Metadata: + """ + Get metadata about the scene. + + Returns: + Metadata object containing scene information + """ + ... diff --git a/src/utils/alpasim_utils/trajdata_data_source.py b/src/utils/alpasim_utils/trajdata_data_source.py new file mode 100644 index 00000000..fac76fa2 --- /dev/null +++ b/src/utils/alpasim_utils/trajdata_data_source.py @@ -0,0 +1,925 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2026 NVIDIA Corporation + +""" +TrajdataDataSource: Implementation for loading scene data directly from trajdata + +This class demonstrates how to create a SceneDataSource implementation that loads +data directly from trajdata converted data without requiring USDZ format. This is +useful for researchers using trajdata datasets. + +Usage example: + from trajdata import UnifiedDataset + from alpasim_utils.trajdata_data_source import TrajdataDataSource + + # Load trajdata dataset + dataset = UnifiedDataset( + desired_data=["nusc_mini"], + data_dirs={"/path/to/trajdata/data"}, + ... + ) + + # Get a scene + scene = dataset.get_scene("nusc_mini", "scene-0001") + + # Create data source + data_source = TrajdataDataSource.from_trajdata_scene(scene) + + # Now can be used in Runtime + # artifacts = {data_source.scene_id: data_source} +""" + +from __future__ import annotations + +import copy +import hashlib +import logging +import os +from dataclasses import dataclass +from typing import Optional + +import csaps +import numpy as np +from alpasim_utils.artifact import Metadata +from alpasim_utils.geometry import Trajectory +from alpasim_utils.scenario import ( + AABB, + CameraId, + Rig, + TrafficObject, + TrafficObjects, + VehicleConfig, +) +from alpasim_utils.scene_data_source import SceneDataSource +from scipy.spatial.transform import Rotation as R +from trajdata.caching import EnvCache +from trajdata.data_structures.agent import AgentMetadata +from trajdata.data_structures.scene_metadata import Scene +from trajdata.dataset import UnifiedDataset +from trajdata.maps import VectorMap + +logger = logging.getLogger(__name__) + + +@dataclass +class TrajdataDataSource(SceneDataSource): + """ + Implementation for loading scene data directly from trajdata. + + This class implements the SceneDataSource protocol, allowing direct loading + from trajdata Scene or AgentBatch objects without requiring USDZ format. + + Property loading order dependencies: + - rig: No dependencies (loads first, sets world_to_nre transformation) + - traffic_objects: Requires rig (uses world_to_nre for coordinate conversion) + - map: Requires rig (uses world_to_nre for coordinate conversion) + - metadata: Requires rig (uses trajectory time range) + + All properties use lazy loading and caching for efficiency. + """ + + _scene: Scene | None = None + _scene_cache: EnvCache | None = None + _dataset: UnifiedDataset | None = None + _map_api: Optional[object] = ( + None # MapAPI for loading maps (lightweight, can be passed separately) + ) + _rig: Rig | None = None + _traffic_objects: TrafficObjects | None = None + _map: VectorMap | None = None + _metadata: Metadata | None = None + _smooth_trajectories: bool = True + _scene_id: str = "" + _asset_base_path: str | None = None # Base path for rendering assets + + @classmethod + def from_trajdata_scene( + cls, + scene: Scene, + dataset: Optional[UnifiedDataset] = None, + map_api=None, + scene_cache: Optional[EnvCache] = None, + scene_id: Optional[str] = None, + smooth_trajectories: bool = True, + base_timestamp_us: int = 0, + asset_base_path: Optional[str] = None, + ) -> TrajdataDataSource: + """ + Create TrajdataDataSource from trajdata Scene object. + + Args: + scene: trajdata Scene object + dataset: UnifiedDataset instance (for getting scene_cache and map) + map_api: MapAPI instance for loading maps (lightweight alternative to dataset) + scene_cache: Optional EnvCache instance (if not provided, will be created from dataset) + scene_id: Optional scene ID (if not provided, uses scene.name) + smooth_trajectories: Whether to smooth trajectories + base_timestamp_us: Base timestamp in microseconds, starts from 0 if None + asset_base_path: Base path for rendering assets (e.g., MTGS assets) + + Returns: + TrajdataDataSource instance + """ + data_source = cls( + _scene=scene, + _dataset=dataset, + _map_api=map_api, + _scene_cache=scene_cache, + _scene_id=scene_id or scene.name, + _smooth_trajectories=smooth_trajectories, + _asset_base_path=asset_base_path, + ) + data_source._base_timestamp_us = base_timestamp_us + return data_source + + @property + def scene_id(self) -> str: + """Scene ID (immutable identifier).""" + return self._scene_id + + @property + def asset_path(self) -> str | None: + """ + Resolve asset folder path for this scene. + + The asset path is constructed by appending the scene name to _asset_base_path. + The _asset_base_path should already contain any dataset-specific subdirectories + (e.g., it might be /data/WE_processed/navtest/assets for MTGS). + + Returns: + Resolved asset folder path, or None if _asset_base_path is not set + """ + if self._asset_base_path is None: + return None + + # Extract asset folder name from scene metadata + scene_name = self._extract_asset_folder_name() + + # Simple join: asset_base_path already contains dataset-specific subdirs + return os.path.join(self._asset_base_path, scene_name) + + def _extract_asset_folder_name(self) -> str: + """ + Extract the asset folder name from scene metadata. + + This method attempts to determine the appropriate asset folder name + based on scene metadata. Override this in subclasses if needed. + + Resolution order: + 1. USDZ: Use usdz_stem from data_access_info + 2. Other datasets: Use log_id or asset_folder from data_access_info + 3. Fallback: Use scene_id with common suffixes removed + + Returns: + Asset folder name (defaults to scene_id if no specific name found) + """ + # Try to get from scene data_access_info + if self._scene is not None and self._scene.data_access_info is not None: + data_access_info = self._scene.data_access_info + + # USDZ: Use usdz_stem (filename without .usdz extension) + if "usdz_stem" in data_access_info: + return data_access_info["usdz_stem"] + + # Look for asset_folder or similar keys + if "asset_folder" in data_access_info: + return data_access_info["asset_folder"] + + # NuPlan and other datasets: use log_id + if "log_id" in data_access_info: + return data_access_info["log_id"] + + # Default: use scene_id (potentially with suffix removed) + scene_name = self.scene_id + # Remove common suffixes like "-001" + if len(scene_name) > 4 and scene_name[-4] == "-" and scene_name[-3:].isdigit(): + scene_name = scene_name[:-4] + return scene_name + + def set_asset_base_path(self, path: str | None) -> None: + """Set the base path for rendering assets.""" + self._asset_base_path = path + + def _get_scene_cache(self) -> EnvCache: + """Get or create scene_cache""" + if self._scene_cache is not None: + return self._scene_cache + + if self._scene is None: + raise ValueError("Cannot create scene_cache: scene is not set") + + # Try to create from dataset if available + if self._dataset is not None: + logger.debug(f"Creating scene_cache for scene: {self._scene.name}") + try: + self._scene_cache = self._dataset.cache_class( + self._dataset.cache_path, self._scene, self._dataset.augmentations + ) + self._scene_cache.set_obs_format(self._dataset.obs_format) + logger.debug("Scene cache created successfully") + return self._scene_cache + except Exception as e: + logger.error(f"Failed to create scene_cache: {e}") + raise + + # If dataset is not set, scene_cache must be provided externally + raise ValueError( + "Cannot create scene_cache: dataset is not set and scene_cache was not provided. " + "Either pass 'dataset' parameter or pre-create 'scene_cache' when creating TrajdataDataSource. " + "Example: TrajdataDataSource.from_trajdata_scene(scene, dataset=your_dataset) " + "or TrajdataDataSource.from_trajdata_scene(scene, scene_cache=your_cache)" + ) + + @staticmethod + def _get_state_value(state, attr_name: str, default=None): + """Extract scalar value from state attribute using StateArray.get_attr(). + + Args: + state: StateArray object from trajdata cache (from get_raw_state) + attr_name: Name of attribute to extract (e.g., "x", "y", "h") + default: Default value if attribute doesn't exist + + Returns: + Scalar float value + + Raises: + AttributeError: If attribute doesn't exist and no default provided + """ + try: + value = state.get_attr(attr_name) + + if value is None: + if default is not None: + return default + raise AttributeError(f"Attribute {attr_name} is None") + + # Convert to scalar if needed + if isinstance(value, np.ndarray): + return float(value.flat[0]) + return float(value) + except (KeyError, AttributeError, TypeError, IndexError) as e: + if default is not None: + return default + raise AttributeError(f"Failed to get {attr_name} from state: {e}") + + def _ensure_rig_loaded(self) -> None: + """Ensure rig is loaded before accessing world_to_nre. + + This method is called by properties that depend on world_to_nre + transformation (traffic_objects, map helper methods). + """ + if self._rig is None: + _ = self.rig + + def _extract_agent_trajectory( + self, + agent: AgentMetadata, + ) -> tuple[Optional[Trajectory], Optional[VehicleConfig]]: + """Extract complete trajectory for agent (refer to trajdata_artifact_converter.py implementation)""" + if self._scene is None: + return None, None + + scene_cache = self._get_scene_cache() + dt = self._scene.dt + base_timestamp_us = getattr(self, "_base_timestamp_us", None) + + try: + timestamps_us = [] + positions_agent_world = [] + quaternions_agent_world = [] + + # Iterate through all timesteps + for ts in range(agent.first_timestep, agent.last_timestep + 1): + try: + state = scene_cache.get_raw_state(agent.name, ts) + + # Get position and orientation using helper + x_val = self._get_state_value(state, "x") + y_val = self._get_state_value(state, "y") + z_val = self._get_state_value(state, "z", default=0.0) + heading_val = self._get_state_value(state, "h") + + # Calculate timestamp + if base_timestamp_us is None: + timestamp_us = int(ts * dt * 1e6) + else: + timestamp_us = int(base_timestamp_us + ts * dt * 1e6) + + timestamps_us.append(timestamp_us) + positions_agent_world.append([x_val, y_val, z_val]) + + # Convert heading to quaternion + quat = R.from_euler("z", heading_val).as_quat() # [x, y, z, w] + quaternions_agent_world.append(quat) + + except Exception as e: + logger.debug( + f"Failed to get state for agent {agent.name} at ts {ts}: {e}" + ) + continue + + if len(timestamps_us) == 0: + return None, None + + # Create Trajectory + trajectory = Trajectory( + timestamps=np.array(timestamps_us, dtype=np.uint64), + positions=np.array(positions_agent_world, dtype=np.float32), + quaternions=np.array(quaternions_agent_world, dtype=np.float32), + ) + + # Create VehicleConfig (extract from extent) + vehicle_config = VehicleConfig( + aabb_x_m=agent.extent.length, + aabb_y_m=agent.extent.width, + aabb_z_m=agent.extent.height, + aabb_x_offset_m=-agent.extent.length / 2, + aabb_y_offset_m=0.0, + aabb_z_offset_m=-agent.extent.height / 2, + ) + + return trajectory, vehicle_config + + except Exception as e: + logger.error(f"Failed to extract trajectory for agent {agent.name}: {e}") + return None, None + + @property + def rig(self) -> Rig: + """Load and return Rig object for ego vehicle""" + if self._rig is not None: + return self._rig + + if self._scene is None: + raise ValueError("Cannot load rig: scene is not set") + + # Get all agents + all_agents = self._scene.agents if self._scene.agents else [] + + # Identify ego agent + ego_agent = next((a for a in all_agents if a.name == "ego"), None) + if ego_agent is None and len(all_agents) > 0: + # If no ego, use first agent + ego_agent = all_agents[0] + logger.warning(f"No ego agent found, using first agent: {ego_agent.name}") + + if ego_agent is None: + raise ValueError("No ego agent found in scene") + + # Extract ego trajectory + ego_trajectory, ego_vehicle_config = self._extract_agent_trajectory(ego_agent) + + if ego_trajectory is None: + logger.error( + f"Failed to extract ego trajectory for agent {ego_agent.name}. " + f"Check if scene_cache is properly initialized and agent data is available." + ) + raise ValueError("Cannot extract ego trajectory") + + # Calculate world_to_nre transformation matrix (use first trajectory point as origin) + world_to_nre = np.eye(4) + if len(ego_trajectory) > 0: + position_ego_first_world = ego_trajectory.positions[0] + world_to_nre[:3, 3] = -position_ego_first_world + logger.info( + f"Setting world_to_nre origin at first pose: {position_ego_first_world}, " + f"translation: {world_to_nre[:3, 3]}" + ) + + # Convert ego trajectory to local coordinates (NRE) + if len(ego_trajectory) > 0: + translation = world_to_nre[:3, 3] + local_positions = ego_trajectory.positions + translation + + # Validate transform + position_ego_first_local = local_positions[0] + if np.linalg.norm(position_ego_first_local[:2]) > 1.0: + logger.warning( + f"First pose after transformation is not at origin: {position_ego_first_local}. " + f"Expected [0, 0, ~z], got {position_ego_first_local}" + ) + + local_quat = ego_trajectory.quaternions.copy() + ego_trajectory = Trajectory( + timestamps=ego_trajectory.timestamps_us.copy(), + positions=local_positions, + quaternions=local_quat, + ) + + logger.debug( + f"Transformed ego trajectory to local coordinates. " + f"First pose: {ego_trajectory.first_pose}, " + f"Range: X[{local_positions[:, 0].min():.2f}, {local_positions[:, 0].max():.2f}], " + f"Y[{local_positions[:, 1].min():.2f}, {local_positions[:, 1].max():.2f}], " + f"Z[{local_positions[:, 2].min():.2f}, {local_positions[:, 2].max():.2f}]" + ) + + # Extract camera information (refer to trajdata_artifact_converter.py) + camera_ids, _ = self._extract_camera_info_from_scene() + + self._rig = Rig( + sequence_id=self.scene_id, + trajectory=ego_trajectory, + camera_ids=camera_ids, + world_to_nre=world_to_nre, + vehicle_config=ego_vehicle_config, + ) + + return self._rig + + def _extract_camera_info_from_scene(self) -> tuple[list[CameraId], dict]: + """Extract camera information from scene (refer to trajdata_artifact_converter.py)""" + camera_ids = [] + camera_calibrations = {} + + if self._scene is None: + return camera_ids, camera_calibrations + + # Check if sensor_calibration information exists + if not self._scene.data_access_info: + logger.warning( + "scene.data_access_info is empty, skipping camera information extraction" + ) + return camera_ids, camera_calibrations + + sensor_calibration = self._scene.data_access_info.get("sensor_calibration") + if not sensor_calibration or not isinstance(sensor_calibration, dict): + logger.warning( + "sensor_calibration does not exist or has incorrect format, skipping camera information extraction" + ) + return camera_ids, camera_calibrations + + unique_sensor_idx = 0 + for camera_name, calibration_info in sensor_calibration.get( + "cameras", {} + ).items(): + try: + unique_camera_id = f"{camera_name}@{self.scene_id}" + + position_sensor_to_ego = calibration_info.get( + "sensor2ego_translation", [0.0, 0.0, 0.0] + ) + rotation_sensor_to_ego = calibration_info.get( + "sensor2ego_rotation", [0.0, 0.0, 0.0, 1.0] + ) + + if isinstance(position_sensor_to_ego, (int, float)): + position_sensor_to_ego = [float(position_sensor_to_ego), 0.0, 0.0] + elif len(position_sensor_to_ego) < 3: + position_sensor_to_ego = list(position_sensor_to_ego) + [0.0] * ( + 3 - len(position_sensor_to_ego) + ) + + if isinstance(rotation_sensor_to_ego, (int, float)): + rotation_sensor_to_ego = [0.0, 0.0, 0.0, 1.0] + elif len(rotation_sensor_to_ego) < 4: + if len(rotation_sensor_to_ego) == 3: + r = R.from_euler("xyz", rotation_sensor_to_ego) + rotation_sensor_to_ego = r.as_quat() + else: + rotation_sensor_to_ego = [0.0, 0.0, 0.0, 1.0] + + camera_id = CameraId( + logical_name=camera_name, + trajectory_idx=0, + sequence_id=self.scene_id, + unique_id=unique_camera_id, + ) + camera_ids.append(camera_id) + unique_sensor_idx += 1 + + except Exception as e: + logger.warning( + f"Error extracting camera {camera_name} information: {e}" + ) + continue + + if len(camera_ids) == 0: + # If no camera information, create a default one + logger.warning( + f"Scene {self.scene_id} has no camera information, using default camera" + ) + camera_ids.append( + CameraId( + logical_name="camera_front", + trajectory_idx=0, + sequence_id=self.scene_id, + unique_id="0@camera_front", + ) + ) + + return camera_ids, camera_calibrations + + def _is_static_object( + self, trajectory: Trajectory, velocity_threshold: float = 0.1 + ) -> bool: + """Determine if object is static (based on velocity)""" + if len(trajectory) < 2: + return True + + positions = trajectory.positions + timestamps = trajectory.timestamps_us.astype(np.float64) / 1e6 + + velocities = [] + for i in range(1, len(positions)): + dt_sec = timestamps[i] - timestamps[i - 1] + if dt_sec > 0: + displacement = np.linalg.norm(positions[i] - positions[i - 1]) + velocity = displacement / dt_sec + velocities.append(velocity) + + if len(velocities) == 0: + return True + + avg_velocity = np.mean(velocities) + return avg_velocity < velocity_threshold + + def _transform_map_points( + self, + points: np.ndarray, + translation_xy: np.ndarray, + first_traj_z: float, + ) -> np.ndarray: + """Transform map points to local coordinates (translation only). + + Args: + points: Nx3 array of map points + translation_xy: 2-element XY translation + first_traj_z: Z coordinate of first trajectory point + + Returns: + Transformed points + """ + if ( + points is None + or len(points) == 0 + or points.ndim != 2 + or points.shape[1] < 3 + ): + return points + + points_copy = points.copy() + + # Apply XY translation + points_copy[:, 0] = points_copy[:, 0] + translation_xy[0] + points_copy[:, 1] = points_copy[:, 1] + translation_xy[1] + + # Align Z coordinate to trajectory baseline + points_copy[:, 2] = points_copy[:, 2] + first_traj_z + + return points_copy + + @property + def traffic_objects(self) -> TrafficObjects: + """Load and return traffic objects""" + if self._traffic_objects is not None: + return self._traffic_objects + + if self._scene is None: + raise ValueError("Cannot load traffic_objects: scene is not set") + + # Get all agents + all_agents = self._scene.agents if self._scene.agents else [] + + # Identify ego agent + ego_agent = next((a for a in all_agents if a.name == "ego"), None) + if ego_agent is None and len(all_agents) > 0: + ego_agent = all_agents[0] + + traffic_dict = {} + for agent in all_agents: + # Skip ego agent + if agent.name == "ego" or agent == ego_agent: + continue + + # Extract trajectory + trajectory, _ = self._extract_agent_trajectory(agent) + + # Filter out empty trajectories or trajectories with only 1 data point + if trajectory is None or len(trajectory) < 2: + continue + + # Convert trajectory to local coordinates (NRE) - use rig's world_to_nre + # Explicit dependency: need world_to_nre from rig + self._ensure_rig_loaded() + + world_to_nre = self._rig.world_to_nre + translation = world_to_nre[:3, 3] + local_positions = trajectory.positions + translation + local_quat = trajectory.quaternions.copy() + trajectory = Trajectory( + timestamps=trajectory.timestamps_us.copy(), + positions=local_positions, + quaternions=local_quat, + ) + + # Smooth if needed + if self._smooth_trajectories: + try: + css = csaps.CubicSmoothingSpline( + trajectory.timestamps_us / 1e6, + trajectory.positions.T, + normalizedsmooth=True, + ) + filtered_positions = css(trajectory.timestamps_us / 1e6).T + max_error = np.max( + np.abs(filtered_positions - trajectory.positions) + ) + if max_error > 1.0: + logger.warning( + f"Max error in cubic spline approximation: {max_error:.6f} m for {agent.name=}" + ) + # Create new trajectory with smoothed positions + trajectory = Trajectory( + timestamps=trajectory.timestamps_us.copy(), + positions=filtered_positions.astype(np.float32), + quaternions=trajectory.quaternions.copy(), + ) + except Exception as e: + logger.warning(f"Failed to smooth trajectory: {e}") + + # Get AABB + aabb = AABB( + x=agent.extent.length, y=agent.extent.width, z=agent.extent.height + ) + + # Determine if static object + is_static = self._is_static_object(trajectory) + + # Get category label + label_class = getattr(agent.type, "name", "UNKNOWN") + + traffic_dict[agent.name] = TrafficObject( + track_id=agent.name, + aabb=aabb, + trajectory=trajectory, + is_static=is_static, + label_class=label_class, + ) + + self._traffic_objects = TrafficObjects(**traffic_dict) + return self._traffic_objects + + def _apply_coordinate_transform_to_map(self, vec_map: VectorMap) -> None: + """Apply world_to_nre coordinate transformation to map in-place. + + Note: world_to_nre only contains translation (no rotation), as the local + coordinate frame maintains the same orientation as the world frame (ENU). + + Args: + vec_map: VectorMap to transform + """ + world_to_nre = self._rig.world_to_nre + translation = world_to_nre[:3, 3] + translation_xy = translation[:2] + first_traj_z = ( + self.rig.trajectory.positions[0][2] if len(self.rig.trajectory) > 0 else 0.0 + ) + + logger.info( + f"Map coordinate transformation: " + f"translation_xy={translation_xy}, " + f"first_traj_z={first_traj_z:.2f}m" + ) + + # Transform all lane points + if vec_map.lanes is None: + return + + for lane in vec_map.lanes: + # Transform center (always exists) + lane.center.points = self._transform_map_points( + lane.center.points, + translation_xy, + first_traj_z, + ) + + # Transform left_edge (optional) + if lane.left_edge is not None and lane.left_edge.points is not None: + lane.left_edge.points = self._transform_map_points( + lane.left_edge.points, + translation_xy, + first_traj_z, + ) + + # Transform right_edge (optional) + if lane.right_edge is not None and lane.right_edge.points is not None: + lane.right_edge.points = self._transform_map_points( + lane.right_edge.points, + translation_xy, + first_traj_z, + ) + + def _fix_map_datatypes(self, vec_map: VectorMap) -> None: + """Fix lane connectivity data types (convert lists to sets). + + This is needed because some trajdata loaders may incorrectly create + these as lists instead of sets. + + Args: + vec_map: VectorMap to fix + """ + if vec_map.lanes is None: + return + + for lane in vec_map.lanes: + # Convert to set if they are lists (defensive, but based on observed issues) + if isinstance(lane.next_lanes, list): + lane.next_lanes = set(lane.next_lanes) + if isinstance(lane.prev_lanes, list): + lane.prev_lanes = set(lane.prev_lanes) + if isinstance(lane.adj_lanes_right, list): + lane.adj_lanes_right = set(lane.adj_lanes_right) + if isinstance(lane.adj_lanes_left, list): + lane.adj_lanes_left = set(lane.adj_lanes_left) + + def _verify_map_transformation(self, vec_map: VectorMap) -> None: + """Verify map coordinate transformation is correct. + + Args: + vec_map: Transformed VectorMap + """ + if vec_map.lanes is None or len(vec_map.lanes) == 0: + return + + # Get first lane and points (center always exists in RoadLane) + first_lane = vec_map.lanes[0] + first_map_point = first_lane.center.points[0, :3] + first_traj_point = self.rig.trajectory.positions[0] + + distance_xy = np.linalg.norm(first_map_point[:2]) + z_diff = abs(first_map_point[2] - first_traj_point[2]) + + logger.info( + f"Map transformation verification: " + f"first lane center: {first_map_point}, " + f"first trajectory: {first_traj_point}, " + f"XY distance: {distance_xy:.2f}m, " + f"Z difference: {z_diff:.2f}m" + ) + + if z_diff > 10.0: + logger.warning( + f"Map Z coordinate may not be correctly aligned. " + f"Map Z={first_map_point[2]:.2f}m, Traj Z={first_traj_point[2]:.2f}m" + ) + + def _load_map_from_scene_data(self) -> Optional[VectorMap]: + """Load map from scene.map_data (USDZ). + + Returns: + VectorMap if available, None otherwise + """ + if not hasattr(self._scene, "map_data") or self._scene.map_data is None: + return None + + logger.info(f"Loading map from scene.map_data for {self.scene_id}") + vec_map = copy.deepcopy(self._scene.map_data) + + # Ensure rig is loaded (need world_to_nre for transformation) + self._ensure_rig_loaded() + + # Apply coordinate transformation + self._apply_coordinate_transform_to_map(vec_map) + + # Fix datatypes and verify + self._fix_map_datatypes(vec_map) + self._verify_map_transformation(vec_map) + + logger.info("Successfully loaded map from scene.map_data") + return vec_map + + def _load_map_from_dataset_api(self) -> Optional[VectorMap]: + """Load map from dataset._map_api (datasets with map API). + + Returns: + VectorMap if available, None otherwise + """ + # Try to get map_api from either self._map_api or self._dataset + map_api = self._map_api + if map_api is None and self._dataset is not None: + map_api = getattr(self._dataset, "_map_api", None) + + if map_api is None: + logger.warning("Cannot load map: map_api not available") + return None + + # Get vector_map_params from dataset if available + vector_map_params = {} + if self._dataset is not None: + vector_map_params = getattr(self._dataset, "vector_map_params", {}) + + # Build map name + if not self._scene.location: + logger.warning(f"Scene {self.scene_id} has no location, cannot load map") + return None + + map_name = f"{self._scene.env_name}:{self._scene.location}" + + try: + vec_map = map_api.get_map(map_name, **vector_map_params) + if vec_map is None: + logger.debug(f"Scene {self.scene_id} (map: {map_name}) has no map data") + return None + + # Deep copy to avoid modifying shared cache + vec_map = copy.deepcopy(vec_map) + + # Ensure rig is loaded (need world_to_nre for transformation) + self._ensure_rig_loaded() + + # Apply coordinate transformation + self._apply_coordinate_transform_to_map(vec_map) + + # Finalize map + vec_map.__post_init__() + vec_map.compute_search_indices() + + # Fix datatypes and verify + self._fix_map_datatypes(vec_map) + self._verify_map_transformation(vec_map) + + logger.info(f"Successfully loaded map: {map_name}") + return vec_map + except Exception as e: + logger.error(f"Error loading map from dataset API: {e}", exc_info=True) + return None + + @property + def map(self) -> Optional[VectorMap]: + """Load and return VectorMap.""" + if self._map is not None: + return self._map + + if self._scene is None: + logger.warning("Cannot load map: scene is not set") + return None + + # Try scene.map_data first (simpler path) + self._map = self._load_map_from_scene_data() + if self._map is not None: + return self._map + + # Fallback to dataset._map_api + self._map = self._load_map_from_dataset_api() + return self._map + + @property + def metadata(self) -> Metadata: + """Create and return Metadata object""" + if self._metadata is not None: + return self._metadata + + # Extract metadata from scene + scene_id = self.scene_id + + # Ensure rig is loaded + rig = self.rig + + # Extract camera ID list from rig + camera_id_names = [] + if rig and rig.camera_ids: + camera_id_names = [camera_id.logical_name for camera_id in rig.camera_ids] + + # Calculate time range + if self._scene is not None: + dt = self._scene.dt + length_timesteps = self._scene.length_timesteps + base_timestamp_us = getattr(self, "_base_timestamp_us", 0.0) + time_range_start = float(base_timestamp_us) / 1e6 + time_range_end = ( + float(base_timestamp_us + length_timesteps * dt * 1e6) / 1e6 + ) + else: + time_range_start = float(rig.trajectory.time_range_us.start) / 1e6 + time_range_end = float(rig.trajectory.time_range_us.stop) / 1e6 + + # Generate deterministic IDs based on scene identifiers + # Create hash from scene_id and time range for reproducibility + hash_input = f"{scene_id}_{time_range_start}_{time_range_end}" + dataset_hash = hashlib.sha256(hash_input.encode()).hexdigest()[:16] + uuid_str = hashlib.sha256(f"{hash_input}_uuid".encode()).hexdigest()[:32] + + # Use fixed training date instead of datetime.now() for determinism + training_date = "trajdata-generated" + + # Create metadata + self._metadata = Metadata( + scene_id=scene_id, + version_string="trajdata_direct", + training_date=training_date, + dataset_hash=dataset_hash, + uuid=uuid_str, + is_resumable=False, + sensors=Metadata.Sensors( + camera_ids=camera_id_names, + lidar_ids=[], + ), + logger=Metadata.Logger(), + time_range=Metadata.TimeRange( + start=time_range_start, + end=time_range_end, + ), + ) + + return self._metadata diff --git a/src/wizard/configs/base_config.yaml b/src/wizard/configs/base_config.yaml index cdde7a13..e5114e77 100644 --- a/src/wizard/configs/base_config.yaml +++ b/src/wizard/configs/base_config.yaml @@ -30,8 +30,10 @@ defines: drivers: "${defines.filesystem}/drivers" sensordata: "${defines.filesystem}/nre-artifacts" trafficsim_map_cache: "${defines.filesystem}/trafficsim/unified_data_cache" + trajdata_cache: "${defines.filesystem}/trajdata_cache" # Trajdata unified cache location sensorsim_entrypoint: "/app/internal/scripts/pycena/runtime/pycena_run" + nre_cache_size: 4 # Default; topology configs override based on concurrent rollouts helper: scripts vscode: sources/remote-vscode-server physics_cache_size: 16 # should match or exceed concurrent scenes to avoid cache thrashing @@ -110,6 +112,11 @@ scenes: suites_csv: - "${repo-relative:'data/scenes/sim_suites.csv'}" + # Relative path within scene_cache to the sceneset directory + # Set automatically by wizard; override for manual runtime usage + # Example: "scenesets/abc123" or "all-usdzs/sample_set/25.07_release/Batch0001" + sceneset_path: "all-usdzs/sample_set/25.07_release/Batch0001" + # \/ services.* defines the individual components of the simulation. Each of them is deployed from an image so the layout of # each item is similar. For the services that pull in local code (e.g., controller, runtime, eval) we mount the repo-relative # `src/` directory into `/mnt/src` in the container and run from there. The virtual environment is reused from the container. @@ -212,7 +219,6 @@ services: gpus: null # uses no GPUs command: - "uv run python -m alpasim_runtime.simulate" - - "--usdz-glob=/mnt/nre-data/{sceneset}/**/*.usdz" - "--user-config=/mnt/log_dir/{runtime_config_name}" - "--network-config=/mnt/log_dir/generated-network-config.yaml" - "--log-dir=/mnt/log_dir" @@ -224,6 +230,40 @@ services: runtime: # nr_workers and endpoints.*.n_concurrent_rollouts are set by topology configs. + # Unified data source configuration using trajdata + # Common settings apply to all sources; source-specific config lives under sources: + data_source: + # Common configuration (applies to all data sources) + cache_location: "${defines.trajdata_cache}" # Shared trajdata cache + desired_dt: 0.1 # 10 Hz sampling rate for trajectories + incl_vector_map: true # Include vector map data (roads, lanes, etc.) + rebuild_cache: false # Set to true to force rebuild cache + rebuild_maps: false # Set to true to force rebuild maps + num_workers: 4 # Parallel workers for cache creation + + sources: + # USDZ data source configuration (NuRec artifacts) + usdz: + enabled: true + # Use wizard's dynamic sceneset path. The wizard creates a sceneset directory + # based on selected scenes and sets sceneset_path at runtime. + # For manual runtime usage without wizard, set sceneset_path in scenes config. + data_dir: "${scenes.scene_cache}/${scenes.sceneset_path}" + extra_params: + asset_base_path: null # Optional: Base path for MTGS rendering assets + + # NuPlan data source configuration (disabled by default) + nuplan: + enabled: false + data_dir: null + extra_params: + config_dir: null # Set to enable YAML batch preprocessing mode + num_timesteps_before: 30 # Timesteps before central token + num_timesteps_after: 80 # Timesteps after central token + + # Enable cubic spline smoothing for trajectories + smooth_trajectories: true + endpoints: # shut down the system after simulation is finished. without this flag the microservice servers # will remain on forever requring a manual interrupt (useful for debugging) @@ -232,8 +272,6 @@ runtime: sensorsim_cache_size: ${defines.nre_cache_size} enable_autoresume: false - # How many scenes (in particular maps) to cache in the worker local artifact cache. - artifact_cache_size: 10 simulation_config: n_sim_steps: 200 # how many steps to simulate in a rollout