自定义代理#

您可能有一些代理的行为不属于预设的范围。在这种情况下,您可以构建自定义代理。

AgentChat 中的所有代理都继承自 BaseChatAgent 类,并实现以下抽象方法和属性

  • on_messages():定义代理响应消息行为的抽象方法。当代理被要求在 run() 中提供响应时,将调用此方法。它返回一个 Response 对象。

  • on_reset():将代理重置为其初始状态的抽象方法。当代理被要求重置自身时,将调用此方法。

  • produced_message_types:代理在其响应中可以产生的可能的 BaseChatMessage 消息类型的列表。

可选地,您可以实现 on_messages_stream() 方法,以便在代理生成消息时对其进行流式传输。此方法由 run_stream() 调用以流式传输消息。如果未实现此方法,则代理使用 on_messages_stream() 的默认实现,该实现调用 on_messages() 方法并生成响应中的所有消息。

CountDownAgent#

在此示例中,我们创建一个简单的代理,该代理从给定的数字倒数到零,并生成包含当前计数的流式消息。

from typing import AsyncGenerator, List, Sequence

from autogen_agentchat.agents import BaseChatAgent
from autogen_agentchat.base import Response
from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage, TextMessage
from autogen_core import CancellationToken


class CountDownAgent(BaseChatAgent):
    def __init__(self, name: str, count: int = 3):
        super().__init__(name, "A simple agent that counts down.")
        self._count = count

    @property
    def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
        return (TextMessage,)

    async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
        # Calls the on_messages_stream.
        response: Response | None = None
        async for message in self.on_messages_stream(messages, cancellation_token):
            if isinstance(message, Response):
                response = message
        assert response is not None
        return response

    async def on_messages_stream(
        self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
    ) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
        inner_messages: List[BaseAgentEvent | BaseChatMessage] = []
        for i in range(self._count, 0, -1):
            msg = TextMessage(content=f"{i}...", source=self.name)
            inner_messages.append(msg)
            yield msg
        # The response is returned at the end of the stream.
        # It contains the final message and all the inner messages.
        yield Response(chat_message=TextMessage(content="Done!", source=self.name), inner_messages=inner_messages)

    async def on_reset(self, cancellation_token: CancellationToken) -> None:
        pass


async def run_countdown_agent() -> None:
    # Create a countdown agent.
    countdown_agent = CountDownAgent("countdown")

    # Run the agent with a given task and stream the response.
    async for message in countdown_agent.on_messages_stream([], CancellationToken()):
        if isinstance(message, Response):
            print(message.chat_message)
        else:
            print(message)


# Use asyncio.run(run_countdown_agent()) when running in a script.
await run_countdown_agent()
3...
2...
1...
Done!

ArithmeticAgent#

在此示例中,我们创建一个代理类,该类可以对给定的整数执行简单的算术运算。然后,我们将在 SelectorGroupChat 中使用此代理类的不同实例,通过应用一系列算术运算将给定的整数转换为另一个整数。

ArithmeticAgent 类接受一个 operator_func,该函数接受一个整数并返回一个整数,即对该整数应用算术运算后得到的结果。在其 on_messages 方法中,它将 operator_func 应用于输入消息中的整数,并返回包含结果的响应。

from typing import Callable, Sequence

from autogen_agentchat.agents import BaseChatAgent
from autogen_agentchat.base import Response
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_agentchat.messages import BaseChatMessage
from autogen_agentchat.teams import SelectorGroupChat
from autogen_agentchat.ui import Console
from autogen_core import CancellationToken
from autogen_ext.models.openai import OpenAIChatCompletionClient


class ArithmeticAgent(BaseChatAgent):
    def __init__(self, name: str, description: str, operator_func: Callable[[int], int]) -> None:
        super().__init__(name, description=description)
        self._operator_func = operator_func
        self._message_history: List[BaseChatMessage] = []

    @property
    def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
        return (TextMessage,)

    async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
        # Update the message history.
        # NOTE: it is possible the messages is an empty list, which means the agent was selected previously.
        self._message_history.extend(messages)
        # Parse the number in the last message.
        assert isinstance(self._message_history[-1], TextMessage)
        number = int(self._message_history[-1].content)
        # Apply the operator function to the number.
        result = self._operator_func(number)
        # Create a new message with the result.
        response_message = TextMessage(content=str(result), source=self.name)
        # Update the message history.
        self._message_history.append(response_message)
        # Return the response.
        return Response(chat_message=response_message)

    async def on_reset(self, cancellation_token: CancellationToken) -> None:
        pass

