Skip to content

ChatAgent 重构规划

重构日期:2026-03-18 状态:✅ 已完成 目标:将 ChatAgent 重构为 node 编排架构,与 ImageAgent 保持一致

1. 背景与目标

1.1 当前架构问题

当前 ChatAgent 只有一个 call_model node,业务逻辑散落在 ChatStreamService 中:

ChatStreamService (服务层):
├── 检查并迁移会话
├── 处理图片格式 (process_mixed_images)
├── 保存用户消息到 message 表
├── 调用 ChatAgent.chat_stream()
└── 保存助手消息到 message 表

ChatAgent (Agent 层):
└── call_model node (仅 LLM 调用)

问题

  1. 职责不清:预处理逻辑应该在 Agent 层而非服务层
  2. 不易扩展:添加新功能需要修改服务层代码
  3. 与 ImageAgent 架构不一致

1.2 目标架构

参考 ImageAgent 的 node 编排模式:

2. 目录结构

agent/chat/
├── __init__.py          # 模块导出
├── agent.py             # ChatAgent 实现
├── config.py            # ChatConfig 配置类
├── graph.py             # ChatGraphBuilder
├── state.py             # ChatState 状态定义 (新增)
├── message.py           # 消息工具函数 (保留)
└── nodes/               # LangGraph 节点 (新增)
    ├── __init__.py
    ├── preprocess.py    # 预处理节点
    ├── generate.py      # 生成节点
    └── postprocess.py   # 后处理节点

3. 状态定义

3.1 ChatState

python
from typing import TypedDict, Optional, List, Dict, Any, Annotated
from langchain_core.messages import BaseMessage
from langgraph.graph import add_messages

class ChatState(TypedDict, total=False):
    """ChatAgent 状态

    使用 LangGraph 的 add_messages reducer 自动合并消息。
    """
    # === 输入 ===
    prompt: str                                    # 用户提示词
    model: str                                     # 模型名称
    images: Optional[List[str]]                    # 原始图片(base64/COS key/URL)
    system_prompt: Optional[str]                   # 系统提示词
    temperature: float                             # 温度参数
    user_id: int                                   # 用户 ID
    thread_id: str                                 # 线程 ID(用于 checkpoint)
    conversation_id: Optional[str]                 # 会话 ID(用于 message 表)
    user_info: Optional[Dict[str, Any]]            # 用户信息(LangSmith)
    provider: Optional[str]                        # 提供商

    # === 预处理结果 ===
    preprocessed_images: Optional[List[str]]       # base64 格式,给 LLM 用
    db_image_keys: Optional[List[str]]             # COS key,给数据库用
    validated_prompt: str                          # 验证后的提示词

    # === 生成结果 ===
    messages: Annotated[List[BaseMessage], add_messages]  # 消息历史

    # === 后处理结果 ===
    response_content: str                          # 响应文本
    token_usage: Dict[str, int]                    # token 使用量

    # === 错误处理 ===
    error: Optional[str]

4. 节点设计

4.1 preprocess_node

职责:准备所有输入数据,验证并转换格式

输入

  • prompt: 用户提示词
  • images: 原始图片列表
  • user_id: 用户 ID
  • conversation_id: 会话 ID
  • thread_id: 线程 ID

输出

  • validated_prompt: 验证后的提示词
  • preprocessed_images: base64 格式图片(给 LLM)
  • db_image_keys: COS key(给数据库)
  • error: 错误信息(如果有)

代码位置agent/chat/nodes/preprocess.py

python
def preprocess_node(state: ChatState) -> Dict[str, Any]:
    """预处理节点"""
    result: Dict[str, Any] = {}

    # 1. 验证提示词
    prompt = state.get("prompt", "")
    if not prompt or not prompt.strip():
        return {"error": "提示词不能为空", "validated_prompt": ""}

    result["validated_prompt"] = prompt.strip()

    # 2. 处理图片格式
    images = state.get("images")
    if images:
        from services.image_processor import ImageProcessor
        processor = ImageProcessor()
        user_id = state.get("user_id", 0)

        llm_images, db_keys = processor.process_mixed_images(images, user_id)
        result["preprocessed_images"] = llm_images
        result["db_image_keys"] = db_keys

    return result

