11import asyncio
22import logging
33import os
4- from shutil import rmtree
5- from pathlib import Path
64import subprocess
75import sys
86import traceback
9- from typing import Optional
7+ from pathlib import Path
108
119from pydantic import BaseModel , Field
1210
2220from beeai_framework .tools import Tool
2321from beeai_framework .tools .search .duckduckgo import DuckDuckGoSearchTool
2422from beeai_framework .tools .think import ThinkTool
23+ from beeai_framework .workflows import Workflow
2524
25+ import tasks
26+ from constants import COMMIT_PREFIX
27+ from observability import setup_observability
28+ from tools .commands import RunShellCommandTool
2629from tools .specfile import AddChangelogEntryTool , BumpReleaseTool
2730from tools .text import CreateTool , InsertTool , StrReplaceTool , ViewTool
2831from tools .wicked_git import GitLogSearchTool , GitPatchCreationTool
29- from constants import COMMIT_PREFIX , BRANCH_PREFIX
30- from observability import setup_observability
31- from tools .commands import RunShellCommandTool
3232from triage_agent import BackportData , ErrorData
33- from utils import get_agent_execution_config , mcp_tools , redis_client , get_git_finalization_steps
33+ from utils import check_subprocess , get_agent_execution_config , mcp_tools , redis_client
3434
3535logger = logging .getLogger (__name__ )
3636
3737
3838class InputSchema (BaseModel ):
39+ local_clone : Path = Field (description = "Path to the local clone of forked dist-git repository" )
40+ unpacked_sources : Path = Field (description = "Path to the unpacked (using `centpkg prep`) sources" )
3941 package : str = Field (description = "Package to update" )
42+ dist_git_branch : str = Field (description = "Git branch in dist-git to be updated" )
4043 upstream_fix : str = Field (description = "Link to an upstream fix for the issue" )
4144 jira_issue : str = Field (description = "Jira issue to reference as resolved" )
4245 cve_id : str = Field (default = "" , description = "CVE ID if the jira issue is a CVE" )
43- dist_git_branch : str = Field (description = "Git branch in dist-git to be updated" )
44- git_repo_basepath : str = Field (
45- description = "Base path for cloned git repos" ,
46- default = os .getenv ("GIT_REPO_BASEPATH" ),
47- )
48- unpacked_sources : str = Field (
49- description = "Path to the unpacked (using `centpkg prep`) sources" ,
50- default = "" ,
51- )
5246
5347
5448class OutputSchema (BaseModel ):
5549 success : bool = Field (description = "Whether the backport was successfully completed" )
5650 status : str = Field (description = "Backport status" )
57- mr_url : Optional [ str ] = Field (description = "URL to the opened merge request" )
58- error : Optional [ str ] = Field (description = "Specific details about an error" )
51+ mr_url : str | None = Field (description = "URL to the opened merge request" )
52+ error : str | None = Field (description = "Specific details about an error" )
5953
6054
6155def render_prompt (input : InputSchema ) -> str :
6256 template = (
63- 'Work inside the repository cloned at "{{ git_repo_basepath }}/{{ package }}" \n '
57+ 'Work inside the repository cloned in "{{ local_clone }}", it is your current working directory \n '
6458 "Use the `git_log_search` tool to check if the jira issue ({{ jira_issue }}) or CVE ({{ cve_id }}) is already resolved.\n "
6559 "If the issue or the cve are already resolved, exit the backporting process with success=True and status=\" Backport already applied\" \n "
6660 "Download the upstream fix from {{ upstream_fix }}\n "
@@ -74,64 +68,8 @@ def render_prompt(input: InputSchema) -> str:
7468 "Delete all *.rej files\n "
7569 "DO **NOT** RUN COMMAND `git am --continue`\n "
7670 "Once you resolve all conflicts, use tool git_patch_create to create a patch file\n "
77- "{{ backport_git_steps }}"
7871 )
79-
80- # Define template function that can be called from the template
81- def backport_git_steps (data ):
82- input_data = InputSchema .model_validate (data )
83- return get_git_finalization_steps (
84- package = input_data .package ,
85- jira_issue = input_data .jira_issue ,
86- commit_title = f"{ COMMIT_PREFIX } backport { input_data .jira_issue } " ,
87- files_to_commit = f"*.spec and { input_data .jira_issue } .patch" ,
88- branch_name = f"{ BRANCH_PREFIX } -{ input_data .jira_issue } " ,
89- dist_git_branch = input_data .dist_git_branch ,
90- )
91-
92- return PromptTemplate (
93- PromptTemplateInput (schema = InputSchema , template = template , functions = {"backport_git_steps" : backport_git_steps })
94- ).render (input )
95-
96-
97- def prepare_package (
98- package : str , jira_issue : str , dist_git_branch : str , input_schema : InputSchema
99- ) -> tuple [Path , Path ]:
100- """
101- Prepare the package for backporting by cloning the dist-git repository, switching to the appropriate branch,
102- and downloading the sources.
103- Returns the path to the unpacked sources.
104- """
105- git_repo = Path (input_schema .git_repo_basepath )
106- git_repo .mkdir (parents = True , exist_ok = True )
107- subprocess .check_call (
108- [
109- "centpkg" ,
110- "clone" ,
111- "--anonymous" ,
112- "--branch" ,
113- dist_git_branch ,
114- package ,
115- ],
116- cwd = git_repo ,
117- )
118- local_clone = git_repo / package
119- subprocess .check_call (
120- [
121- "git" ,
122- "switch" ,
123- "-c" ,
124- f"automated-package-update-{ jira_issue } " ,
125- dist_git_branch ,
126- ],
127- cwd = local_clone ,
128- )
129- subprocess .check_call (["centpkg" , "sources" ], cwd = local_clone )
130- subprocess .check_call (["centpkg" , "prep" ], cwd = local_clone )
131- unpacked_sources = list (local_clone .glob (f"*-build/*{ package } *" ))
132- if len (unpacked_sources ) != 1 :
133- raise ValueError (f"Expected exactly one unpacked source, got { unpacked_sources } " )
134- return unpacked_sources [0 ], local_clone
72+ return PromptTemplate (PromptTemplateInput (schema = InputSchema , template = template )).render (input )
13573
13674
13775async def main () -> None :
@@ -141,7 +79,7 @@ async def main() -> None:
14179 cve_id = os .getenv ("CVE_ID" , "" )
14280
14381 async with mcp_tools (os .getenv ("MCP_GATEWAY_URL" )) as gateway_tools :
144- agent = RequirementAgent (
82+ backport_agent = RequirementAgent (
14583 llm = ChatModel .from_name (os .getenv ("CHAT_MODEL" )),
14684 tools = [
14785 ThinkTool (),
@@ -155,11 +93,6 @@ async def main() -> None:
15593 GitLogSearchTool (),
15694 BumpReleaseTool (),
15795 AddChangelogEntryTool (),
158- ]
159- + [
160- t
161- for t in gateway_tools
162- if t .name in ("fork_repository" , "open_merge_request" , "push_to_remote_repository" )
16396 ],
16497 memory = UnconstrainedMemory (),
16598 requirements = [
@@ -182,41 +115,110 @@ async def main() -> None:
182115 ],
183116 )
184117
185- dry_run = os .getenv ("DRY_RUN" , "False" ).lower () == "true"
118+ class State (BaseModel ):
119+ jira_issue : str
120+ package : str
121+ dist_git_branch : str
122+ upstream_fix : str
123+ cve_id : str
124+ local_clone : Path | None = Field (default = None )
125+ update_branch : str | None = Field (default = None )
126+ fork_url : str | None = Field (default = None )
127+ unpacked_sources : Path | None = Field (default = None )
128+ backport_result : OutputSchema | None = Field (default = None )
129+ merge_request_url : str | None = Field (default = None )
130+
131+ workflow = Workflow (State )
186132
187- async def run (input ):
188- response = await agent .run (
189- prompt = render_prompt (input ),
190- expected_output = OutputSchema ,
191- execution = get_agent_execution_config (),
133+ async def fork_and_prepare_dist_git (state ):
134+ state .local_clone , state .update_branch , state .fork_url = await tasks .fork_and_prepare_dist_git (
135+ jira_issue = state .jira_issue ,
136+ package = state .package ,
137+ dist_git_branch = state .dist_git_branch ,
138+ available_tools = gateway_tools ,
192139 )
193- return OutputSchema .model_validate_json (response .answer .text )
140+ await check_subprocess (["centpkg" , "sources" ], cwd = state .local_clone )
141+ await check_subprocess (["centpkg" , "prep" ], cwd = state .local_clone )
142+ unpacked_sources = list (state .local_clone .glob (f"*-build/*{ state .package } *" ))
143+ if len (unpacked_sources ) != 1 :
144+ raise ValueError (f"Expected exactly one unpacked source, got { unpacked_sources } " )
145+ [state .unpacked_sources ] = unpacked_sources
146+ return "run_backport_agent"
147+
148+ async def run_backport_agent (state ):
149+ cwd = Path .cwd ()
150+ try :
151+ # make things easier for the LLM
152+ os .chdir (state .local_clone )
153+ response = await backport_agent .run (
154+ prompt = render_prompt (
155+ InputSchema (
156+ local_clone = state .local_clone ,
157+ unpacked_sources = state .unpacked_sources ,
158+ package = state .package ,
159+ dist_git_branch = state .dist_git_branch ,
160+ upstream_fix = state .upstream_fix ,
161+ jira_issue = state .jira_issue ,
162+ cve_id = state .cve_id ,
163+ ),
164+ ),
165+ expected_output = OutputSchema ,
166+ execution = get_agent_execution_config (),
167+ )
168+ state .backport_result = OutputSchema .model_validate_json (response .answer .text )
169+ finally :
170+ os .chdir (cwd )
171+ if state .backport_result .success :
172+ return "commit_push_and_open_mr"
173+ else :
174+ return Workflow .END
175+
176+ async def commit_push_and_open_mr (state ):
177+ state .merge_request_url = await tasks .commit_push_and_open_mr (
178+ local_clone = state .local_clone ,
179+ files_to_commit = ["*.spec" , f"{ state .jira_issue } .patch" ],
180+ commit_message = f"{ COMMIT_PREFIX } backport { state .jira_issue } " ,
181+ fork_url = state .fork_url ,
182+ dist_git_branch = state .dist_git_branch ,
183+ update_branch = state .update_branch ,
184+ mr_title = "{COMMIT_PREFIX} backport {state.jira_issue}" ,
185+ mr_description = "TODO" ,
186+ available_tools = gateway_tools ,
187+ commit_only = os .getenv ("DRY_RUN" , "False" ).lower () == "true" ,
188+ )
189+ return Workflow .END
190+
191+ workflow .add_step ("fork_and_prepare_dist_git" , fork_and_prepare_dist_git )
192+ workflow .add_step ("run_backport_agent" , run_backport_agent )
193+ workflow .add_step ("commit_push_and_open_mr" , commit_push_and_open_mr )
194+
195+ async def run_workflow (package , dist_git_branch , upstream_fix , jira_issue , cve_id ):
196+ response = await workflow .run (
197+ State (
198+ package = package ,
199+ dist_git_branch = dist_git_branch ,
200+ upstream_fix = upstream_fix ,
201+ jira_issue = jira_issue ,
202+ cve_id = cve_id ,
203+ ),
204+ )
205+ return response .state
194206
195207 if (
196208 (package := os .getenv ("PACKAGE" , None ))
209+ and (branch := os .getenv ("BRANCH" , None ))
197210 and (upstream_fix := os .getenv ("UPSTREAM_FIX" , None ))
198211 and (jira_issue := os .getenv ("JIRA_ISSUE" , None ))
199- and (branch := os .getenv ("BRANCH" , None ))
200212 ):
201213 logger .info ("Running in direct mode with environment variables" )
202- input = InputSchema (
214+ state = await run_workflow (
203215 package = package ,
216+ dist_git_branch = branch ,
204217 upstream_fix = upstream_fix ,
205218 jira_issue = jira_issue ,
206- dist_git_branch = branch ,
207- cve_id = cve_id ,
219+ cve_id = os .getenv ("CVE_ID" , "" ),
208220 )
209- unpacked_sources , local_clone = prepare_package (package , jira_issue , branch , input )
210- input .unpacked_sources = str (unpacked_sources )
211- try :
212- output = await run (input )
213- finally :
214- if not dry_run :
215- logger .info (f"Removing { local_clone } " )
216- rmtree (local_clone )
217- else :
218- logger .info (f"DRY RUN: Not removing { local_clone } " )
219- logger .info (f"Direct run completed: { output .model_dump_json (indent = 4 )} " )
221+ logger .info (f"Direct run completed: { state .backport_result .model_dump_json (indent = 4 )} " )
220222 return
221223
222224 class Task (BaseModel ):
@@ -245,18 +247,6 @@ class Task(BaseModel):
245247 f"JIRA: { backport_data .jira_issue } , attempt: { task .attempts + 1 } "
246248 )
247249
248- input = InputSchema (
249- package = backport_data .package ,
250- upstream_fix = backport_data .patch_url ,
251- jira_issue = backport_data .jira_issue ,
252- dist_git_branch = backport_data .branch ,
253- cve_id = backport_data .cve_id ,
254- )
255- unpacked_sources , local_clone = prepare_package (
256- backport_data .package , backport_data .jira_issue , backport_data .branch , input
257- )
258- input .unpacked_sources = str (unpacked_sources )
259-
260250 async def retry (task , error ):
261251 task .attempts += 1
262252 if task .attempts < max_retries :
@@ -274,23 +264,29 @@ async def retry(task, error):
274264
275265 try :
276266 logger .info (f"Starting backport processing for { backport_data .jira_issue } " )
277- output = await run (input )
267+ state = await run_workflow (
268+ package = backport_data .package ,
269+ dist_git_branch = backport_data .branch ,
270+ upstream_fix = backport_data .patch_url ,
271+ jira_issue = backport_data .jira_issue ,
272+ cve_id = backport_data .cve_id ,
273+ )
278274 logger .info (
279- f"Backport processing completed for { backport_data .jira_issue } , " f"success: { output .success } "
275+ f"Backport processing completed for { backport_data .jira_issue } , " f"success: { state . backport_result .success } "
280276 )
281277 except Exception as e :
282278 error = "" .join (traceback .format_exception (e ))
283279 logger .error (f"Exception during backport processing for { backport_data .jira_issue } : { error } " )
284- await retry (task , ErrorData (details = error , jira_issue = input .jira_issue ).model_dump_json ())
280+ await retry (task , ErrorData (details = error , jira_issue = backport_data .jira_issue ).model_dump_json ())
285281 rmtree (local_clone )
286282 else :
287283 rmtree (local_clone )
288- if output .success :
284+ if state . backport_data .success :
289285 logger .info (f"Backport successful for { backport_data .jira_issue } , " f"adding to completed list" )
290- await redis .lpush ("completed_backport_list" , output .model_dump_json ())
286+ await redis .lpush ("completed_backport_list" , state . backport_data .model_dump_json ())
291287 else :
292- logger .warning (f"Backport failed for { backport_data .jira_issue } : { output .error } " )
293- await retry (task , output .error )
288+ logger .warning (f"Backport failed for { backport_data .jira_issue } : { state . backport_data .error } " )
289+ await retry (task , state . backport_data .error )
294290
295291
296292if __name__ == "__main__" :
0 commit comments