Middleware Basics -main.py

"""
LangChain 1.0 - Middleware Basics (中间件基础)
==============================================

本模块重点讲解:
1. 什么是中间件(Middleware)
2. before_model 和 after_model 钩子
3. 自定义中间件的创建
4. 多个中间件的组合
5. 内置中间件的使用
"""

import os
import sys
from dotenv import load_dotenv
from langchain.chat_models import init_chat_model
from langchain.agents import create_agent
from langchain_core.tools import tool
from langchain.agents.middleware import AgentMiddleware
from langgraph.checkpoint.memory import InMemorySaver

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from init_model import get_chat_model


chat_model = get_chat_model()


@tool
def get_weather(city: str) -> str:
    """
    获取指定城市的天气信息
    参数:
        city (str): 要查询天气的城市名称 例如 "北京"

    返回:
        str: 包含天气信息的字符串,格式为 "城市 天气描述 温度"

    """
    # 模拟天气数据
    weather_data = {
        "北京": "晴天,温度 15°C,空气质量良好",
        "上海": "多云,温度 18°C,有轻微雾霾",
        "深圳": "阴天,温度 22°C,可能有小雨",
        "成都": "小雨,温度 12°C,湿度较高",
    }
    # 检查城市是否在数据中

    if city not in weather_data:
        return f"城市 {city} 的天气信息未找到"

    # 返回天气信息
    return weather_data.get(city, f"抱歉,暂时没有{city}的天气数据")


# ============================================================================
# 示例 1:最简单的中间件
"""
  关键规则

  1. 必须继承 AgentMiddleware ← 这个固定
  2. 方法名固定 (before_model, after_model) ← 这个固定
  3. 类名随意 ← 这个不固定

  LangGraph 只看:
  - 是否继承 AgentMiddleware?
  - 是否有 before_model / after_model 方法?
  """
import logging
from typing import Any, Dict, Optional

# 配置 logger
logger = logging.getLogger(__name__)


# ============================================================================
class LoggingMiddleware(AgentMiddleware):
    """
    日志中间件 - 记录每次模型调用
    """

    def before_model(
        self, state: Dict[str, Any], runtime: Any
    ) -> Optional[Dict[str, Any]]:
        """模型调用前"""
        messages = state.get("messages", [])
        logger.info(f"[Middleware] Before Model Call | Message Count: {len(messages)}")
        # 这里可以添加更多上下文信息,例如 runtime.request_id
        return None  # 继续正常流程

    def after_model(
        self, state: Dict[str, Any], runtime: Any
    ) -> Optional[Dict[str, Any]]:
        """模型响应后"""
        messages = state.get("messages", [])

        if not messages:
            logger.warning(
                "[Middleware] After Model Call | No messages found in state."
            )
            return None

        last_message = messages[-1]
        msg_type = last_message.__class__.__name__

        logger.info(f"[Middleware] After Model Call | Response Type: {msg_type}")

        # 如果是特定类型的消息,可以提取内容摘要记录
        # if hasattr(last_message, 'content'):
        #     logger.debug(f"Response Content Preview: {str(last_message.content)[:100]}...")

        return None  # 不修改状态


# ============================================================================
# 示例 2:修改状态的中间件
# ============================================================================
class CallCounterMiddleware(AgentMiddleware):
    """
    计数中间件 - 统计模型调用次数

    在中间件内部维护计数器(简单版本)

    """

    def __init__(self):
        super().__init__()
        self.count = 0  # 简单计数器

    def after_model(self, state, runtime):
        """模型响应后,增加计数"""
        self.count += 1
        print(f"\n[计数器] 模型调用次数: {self.count}")
        return None  # 不修改 state


