Skip to content

Commit ca47d07

Browse files
Kovboclaude
andauthored
fix: lazy-import heavy dependencies in CLI to fix commands without backend extras (#571)
`art install-skills`, `art train-sft`, `art train-rl`, and `art migrate` all crash with ModuleNotFoundError when backend extras (torch, fastapi, uvicorn, numpy) are not installed, even though they don't need them. - Move fastapi, uvicorn, pydantic, LocalBackend, and other backend-only imports from top-level cli.py into the `run` command where they're actually used - Remove unused imports (Optional, Provider, TogetherDeploymentConfig, WandbDeploymentConfig) - Lazy-import calculate_step_std_dev in model.py to break the __init__ -> utils -> model -> numpy eager import chain Co-authored-by: Claude Opus 4.6 <[email protected]>
1 parent 8714773 commit ca47d07

2 files changed

Lines changed: 17 additions & 18 deletions

File tree

src/art/cli.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,10 @@
11
import json
22
from pathlib import Path
33
import socket
4-
from typing import Any, AsyncIterator, Optional
4+
from typing import Any, AsyncIterator
55

66
from dotenv import load_dotenv
7-
from fastapi import Body, FastAPI, Request
8-
from fastapi.responses import JSONResponse, StreamingResponse
9-
import pydantic
107
import typer
11-
import uvicorn
12-
13-
from . import dev
14-
from .errors import ARTError
15-
from .local import LocalBackend
16-
from .model import Model, TrainableModel
17-
from .trajectories import TrajectoryGroup
18-
from .types import TrainConfig
19-
from .utils.deployment import (
20-
Provider,
21-
TogetherDeploymentConfig,
22-
WandbDeploymentConfig,
23-
)
248

259
load_dotenv()
2610

@@ -302,6 +286,18 @@ def migrate(
302286
def run(host: str = "0.0.0.0", port: int = 7999) -> None:
303287
"""Run the ART CLI."""
304288

289+
from fastapi import Body, FastAPI, Request
290+
from fastapi.responses import JSONResponse, StreamingResponse
291+
import pydantic
292+
import uvicorn
293+
294+
from . import dev
295+
from .errors import ARTError
296+
from .local import LocalBackend
297+
from .model import Model, TrainableModel
298+
from .trajectories import TrajectoryGroup
299+
from .types import TrainConfig
300+
305301
# check if port is available
306302
def is_port_available(port: int) -> bool:
307303
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:

src/art/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from .costs import CostCalculator
1616
from .trajectories import Trajectory, TrajectoryGroup
1717
from .types import TrainConfig, TrainSFTConfig
18-
from .utils.old_benchmarking.calculate_step_metrics import calculate_step_std_dev
1918
from .utils.trajectory_logging import write_trajectory_groups_parquet
2019

2120
if TYPE_CHECKING:
@@ -638,6 +637,10 @@ def _add_costs(metrics_dict: dict[str, float | int | bool]) -> None:
638637
averages[f"group_metric_{metric}"] = sum(values) / len(values)
639638

640639
# Calculate average standard deviation of rewards within groups
640+
from .utils.old_benchmarking.calculate_step_metrics import (
641+
calculate_step_std_dev,
642+
)
643+
641644
averages["reward_std_dev"] = calculate_step_std_dev(trajectory_groups)
642645

643646
# Merge in any additional metrics passed directly

0 commit comments

Comments
 (0)