高级概念
通过直接实现 _run_async_impl
来构建自定义智能体虽然能提供强大的控制能力,但比使用预定义的 LlmAgent
或标准 WorkflowAgent
类型更为复杂。建议先掌握基础智能体类型,再着手处理自定义编排逻辑。
自定义智能体
自定义智能体为ADK提供了终极灵活性,允许您通过直接继承 BaseAgent
并实现自己的控制流来定义任意编排逻辑。这超越了 SequentialAgent
、LoopAgent
和 ParallelAgent
的预定义模式,使您能够构建高度定制化的复杂智能体工作流。
引言:超越预定义工作流
什么是自定义智能体?
自定义智能体本质上是任何继承自 google.adk.agents.BaseAgent
并在异步方法 _run_async_impl
中实现核心执行逻辑的类。您可以完全控制该方法如何调用其他智能体(子智能体)、管理状态和处理事件。
使用场景
虽然标准工作流智能体(SequentialAgent
、LoopAgent
、ParallelAgent
)涵盖了常见编排模式,但在以下场景需要自定义智能体:
- 条件逻辑:根据运行时条件或上一步结果执行不同的子智能体或路径
- 复杂状态管理:实现超越简单顺序传递的精细状态维护逻辑
- 外部集成:在编排流控制中直接调用外部API、数据库或自定义Python库
- 动态智能体选择:根据对情境或输入的动态评估选择后续子智能体
- 独特工作流模式:实现不符合标准顺序、并行或循环结构的编排逻辑
实现自定义逻辑
自定义智能体的核心是 _run_async_impl
方法,这里定义了其独特行为:
- 签名:
async def _run_async_impl(self, ctx: InvocationContext) -> AsyncGenerator[Event, None]:
- 异步生成器:必须是
async def
函数并返回AsyncGenerator
,以便将子智能体或自身逻辑产生的事件yield
回传给运行器 ctx
(调用上下文):提供关键运行时信息访问,最重要的是ctx.session.state
——这是自定义智能体编排的步骤间共享数据的主要方式
_run_async_impl
中的关键能力:
-
调用子智能体:使用子智能体的
run_async
方法调用它们(通常存储为实例属性如self.my_llm_agent
)并生成其事件: -
状态管理:通过会话状态字典(
ctx.session.state
)在子智能体调用间传递数据或做出决策: -
实现控制流:使用标准Python结构(
if
/elif
/else
、for
/while
循环、try
/except
)创建涉及子智能体的复杂条件或迭代工作流
管理子智能体与状态
通常自定义智能体会编排其他智能体(如 LlmAgent
、LoopAgent
等):
- 初始化:通常将这些子智能体实例传入自定义智能体的
__init__
方法并存储为实例属性(如self.story_generator = story_generator_instance
),以便在_run_async_impl
中访问 sub_agents
列表:使用super().__init__(...)
初始化BaseAgent
时应传递sub_agents
列表,告知ADK框架该自定义智能体直接管理的子智能体层级关系。这对生命周期管理、内省等框架功能至关重要,即使_run_async_impl
通过self.xxx_agent
直接调用智能体- 状态:如前所述,
ctx.session.state
是子智能体(特别是使用output_key
的LlmAgent
)与编排器通信的标准方式,也是编排器向下传递必要输入的通道
设计模式示例:StoryFlowAgent
通过示例展示自定义智能体的强大能力:包含条件逻辑的多阶段内容生成工作流。
目标:创建生成故事→通过评审修订迭代优化→执行最终检查→若最终语气检查失败则重新生成故事的系统
为何自定义:核心需求是基于语气检查结果的条件性重新生成。标准工作流智能体没有基于子智能体任务结果的条件分支功能,需要在编排器中实现自定义Python逻辑(if tone == "negative": ...
)
第一部分:简化自定义智能体初始化
定义继承自 BaseAgent
的 StoryFlowAgent
。在 __init__
中存储必要子智能体(传入)为实例属性,并告知 BaseAgent
框架该自定义智能体直接管理的顶层智能体。
class StoryFlowAgent(BaseAgent):
"""
Custom agent for a story generation and refinement workflow.
This agent orchestrates a sequence of LLM agents to generate a story,
critique it, revise it, check grammar and tone, and potentially
regenerate the story if the tone is negative.
"""
# --- Field Declarations for Pydantic ---
# Declare the agents passed during initialization as class attributes with type hints
story_generator: LlmAgent
critic: LlmAgent
reviser: LlmAgent
grammar_check: LlmAgent
tone_check: LlmAgent
loop_agent: LoopAgent
sequential_agent: SequentialAgent
# model_config allows setting Pydantic configurations if needed, e.g., arbitrary_types_allowed
model_config = {"arbitrary_types_allowed": True}
def __init__(
self,
name: str,
story_generator: LlmAgent,
critic: LlmAgent,
reviser: LlmAgent,
grammar_check: LlmAgent,
tone_check: LlmAgent,
):
"""
Initializes the StoryFlowAgent.
Args:
name: The name of the agent.
story_generator: An LlmAgent to generate the initial story.
critic: An LlmAgent to critique the story.
reviser: An LlmAgent to revise the story based on criticism.
grammar_check: An LlmAgent to check the grammar.
tone_check: An LlmAgent to analyze the tone.
"""
# Create internal agents *before* calling super().__init__
loop_agent = LoopAgent(
name="CriticReviserLoop", sub_agents=[critic, reviser], max_iterations=2
)
sequential_agent = SequentialAgent(
name="PostProcessing", sub_agents=[grammar_check, tone_check]
)
# Define the sub_agents list for the framework
sub_agents_list = [
story_generator,
loop_agent,
sequential_agent,
]
# Pydantic will validate and assign them based on the class annotations.
super().__init__(
name=name,
story_generator=story_generator,
critic=critic,
reviser=reviser,
grammar_check=grammar_check,
tone_check=tone_check,
loop_agent=loop_agent,
sequential_agent=sequential_agent,
sub_agents=sub_agents_list, # Pass the sub_agents list directly
)
第二部分:定义自定义执行逻辑
该方法使用标准Python async/await和控制流编排子智能体:
@override
async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
"""
Implements the custom orchestration logic for the story workflow.
Uses the instance attributes assigned by Pydantic (e.g., self.story_generator).
"""
logger.info(f"[{self.name}] Starting story generation workflow.")
# 1. Initial Story Generation
logger.info(f"[{self.name}] Running StoryGenerator...")
async for event in self.story_generator.run_async(ctx):
logger.info(f"[{self.name}] Event from StoryGenerator: {event.model_dump_json(indent=2, exclude_none=True)}")
yield event
# Check if story was generated before proceeding
if "current_story" not in ctx.session.state or not ctx.session.state["current_story"]:
logger.error(f"[{self.name}] Failed to generate initial story. Aborting workflow.")
return # Stop processing if initial story failed
logger.info(f"[{self.name}] Story state after generator: {ctx.session.state.get('current_story')}")
# 2. Critic-Reviser Loop
logger.info(f"[{self.name}] Running CriticReviserLoop...")
# Use the loop_agent instance attribute assigned during init
async for event in self.loop_agent.run_async(ctx):
logger.info(f"[{self.name}] Event from CriticReviserLoop: {event.model_dump_json(indent=2, exclude_none=True)}")
yield event
logger.info(f"[{self.name}] Story state after loop: {ctx.session.state.get('current_story')}")
# 3. Sequential Post-Processing (Grammar and Tone Check)
logger.info(f"[{self.name}] Running PostProcessing...")
# Use the sequential_agent instance attribute assigned during init
async for event in self.sequential_agent.run_async(ctx):
logger.info(f"[{self.name}] Event from PostProcessing: {event.model_dump_json(indent=2, exclude_none=True)}")
yield event
# 4. Tone-Based Conditional Logic
tone_check_result = ctx.session.state.get("tone_check_result")
logger.info(f"[{self.name}] Tone check result: {tone_check_result}")
if tone_check_result == "negative":
logger.info(f"[{self.name}] Tone is negative. Regenerating story...")
async for event in self.story_generator.run_async(ctx):
logger.info(f"[{self.name}] Event from StoryGenerator (Regen): {event.model_dump_json(indent=2, exclude_none=True)}")
yield event
else:
logger.info(f"[{self.name}] Tone is not negative. Keeping current story.")
pass
logger.info(f"[{self.name}] Workflow finished.")
逻辑说明:
- 初始运行
story_generator
,其输出预期存入ctx.session.state["current_story"]
- 运行
loop_agent
(内部调用critic
和reviser
进行max_iterations
次迭代),它们从/向状态读写current_story
和criticism
- 运行
sequential_agent
(调用grammar_check
和tone_check
),读取current_story
并向状态写入grammar_suggestions
和tone_check_result
- 自定义部分:
if
语句检查状态中的tone_check_result
,若为"negative"则再次调用story_generator
覆盖状态中的current_story
,否则结束流程
第三部分:定义LLM子智能体
这些是标准 LlmAgent
定义,负责特定任务。其 output_key
参数对将结果存入 session.state
(供其他智能体或自定义编排器访问)至关重要。
GEMINI_FLASH = "gemini-2.0-flash" # Define model constant
# --- Define the individual LLM agents ---
story_generator = LlmAgent(
name="StoryGenerator",
model=GEMINI_2_FLASH,
instruction="""You are a story writer. Write a short story (around 100 words) about a cat,
based on the topic provided in session state with key 'topic'""",
input_schema=None,
output_key="current_story", # Key for storing output in session state
)
critic = LlmAgent(
name="Critic",
model=GEMINI_2_FLASH,
instruction="""You are a story critic. Review the story provided in
session state with key 'current_story'. Provide 1-2 sentences of constructive criticism
on how to improve it. Focus on plot or character.""",
input_schema=None,
output_key="criticism", # Key for storing criticism in session state
)
reviser = LlmAgent(
name="Reviser",
model=GEMINI_2_FLASH,
instruction="""You are a story reviser. Revise the story provided in
session state with key 'current_story', based on the criticism in
session state with key 'criticism'. Output only the revised story.""",
input_schema=None,
output_key="current_story", # Overwrites the original story
)
grammar_check = LlmAgent(
name="GrammarCheck",
model=GEMINI_2_FLASH,
instruction="""You are a grammar checker. Check the grammar of the story
provided in session state with key 'current_story'. Output only the suggested
corrections as a list, or output 'Grammar is good!' if there are no errors.""",
input_schema=None,
output_key="grammar_suggestions",
)
tone_check = LlmAgent(
name="ToneCheck",
model=GEMINI_2_FLASH,
instruction="""You are a tone analyzer. Analyze the tone of the story
provided in session state with key 'current_story'. Output only one word: 'positive' if
the tone is generally positive, 'negative' if the tone is generally negative, or 'neutral'
otherwise.""",
input_schema=None,
output_key="tone_check_result", # This agent's output determines the conditional flow
)
第四部分:实例化并运行自定义智能体
最后实例化 StoryFlowAgent
并照常使用 Runner
:
# --- Create the custom agent instance ---
story_flow_agent = StoryFlowAgent(
name="StoryFlowAgent",
story_generator=story_generator,
critic=critic,
reviser=reviser,
grammar_check=grammar_check,
tone_check=tone_check,
)
# --- Setup Runner and Session ---
session_service = InMemorySessionService()
initial_state = {"topic": "a brave kitten exploring a haunted house"}
session = session_service.create_session(
app_name=APP_NAME,
user_id=USER_ID,
session_id=SESSION_ID,
state=initial_state # Pass initial state here
)
logger.info(f"Initial session state: {session.state}")
runner = Runner(
agent=story_flow_agent, # Pass the custom orchestrator agent
app_name=APP_NAME,
session_service=session_service
)
# --- Function to Interact with the Agent ---
def call_agent(user_input_topic: str):
"""
Sends a new topic to the agent (overwriting the initial one if needed)
and runs the workflow.
"""
current_session = session_service.get_session(app_name=APP_NAME,
user_id=USER_ID,
session_id=SESSION_ID)
if not current_session:
logger.error("Session not found!")
return
current_session.state["topic"] = user_input_topic
logger.info(f"Updated session state topic to: {user_input_topic}")
content = types.Content(role='user', parts=[types.Part(text=f"Generate a story about: {user_input_topic}")])
events = runner.run(user_id=USER_ID, session_id=SESSION_ID, new_message=content)
final_response = "No final response captured."
for event in events:
if event.is_final_response() and event.content and event.content.parts:
logger.info(f"Potential final response from [{event.author}]: {event.content.parts[0].text}")
final_response = event.content.parts[0].text
print("\n--- Agent Interaction Result ---")
print("Agent Final Response: ", final_response)
final_session = session_service.get_session(app_name=APP_NAME,
user_id=USER_ID,
session_id=SESSION_ID)
print("Final Session State:")
import json
print(json.dumps(final_session.state, indent=2))
print("-------------------------------\n")
# --- Run the Agent ---
call_agent("a lonely robot finding a friend in a junkyard")
(注:完整可运行代码包括导入和执行逻辑,参见下方链接)
完整代码示例
Storyflow智能体
# StoryFlowAgent示例的完整可运行代码
import logging
from typing import AsyncGenerator
from typing_extensions import override
from google.adk.agents import LlmAgent, BaseAgent, LoopAgent, SequentialAgent
from google.adk.agents.invocation_context import InvocationContext
from google.genai import types
from google.adk.sessions import InMemorySessionService
from google.adk.runners import Runner
from google.adk.events import Event
from pydantic import BaseModel, Field
# --- Constants ---
APP_NAME = "story_app"
USER_ID = "12345"
SESSION_ID = "123344"
GEMINI_2_FLASH = "gemini-2.0-flash"
# --- Configure Logging ---
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Custom Orchestrator Agent ---
class StoryFlowAgent(BaseAgent):
"""
Custom agent for a story generation and refinement workflow.
This agent orchestrates a sequence of LLM agents to generate a story,
critique it, revise it, check grammar and tone, and potentially
regenerate the story if the tone is negative.
"""
# --- Field Declarations for Pydantic ---
# Declare the agents passed during initialization as class attributes with type hints
story_generator: LlmAgent
critic: LlmAgent
reviser: LlmAgent
grammar_check: LlmAgent
tone_check: LlmAgent
loop_agent: LoopAgent
sequential_agent: SequentialAgent
# model_config allows setting Pydantic configurations if needed, e.g., arbitrary_types_allowed
model_config = {"arbitrary_types_allowed": True}
def __init__(
self,
name: str,
story_generator: LlmAgent,
critic: LlmAgent,
reviser: LlmAgent,
grammar_check: LlmAgent,
tone_check: LlmAgent,
):
"""
Initializes the StoryFlowAgent.
Args:
name: The name of the agent.
story_generator: An LlmAgent to generate the initial story.
critic: An LlmAgent to critique the story.
reviser: An LlmAgent to revise the story based on criticism.
grammar_check: An LlmAgent to check the grammar.
tone_check: An LlmAgent to analyze the tone.
"""
# Create internal agents *before* calling super().__init__
loop_agent = LoopAgent(
name="CriticReviserLoop", sub_agents=[critic, reviser], max_iterations=2
)
sequential_agent = SequentialAgent(
name="PostProcessing", sub_agents=[grammar_check, tone_check]
)
# Define the sub_agents list for the framework
sub_agents_list = [
story_generator,
loop_agent,
sequential_agent,
]
# Pydantic will validate and assign them based on the class annotations.
super().__init__(
name=name,
story_generator=story_generator,
critic=critic,
reviser=reviser,
grammar_check=grammar_check,
tone_check=tone_check,
loop_agent=loop_agent,
sequential_agent=sequential_agent,
sub_agents=sub_agents_list, # Pass the sub_agents list directly
)
@override
async def _run_async_impl(
self, ctx: InvocationContext
) -> AsyncGenerator[Event, None]:
"""
Implements the custom orchestration logic for the story workflow.
Uses the instance attributes assigned by Pydantic (e.g., self.story_generator).
"""
logger.info(f"[{self.name}] Starting story generation workflow.")
# 1. Initial Story Generation
logger.info(f"[{self.name}] Running StoryGenerator...")
async for event in self.story_generator.run_async(ctx):
logger.info(f"[{self.name}] Event from StoryGenerator: {event.model_dump_json(indent=2, exclude_none=True)}")
yield event
# Check if story was generated before proceeding
if "current_story" not in ctx.session.state or not ctx.session.state["current_story"]:
logger.error(f"[{self.name}] Failed to generate initial story. Aborting workflow.")
return # Stop processing if initial story failed
logger.info(f"[{self.name}] Story state after generator: {ctx.session.state.get('current_story')}")
# 2. Critic-Reviser Loop
logger.info(f"[{self.name}] Running CriticReviserLoop...")
# Use the loop_agent instance attribute assigned during init
async for event in self.loop_agent.run_async(ctx):
logger.info(f"[{self.name}] Event from CriticReviserLoop: {event.model_dump_json(indent=2, exclude_none=True)}")
yield event
logger.info(f"[{self.name}] Story state after loop: {ctx.session.state.get('current_story')}")
# 3. Sequential Post-Processing (Grammar and Tone Check)
logger.info(f"[{self.name}] Running PostProcessing...")
# Use the sequential_agent instance attribute assigned during init
async for event in self.sequential_agent.run_async(ctx):
logger.info(f"[{self.name}] Event from PostProcessing: {event.model_dump_json(indent=2, exclude_none=True)}")
yield event
# 4. Tone-Based Conditional Logic
tone_check_result = ctx.session.state.get("tone_check_result")
logger.info(f"[{self.name}] Tone check result: {tone_check_result}")
if tone_check_result == "negative":
logger.info(f"[{self.name}] Tone is negative. Regenerating story...")
async for event in self.story_generator.run_async(ctx):
logger.info(f"[{self.name}] Event from StoryGenerator (Regen): {event.model_dump_json(indent=2, exclude_none=True)}")
yield event
else:
logger.info(f"[{self.name}] Tone is not negative. Keeping current story.")
pass
logger.info(f"[{self.name}] Workflow finished.")
# --- Define the individual LLM agents ---
story_generator = LlmAgent(
name="StoryGenerator",
model=GEMINI_2_FLASH,
instruction="""You are a story writer. Write a short story (around 100 words) about a cat,
based on the topic provided in session state with key 'topic'""",
input_schema=None,
output_key="current_story", # Key for storing output in session state
)
critic = LlmAgent(
name="Critic",
model=GEMINI_2_FLASH,
instruction="""You are a story critic. Review the story provided in
session state with key 'current_story'. Provide 1-2 sentences of constructive criticism
on how to improve it. Focus on plot or character.""",
input_schema=None,
output_key="criticism", # Key for storing criticism in session state
)
reviser = LlmAgent(
name="Reviser",
model=GEMINI_2_FLASH,
instruction="""You are a story reviser. Revise the story provided in
session state with key 'current_story', based on the criticism in
session state with key 'criticism'. Output only the revised story.""",
input_schema=None,
output_key="current_story", # Overwrites the original story
)
grammar_check = LlmAgent(
name="GrammarCheck",
model=GEMINI_2_FLASH,
instruction="""You are a grammar checker. Check the grammar of the story
provided in session state with key 'current_story'. Output only the suggested
corrections as a list, or output 'Grammar is good!' if there are no errors.""",
input_schema=None,
output_key="grammar_suggestions",
)
tone_check = LlmAgent(
name="ToneCheck",
model=GEMINI_2_FLASH,
instruction="""You are a tone analyzer. Analyze the tone of the story
provided in session state with key 'current_story'. Output only one word: 'positive' if
the tone is generally positive, 'negative' if the tone is generally negative, or 'neutral'
otherwise.""",
input_schema=None,
output_key="tone_check_result", # This agent's output determines the conditional flow
)
# --- Create the custom agent instance ---
story_flow_agent = StoryFlowAgent(
name="StoryFlowAgent",
story_generator=story_generator,
critic=critic,
reviser=reviser,
grammar_check=grammar_check,
tone_check=tone_check,
)
# --- Setup Runner and Session ---
session_service = InMemorySessionService()
initial_state = {"topic": "a brave kitten exploring a haunted house"}
session = session_service.create_session(
app_name=APP_NAME,
user_id=USER_ID,
session_id=SESSION_ID,
state=initial_state # Pass initial state here
)
logger.info(f"Initial session state: {session.state}")
runner = Runner(
agent=story_flow_agent, # Pass the custom orchestrator agent
app_name=APP_NAME,
session_service=session_service
)
# --- Function to Interact with the Agent ---
def call_agent(user_input_topic: str):
"""
Sends a new topic to the agent (overwriting the initial one if needed)
and runs the workflow.
"""
current_session = session_service.get_session(app_name=APP_NAME,
user_id=USER_ID,
session_id=SESSION_ID)
if not current_session:
logger.error("Session not found!")
return
current_session.state["topic"] = user_input_topic
logger.info(f"Updated session state topic to: {user_input_topic}")
content = types.Content(role='user', parts=[types.Part(text=f"Generate a story about: {user_input_topic}")])
events = runner.run(user_id=USER_ID, session_id=SESSION_ID, new_message=content)
final_response = "No final response captured."
for event in events:
if event.is_final_response() and event.content and event.content.parts:
logger.info(f"Potential final response from [{event.author}]: {event.content.parts[0].text}")
final_response = event.content.parts[0].text
print("\n--- Agent Interaction Result ---")
print("Agent Final Response: ", final_response)
final_session = session_service.get_session(app_name=APP_NAME,
user_id=USER_ID,
session_id=SESSION_ID)
print("Final Session State:")
import json
print(json.dumps(final_session.state, indent=2))
print("-------------------------------\n")
# --- Run the Agent ---
call_agent("a lonely robot finding a friend in a junkyard")