# ============================================================================
# 示例 3:消息修剪中间件
# ============================================================================
class MessageTrimmerMiddleware(AgentMiddleware):
    """
    消息修剪中间件 - 限制消息数量

    before_model 修改消息列表
    注意:需要配合无 checkpointer 使用,否则历史会被恢复
    """

    def __init__(self, max_messages=5):
        super().__init__()
        self.max_messages = max_messages
        self.trimmed_count = 0  # 统计修剪次数

    def before_model(self, state, runtime):
        """模型调用前,修剪消息"""
        messages = state.get("messages", [])

        if len(messages) > self.max_messages:
            # 保留最近的 N 条消息
            trimmed_messages = messages[-self.max_messages :]
            self.trimmed_count += 1
            print(
                f"\n[修剪] 消息从 {len(messages)} 条减少到 {len(trimmed_messages)} 条 (第{self.trimmed_count}次修剪)"
            )
            return {"messages": trimmed_messages}

        return None


# ============================================================================
# 示例 4:输出验证中间件
# ============================================================================
class OutputValidationMiddleware(AgentMiddleware):
    """
    输出验证中间件 - 检查响应长度

    after_model 验证输出
    """

    def __init__(self, max_length=100):
        super().__init__()
        self.max_length = max_length

    def after_model(self, state, runtime):
        """模型响应后,验证输出"""
        messages = state.get("messages", [])
        if not messages:
            return None

        last_message = messages[-1]
        content = getattr(last_message, "content", "")

        if len(content) > self.max_length:
            print(
                f"\n[警告] 响应���长 ({len(content)} 字符),已截断到 {self.max_length}"
            )
            # 这里可以实现截断或重试逻辑

        return None


# ============================================================================
# 示例 5:条件跳转(高级)
# ============================================================================
class MaxCallsMiddleware(AgentMiddleware):
    """
    最大调用限制中间件

    通过抛出异常来阻止模型调用(更可靠的方式)
    """

    def __init__(self, max_calls=3):
        super().__init__()
        self.max_calls = max_calls
        self.count = 0  # 简单计数器

    def before_model(self, state, runtime):
        """检查调用次数,超过限制则抛出异常"""
        if self.count >= self.max_calls:
            print(f"\n[限制] 已达到最大调用次数 {self.max_calls},停止调用")
            # 抛出自定义异常来阻止继续执行
            raise ValueError(f"已达到最大调用次数限制: {self.max_calls}")

        print(f"[限制] 当前调用次数: {self.count}/{self.max_calls}")
        return None

    def after_model(self, state, runtime):
        """增加计数"""
        self.count += 1
        print("次数+1")
        return None


# ============================================================================
# 示例 6:内置中间件使用
# ============================================================================
from langchain.agents.middleware import SummarizationMiddleware


def demo1():
    """
    示例 1:最简单的中间件
    """
    # 创建代理时添加中间件
    agent = create_agent(
        model=chat_model,
        tools=[get_weather],
        checkpointer=InMemorySaver(),
        middleware=[
            LoggingMiddleware(),
            CallCounterMiddleware(),
            MessageTrimmerMiddleware(),
            SummarizationMiddleware(  # 内置中间件
                model="groq:llama-3.3-70b-versatile",
                max_tokens_before_summary=200,  # 超过 200 token 就摘要
            ),
        ],  # 添加日志中间件
    )

    # 2. 构造 config 对象
    config = {"configurable": {"thread_id": "demo1"}}

    response = agent.invoke(
        {"messages": [{"role": "user", "content": "你好,今天北京的天气怎么样"}]},
        config,
    )
    print(f"Agent: {response['messages'][-1].content}")
    # 结果
    # [中间件] before_model: 准备调用模型
    # [中间件] 当前消息数: 1
    # [中间件] after_model: 模型已响应
    # [中间件] 响应类型: AIMessage

    # [中间件] before_model: 准备调用模型
    # [中间件] 当前消息数: 3
    # [中间件] after_model: 模型已响应
    # [中间件] 响应类型: AIMessage
    # Agent: 今天北京的天气是晴天,温度 15°C,空气质量良好。


if __name__ == "__main__":
    demo1()
    
添加新评论