4.2 generate_node

职责:调用 LLM 生成响应

输入

  • validated_prompt: 验证后的提示词
  • preprocessed_images: base64 格式图片
  • messages: 历史消息(从 checkpoint 加载)

输出

  • messages: 更新后的消息列表(包含新的响应)

代码位置agent/chat/nodes/generate.py

python
def create_generate_node(config: ChatConfig):
    """创建生成节点(工厂函数)"""

    def generate_node(state: ChatState) -> Dict[str, Any]:
        # 1. 构建多模态消息
        content = build_multimodal_content(
            state["validated_prompt"],
            state.get("preprocessed_images")
        )

        # 2. 获取 LLM
        llm = ChatOpenAI(
            model=state.get("model") or config.default_model,
            api_key=config.api_key,
            base_url=config.base_url,
            temperature=state.get("temperature", 0.7)
        )

        # 3. 构建并调用 chain
        messages = state.get("messages", [])
        prompt_template = ChatPromptTemplate.from_messages([
            ("system", state.get("system_prompt") or config.default_system_prompt),
            MessagesPlaceholder("messages"),
        ])

        chain = prompt_template | llm
        response = chain.invoke({
            "messages": messages + [HumanMessage(content=content)]
        })

        return {"messages": [response]}

    return generate_node

4.3 postprocess_node

职责:提取结果,准备返回数据

输入

  • messages: 包含响应的消息列表

输出

  • response_content: 响应文本
  • token_usage: token 使用量

代码位置agent/chat/nodes/postprocess.py

python
def postprocess_node(state: ChatState) -> Dict[str, Any]:
    """后处理节点"""
    messages = state.get("messages", [])
    if not messages:
        return {"error": "没有生成响应"}

    # 获取最后一条消息(AI 响应)
    last_message = messages[-1]
    response_content = last_message.content if hasattr(last_message, 'content') else str(last_message)

    # 提取 token 使用量
    token_usage = extract_token_usage(last_message)

    return {
        "response_content": response_content,
        "token_usage": token_usage,
    }

5. Graph 构建

python
# agent/chat/graph.py

class ChatGraphBuilder:
    def build(self, checkpointer) -> StateGraph:
        """构建 StateGraph"""
        generate_node = create_generate_node(self.config)

        workflow = StateGraph(ChatState)

        # 添加节点
        workflow.add_node("preprocess", preprocess_node)
        workflow.add_node("generate", generate_node)
        workflow.add_node("postprocess", postprocess_node)

        # 定义边
        workflow.add_edge(START, "preprocess")
        workflow.add_conditional_edges("preprocess", self._route_on_error, {"next": "generate", "end": END})
        workflow.add_conditional_edges("generate", self._route_on_error, {"next": "postprocess", "end": END})
        workflow.add_edge("postprocess", END)

        return workflow.compile(checkpointer=checkpointer)

6. ChatAgent API 变化

6.1 新接口

python
class ChatAgent:
    def chat(
        self,
        thread_id: str,
        prompt: str,
        model: Optional[str] = None,
        images: Optional[List[str]] = None,
        system_prompt: Optional[str] = None,
        temperature: float = 0.7,
        user_id: int = 0,
        conversation_id: Optional[str] = None,
        user_info: Optional[Dict[str, Any]] = None,
    ) -> Dict[str, Any]:
        """同步对话调用

        Returns:
            {
                "success": bool,
                "content": str,           # 响应文本
                "db_image_keys": List[str],  # COS key,用于保存到 message 表
                "token_usage": dict,
                "error": Optional[str]
            }
        """
        pass

    def chat_stream(
        self,
        thread_id: str,
        prompt: str,
        # ... 同上参数
    ) -> Iterator[str]:
        """流式对话调用

        Yields:
            str: token 级别的响应片段
        """
        pass

