Skip to content

Commit ebfdcc4

Browse files
committed
feat(cloud-agent): user-authored PRs
chore: cleaner
1 parent d1270c1 commit ebfdcc4

File tree

9 files changed

+500
-41
lines changed

9 files changed

+500
-41
lines changed

products/tasks/backend/api.py

Lines changed: 77 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
from .services.connection_token import create_sandbox_connection_token
6464
from .stream.redis_stream import TaskRunRedisStream, TaskRunStreamError, get_task_run_stream_key
6565
from .temporal.client import execute_posthog_code_agent_relay_workflow, execute_task_processing_workflow
66+
from .temporal.process_task.utils import PR_AUTHORSHIP_MODE_USER, cache_github_user_token
6667

6768
logger = logging.getLogger(__name__)
6869

@@ -247,6 +248,10 @@ def run(self, request, pk=None, **kwargs):
247248
resume_from_run_id = request.validated_data.get("resume_from_run_id")
248249
pending_user_message = request.validated_data.get("pending_user_message")
249250
sandbox_environment_id = request.validated_data.get("sandbox_environment_id")
251+
pr_authorship_mode = request.validated_data.get("pr_authorship_mode")
252+
run_source = request.validated_data.get("run_source")
253+
signal_report_id = request.validated_data.get("signal_report_id")
254+
github_user_token = request.validated_data.get("github_user_token")
250255

251256
extra_state = None
252257
if resume_from_run_id:
@@ -269,6 +274,32 @@ def run(self, request, pk=None, **kwargs):
269274
if prev_sandbox_env_id and sandbox_environment_id is None:
270275
sandbox_environment_id = prev_sandbox_env_id
271276

277+
previous_state = previous_run.state or {}
278+
if pr_authorship_mode is None:
279+
pr_authorship_mode = previous_state.get("pr_authorship_mode")
280+
if run_source is None:
281+
run_source = previous_state.get("run_source")
282+
if signal_report_id is None:
283+
signal_report_id = previous_state.get("signal_report_id")
284+
if branch is None:
285+
previous_base_branch = previous_state.get("pr_base_branch")
286+
if isinstance(previous_base_branch, str):
287+
branch = previous_base_branch
288+
289+
for key, value in {
290+
"pr_base_branch": branch,
291+
"pr_authorship_mode": pr_authorship_mode,
292+
"run_source": run_source,
293+
"signal_report_id": signal_report_id,
294+
}.items():
295+
if value is not None:
296+
extra_state = extra_state or {}
297+
extra_state[key] = value
298+
299+
# Only require a user token when the task has a repo (no-repo cloud runs skip GitHub operations)
300+
if pr_authorship_mode == PR_AUTHORSHIP_MODE_USER and task.repository and not github_user_token:
301+
return Response({"detail": "github_user_token is required for user-authored cloud runs"}, status=400)
302+
272303
if sandbox_environment_id is not None:
273304
sandbox_environment = SandboxEnvironment.objects.filter(id=sandbox_environment_id, team=task.team).first()
274305
if not sandbox_environment:
@@ -291,6 +322,9 @@ def run(self, request, pk=None, **kwargs):
291322

292323
task_run = task.create_run(mode=mode, branch=branch, extra_state=extra_state)
293324

325+
if github_user_token and pr_authorship_mode == PR_AUTHORSHIP_MODE_USER:
326+
cache_github_user_token(str(task_run.id), github_user_token)
327+
294328
logger.info(f"Triggering workflow for task {task.id}, run {task_run.id}")
295329

296330
self._trigger_workflow(task, task_run)
@@ -379,34 +413,50 @@ def update(self, request, *args, **kwargs):
379413
)
380414
def partial_update(self, request, *args, **kwargs):
381415
task_run = cast(TaskRun, self.get_object())
382-
old_status = task_run.status
383-
384-
# Update fields from validated data
385-
for key, value in request.validated_data.items():
386-
setattr(task_run, key, value)
416+
has_output_merge = "output" in request.validated_data and isinstance(request.validated_data["output"], dict)
387417