注意

可能会使用空消息列表调用 on_messages 方法,这意味着先前已调用该代理,现在再次调用该代理,而没有来自调用方的任何新消息。因此,重要的是保留代理收到的先前消息的历史记录,并使用该历史记录来生成响应。

现在,我们可以创建一个包含 5 个 ArithmeticAgent 实例的 SelectorGroupChat

  • 一个将输入整数加 1,

  • 一个从输入整数中减去 1,

  • 一个将输入整数乘以 2,

  • 一个将输入整数除以 2 并向下舍入到最接近的整数,以及

  • 一个返回未更改的输入整数。

然后,我们使用这些代理创建一个 SelectorGroupChat,并设置适当的选择器设置

  • 允许连续选择同一代理以允许重复操作,以及

  • 自定义选择器提示以使模型的响应适应特定任务。

async def run_number_agents() -> None:
    # Create agents for number operations.
    add_agent = ArithmeticAgent("add_agent", "Adds 1 to the number.", lambda x: x + 1)
    multiply_agent = ArithmeticAgent("multiply_agent", "Multiplies the number by 2.", lambda x: x * 2)
    subtract_agent = ArithmeticAgent("subtract_agent", "Subtracts 1 from the number.", lambda x: x - 1)
    divide_agent = ArithmeticAgent("divide_agent", "Divides the number by 2 and rounds down.", lambda x: x // 2)
    identity_agent = ArithmeticAgent("identity_agent", "Returns the number as is.", lambda x: x)

    # The termination condition is to stop after 10 messages.
    termination_condition = MaxMessageTermination(10)

    # Create a selector group chat.
    selector_group_chat = SelectorGroupChat(
        [add_agent, multiply_agent, subtract_agent, divide_agent, identity_agent],
        model_client=OpenAIChatCompletionClient(model="gpt-4o"),
        termination_condition=termination_condition,
        allow_repeated_speaker=True,  # Allow the same agent to speak multiple times, necessary for this task.
        selector_prompt=(
            "Available roles:\n{roles}\nTheir job descriptions:\n{participants}\n"
            "Current conversation history:\n{history}\n"
            "Please select the most appropriate role for the next message, and only return the role name."
        ),
    )

    # Run the selector group chat with a given task and stream the response.
    task: List[BaseChatMessage] = [
        TextMessage(content="Apply the operations to turn the given number into 25.", source="user"),
        TextMessage(content="10", source="user"),
    ]
    stream = selector_group_chat.run_stream(task=task)
    await Console(stream)


# Use asyncio.run(run_number_agents()) when running in a script.
await run_number_agents()
---------- user ----------
Apply the operations to turn the given number into 25.
---------- user ----------
10
---------- multiply_agent ----------
20
---------- add_agent ----------
21
---------- multiply_agent ----------
42
---------- divide_agent ----------
21
---------- add_agent ----------
22
---------- add_agent ----------
23
---------- add_agent ----------
24
---------- add_agent ----------
25
---------- Summary ----------
Number of messages: 10
Finish reason: Maximum number of messages 10 reached, current message count: 10
Total prompt tokens: 0
Total completion tokens: 0
Duration: 2.40 seconds

从输出中,我们可以看到代理已通过选择按顺序应用算术运算的适当代理,成功地将输入整数从 10 转换为 25。

在自定义代理中使用自定义模型客户端#

AgentChat 中 AssistantAgent 预设的关键功能之一是,它采用 model_client 参数,并且可以在响应消息时使用它。但是,在某些情况下,您可能希望您的代理使用当前不支持的自定义模型客户端(请参阅受支持的模型客户端)或自定义模型行为。

您可以使用实现您的自定义模型客户端的自定义代理来完成此操作。

在下面的示例中,我们将逐步介绍一个自定义代理的示例,该代理直接使用 Google Gemini SDK 来响应消息。

注意:您需要安装 Google Gemini SDK 才能运行此示例。您可以使用以下命令进行安装

pip install google-genai
# !pip install google-genai
import os
from typing import AsyncGenerator, Sequence

from autogen_agentchat.agents import BaseChatAgent
from autogen_agentchat.base import Response
from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage
from autogen_core import CancellationToken
from autogen_core.model_context import UnboundedChatCompletionContext
from autogen_core.models import AssistantMessage, RequestUsage, UserMessage
from google import genai
from google.genai import types


class GeminiAssistantAgent(BaseChatAgent):
    def __init__(
        self,
        name: str,
        description: str = "An agent that provides assistance with ability to use tools.",
        model: str = "gemini-1.5-flash-002",
        api_key: str = os.environ["GEMINI_API_KEY"],
        system_message: str
        | None = "You are a helpful assistant that can respond to messages. Reply with TERMINATE when the task has been completed.",
    ):
        super().__init__(name=name, description=description)
        self._model_context = UnboundedChatCompletionContext()
        self._model_client = genai.Client(api_key=api_key)
        self._system_message = system_message
        self._model = model

    @property
    def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
        return (TextMessage,)

    async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
        final_response = None
        async for message in self.on_messages_stream(messages, cancellation_token):
            if isinstance(message, Response):
                final_response = message

        if final_response is None:
            raise AssertionError("The stream should have returned the final result.")

        return final_response

    async def on_messages_stream(
        self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
    ) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
        # Add messages to the model context
        for msg in messages:
            await self._model_context.add_message(msg.to_model_message())

        # Get conversation history
        history = [
            (msg.source if hasattr(msg, "source") else "system")
            + ": "
            + (msg.content if isinstance(msg.content, str) else "")
            + "\n"
            for msg in await self._model_context.get_messages()
        ]
        # Generate response using Gemini
        response = self._model_client.models.generate_content(
            model=self._model,
            contents=f"History: {history}\nGiven the history, please provide a response",
            config=types.GenerateContentConfig(
                system_instruction=self._system_message,
                temperature=0.3,
            ),
        )

        # Create usage metadata
        usage = RequestUsage(
            prompt_tokens=response.usage_metadata.prompt_token_count,
            completion_tokens=response.usage_metadata.candidates_token_count,
        )

        # Add response to model context
        await self._model_context.add_message(AssistantMessage(content=response.text, source=self.name))

        # Yield the final response
        yield Response(
            chat_message=TextMessage(content=response.text, source=self.name, models_usage=usage),
            inner_messages=[],
        )

    async def on_reset(self, cancellation_token: CancellationToken) -> None:
        """Reset the assistant by clearing the model context."""
        await self._model_context.clear()
gemini_assistant = GeminiAssistantAgent("gemini_assistant")
await Console(gemini_assistant.run_stream(task="What is the capital of New York?"))
---------- user ----------
What is the capital of New York?
---------- gemini_assistant ----------
Albany
TERMINATE
TaskResult(messages=[TextMessage(source='user', models_usage=None, content='What is the capital of New York?', type='TextMessage'), TextMessage(source='gemini_assistant', models_usage=RequestUsage(prompt_tokens=46, completion_tokens=5), content='Albany\nTERMINATE\n', type='TextMessage')], stop_reason=None)

在上面的示例中,我们选择提供 modelapi_keysystem_message 作为参数 - 您可以选择提供您使用的模型客户端所需的或适合您的应用程序设计的任何其他参数。

现在,让我们探讨如何在 AgentChat 中将此自定义代理用作团队的一部分。

from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import TextMentionTermination
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_agentchat.ui import Console

model_client = OpenAIChatCompletionClient(model="gpt-4o-mini")

# Create the primary agent.
primary_agent = AssistantAgent(
    "primary",
    model_client=model_client,
    system_message="You are a helpful AI assistant.",
)

# Create a critic agent based on our new GeminiAssistantAgent.
gemini_critic_agent = GeminiAssistantAgent(
    "gemini_critic",
    system_message="Provide constructive feedback. Respond with 'APPROVE' to when your feedbacks are addressed.",
)


# Define a termination condition that stops the task if the critic approves or after 10 messages.
termination = TextMentionTermination("APPROVE") | MaxMessageTermination(10)

# Create a team with the primary and critic agents.
team = RoundRobinGroupChat([primary_agent, gemini_critic_agent], termination_condition=termination)

await Console(team.run_stream(task="Write a Haiku poem with 4 lines about the fall season."))
await model_client.close()
---------- user ----------
Write a Haiku poem with 4 lines about the fall season.
---------- primary ----------
Crimson leaves cascade,  
Whispering winds sing of change,  
Chill wraps the fading,  
Nature's quilt, rich and warm.
---------- gemini_critic ----------
The poem is good, but it has four lines instead of three.  A haiku must have three lines with a 5-7-5 syllable structure.  The content is evocative of autumn, but the form is incorrect.  Please revise to adhere to the haiku's syllable structure.

---------- primary ----------
Thank you for your feedback! Here’s a revised haiku that follows the 5-7-5 syllable structure:

Crimson leaves drift down,  
Chill winds whisper through the gold,  
Autumn’s breath is near.
---------- gemini_critic ----------
The revised haiku is much improved.  It correctly follows the 5-7-5 syllable structure and maintains the evocative imagery of autumn.  APPROVE
TaskResult(messages=[TextMessage(source='user', models_usage=None, content='Write a Haiku poem with 4 lines about the fall season.', type='TextMessage'), TextMessage(source='primary', models_usage=RequestUsage(prompt_tokens=33, completion_tokens=31), content="Crimson leaves cascade,  \nWhispering winds sing of change,  \nChill wraps the fading,  \nNature's quilt, rich and warm.", type='TextMessage'), TextMessage(source='gemini_critic', models_usage=RequestUsage(prompt_tokens=86, completion_tokens=60), content="The poem is good, but it has four lines instead of three.  A haiku must have three lines with a 5-7-5 syllable structure.  The content is evocative of autumn, but the form is incorrect.  Please revise to adhere to the haiku's syllable structure.\n", type='TextMessage'), TextMessage(source='primary', models_usage=RequestUsage(prompt_tokens=141, completion_tokens=49), content='Thank you for your feedback! Here’s a revised haiku that follows the 5-7-5 syllable structure:\n\nCrimson leaves drift down,  \nChill winds whisper through the gold,  \nAutumn’s breath is near.', type='TextMessage'), TextMessage(source='gemini_critic', models_usage=RequestUsage(prompt_tokens=211, completion_tokens=32), content='The revised haiku is much improved.  It correctly follows the 5-7-5 syllable structure and maintains the evocative imagery of autumn.  APPROVE\n', type='TextMessage')], stop_reason="Text 'APPROVE' mentioned")

在上面的部分中,我们展示了几个非常重要的概念

  • 我们开发了一个使用 Google Gemini SDK 响应消息的自定义代理。

  • 我们表明,只要它继承自 BaseChatAgent,此自定义代理就可以用作更广泛的 AgentChat 生态系统的一部分 - 在这种情况下,作为 RoundRobinGroupChat 中的参与者。

使自定义代理具有声明性#

Autogen 提供了一个 Component 接口,用于使组件的配置可序列化为声明性格式。这对于保存和加载配置,以及与他人共享配置很有用。

我们通过继承 Component 类并实现 _from_config_to_config 方法来实现此目的。可以使用 dump_component 方法将声明性类序列化为 JSON 格式,并可以使用 load_component 方法从 JSON 格式反序列化。

import os
from typing import AsyncGenerator, Sequence

from autogen_agentchat.agents import BaseChatAgent
from autogen_agentchat.base import Response
from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage
from autogen_core import CancellationToken, Component
from pydantic import BaseModel
from typing_extensions import Self


class GeminiAssistantAgentConfig(BaseModel):
    name: str
    description: str = "An agent that provides assistance with ability to use tools."
    model: str = "gemini-1.5-flash-002"
    system_message: str | None = None


class GeminiAssistantAgent(BaseChatAgent, Component[GeminiAssistantAgentConfig]):  # type: ignore[no-redef]
    component_config_schema = GeminiAssistantAgentConfig
    # component_provider_override = "mypackage.agents.GeminiAssistantAgent"

    def __init__(
        self,
        name: str,
        description: str = "An agent that provides assistance with ability to use tools.",
        model: str = "gemini-1.5-flash-002",
        api_key: str = os.environ["GEMINI_API_KEY"],
        system_message: str
        | None = "You are a helpful assistant that can respond to messages. Reply with TERMINATE when the task has been completed.",
    ):
        super().__init__(name=name, description=description)
        self._model_context = UnboundedChatCompletionContext()
        self._model_client = genai.Client(api_key=api_key)
        self._system_message = system_message
        self._model = model

    @property
    def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
        return (TextMessage,)

    async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
        final_response = None
        async for message in self.on_messages_stream(messages, cancellation_token):
            if isinstance(message, Response):
                final_response = message

        if final_response is None:
            raise AssertionError("The stream should have returned the final result.")

        return final_response

    async def on_messages_stream(
        self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
    ) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
        # Add messages to the model context
        for msg in messages:
            await self._model_context.add_message(msg.to_model_message())

        # Get conversation history
        history = [
            (msg.source if hasattr(msg, "source") else "system")
            + ": "
            + (msg.content if isinstance(msg.content, str) else "")
            + "\n"
            for msg in await self._model_context.get_messages()
        ]

        # Generate response using Gemini
        response = self._model_client.models.generate_content(
            model=self._model,
            contents=f"History: {history}\nGiven the history, please provide a response",
            config=types.GenerateContentConfig(
                system_instruction=self._system_message,
                temperature=0.3,
            ),
        )

        # Create usage metadata
        usage = RequestUsage(
            prompt_tokens=response.usage_metadata.prompt_token_count,
            completion_tokens=response.usage_metadata.candidates_token_count,
        )

        # Add response to model context
        await self._model_context.add_message(AssistantMessage(content=response.text, source=self.name))

        # Yield the final response
        yield Response(
            chat_message=TextMessage(content=response.text, source=self.name, models_usage=usage),
            inner_messages=[],
        )

    async def on_reset(self, cancellation_token: CancellationToken) -> None:
        """Reset the assistant by clearing the model context."""
        await self._model_context.clear()

    @classmethod
    def _from_config(cls, config: GeminiAssistantAgentConfig) -> Self:
        return cls(
            name=config.name, description=config.description, model=config.model, system_message=config.system_message
        )

    def _to_config(self) -> GeminiAssistantAgentConfig:
        return GeminiAssistantAgentConfig(
            name=self.name,
            description=self.description,
            model=self._model,
            system_message=self._system_message,
        )

现在我们已经实现了所需的方法,我们现在可以将自定义代理加载并转储到 JSON 格式以及从 JSON 格式加载。

注意:您应该将 component_provider_override 类变量设置为包含自定义代理类的模块的完整路径,例如(mypackage.agents.GeminiAssistantAgent)。load_component 方法使用此变量来确定如何实例化该类。

gemini_assistant = GeminiAssistantAgent("gemini_assistant")
config = gemini_assistant.dump_component()
print(config.model_dump_json(indent=2))
loaded_agent = GeminiAssistantAgent.load_component(config)
print(loaded_agent)
{
  "provider": "__main__.GeminiAssistantAgent",
  "component_type": "agent",
  "version": 1,
  "component_version": 1,
  "description": null,
  "label": "GeminiAssistantAgent",
  "config": {
    "name": "gemini_assistant",
    "description": "An agent that provides assistance with ability to use tools.",
    "model": "gemini-1.5-flash-002",
    "system_message": "You are a helpful assistant that can respond to messages. Reply with TERMINATE when the task has been completed."
  }
}
<__main__.GeminiAssistantAgent object at 0x11a5c5a90>

后续步骤#

到目前为止,我们已经了解了如何创建自定义代理、向代理添加自定义模型客户端以及使自定义代理具有声明性。可以通过以下几种方式扩展此基本示例