6.2 返回值

python
{
    "success": True,
    "content": "AI 响应文本...",
    "db_image_keys": ["cos://uploads/images/..."],  # 用户上传图片的 COS key
    "token_usage": {
        "prompt_tokens": 100,
        "completion_tokens": 50,
        "total_tokens": 150
    }
}

7. ChatStreamService 变化

重构后,ChatStreamService 的职责更清晰:

python
class ChatStreamService:
    async def iterate_chunks(
        self,
        conversation_id: Optional[str],
        prompt: str,
        model: Optional[str],
        system_prompt: Optional[str],
        images: list,
        user_id: int,
        user_info: dict,
        thread_id: Optional[str],
        db: Session,
        provider: Optional[str] = None,
    ) -> AsyncIterator[Dict[str, Any]]:
        """迭代聊天响应块"""

        # 检查并迁移会话
        if conversation_id and thread_id:
            self._check_and_migrate(conversation_id, thread_id)

        # 调用 ChatAgent(图片处理在 Agent 内部完成)
        chat_agent = get_chat_agent()

        # ... 流式调用逻辑 ...

        # 保存用户消息(使用 Agent 返回的 db_image_keys)
        if conversation_id:
            MessageService.save_user_message(
                db, conversation_id, prompt,
                result["db_image_keys"],  # 来自 Agent
                model, provider
            )

        # 保存助手消息
        if conversation:
            MessageService.save_assistant_message(
                db, conversation, prompt,
                result["content"],  # 响应文本
                model, provider,
                result["token_usage"]
            )

8. 数据流对比

8.1 重构前

用户请求 (images: base64)


ChatStreamService

    ├─► process_mixed_images(images)
    │       → llm_images (base64), db_keys (COS key)

    ├─► save_user_message(db_keys)  ← 直接用 db_keys

    ├─► ChatAgent.chat_stream(images=llm_images)

    └─► save_assistant_message()

8.2 重构后

用户请求 (images: base64)


ChatStreamService

    └─► ChatAgent.chat_stream(images=base64)  ← 传原始图片


        preprocess_node

            ├─► process_mixed_images(images)
            │       → llm_images, db_keys

            └─► 返回 { llm_images, db_keys }


        generate_node(llm_images)

            └─► 调用 LLM


        postprocess_node

            └─► 返回 { content, db_keys, token_usage }


        返回给 ChatStreamService


        save_user_message(result["db_keys"])
        save_assistant_message(result["content"])

9. 实施步骤

  1. 创建 agent/chat/state.py:定义 ChatState
  2. 创建 agent/chat/nodes/:实现 preprocess、generate、postprocess 节点
  3. 重构 agent/chat/graph.py:改为多节点架构
  4. 修改 agent/chat/agent.py:适配新接口
  5. 修改 services/chat_stream_service.py:调用新接口

10. 注意事项

10.1 流式输出

流式输出需要在 generate_node 中特殊处理:

python
# 使用 stream_mode="messages" 实现流式
for chunk in app.stream(
    initial_state,
    config=config,
    stream_mode="messages"
):
    if isinstance(chunk, tuple) and len(chunk) == 2:
        message_chunk, _ = chunk
        if hasattr(message_chunk, 'content') and message_chunk.content:
            yield message_chunk.content

10.2 数据库操作位置

保存消息到 message 表仍然在 ChatStreamService 中进行,原因是:

  1. 需要数据库 session
  2. 需要在流式输出完成后才保存
  3. 保持 Agent 层与数据库解耦

Agent 只负责:

  • 处理图片格式
  • 调用 LLM
  • 返回处理后的数据(包括 db_image_keys

Service 负责:

  • 保存消息到数据库
  • 处理事务

11. 参考