diff --git a/mani_skill/trajectory/replay_trajectory.py b/mani_skill/trajectory/replay_trajectory.py index 4750cd810b..84ca072675 100644 --- a/mani_skill/trajectory/replay_trajectory.py +++ b/mani_skill/trajectory/replay_trajectory.py @@ -94,6 +94,14 @@ class ReplayResult: def sanity_check_and_format_seed(episode): """sanity checks the trajectory seed aligns with the episode seed. reformats the reset kwargs seed if missing or formatted wrong""" + # normalize episode_seed to a scalar int (may be stored as numpy array or list) + episode_seed = episode["episode_seed"] + if isinstance(episode_seed, np.ndarray): + episode_seed = episode_seed.item() if episode_seed.ndim == 0 else episode_seed[0] + elif isinstance(episode_seed, (list, tuple)): + episode_seed = episode_seed[0] + episode["episode_seed"] = int(episode_seed) + if "seed" in episode["reset_kwargs"]: if isinstance(episode["reset_kwargs"]["seed"], list):