# Middleware(中间件)= Agent 执行过程中的钩子函数
# 在 LangChain 1.0 中,中间件是处理 Agent 生命周期的标准方式。
# 基本用法
# 创建自定义中间件
import sys
import os
from openai import chat
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from langchain.tools import tool
from init_model import get_chat_model
from langgraph.checkpoint.memory import InMemorySaver
from langchain.agents import create_agent
from langgraph.checkpoint.sqlite import SqliteSaver
from langchain.agents.middleware import AgentMiddleware
chat_model = get_chat_model()
class MyMiddleware(AgentMiddleware):
def before_model(self, state, runtime):
"""模型调用前执行"""
print("准备调用模型")
return None # 返回 None 表示继续正常流程
def after_model(self, state, runtime):
"""模型响应后执行"""
print("模型已响应")
return None # 返回 None 表示不修改状态
def demo1():
# 使用中间件
agent = create_agent(model=chat_model, tools=[], middleware=[MyMiddleware()])
result = agent.invoke({"messages": [{"role": "user", "content": "你好"}]})
print(result["messages"][-1].content)
# 准备调用模型
# 模型已响应
# 你好!👋 很高兴见到你!
# 有什么我可以帮你的吗?无论是回答问题、创作内容、查阅资料,还是随便聊 聊,我都在这里哦~
# 核心钩子方法
# 1. before_model(模型调用前)
def before_model(self, state, runtime):
"""
在模型调用前执行
返回值:
- None: 继续正常流程
- dict: 更新状态(如 {"messages": [...]})
- {"jump_to": "..."}: 跳过正常流程
"""
messages = state.get("messages", [])
print(f"当前消息数: {len(messages)}")
return None
# 用途:
# 消息修剪(trim messages)
# PII 脱敏
# 输入验证
# 条件路由
# 1. before_model: 模型调用前的“守门员”
# 执行时机:
# 在 Agent 收集完所有上下文信息,准备将请求发送给大语言模型(LLM)之前立即执行。
# 核心参数:
# state: 当前累积的对话状态(包含历史消息 messages、变量、用户输入等)。
# runtime: 运行时环境(包含配置、API Key、执行元数据等)。
# 主要功能场景:
# 1. 动态提示词注入 (Prompt Injection):
# 根据当前状态,动态在 messages 列表头部插入 System Prompt 或 Few-shot 示例。
# 例:检测到用户情绪愤怒时,插入“请保持冷静和同理心”的系统指令。
# 2.上下文裁剪/压缩 (Context Pruning):
# 如果 state['messages'] 超过了模型的 Token 限制,在此处截断旧消息或调用摘要模型进行压缩。
# 3. 安全拦截 (Safety Guardrails):
# 检查用户输入是否包含敏感词、攻击性内容或越狱指令。如果违规,可以直接返回跳转指令,阻止请求发送给 LLM。
# 4.路由决策 (Routing):
# 分析用户意图,决定是否需要跳过当前模型,直接跳转到另一个特定的节点(Node)或工作流。
# 5.工具预填充:
# 在某些架构中,可以预先计算某些工具的结果并放入上下文,让模型“看到”这些结果。
# 2. after_model(模型响应后)
def after_model(self, state, runtime):
"""
在模型响应后执行
返回值:
- None: 不修改状态
- dict: 更新状态
"""
# 统计调用次数
count = state.get("call_count", 0)
return {"call_count": count + 1}
# 用途:
# 输出验证
# 格式化响应
# 统计信息
# 状态更新
# 2. after_model: 模型调用后的“质检员”
# 执行时机:
# 在 LLM 返回生成结果(Completion),但尚未将该结果正式存入状态或返回给用户之后立即执行。
# 核心参数:
# state: 包含模型刚刚生成的原始响应(通常在 messages 的最后一条,或者在单独的 response 字段中)。
# runtime: 运行时环境。
# (部分框架还会传入 response 对象,包含 token 用量、延迟等元数据)
# 主要功能场景:
# 1.输出格式化与清洗:
# 移除模型生成的多余标记(如 Thought:, Action: 等中间思维链,如果不需要展示给用户)。
# 修复 JSON 格式错误(如果模型本应输出 JSON 但漏了括号)。
# 2.幻觉检测与事实核查 (Hallucination Check):
# 调用外部知识库或搜索工具,验证模型回答的事实准确性。如果不准确,可以标记为需要重试。
# 3.合规性审查 (Compliance):
# 检查模型输出是否包含偏见、泄露隐私或违反公司政策的内容。
# 4.结构化提取 (Structured Extraction):
# 如果模型输出了非结构化的文本,在此处将其解析为结构化数据(如提取订单号、日期),并存入 state 供后续节点使用。
# 5.自动重试逻辑 (Self-Correction):
# 如果检测到模型输出不符合预期(例如没有调用应有的工具),可以修改 state 并指示框架重新调用模型(通过返回特定的状态标志或抛出重试信号)。
# 返回值的作用
# 返回 None
# def before_model(self, state, runtime):
# print("日志记录")
# return None # 不做任何修改,继续流程
# 返回字典(更新状态)
# def after_model(self, state, runtime):
# count = state.get("count", 0)
# return {"count": count + 1} # 更新状态中的 count
# 返回 jump_to(控制流程)
# def before_model(self, state, runtime):
# if state.get("count", 0) > 10:
# return {"jump_to": "__end__"} # 跳过模型,直接结束
# return None
# jump_to 目标:
# "__end__" - 结束 Agent
# "tools" - 跳到工具节点
# 其他自定义节点
# 执行顺序(重要!)
# agent = create_agent(
# model=model,
# middleware=[Middleware1(), Middleware2(), Middleware3()]
# )
# 执行流程:
# 1. Middleware1.before_model ↓ 正序
# 2. Middleware2.before_model ↓
# 3. Middleware3.before_model ↓
# [模型调用]
# 6. Middleware3.after_model ↑ 逆序
# 5. Middleware2.after_model ↑
# 4. Middleware1.after_model ↑
# 类似洋葱模型:外层先进后出
# 用户请求
# ↓
# [ Middleware1.before ] <-- 外层入口
# ↓
# [ Middleware2.before ]
# ↓
# [ Middleware3.before ] <-- 内层入口
# ↓
# [ 🤖 模型调用 (Model) ]
# ↓
# [ Middleware3.after ] <-- 内层出口 (最先处理模型结果)
# ↑
# [ Middleware2.after ]
# ↑
# [ Middleware1.after ] <-- 外层出口 (最后处理结果)
# ↑
# 返回给用户
# 实际应用
class TestMiddleware:
# before_model 示例:添加情感安抚指令
def before_model(self, state, runtime):
last_msg = state["messages"][-1]
# 验证用户身份:如果用户不是授权用户,直接返回跳转指令
if last_msg["role"] != "user":
return {"jump_to": "unauthorized_node"}
# 安全检查:如果包含攻击性词汇,直接返回跳转指令
if "混蛋" in last_msg["content"]:
return {"jump_to": "safety_block_node"}
# 2. 动态注入:如果检测到用户很着急,注入安抚指令
if "着急" in last_msg["content"]:
system_prompt = {
"role": "system",
"content": "【动态指令】用户非常着急,请优先安抚情绪并快速给出方案。",
}
# 插入到 system prompt 之后,历史记录之前
new_messages = [state["messages"][0], system_prompt] + state["messages"][1:]
return {"messages": new_messages}
return None # 正常流程
# after_model 示例:确保输出包含工单号
def after_model(self, state, runtime):
# 获取模型最后生成的消息
response_content = state["messages"][-1].content
if "投诉" in state.get("intent", ""):
if "工单号" not in response_content:
response_content += " 请包含工单号,方便我们定位问题。"
return None # 接受输出
# 1. 日志中间件
class LogMiddleware:
def before_model(self, state, runtime):
print("日志记录:开始模型调用")
return None # 不做任何修改,继续流程
def after_model(self, state, runtime):
print("日志记录:模型调用结束")
return None # 不做任何修改,继续流程
# 2. 计数中间件
class CallCounterMiddleware(AgentMiddleware):
def after_model(self, state, runtime):
count = state.get("model_call_count", 0)
return {"model_call_count": count + 1}
# # 需要 checkpointer 来保存自定义状态
# agent = create_agent(
# model=model,
# middleware=[CallCounterMiddleware()],
# checkpointer=InMemorySaver()
# )
# 3. 消息修剪中间件
class MessageTrimmerMiddleware(AgentMiddleware):
def __init__(self, max_messages=5):
super().__init__()
self.max_messages = max_messages
def before_model(self, state, runtime):
messages = state.get("messages", [])
if len(messages) > self.max_messages:
# 只保留最近的 N 条消息
return {"messages": messages[-self.max_messages :]}
return None
# 4. 输出验证中间件
class OutputValidatorMiddleware(AgentMiddleware):
def after_model(self, state, runtime):
messages = state["messages"]
last_msg = messages[-1]
if last_msg["role"] == "assistant":
# 验证模型输出是否符合要求
if "工单号" not in last_msg["content"]:
# 触发跳转指令
return {"jump_to": "missing_order_id_node"}
return None
# 5. 限流中间件
class MaxCallsMiddleware(AgentMiddleware):
def __init__(self, max_calls=10):
super().__init__()
self.max_calls = max_calls
def before_model(self, state, runtime):
count = state.get("model_call_count", 0)
if count >= self.max_calls:
# 触发跳转指令
return {"jump_to": "rate_limit_node"}
return None
def after_model(self, state, runtime):
count = state.get("call_count", 0)
return {"call_count": count + 1}
# 内置中间件
# SummarizationMiddleware(自动摘要)
from langchain.agents.middleware import SummarizationMiddleware
def demo1():
agent = create_agent(
model=chat_model,
middleware=[
SummarizationMiddleware(
model="groq:llama-3.1-8b-instant", # 可用便宜模型
max_tokens_before_summary=500,
)
],
checkpointer=InMemorySaver(),
)
# 作用:
# 消息超过 token 限制时自动摘要
# 保留最近消息 + 旧消息摘要
# 详见 08_context_management 章节
# HumanInTheLoopMiddleware(人工审核)
from langchain.agents.middleware import HumanInTheLoopMiddleware
def demo2():
agent = create_agent(
model=chat_model,
middleware=[
HumanInTheLoopMiddleware(
interrupt_on={"send_email": True} # 调用此工具前暂停
)
],
checkpointer=InMemorySaver(),
)
# PIIMiddleware(敏感信息处理)
from langchain.agents.middleware import PIIMiddleware
def demo3():
agent = create_agent(
model=chat_model,
middleware=[
PIIMiddleware("email", strategy="redact"), # 邮箱脱敏
PIIMiddleware("phone_number", strategy="block"), # 电话拦截
# 输出阶段:防止 LLM 幻觉生成假的或泄露真的身份证号
PIIMiddleware("id_number", strategy="redact", direction="output"),
],
)
# graph LR
# User[用户输入] --> PII_Check{PIIMiddleware 检查}
# PII_Check -- 发现 Email --> Redact[执行脱敏 redact]
# Redact --> LLM[发送给大模型]
# PII_Check -- 发现 手机号 --> Block[执行拦截 block]
# Block --> Error[返回错误/拒绝服务]
# PII_Check -- 无敏感信息 --> LLM
# 常见问题
# 不能直接访问。before_model 和 after_model 只在模型节点执行。
# 如果需要拦截工具调用,使用 wrap_tool_call(高级特性)。
# 2. 多个中间件的顺序重要吗?
# 非常重要!
# middleware=[
# TrimmerMiddleware(), # 1. 先修剪消息
# SummarizationMiddleware(), # 2. 再摘要
# LoggingMiddleware() # 3. 最后记录日志
# ]
# before_model 按列表顺序执行
# after_model 按列表逆序执行
# 3. 修改状态需要 checkpointer 吗?
# 自定义状态需要,messages 不需要:
# # 不需要 checkpointer(messages 自动保存)
# def after_model(self, state, runtime):
# return {"messages": [...]}
# # 需要 checkpointer(自定义字段)
# def after_model(self, state, runtime):
# return {"my_custom_field": 123}
# 4. 能在中间件里调用另一个模型吗
# 可以,但要小心:
# class ValidationMiddleware(AgentMiddleware):
# def __init__(self):
# self.validator_model = init_chat_model(...)
# def after_model(self, state, runtime):
# # 用另一个模型验证输出
# last_msg = state['messages'][-1]
# validation_result = self.validator_model.invoke(...)
# return None
# 最佳实践
# # 1. 生产环境推荐配置
# agent = create_agent(
# model=model,
# tools=[...],
# middleware=[
# MessageTrimmerMiddleware(max_messages=20), # 限制消息数
# SummarizationMiddleware(model=..., max_tokens=2000), # 自动摘要
# LoggingMiddleware(), # 日志记录
# ],
# checkpointer=SqliteSaver.from_conn_string("...")
# )
# # 2. 开发环境
# agent = create_agent(
# model=model,
# tools=[...],
# middleware=[
# LoggingMiddleware(), # 只要日志
# ]
# )
# # 3. 测试环境
# agent = create_agent(
# model=model,
# tools=[...],
# middleware=[
# MaxCallsMiddleware(max_calls=5), # 防止测试费用爆炸
# ]
# )
if __name__ == "__main__":
demo1()