Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion examples/experimental/config/eval_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ test_dataset_size: 10_000 # Number of test scenarios to evaluate on

# Environment settings
train_dir: data/processed/training
test_dir: data/processed/validation
test_dir: data/processed/validation
file_prefix: null
Comment thread
EllingtonKirby marked this conversation as resolved.
Outdated

num_worlds: 50 # Number of parallel environments for evaluation
max_controlled_agents: 64 # Maximum number of agents controlled by the model.
Expand All @@ -26,6 +27,7 @@ obs_radius: 50.0 # Visibility radius of the agents
init_roadgraph: False
render_3d: True

action_type: "discrete"
# Number of discretizations in the action space
# Note: Make sure that this equals the discretizations that the policy
# has been trained with
Expand Down
37 changes: 37 additions & 0 deletions examples/experimental/config/expert_replay_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
res_path: examples/experimental/dataframes # Store dataframes here
test_dataset_size: 300 # Number of test scenarios to evaluate on

# Environment settings
train_dir: data/processed/training
test_dir: data/processed/validation
file_prefix: nuplan

num_worlds: 100 # Number of parallel environments for evaluation
max_controlled_agents: 64 # Maximum number of agents controlled by the model.
ego_state: true
road_map_obs: true
partner_obs: true
norm_obs: true
remove_non_vehicles: true # If false, all agents are included (vehicles, pedestrians, cyclists)
lidar_obs: false # NOTE: Setting this to true currently turns of the other observation types
reward_type: "weighted_combination"
collision_weight: -0.75
off_road_weight: -0.75
goal_achieved_weight: 1.0
dynamics_model: "delta_local"
collision_behavior: "ignore" # Options: "remove", "stop"
dist_to_goal_threshold: 2.0
polyline_reduction_threshold: 0.1 # Rate at which to sample points from the polyline (0 is use all closest points, 1 maximum sparsity), needs to be balanced with kMaxAgentMapObservationsCount
sampling_seed: 42 # If given, the set of scenes to sample from will be deterministic, if None, the set of scenes will be random
obs_radius: 50.0 # Visibility radius of the agents
init_roadgraph: False
render_3d: True

action_type: "continuous"
# Number of discretizations in the action space
# Note: Make sure that this equals the discretizations that the policy
# has been trained with
action_space_steer_disc: 13
action_space_accel_disc: 7

device: "cuda" # Options: "cpu", "cuda"
4 changes: 2 additions & 2 deletions examples/experimental/config/model_config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
models_path: examples/experimental/models

models:
- name: model_PPO____R_10000__02_27_09_19_10_626_003200
train_dataset_size: 10_000
- name: expert_replay
train_dataset_size: 1000
Comment thread
EllingtonKirby marked this conversation as resolved.
Outdated
wandb: null
trained_on: null
33 changes: 21 additions & 12 deletions examples/experimental/eval_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,18 @@ def __call__(self, obs, deterministic=False):
)
return random_action, None, None, None

class ExpertReplayPolicy:
def __init__(self):
pass

def load_policy(path_to_cpt, model_name, device, env=None):
"""Load a policy from a given path."""

# Load the saved checkpoint
if model_name == "random_baseline":
return RandomPolicy(env.action_space.n)

if model_name == "expert_replay":
return ExpertReplayPolicy()
else: # Load a trained model
saved_cpt = torch.load(
f=f"{path_to_cpt}/{model_name}.pt",
Expand Down Expand Up @@ -110,22 +114,26 @@ def rollout(

control_mask = env.cont_agent_mask
live_agent_mask = control_mask.clone()

expert_actions, _, _, _ = env.get_expert_actions()

for time_step in range(episode_len):

print(f't: {time_step}')

# Get actions for active agents
if live_agent_mask.any():
action, _, _, _ = policy(
next_obs[live_agent_mask], deterministic=deterministic
)

# Insert actions into a template
action_template = torch.zeros(
(num_worlds, max_agent_count), dtype=torch.int64, device=device
)
action_template[live_agent_mask] = action.to(device)
if isinstance(policy, ExpertReplayPolicy):
action_template = expert_actions[:, :, time_step, :]
else:
action, _, _, _ = policy(
next_obs[live_agent_mask], deterministic=deterministic
)

# Insert actions into a template
action_template = torch.zeros(
(num_worlds, max_agent_count), dtype=torch.int64, device=device
)
action_template[live_agent_mask] = action.to(device)

# Step the environment
env.step_dynamics(action_template)
Expand Down Expand Up @@ -274,7 +282,8 @@ def make_env(config, train_loader, render_3d=False):
data_loader=train_loader,
max_cont_agents=config.max_controlled_agents,
device=config.device,
render_config=render_config
render_config=render_config,
action_type=config.action_type,
)

return env
Expand Down
2 changes: 2 additions & 0 deletions examples/experimental/get_model_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def set_seed(seed: int):
else 1000,
sample_with_replacement=False,
shuffle=False,
file_prefix=eval_config.file_predix
)

test_loader = SceneDataLoader(
Expand All @@ -74,6 +75,7 @@ def set_seed(seed: int):
else 1000,
sample_with_replacement=False,
shuffle=True,
file_prefix=eval_config.file_predix
)

# Rollouts
Expand Down
11 changes: 6 additions & 5 deletions gpudrive/visualize/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection, Line3DCollection
from matplotlib.colors import ListedColormap
from jaxlib.xla_extension import ArrayImpl
import numpy as np
import madrona_gpudrive
from gpudrive.visualize import utils
Expand Down Expand Up @@ -78,10 +77,12 @@ def initialize_static_scenario_data(self, controlled_agent_mask):
)
self.controlled_agent_mask = controlled_agent_mask

if isinstance(controlled_agent_mask, ArrayImpl):
self.controlled_agent_mask = torch.from_numpy(
np.array(controlled_agent_mask)
)
if self.backend == "jax":
from jaxlib.xla_extension import ArrayImpl
if isinstance(controlled_agent_mask, ArrayImpl):
self.controlled_agent_mask = torch.from_numpy(
np.array(controlled_agent_mask)
)

self.controlled_agent_mask = self.controlled_agent_mask.to(self.device)

Expand Down
Loading