Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/changelogs/v0.0.25.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
### Features

* **New neural models**: Added 3 new auto neural models powered by `neuralforecast`: `AutoNBEATS`, `AutoDeepAR`, and `AutoPatchTST`. All support `quantiles` for probabilistic forecasts trained with `MQLoss` and follow the same interface as the existing `AutoNHITS` and `AutoTFT`.

```python
import pandas as pd
from timecopilot.models.neural import AutoDeepAR, AutoNBEATS, AutoPatchTST

df = pd.read_csv(
"https://timecopilot.s3.amazonaws.com/public/data/air_passengers.csv",
parse_dates=["ds"],
)

model = AutoNBEATS()
fcst_df = model.forecast(df, h=12, quantiles=[0.1, 0.5, 0.9])
```

* **New ML models**: Added 7 new auto ML models powered by `mlforecast`'s hyperparameter optimization: `AutoLinearRegression`, `AutoXGBoost`, `AutoRidge`, `AutoLasso`, `AutoElasticNet`, `AutoRandomForest`, and `AutoCatboost`. All models support `quantiles` for probabilistic forecasts via conformal prediction and follow the same interface as the existing `AutoLGBM`.

```python
Expand Down
38 changes: 37 additions & 1 deletion tests/models/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@
AutoRidge,
AutoXGBoost,
)
from timecopilot.models.neural import AutoNHITS, AutoTFT
from timecopilot.models.neural import (
AutoDeepAR,
AutoNBEATS,
AutoNHITS,
AutoPatchTST,
AutoTFT,
)
from timecopilot.models.prophet import Prophet
from timecopilot.models.stats import (
ADIDA,
Expand Down Expand Up @@ -73,6 +79,36 @@ def disable_mps_session(monkeypatch):
hidden_size=8,
),
),
AutoNBEATS(
num_samples=2,
config=dict(
max_steps=1,
val_check_steps=1,
input_size=12,
n_harmonics=1,
n_polynomials=1,
),
),
AutoDeepAR(
num_samples=2,
config=dict(
max_steps=1,
val_check_steps=1,
input_size=12,
lstm_n_layers=1,
lstm_hidden_size=8,
),
),
AutoPatchTST(
num_samples=2,
config=dict(
max_steps=1,
val_check_steps=1,
input_size=12,
hidden_size=8,
n_heads=2,
),
),
AutoARIMA(),
SeasonalNaive(),
ZeroModel(),
Expand Down
20 changes: 18 additions & 2 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,14 @@ def test_correct_forecast_dates(model, freq, h):
"AutoRandomForest",
"AutoCatboost",
}
if model.alias in _ml_auto_aliases | {"AutoNHITS", "AutoTFT"}:
_neural_auto_aliases = {
"AutoNHITS",
"AutoTFT",
"AutoNBEATS",
"AutoDeepAR",
"AutoPatchTST",
}
if model.alias in _ml_auto_aliases | _neural_auto_aliases:
# These auto ML and neural models require a longer minimum series length
sizes_per_freq = {
freq: 1_000 for freq in ["10S", "10T", "15T", "5T", "H", "Q-DEC"]
Expand Down Expand Up @@ -230,7 +237,13 @@ def test_using_quantiles(model):
elif "moe" in model.alias.lower():
# MoE is a bit more lenient with the monotonicity condition
assert fcst_df[c1].le(fcst_df[c2]).mean() >= 0.5
elif model.alias in ["AutoNHITS", "AutoTFT"]:
elif model.alias in [
"AutoNHITS",
"AutoTFT",
"AutoNBEATS",
"AutoDeepAR",
"AutoPatchTST",
]:
# test config uses max_steps=1, so quantile ordering is not guaranteed
continue
else:
Expand All @@ -252,6 +265,9 @@ def test_using_level(model):
"AutoCatboost",
"AutoNHITS",
"AutoTFT",
"AutoNBEATS",
"AutoDeepAR",
"AutoPatchTST",
}
if model.alias in _level_unsupported:
# these models only support quantiles, not level
Expand Down
4 changes: 4 additions & 0 deletions timecopilot/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
AutoRidge,
AutoXGBoost,
)
from .neural import AutoDeepAR, AutoNBEATS, AutoPatchTST
from .stats import (
ADIDA,
IMAPA,
Expand All @@ -25,8 +26,11 @@
__all__ = [
"ADIDA",
"AutoCatboost",
"AutoDeepAR",
"AutoElasticNet",
"IMAPA",
"AutoNBEATS",
"AutoPatchTST",
"AutoARIMA",
"AutoCES",
"AutoETS",
Expand Down
Loading