10 - Middleware Basics (中间件基础)
核心概念
Middleware(中间件)= Agent 执行过程中的钩子函数
在 LangChain 1.0 中,中间件是处理 Agent 生命周期的标准方式。
基本用法
创建自定义中间件
from langchain.agents.middleware import AgentMiddleware
class MyMiddleware(AgentMiddleware):
def before_model(self, state, runtime):
"""模型调用前执行"""
print("准备调用模型")
return None # 返回 None 表示继续正常流程
def after_model(self, state, runtime):
"""模型响应后执行"""
print("模型已响应")
return None # 返回 None 表示不修改状态
# 使用中间件
agent = create_agent(
model=model,
tools=[],
middleware=[MyMiddleware()]
)核心钩子方法
1. before_model(模型调用前)
def before_model(self, state, runtime):
"""
在模型调用前执行
返回值:
- None: 继续正常流程
- dict: 更新状态(如 {"messages": [...]})
- {"jump_to": "..."}: 跳过正常流程
"""
messages = state.get('messages', [])
print(f"当前消息数: {len(messages)}")
return None用途:
- 消息修剪(trim messages)
- PII 脱敏
- 输入验证
- 条件路由
2. after_model(模型响应后)
def after_model(self, state, runtime):
"""
在模型响应后执行
返回值:
- None: 不修改状态
- dict: 更新状态
"""
# 统计调用次数
count = state.get("call_count", 0)
return {"call_count": count + 1}用途:
- 输出验证
- 格式化响应
- 统计信息
- 状态更新
返回值的作用
返回 None
def before_model(self, state, runtime):
print("日志记录")
return None # 不做任何修改,继续流程返回字典(更新状态)
def after_model(self, state, runtime):
count = state.get("count", 0)
return {"count": count + 1} # 更新状态中的 count返回 jump_to(控制流程)
def before_model(self, state, runtime):
if state.get("count", 0) > 10:
return {"jump_to": "__end__"} # 跳过模型,直接结束
return Nonejump_to 目标:
"__end__"- 结束 Agent"tools"- 跳到工具节点- 其他自定义节点
执行顺序(重要!)
agent = create_agent(
model=model,
middleware=[Middleware1(), Middleware2(), Middleware3()]
)执行流程:
1. Middleware1.before_model ↓ 正序
2. Middleware2.before_model ↓
3. Middleware3.before_model ↓
[模型调用]
6. Middleware3.after_model ↑ 逆序
5. Middleware2.after_model ↑
4. Middleware1.after_model ↑类似洋葱模型**:外层先进后出
用户请求
↓
[ Middleware1.before ] <-- 外层入口
↓
[ Middleware2.before ]
↓
[ Middleware3.before ] <-- 内层入口
↓
[ 🤖 模型调用 (Model) ]
↓
[ Middleware3.after ] <-- 内层出口 (最先处理模型结果)
↑
[ Middleware2.after ]
↑
[ Middleware1.after ] <-- 外层出口 (最后处理结果)
↑
返回给用户实际应用
1. 日志中间件
class LoggingMiddleware(AgentMiddleware):
def before_model(self, state, runtime):
print(f"[日志] 消息数: {len(state.get('messages', []))}")
return None
def after_model(self, state, runtime):
last_msg = state.get('messages', [])[-1]
print(f"[日志] 响应类型: {last_msg.__class__.__name__}")
return None2. 计数中间件
class CallCounterMiddleware(AgentMiddleware):
def after_model(self, state, runtime):
count = state.get("model_call_count", 0)
return {"model_call_count": count + 1}
# 需要 checkpointer 来保存自定义状态
agent = create_agent(
model=model,
middleware=[CallCounterMiddleware()],
checkpointer=InMemorySaver()
)3. 消息修剪中间件
class MessageTrimmerMiddleware(AgentMiddleware):
def __init__(self, max_messages=5):
super().__init__()
self.max_messages = max_messages
def before_model(self, state, runtime):
messages = state.get('messages', [])
if len(messages) > self.max_messages:
# 只保留最近的 N 条消息
return {"messages": messages[-self.max_messages:]}
return None4. 输出验证中间件
class OutputValidationMiddleware(AgentMiddleware):
def after_model(self, state, runtime):
last_msg = state.get('messages', [])[-1]
content = getattr(last_msg, 'content', '')
if len(content) > 1000:
print("[警告] 响应过长")
return None5. 限流中间件
class MaxCallsMiddleware(AgentMiddleware):
def __init__(self, max_calls=10):
super().__init__()
self.max_calls = max_calls
def before_model(self, state, runtime):
count = state.get("call_count", 0)
if count >= self.max_calls:
return {"jump_to": "__end__"} # 达到限制,直接结束
return None
def after_model(self, state, runtime):
count = state.get("call_count", 0)
return {"call_count": count + 1}内置中间件
SummarizationMiddleware(自动摘要)
from langchain.agents.middleware import SummarizationMiddleware
agent = create_agent(
model=model,
middleware=[
SummarizationMiddleware(
model="groq:llama-3.1-8b-instant", # 可用便宜模型
max_tokens_before_summary=500
)
],
checkpointer=InMemorySaver()
)作用:
- 消息超过 token 限制时自动摘要
- 保留最近消息 + 旧消息摘要
- 详见 08_context_management 章节
HumanInTheLoopMiddleware(人工审核)
from langchain.agents.middleware import HumanInTheLoopMiddleware
agent = create_agent(
model=model,
tools=[send_email],
middleware=[
HumanInTheLoopMiddleware(
interrupt_on={"send_email": True} # 调用此工具前暂停
)
]
)PIIMiddleware(敏感信息处理)
from langchain.agents.middleware import PIIMiddleware
agent = create_agent(
model=model,
middleware=[
PIIMiddleware("email", strategy="redact"), # 邮箱脱敏
PIIMiddleware("phone_number", strategy="block") # 电话拦截
]
)常见问题
1. 中间件能访问工具调用吗?
不能直接访问。before_model 和 after_model 只在模型节点执行。
如果需要拦截工具调用,使用 wrap_tool_call(高级特性)。
2. 多个中间件的顺序重要吗?
非常重要!
middleware=[
TrimmerMiddleware(), # 1. 先修剪消息
SummarizationMiddleware(), # 2. 再摘要
LoggingMiddleware() # 3. 最后记录日志
]before_model按列表顺序执行after_model按列表逆序执行
3. 修改状态需要 checkpointer 吗?
自定义状态需要,messages 不需要:
# 不需要 checkpointer(messages 自动保存)
def after_model(self, state, runtime):
return {"messages": [...]}
# 需要 checkpointer(自定义字段)
def after_model(self, state, runtime):
return {"my_custom_field": 123}4. 能在中间件里调用另一个模型吗?
可以,但要小心:
class ValidationMiddleware(AgentMiddleware):
def __init__(self):
self.validator_model = init_chat_model(...)
def after_model(self, state, runtime):
# 用另一个模型验证输出
last_msg = state['messages'][-1]
validation_result = self.validator_model.invoke(...)
return None最佳实践
# 1. 生产环境推荐配置
agent = create_agent(
model=model,
tools=[...],
middleware=[
MessageTrimmerMiddleware(max_messages=20), # 限制消息数
SummarizationMiddleware(model=..., max_tokens=2000), # 自动摘要
LoggingMiddleware(), # 日志记录
],
checkpointer=SqliteSaver.from_conn_string("...")
)
# 2. 开发环境
agent = create_agent(
model=model,
tools=[...],
middleware=[
LoggingMiddleware(), # 只要日志
]
)
# 3. 测试环境
agent = create_agent(
model=model,
tools=[...],
middleware=[
MaxCallsMiddleware(max_calls=5), # 防止测试费用爆炸
]
)核心要点
- 中间件 = Agent 生命周期钩子
- before_model - 模型调用前(正序执行)
- after_model - 模型响应后(逆序执行)
- 返回 None - 不修改状态
- 返回 dict - 更新状态
- 返回 {"jump_to": "..."} - 控制流程
- 顺序重要 - 类似洋葱模型
- 内置中间件 - SummarizationMiddleware 最常用
下一步
11_structured_output - 学习如何使用 Pydantic 获取结构化输出