Middleware Basics.py

# Middleware(中间件)= Agent 执行过程中的钩子函数
# 在 LangChain 1.0 中,中间件是处理 Agent 生命周期的标准方式。
# 基本用法
# 创建自定义中间件

import sys
import os

from openai import chat

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from langchain.tools import tool

from init_model import get_chat_model
from langgraph.checkpoint.memory import InMemorySaver
from langchain.agents import create_agent
from langgraph.checkpoint.sqlite import SqliteSaver
from langchain.agents.middleware import AgentMiddleware


chat_model = get_chat_model()


class MyMiddleware(AgentMiddleware):
    def before_model(self, state, runtime):
        """模型调用前执行"""
        print("准备调用模型")
        return None  # 返回 None 表示继续正常流程

    def after_model(self, state, runtime):
        """模型响应后执行"""
        print("模型已响应")
        return None  # 返回 None 表示不修改状态


def demo1():
    # 使用中间件
    agent = create_agent(model=chat_model, tools=[], middleware=[MyMiddleware()])
    result = agent.invoke({"messages": [{"role": "user", "content": "你好"}]})
    print(result["messages"][-1].content)


# 准备调用模型
# 模型已响应
# 你好!👋 很高兴见到你!
# 有什么我可以帮你的吗?无论是回答问题、创作内容、查阅资料,还是随便聊 聊,我都在这里哦~


# 核心钩子方法
# 1. before_model(模型调用前)
def before_model(self, state, runtime):
    """
    在模型调用前执行

    返回值:
    - None: 继续正常流程
    - dict: 更新状态(如 {"messages": [...]})
    - {"jump_to": "..."}: 跳过正常流程
    """
    messages = state.get("messages", [])
    print(f"当前消息数: {len(messages)}")
    return None


# 用途:
#     消息修剪(trim messages)
#     PII 脱敏
#     输入验证
#     条件路由


# 1. before_model: 模型调用前的“守门员”
# 执行时机:
# 在 Agent 收集完所有上下文信息,准备将请求发送给大语言模型(LLM)之前立即执行。
# 核心参数:
#     state: 当前累积的对话状态(包含历史消息 messages、变量、用户输入等)。
#     runtime: 运行时环境(包含配置、API Key、执行元数据等)。
# 主要功能场景:
# 1. 动态提示词注入 (Prompt Injection):
#     根据当前状态,动态在 messages 列表头部插入 System Prompt 或 Few-shot 示例。
#     例:检测到用户情绪愤怒时,插入“请保持冷静和同理心”的系统指令。
# 2.上下文裁剪/压缩 (Context Pruning):
#     如果 state['messages'] 超过了模型的 Token 限制,在此处截断旧消息或调用摘要模型进行压缩。
# 3. 安全拦截 (Safety Guardrails):
#     检查用户输入是否包含敏感词、攻击性内容或越狱指令。如果违规,可以直接返回跳转指令,阻止请求发送给 LLM。
# 4.路由决策 (Routing):
#     分析用户意图,决定是否需要跳过当前模型,直接跳转到另一个特定的节点(Node)或工作流。
# 5.工具预填充:
#     在某些架构中,可以预先计算某些工具的结果并放入上下文,让模型“看到”这些结果。


# 2. after_model(模型响应后)
def after_model(self, state, runtime):
    """
    在模型响应后执行

    返回值:
    - None: 不修改状态
    - dict: 更新状态
    """
    # 统计调用次数
    count = state.get("call_count", 0)
    return {"call_count": count + 1}


# 用途:
#     输出验证
#     格式化响应
#     统计信息
#     状态更新


# 2. after_model: 模型调用后的“质检员”
# 执行时机:
# 在 LLM 返回生成结果(Completion),但尚未将该结果正式存入状态或返回给用户之后立即执行。
# 核心参数:
#     state: 包含模型刚刚生成的原始响应(通常在 messages 的最后一条,或者在单独的 response 字段中)。
#     runtime: 运行时环境。
#     (部分框架还会传入 response 对象,包含 token 用量、延迟等元数据)
# 主要功能场景:
# 1.输出格式化与清洗:
#     移除模型生成的多余标记(如 Thought:, Action: 等中间思维链,如果不需要展示给用户)。
#     修复 JSON 格式错误(如果模型本应输出 JSON 但漏了括号)。
# 2.幻觉检测与事实核查 (Hallucination Check):
#     调用外部知识库或搜索工具,验证模型回答的事实准确性。如果不准确,可以标记为需要重试。
# 3.合规性审查 (Compliance):
#     检查模型输出是否包含偏见、泄露隐私或违反公司政策的内容。
# 4.结构化提取 (Structured Extraction):
#     如果模型输出了非结构化的文本,在此处将其解析为结构化数据(如提取订单号、日期),并存入 state 供后续节点使用。
# 5.自动重试逻辑 (Self-Correction):
#     如果检测到模型输出不符合预期(例如没有调用应有的工具),可以修改 state 并指示框架重新调用模型(通过返回特定的状态标志或抛出重试信号)。

