6363from .services .connection_token import create_sandbox_connection_token
6464from .stream .redis_stream import TaskRunRedisStream , TaskRunStreamError , get_task_run_stream_key
6565from .temporal .client import execute_posthog_code_agent_relay_workflow , execute_task_processing_workflow
66+ from .temporal .process_task .utils import PR_AUTHORSHIP_MODE_USER , cache_github_user_token
6667
6768logger = logging .getLogger (__name__ )
6869
@@ -247,6 +248,10 @@ def run(self, request, pk=None, **kwargs):
247248 resume_from_run_id = request .validated_data .get ("resume_from_run_id" )
248249 pending_user_message = request .validated_data .get ("pending_user_message" )
249250 sandbox_environment_id = request .validated_data .get ("sandbox_environment_id" )
251+ pr_authorship_mode = request .validated_data .get ("pr_authorship_mode" )
252+ run_source = request .validated_data .get ("run_source" )
253+ signal_report_id = request .validated_data .get ("signal_report_id" )
254+ github_user_token = request .validated_data .get ("github_user_token" )
250255
251256 extra_state = None
252257 if resume_from_run_id :
@@ -269,6 +274,32 @@ def run(self, request, pk=None, **kwargs):
269274 if prev_sandbox_env_id and sandbox_environment_id is None :
270275 sandbox_environment_id = prev_sandbox_env_id
271276
277+ previous_state = previous_run .state or {}
278+ if pr_authorship_mode is None :
279+ pr_authorship_mode = previous_state .get ("pr_authorship_mode" )
280+ if run_source is None :
281+ run_source = previous_state .get ("run_source" )
282+ if signal_report_id is None :
283+ signal_report_id = previous_state .get ("signal_report_id" )
284+ if branch is None :
285+ previous_base_branch = previous_state .get ("pr_base_branch" )
286+ if isinstance (previous_base_branch , str ):
287+ branch = previous_base_branch
288+
289+ for key , value in {
290+ "pr_base_branch" : branch ,
291+ "pr_authorship_mode" : pr_authorship_mode ,
292+ "run_source" : run_source ,
293+ "signal_report_id" : signal_report_id ,
294+ }.items ():
295+ if value is not None :
296+ extra_state = extra_state or {}
297+ extra_state [key ] = value
298+
299+ # Only require a user token when the task has a repo (no-repo cloud runs skip GitHub operations)
300+ if pr_authorship_mode == PR_AUTHORSHIP_MODE_USER and task .repository and not github_user_token :
301+ return Response ({"detail" : "github_user_token is required for user-authored cloud runs" }, status = 400 )
302+
272303 if sandbox_environment_id is not None :
273304 sandbox_environment = SandboxEnvironment .objects .filter (id = sandbox_environment_id , team = task .team ).first ()
274305 if not sandbox_environment :
@@ -291,6 +322,9 @@ def run(self, request, pk=None, **kwargs):
291322
292323 task_run = task .create_run (mode = mode , branch = branch , extra_state = extra_state )
293324
325+ if github_user_token and pr_authorship_mode == PR_AUTHORSHIP_MODE_USER :
326+ cache_github_user_token (str (task_run .id ), github_user_token )
327+
294328 logger .info (f"Triggering workflow for task { task .id } , run { task_run .id } " )
295329
296330 self ._trigger_workflow (task , task_run )
@@ -379,34 +413,50 @@ def update(self, request, *args, **kwargs):
379413 )
380414 def partial_update (self , request , * args , ** kwargs ):
381415 task_run = cast (TaskRun , self .get_object ())
382- old_status = task_run .status
383-
384- # Update fields from validated data
385- for key , value in request .validated_data .items ():
386- setattr (task_run , key , value )
416+ has_output_merge = "output" in request .validated_data and isinstance (request .validated_data ["output" ], dict )
387417
388- new_status = request .validated_data .get ("status" )
389- terminal_statuses = [
390- TaskRun .Status .COMPLETED ,
391- TaskRun .Status .FAILED ,
392- TaskRun .Status .CANCELLED ,
393- ]
394-
395- # Auto-set completed_at if status is completed or failed
396- if new_status in terminal_statuses :
397- if not task_run .completed_at :
398- task_run .completed_at = timezone .now ()
399-
400- # Signal Temporal workflow if status changed to terminal state
401- if old_status != new_status :
402- self ._signal_workflow_completion (
403- task_run ,
404- new_status ,
405- request .validated_data .get ("error_message" ),
406- )
407-
408- task_run .save ()
409- self ._post_slack_update_for_pr (task_run )
418+ with transaction .atomic ():
419+ # Re-fetch with row lock when merging output to prevent concurrent
420+ # PATCHes (e.g. branch sync + PR URL) from clobbering each other.
421+ if has_output_merge :
422+ task_run = TaskRun .objects .select_for_update ().get (pk = task_run .pk )
423+
424+ old_status = task_run .status
425+ old_pr_url = (task_run .output or {}).get ("pr_url" ) if isinstance (task_run .output , dict ) else None
426+
427+ # Update fields from validated data
428+ for key , value in request .validated_data .items ():
429+ if key == "output" and isinstance (value , dict ):
430+ existing_output = task_run .output if isinstance (task_run .output , dict ) else {}
431+ setattr (task_run , key , {** existing_output , ** value })
432+ continue
433+ setattr (task_run , key , value )
434+
435+ new_status = request .validated_data .get ("status" )
436+ terminal_statuses = [
437+ TaskRun .Status .COMPLETED ,
438+ TaskRun .Status .FAILED ,
439+ TaskRun .Status .CANCELLED ,
440+ ]
441+
442+ # Auto-set completed_at if status is completed or failed
443+ if new_status in terminal_statuses :
444+ if not task_run .completed_at :
445+ task_run .completed_at = timezone .now ()
446+
447+ task_run .save ()
448+
449+ # Signal Temporal and post Slack updates after commit to avoid
450+ # holding the row lock during external calls.
451+ if new_status in terminal_statuses and old_status != new_status :
452+ self ._signal_workflow_completion (
453+ task_run ,
454+ new_status ,
455+ request .validated_data .get ("error_message" ),
456+ )
457+ new_pr_url = (task_run .output or {}).get ("pr_url" ) if isinstance (task_run .output , dict ) else None
458+ if new_pr_url and new_pr_url != old_pr_url :
459+ self ._post_slack_update_for_pr (task_run )
410460
411461 return Response (TaskRunDetailSerializer (task_run , context = self .get_serializer_context ()).data )
412462
0 commit comments