388-
new_status = request.validated_data.get("status")
389-
terminal_statuses = [
390-
TaskRun.Status.COMPLETED,
391-
TaskRun.Status.FAILED,
392-
TaskRun.Status.CANCELLED,
393-
]
394-
395-
# Auto-set completed_at if status is completed or failed
396-
if new_status in terminal_statuses:
397-
if not task_run.completed_at:
398-
task_run.completed_at = timezone.now()
399-
400-
# Signal Temporal workflow if status changed to terminal state
401-
if old_status != new_status:
402-
self._signal_workflow_completion(
403-
task_run,
404-
new_status,
405-
request.validated_data.get("error_message"),
406-
)
407-
408-
task_run.save()
409-
self._post_slack_update_for_pr(task_run)
418+
with transaction.atomic():
419+
# Re-fetch with row lock when merging output to prevent concurrent
420+
# PATCHes (e.g. branch sync + PR URL) from clobbering each other.
421+
if has_output_merge:
422+
task_run = TaskRun.objects.select_for_update().get(pk=task_run.pk)
423+
424+
old_status = task_run.status
425+
old_pr_url = (task_run.output or {}).get("pr_url") if isinstance(task_run.output, dict) else None
426+
427+
# Update fields from validated data
428+
for key, value in request.validated_data.items():
429+
if key == "output" and isinstance(value, dict):
430+
existing_output = task_run.output if isinstance(task_run.output, dict) else {}
431+
setattr(task_run, key, {**existing_output, **value})
432+
continue
433+
setattr(task_run, key, value)
434+
435+
new_status = request.validated_data.get("status")
436+
terminal_statuses = [
437+
TaskRun.Status.COMPLETED,
438+
TaskRun.Status.FAILED,
439+
TaskRun.Status.CANCELLED,
440+
]
441+
442+
# Auto-set completed_at if status is completed or failed
443+
if new_status in terminal_statuses:
444+
if not task_run.completed_at:
445+
task_run.completed_at = timezone.now()
446+
447+
task_run.save()
448+
449+
# Signal Temporal and post Slack updates after commit to avoid
450+
# holding the row lock during external calls.
451+
if new_status in terminal_statuses and old_status != new_status:
452+
self._signal_workflow_completion(
453+
task_run,
454+
new_status,
455+
request.validated_data.get("error_message"),
456+
)
457+
new_pr_url = (task_run.output or {}).get("pr_url") if isinstance(task_run.output, dict) else None
458+
if new_pr_url and new_pr_url != old_pr_url:
459+
self._post_slack_update_for_pr(task_run)
410460

411461
return Response(TaskRunDetailSerializer(task_run, context=self.get_serializer_context()).data)
412462

products/tasks/backend/sandbox/images/Dockerfile.sandbox-base

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,11 @@ RUN chmod +x /tmp/install-skills.sh && \
101101
/tmp/install-skills.sh /tmp/skills && \
102102
rm -rf /tmp/install-skills.sh /tmp/skills
103103

104-
# Configure git for commits - TODO: Use user's email and name
105-
RUN git config --global user.email "array@posthog.com" && \
106-
git config --global user.name "Array"
104+
# Default git identity for bot-authored commits.
105+
# Runs configured for user-authored pull requests receive
106+
# GIT_AUTHOR_*/GIT_COMMITTER_* env vars that override these defaults.
107+
RUN git config --global user.email "code@posthog.com" && \
108+
git config --global user.name "PostHog Code"
107109

108110
# This is required for the Claude Code SDK to allow --dangerously-skip-permissions as the root user
109111
ENV IS_SANDBOX=1

products/tasks/backend/sandbox/images/Dockerfile.sandbox-notebook

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,11 @@ RUN cd /scripts && \
8282
-o runAgent.mjs && \
8383
chmod +x runAgent.mjs
8484

85-
# Configure git for commits - TODO: Use user's email and name
86-
RUN git config --global user.email "array@posthog.com" && \
87-
git config --global user.name "Array"
85+
# Default git identity for bot-authored commits.
86+
# Runs configured for user-authored pull requests receive
87+
# GIT_AUTHOR_*/GIT_COMMITTER_* env vars that override these defaults.
88+
RUN git config --global user.email "code@posthog.com" && \
89+
git config --global user.name "PostHog Code"
8890

8991
# This is required for the Claude Code SDK to allow --dangerously-skip-permissions as the root user
9092
ENV IS_SANDBOX=1

products/tasks/backend/serializers.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@
1010

1111
from .models import SandboxEnvironment, Task, TaskRun
1212
from .services.title_generator import generate_task_title
13+
from .temporal.process_task.utils import (
14+
PR_AUTHORSHIP_MODE_BOT,
15+
PR_AUTHORSHIP_MODE_USER,
16+
RUN_SOURCE_MANUAL,
17+
RUN_SOURCE_SIGNAL_REPORT,
18+
)
1319

1420
PRESIGNED_URL_CACHE_TTL = 55 * 60 # 55 minutes (less than 1 hour URL expiry)
1521

@@ -351,6 +357,9 @@ class ConnectionTokenResponseSerializer(serializers.Serializer):
351357
class TaskRunCreateRequestSerializer(serializers.Serializer):
352358
"""Request body for creating a new task run"""
353359