# 返回值的作用
# 返回 None
# def before_model(self, state, runtime):
#     print("日志记录")
#     return None  # 不做任何修改,继续流程

# 返回字典(更新状态)
# def after_model(self, state, runtime):
#     count = state.get("count", 0)
#     return {"count": count + 1}  # 更新状态中的 count

# 返回 jump_to(控制流程)
# def before_model(self, state, runtime):
#     if state.get("count", 0) > 10:
#         return {"jump_to": "__end__"}  # 跳过模型,直接结束
#     return None

# jump_to 目标:

# "__end__" - 结束 Agent
# "tools" - 跳到工具节点
# 其他自定义节点


# 执行顺序(重要!)
# agent = create_agent(
#     model=model,
#     middleware=[Middleware1(), Middleware2(), Middleware3()]
# )

# 执行流程:

# 1. Middleware1.before_model   ↓ 正序
# 2. Middleware2.before_model   ↓
# 3. Middleware3.before_model   ↓

#    [模型调用]

# 6. Middleware3.after_model    ↑ 逆序
# 5. Middleware2.after_model    ↑
# 4. Middleware1.after_model    ↑
# 类似洋葱模型:外层先进后出

# 用户请求
#   ↓
# [ Middleware1.before ]  <-- 外层入口
#   ↓
#   [ Middleware2.before ]
#     ↓
#     [ Middleware3.before ] <-- 内层入口
#       ↓
#       [ 🤖 模型调用 (Model) ]
#       ↓
#     [ Middleware3.after ]  <-- 内层出口 (最先处理模型结果)
#     ↑
#   [ Middleware2.after ]
#   ↑
# [ Middleware1.after ]    <-- 外层出口 (最后处理结果)
#   ↑
# 返回给用户


# 实际应用
class TestMiddleware:
    # before_model 示例:添加情感安抚指令
    def before_model(self, state, runtime):
        last_msg = state["messages"][-1]

        # 验证用户身份:如果用户不是授权用户,直接返回跳转指令
        if last_msg["role"] != "user":
            return {"jump_to": "unauthorized_node"}

        # 安全检查:如果包含攻击性词汇,直接返回跳转指令
        if "混蛋" in last_msg["content"]:
            return {"jump_to": "safety_block_node"}

        # 2. 动态注入:如果检测到用户很着急,注入安抚指令
        if "着急" in last_msg["content"]:
            system_prompt = {
                "role": "system",
                "content": "【动态指令】用户非常着急,请优先安抚情绪并快速给出方案。",
            }
            # 插入到 system prompt 之后,历史记录之前
            new_messages = [state["messages"][0], system_prompt] + state["messages"][1:]
            return {"messages": new_messages}
        return None  # 正常流程

    # after_model 示例:确保输出包含工单号
    def after_model(self, state, runtime):
        # 获取模型最后生成的消息
        response_content = state["messages"][-1].content
        if "投诉" in state.get("intent", ""):
            if "工单号" not in response_content:
                response_content += " 请包含工单号,方便我们定位问题。"
        return None  # 接受输出


# 1. 日志中间件
class LogMiddleware:
    def before_model(self, state, runtime):
        print("日志记录:开始模型调用")
        return None  # 不做任何修改,继续流程

    def after_model(self, state, runtime):
        print("日志记录:模型调用结束")
        return None  # 不做任何修改,继续流程


# 2. 计数中间件
class CallCounterMiddleware(AgentMiddleware):
    def after_model(self, state, runtime):
        count = state.get("model_call_count", 0)
        return {"model_call_count": count + 1}


# # 需要 checkpointer 来保存自定义状态
# agent = create_agent(
#     model=model,
#     middleware=[CallCounterMiddleware()],
#     checkpointer=InMemorySaver()
# )


# 3. 消息修剪中间件
class MessageTrimmerMiddleware(AgentMiddleware):
    def __init__(self, max_messages=5):
        super().__init__()
        self.max_messages = max_messages

    def before_model(self, state, runtime):
        messages = state.get("messages", [])
        if len(messages) > self.max_messages:
            # 只保留最近的 N 条消息
            return {"messages": messages[-self.max_messages :]}
        return None


# 4. 输出验证中间件
class OutputValidatorMiddleware(AgentMiddleware):
    def after_model(self, state, runtime):
        messages = state["messages"]
        last_msg = messages[-1]
        if last_msg["role"] == "assistant":
            # 验证模型输出是否符合要求
            if "工单号" not in last_msg["content"]:
                # 触发跳转指令
                return {"jump_to": "missing_order_id_node"}
        return None


