"""
LangChain 1.0 - Middleware Basics (中间件基础)
==============================================
本模块重点讲解:
1. 什么是中间件(Middleware)
2. before_model 和 after_model 钩子
3. 自定义中间件的创建
4. 多个中间件的组合
5. 内置中间件的使用
"""
import os
import sys
from dotenv import load_dotenv
from langchain.chat_models import init_chat_model
from langchain.agents import create_agent
from langchain_core.tools import tool
from langchain.agents.middleware import AgentMiddleware
from langgraph.checkpoint.memory import InMemorySaver
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from init_model import get_chat_model
chat_model = get_chat_model()
@tool
def get_weather(city: str) -> str:
"""
获取指定城市的天气信息
参数:
city (str): 要查询天气的城市名称 例如 "北京"
返回:
str: 包含天气信息的字符串,格式为 "城市 天气描述 温度"
"""
# 模拟天气数据
weather_data = {
"北京": "晴天,温度 15°C,空气质量良好",
"上海": "多云,温度 18°C,有轻微雾霾",
"深圳": "阴天,温度 22°C,可能有小雨",
"成都": "小雨,温度 12°C,湿度较高",
}
# 检查城市是否在数据中
if city not in weather_data:
return f"城市 {city} 的天气信息未找到"
# 返回天气信息
return weather_data.get(city, f"抱歉,暂时没有{city}的天气数据")
# ============================================================================
# 示例 1:最简单的中间件
"""
关键规则
1. 必须继承 AgentMiddleware ← 这个固定
2. 方法名固定 (before_model, after_model) ← 这个固定
3. 类名随意 ← 这个不固定
LangGraph 只看:
- 是否继承 AgentMiddleware?
- 是否有 before_model / after_model 方法?
"""
import logging
from typing import Any, Dict, Optional
# 配置 logger
logger = logging.getLogger(__name__)
# ============================================================================
class LoggingMiddleware(AgentMiddleware):
"""
日志中间件 - 记录每次模型调用
"""
def before_model(
self, state: Dict[str, Any], runtime: Any
) -> Optional[Dict[str, Any]]:
"""模型调用前"""
messages = state.get("messages", [])
logger.info(f"[Middleware] Before Model Call | Message Count: {len(messages)}")
# 这里可以添加更多上下文信息,例如 runtime.request_id
return None # 继续正常流程
def after_model(
self, state: Dict[str, Any], runtime: Any
) -> Optional[Dict[str, Any]]:
"""模型响应后"""
messages = state.get("messages", [])
if not messages:
logger.warning(
"[Middleware] After Model Call | No messages found in state."
)
return None
last_message = messages[-1]
msg_type = last_message.__class__.__name__
logger.info(f"[Middleware] After Model Call | Response Type: {msg_type}")
# 如果是特定类型的消息,可以提取内容摘要记录
# if hasattr(last_message, 'content'):
# logger.debug(f"Response Content Preview: {str(last_message.content)[:100]}...")
return None # 不修改状态
# ============================================================================
# 示例 2:修改状态的中间件
# ============================================================================
class CallCounterMiddleware(AgentMiddleware):
"""
计数中间件 - 统计模型调用次数
在中间件内部维护计数器(简单版本)
"""
def __init__(self):
super().__init__()
self.count = 0 # 简单计数器
def after_model(self, state, runtime):
"""模型响应后,增加计数"""
self.count += 1
print(f"\n[计数器] 模型调用次数: {self.count}")
return None # 不修改 state
# ============================================================================
# 示例 3:消息修剪中间件
# ============================================================================
class MessageTrimmerMiddleware(AgentMiddleware):
"""
消息修剪中间件 - 限制消息数量
before_model 修改消息列表
注意:需要配合无 checkpointer 使用,否则历史会被恢复
"""
def __init__(self, max_messages=5):
super().__init__()
self.max_messages = max_messages
self.trimmed_count = 0 # 统计修剪次数
def before_model(self, state, runtime):
"""模型调用前,修剪消息"""
messages = state.get("messages", [])
if len(messages) > self.max_messages:
# 保留最近的 N 条消息
trimmed_messages = messages[-self.max_messages :]
self.trimmed_count += 1
print(
f"\n[修剪] 消息从 {len(messages)} 条减少到 {len(trimmed_messages)} 条 (第{self.trimmed_count}次修剪)"
)
return {"messages": trimmed_messages}
return None
# ============================================================================
# 示例 4:输出验证中间件
# ============================================================================
class OutputValidationMiddleware(AgentMiddleware):
"""
输出验证中间件 - 检查响应长度
after_model 验证输出
"""
def __init__(self, max_length=100):
super().__init__()
self.max_length = max_length
def after_model(self, state, runtime):
"""模型响应后,验证输出"""
messages = state.get("messages", [])
if not messages:
return None
last_message = messages[-1]
content = getattr(last_message, "content", "")
if len(content) > self.max_length:
print(
f"\n[警告] 响应���长 ({len(content)} 字符),已截断到 {self.max_length}"
)
# 这里可以实现截断或重试逻辑
return None
# ============================================================================
# 示例 5:条件跳转(高级)
# ============================================================================
class MaxCallsMiddleware(AgentMiddleware):
"""
最大调用限制中间件
通过抛出异常来阻止模型调用(更可靠的方式)
"""
def __init__(self, max_calls=3):
super().__init__()
self.max_calls = max_calls
self.count = 0 # 简单计数器
def before_model(self, state, runtime):
"""检查调用次数,超过限制则抛出异常"""
if self.count >= self.max_calls:
print(f"\n[限制] 已达到最大调用次数 {self.max_calls},停止调用")
# 抛出自定义异常来阻止继续执行
raise ValueError(f"已达到最大调用次数限制: {self.max_calls}")
print(f"[限制] 当前调用次数: {self.count}/{self.max_calls}")
return None
def after_model(self, state, runtime):
"""增加计数"""
self.count += 1
print("次数+1")
return None
# ============================================================================
# 示例 6:内置中间件使用
# ============================================================================
from langchain.agents.middleware import SummarizationMiddleware
def demo1():
"""
示例 1:最简单的中间件
"""
# 创建代理时添加中间件
agent = create_agent(
model=chat_model,
tools=[get_weather],
checkpointer=InMemorySaver(),
middleware=[
LoggingMiddleware(),
CallCounterMiddleware(),
MessageTrimmerMiddleware(),
SummarizationMiddleware( # 内置中间件
model="groq:llama-3.3-70b-versatile",
max_tokens_before_summary=200, # 超过 200 token 就摘要
),
], # 添加日志中间件
)
# 2. 构造 config 对象
config = {"configurable": {"thread_id": "demo1"}}
response = agent.invoke(
{"messages": [{"role": "user", "content": "你好,今天北京的天气怎么样"}]},
config,
)
print(f"Agent: {response['messages'][-1].content}")
# 结果
# [中间件] before_model: 准备调用模型
# [中间件] 当前消息数: 1
# [中间件] after_model: 模型已响应
# [中间件] 响应类型: AIMessage
# [中间件] before_model: 准备调用模型
# [中间件] 当前消息数: 3
# [中间件] after_model: 模型已响应
# [中间件] 响应类型: AIMessage
# Agent: 今天北京的天气是晴天,温度 15°C,空气质量良好。
if __name__ == "__main__":
demo1()