360+
PR_AUTHORSHIP_MODE_CHOICES = [PR_AUTHORSHIP_MODE_USER, PR_AUTHORSHIP_MODE_BOT]
361+
RUN_SOURCE_CHOICES = [RUN_SOURCE_MANUAL, RUN_SOURCE_SIGNAL_REPORT]
362+
354363
mode = serializers.ChoiceField(
355364
choices=["interactive", "background"],
356365
required=False,
@@ -380,6 +389,31 @@ class TaskRunCreateRequestSerializer(serializers.Serializer):
380389
default=None,
381390
help_text="Optional sandbox environment to apply for this cloud run.",
382391
)
392+
pr_authorship_mode = serializers.ChoiceField(
393+
choices=PR_AUTHORSHIP_MODE_CHOICES,
394+
required=False,
395+
default=None,
396+
help_text="Whether pull requests for this run should be authored by the user or the bot.",
397+
)
398+
run_source = serializers.ChoiceField(
399+
choices=RUN_SOURCE_CHOICES,
400+
required=False,
401+
default=None,
402+
help_text="High-level source that triggered this run, used to distinguish manual and signal-based cloud runs.",
403+
)
404+
signal_report_id = serializers.CharField(
405+
required=False,
406+
default=None,
407+
allow_blank=False,
408+
help_text="Optional signal report identifier when this run was started from Inbox.",
409+
)
410+
github_user_token = serializers.CharField(
411+
required=False,
412+
default=None,
413+
allow_blank=False,
414+
write_only=True,
415+
help_text="Ephemeral GitHub user token from PostHog Code for user-authored cloud pull requests.",
416+
)
383417

384418

385419
class TaskRunCommandRequestSerializer(serializers.Serializer):

products/tasks/backend/temporal/process_task/activities/create_sandbox_from_snapshot.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from products.tasks.backend.temporal.observability import emit_agent_log, log_activity_execution
1818
from products.tasks.backend.temporal.process_task.utils import (
1919
build_sandbox_environment_variables,
20-
get_github_token,
20+
get_git_identity_env_vars,
21+
get_sandbox_github_token,
2122
get_sandbox_name_for_task,
2223
)
2324

@@ -70,7 +71,14 @@ def create_sandbox_from_snapshot(input: CreateSandboxFromSnapshotInput) -> Creat
7071
github_token = ""
7172
if ctx.github_integration_id is not None:
7273
try:
73-
github_token = get_github_token(ctx.github_integration_id) or ""
74+
github_token = (
75+
get_sandbox_github_token(
76+
ctx.github_integration_id,
77+
run_id=ctx.run_id,
78+
state=ctx.state,
79+
)
80+
or ""
81+
)
7482
except Exception as e:
7583
raise GitHubAuthenticationError(
7684
f"Failed to get GitHub token for integration {ctx.github_integration_id}",
@@ -99,6 +107,7 @@ def create_sandbox_from_snapshot(input: CreateSandboxFromSnapshotInput) -> Creat
99107
team_id=ctx.team_id,
100108
sandbox_environment=sandbox_env,
101109
)
110+
environment_variables.update(get_git_identity_env_vars(task, ctx.state))
102111

103112
config = SandboxConfig(
104113
name=get_sandbox_name_for_task(ctx.task_id),

products/tasks/backend/temporal/process_task/activities/get_sandbox_for_repository.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
from products.tasks.backend.temporal.oauth import create_oauth_access_token
1717
from products.tasks.backend.temporal.observability import emit_agent_log, log_activity_execution
1818
from products.tasks.backend.temporal.process_task.utils import (
19-
get_github_token,
19+
get_git_identity_env_vars,
2020
get_sandbox_api_url,
21+
get_sandbox_github_token,
2122
get_sandbox_name_for_task,
2223
)
2324

@@ -31,6 +32,7 @@
3132
"POSTHOG_PROJECT_ID",
3233
"JWT_PUBLIC_KEY",
3334
"GITHUB_TOKEN",
35+
"GH_TOKEN",
3436
"LLM_GATEWAY_URL",
3537
"POSTHOG_RESUME_RUN_ID",
3638
}
@@ -85,7 +87,14 @@ def get_sandbox_for_repository(input: GetSandboxForRepositoryInput) -> GetSandbo
8587
if has_repo:
8688
assert github_integration_id is not None
8789
try:
88-
github_token = get_github_token(github_integration_id) or ""
90+
github_token = (
91+
get_sandbox_github_token(
92+
github_integration_id,
93+
run_id=ctx.run_id,
94+
state=ctx.state,
95+
)
96+
or ""
97+
)
8998
except Exception as e:
9099
raise GitHubAuthenticationError(
91100
f"Failed to get GitHub token for integration {github_integration_id}",
@@ -138,10 +147,13 @@ def get_sandbox_for_repository(input: GetSandboxForRepositoryInput) -> GetSandbo
138147

139148
if github_token:
140149
environment_variables["GITHUB_TOKEN"] = github_token
150+
environment_variables["GH_TOKEN"] = github_token
141151

142152
if settings.SANDBOX_LLM_GATEWAY_URL:
143153
environment_variables["LLM_GATEWAY_URL"] = settings.SANDBOX_LLM_GATEWAY_URL
144154

155+
environment_variables.update(get_git_identity_env_vars(task, ctx.state))
156+
145157
# Set resume run ID independently of snapshot so conversation history
146158
# can be rebuilt from logs even when the filesystem snapshot has expired.
147159
resume_from_run_id = (ctx.state or {}).get("resume_from_run_id", "")

0 commit comments

Comments
 (0)