# 5. 限流中间件
class MaxCallsMiddleware(AgentMiddleware):
    def __init__(self, max_calls=10):
        super().__init__()
        self.max_calls = max_calls

    def before_model(self, state, runtime):
        count = state.get("model_call_count", 0)
        if count >= self.max_calls:
            # 触发跳转指令
            return {"jump_to": "rate_limit_node"}
        return None

    def after_model(self, state, runtime):
        count = state.get("call_count", 0)
        return {"call_count": count + 1}


# 内置中间件
# SummarizationMiddleware(自动摘要)
from langchain.agents.middleware import SummarizationMiddleware


def demo1():
    agent = create_agent(
        model=chat_model,
        middleware=[
            SummarizationMiddleware(
                model="groq:llama-3.1-8b-instant",  # 可用便宜模型
                max_tokens_before_summary=500,
            )
        ],
        checkpointer=InMemorySaver(),
    )


# 作用:

# 消息超过 token 限制时自动摘要
# 保留最近消息 + 旧消息摘要
# 详见 08_context_management 章节


# HumanInTheLoopMiddleware(人工审核)
from langchain.agents.middleware import HumanInTheLoopMiddleware


def demo2():
    agent = create_agent(
        model=chat_model,
        middleware=[
            HumanInTheLoopMiddleware(
                interrupt_on={"send_email": True}  # 调用此工具前暂停
            )
        ],
        checkpointer=InMemorySaver(),
    )


# PIIMiddleware(敏感信息处理)
from langchain.agents.middleware import PIIMiddleware


def demo3():
    agent = create_agent(
        model=chat_model,
        middleware=[
            PIIMiddleware("email", strategy="redact"),  # 邮箱脱敏
            PIIMiddleware("phone_number", strategy="block"),  # 电话拦截
            # 输出阶段:防止 LLM 幻觉生成假的或泄露真的身份证号
            PIIMiddleware("id_number", strategy="redact", direction="output"),
        ],
    )


# graph LR
#     User[用户输入] --> PII_Check{PIIMiddleware 检查}

#     PII_Check -- 发现 Email --> Redact[执行脱敏 redact]
#     Redact --> LLM[发送给大模型]

#     PII_Check -- 发现 手机号 --> Block[执行拦截 block]
#     Block --> Error[返回错误/拒绝服务]

#     PII_Check -- 无敏感信息 --> LLM

# 常见问题
# 不能直接访问。before_model 和 after_model 只在模型节点执行。
# 如果需要拦截工具调用,使用 wrap_tool_call(高级特性)。

# 2. 多个中间件的顺序重要吗?
# 非常重要!

# middleware=[
#     TrimmerMiddleware(),     # 1. 先修剪消息
#     SummarizationMiddleware(), # 2. 再摘要
#     LoggingMiddleware()      # 3. 最后记录日志
# ]
# before_model 按列表顺序执行
# after_model 按列表逆序执行

# 3. 修改状态需要 checkpointer 吗?
# 自定义状态需要,messages 不需要:

# # 不需要 checkpointer(messages 自动保存)
# def after_model(self, state, runtime):
#     return {"messages": [...]}

# # 需要 checkpointer(自定义字段)
# def after_model(self, state, runtime):
#     return {"my_custom_field": 123}


# 4. 能在中间件里调用另一个模型吗
# 可以,但要小心:
# class ValidationMiddleware(AgentMiddleware):
#     def __init__(self):
#         self.validator_model = init_chat_model(...)

#     def after_model(self, state, runtime):
#         # 用另一个模型验证输出
#         last_msg = state['messages'][-1]
#         validation_result = self.validator_model.invoke(...)
#         return None


# 最佳实践
# # 1. 生产环境推荐配置
# agent = create_agent(
#     model=model,
#     tools=[...],
#     middleware=[
#         MessageTrimmerMiddleware(max_messages=20),  # 限制消息数
#         SummarizationMiddleware(model=..., max_tokens=2000), # 自动摘要
#         LoggingMiddleware(),  # 日志记录
#     ],
#     checkpointer=SqliteSaver.from_conn_string("...")
# )

# # 2. 开发环境
# agent = create_agent(
#     model=model,
#     tools=[...],
#     middleware=[
#         LoggingMiddleware(),  # 只要日志
#     ]
# )

# # 3. 测试环境
# agent = create_agent(
#     model=model,
#     tools=[...],
#     middleware=[
#         MaxCallsMiddleware(max_calls=5),  # 防止测试费用爆炸
#     ]
# )


if __name__ == "__main__":
    demo1()
添加新评论