Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def parameters():

@pytest.fixture
def maketgg(monkeypatch, parameters):
def inner(target_tasks=None, kinds=None, params=None, enable_verifications=True):
def inner(target_tasks=None, kinds=None, params=None, enable_verifications=True, cached_params={}, cached_graphs={}):
kinds = kinds or [("_fake", [])]
params = params or {}
FakeKind.loaded_kinds = []
Expand All @@ -196,7 +196,7 @@ def target_tasks_method(full_task_graph, parameters, graph_config):
monkeypatch.setattr(generator, "load_graph_config", fake_load_graph_config)

return WithFakeKind(
"/root", parameters, enable_verifications=enable_verifications
"/root", parameters, enable_verifications=enable_verifications, cached_params=cached_params, cached_graphs=cached_graphs
)

return inner
Expand Down
35 changes: 34 additions & 1 deletion src/taskgraph/decision.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
# file, You can obtain one at http://mozilla.org/MPL/2.0/.


from functools import cache
import logging
import os
import pathlib
import shutil
import sys
import time
from pathlib import Path

Expand Down Expand Up @@ -64,7 +66,7 @@ def full_task_graph_to_runnable_tasks(full_task_json):
return runnable_tasks


def taskgraph_decision(options, parameters=None):
def taskgraph_decision(options, parameters=None, cache_dir=None):
"""
Run the decision task. This function implements `mach taskgraph decision`,
and is responsible for
Expand Down Expand Up @@ -104,13 +106,43 @@ def taskgraph_decision(options, parameters=None):

decision_task_id = os.environ["TASK_ID"]

cached_results = {}
if cache_dir:
# TODO: don't load irrelevant files
# TODO: ensure that we only use cached results if we have cached results
# for all prior steps
for name in os.listdir(cache_dir):
if name == "graph_config":
pass
elif name == "parameters":
cached_results[name] = Parameters(**load_yaml(cache_dir, name))
elif name == "kind_graph":
pass
elif name == "full_task_set":
# TODO: we should be able to get this from `full_task_graph`; don't need to load it separately
cached_results[name] = TaskGraph.from_json(json.load(open(os.path.join(cache_dir, name))))[0]
elif name == "full_task_graph":
cached_results[name] = TaskGraph.from_json(json.load(open(os.path.join(cache_dir, name))))[1]
elif name == "target_task_set":
# derivable from target_task_graph?
cached_results[name] = TaskGraph.from_json(json.load(open(os.path.join(cache_dir, name))))
elif name == "target_task_graph":
cached_results[name] = TaskGraph.from_json(json.load(open(os.path.join(cache_dir, name))))
elif name == "optimized_task_graph":
cached_results[name] = TaskGraph.from_json(json.load(open(os.path.join(cache_dir, name))))
elif name == "label_to_taskid":
cached_results[name] = json.load(open(os.path.join(cache_dir, name)))
elif name == "morphed_task_graph":
cached_results[name] = TaskGraph.from_json(json.load(open(os.path.join(cache_dir, name))))

# create a TaskGraphGenerator instance
tgg = TaskGraphGenerator(
root_dir=options.get("root"),
parameters=parameters,
decision_task_id=decision_task_id,
write_artifacts=True,
enable_verifications=options.get("verify", True),
initial_results=cached_results,
)

# write out the parameters used to generate this graph
Expand All @@ -126,6 +158,7 @@ def taskgraph_decision(options, parameters=None):
full_task_json = tgg.full_task_graph.to_json()
write_artifact("full-task-graph.json", full_task_json)

sys.exit(1)
# write out the public/runnable-jobs.json file
write_artifact(
"runnable-jobs.json", full_task_graph_to_runnable_tasks(full_task_json)
Expand Down
28 changes: 27 additions & 1 deletion src/taskgraph/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ def __init__(
decision_task_id: str = "DECISION-TASK",
write_artifacts: bool = False,
enable_verifications: bool = True,
cached_params: dict = {},
cached_graphs: dict[str, TaskGraph] = {},
):
"""
@param root_dir: root directory containing the Taskgraph config.yml file
Expand All @@ -162,6 +164,28 @@ def __init__(
# start the generator
self._run = self._run() # type: ignore
self._run_results = {}
# TODO: should we require all earlier results cached to cache a later result?
# this would mean that we would be required to cache and load graph_config, kind_graph
# the argument against it is that we still do things like, eg: load_kinds even
# when we use a cached full task graph
# it probably makes sense to have this requirement strictly for graphs, at least?
if cached_params:
self._run_results["parameters"] = cached_params
for k, v in cached_graphs.items():
if k == "full_task_graph":
# full task set is always the full task graph with the edges removed
self._run_results["full_task_set"] = TaskGraph(v.tasks, Graph(frozenset(v.tasks), frozenset()))
self._run_results["full_task_graph"] = v
elif k == "target_task_graph":
# target task set is always the full task graph with the edges removed
self._run_results["target_task_set"] = TaskGraph(v.tasks, Graph(frozenset(v.tasks), frozenset()))
self._run_results["target_task_graph"] = v
elif k == "optimized_task_graph":
self._run_results["optimized_task_graph"] = v
elif k == "morphed_task_graph":
self._run_results["morphed_task_graph"] = v
else:
raise ValueError(f"cached graph {k} not supported")

@property
def parameters(self):
Expand Down Expand Up @@ -564,7 +588,9 @@ def _run_until(self, name):
k, v = next(self._run) # type: ignore
except StopIteration:
raise AttributeError(f"No such run result {name}")
self._run_results[k] = v
# might have been in `cached_results`
if k not in self._run_results:
self._run_results[k] = v
return self._run_results[name]

def verify(self, name, *args, **kwargs):
Expand Down
5 changes: 4 additions & 1 deletion src/taskgraph/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,10 @@ def load_task(args):
def decision(options):
from taskgraph.decision import taskgraph_decision # noqa: PLC0415

taskgraph_decision(options)
# TODO: add parameter that instructs us to go fetch cached artifacts from
# elsewhere, eg: an index
cache_dir = "/home/bhearsum/tmp/2026-01-07/tgcache"
taskgraph_decision(options, cache_dir=cache_dir)


@command("actions", help="Print the rendered actions.json")
Expand Down
38 changes: 37 additions & 1 deletion test/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
from concurrent.futures import ProcessPoolExecutor

import pytest
from pytest_taskgraph import WithFakeKind, fake_load_graph_config
from pytest_taskgraph import WithFakeKind, fake_load_graph_config, make_task

from taskgraph import generator, graph
from taskgraph.generator import Kind, load_tasks_for_kind, load_tasks_for_kinds
from taskgraph.loader.default import loader as default_loader
from taskgraph.parameters import Parameters
from taskgraph.taskgraph import TaskGraph

linuxonly = pytest.mark.skipif(
platform.system() != "Linux",
Expand Down Expand Up @@ -386,3 +388,37 @@ def test_kind_graph_with_target_kinds(maketgg):
# _fake3 and _other should not be included
assert "_fake3" not in kind_graph.nodes
assert "_other" not in kind_graph.nodes


def test_cached_results(maketgg):
"""Initial results are returned instead of regenerating parts of the taskgraph"""
fake1 = make_task("fake1", kind="_fake1")
fake2 = make_task("fake2", kind="_fake2", dependencies={"fake1": "fake1"})
fake3 = make_task("fake3", kind="_fake3", dependencies={"fake1": "fake1", "fake2": "fake2"})
tasks, full_task_graph = TaskGraph.from_json({
"fake1": fake1.to_json(),
"fake2": fake2.to_json(),
"fake3": fake3.to_json(),
})
tgg = maketgg(
target_tasks=["fake1", "fake2", "fake3"],
kinds=[
("_fake3", {"kind-dependencies": ["_fake2", "_fake1"]}),
("_fake2", {"kind-dependencies": ["_fake1"]}),
("_fake1", {"kind-dependencies": []}),
],
cached_graphs={"full_task_graph": full_task_graph},
)
assert tgg.full_task_set == TaskGraph(tasks, graph.Graph(frozenset(tasks), frozenset()))
assert tgg.full_task_graph == full_task_graph
# ensure we get different results when not using cached_results
tgg2 = maketgg(
target_tasks=["fake1", "fake2", "fake3"],
kinds=[
("_fake3", {"kind-dependencies": ["_fake2", "_fake1"]}),
("_fake2", {"kind-dependencies": ["_fake1"]}),
("_fake1", {"kind-dependencies": []}),
],
)
assert tgg.full_task_graph != tgg2.full_task_graph
tgg.target_task_graph