Checkpointing - main.py

"""
LangChain 1.0 - Checkpointing (检查点持久化)
=========================================

本模块重点讲解:
1. SqliteSaver - SQLite 持久化(LangGraph 提供)
2. 与 InMemorySaver 的区别
3. 跨进程、跨重启的对话持久化
4. 实际应用场景


1. invoke 前:LangGraph 自动调用 checkpointer.get(thread_id="user_123")
    - 查询数据库,读取该 thread_id 的历史消息
    - 如果是第一次,返回空列表
  2. invoke 中:Agent 处理时会看到完整历史(默认确实会全部读取,)
  state = {
      "messages": [历史消息1, 历史消息2, 新消息]  # 自动合并
  }
  3. invoke 后:LangGraph 自动调用 checkpointer.put(thread_id, state)
    - 将新的完整状态写入数据库
    - 数据库存储:(thread_id, timestamp, messages)
"""

import sys
import os

from openai import chat

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from langchain.tools import tool

from init_model import get_chat_model

from langgraph.checkpoint.memory import InMemorySaver
from langchain.agents import create_agent
from langgraph.checkpoint.sqlite import SqliteSaver


chat_model = get_chat_model()

# 持久化数据库路径
db_path = "checkpoints.sqlite"


# ============================================================================
# SqliteSaver 持久化
# ============================================================================
def example_sqlite_saver():
    """
    示例2:使用 SqliteSaver 实现持久化

    关键:
    1. SqliteSaver.from_conn_string("sqlite:///path/to/db.sqlite")
    2. 对话持久化到 SQLite 文件
    3. 程序重启后仍可恢复
    """

    with SqliteSaver.from_conn_string(db_path) as checkpointer:
        agent = create_agent(
            model=chat_model, tools=[], checkpointer=checkpointer  # 使用 SQLite
        )
        config = {"configurable": {"thread_id": "user_zhangsan"}}
        # 第一次运行
        response = agent.invoke(
            {"messages": [{"role": "user", "content": "你好,我叫张三"}]}, config
        )
        print(response["messages"][-1].content)


def example_sqlite_saver_2():
    """
    示例3:使用 SqliteSaver 实现持久化

    关键:
    1. SqliteSaver.from_conn_string("sqlite:///path/to/db.sqlite")
    2. 对话持久化到 SQLite 文件
    3. 程序重启后仍可恢复
    """
    with SqliteSaver.from_conn_string(db_path) as checkpointer:
        agent = create_agent(
            model=chat_model, tools=[], checkpointer=checkpointer  # 使用 SQLite
        )
        config = {"configurable": {"thread_id": "user_zhangsan"}}
        # 第一次运行
        reponse = agent.invoke(
            {"messages": [{"role": "user", "content": "你好,我是谁"}]}, config
        )
        print(reponse["messages"][-1].content)


# ============================================================================
# 带工具的持久化 Agent
# ============================================================================


@tool
def get_order_status(order_id: str) -> str:
    """查询订单状态"""
    orders = {"12345": "已发货,预计明天送达", "67890": "配送中,今天下午送达"}
    return orders.get(order_id, "订单不存在")


"""
示例5:工具调用 + 持久化

Agent 记住工具调用历史
"""


def example_3_tools_with_persistence():
    """
    示例5:工具调用 + 持久化

    Agent 记住工具调用历史
    """
    print("\n" + "=" * 70)
    print("示例 5:工具调用 + 持久化")
    print("=" * 70)

    db_path = "tools.sqlite"

    with SqliteSaver.from_conn_string(db_path) as checkpointer:
        agent = create_agent(
            model=chat_model, tools=[get_order_status], checkpointer=checkpointer
        )

        config = {"configurable": {"thread_id": "customer_001"}}

        print("\n第一轮:查询订单")
        print("客户: 查询订单 12345 的状态")
        response1 = agent.invoke(
            {"messages": [{"role": "user", "content": "查询订单 12345 的状态"}]},
            config=config,
        )
        print(f"Agent: {response1['messages'][-1].content}")

        print("\n第二轮:询问之前的查询结果")
        print("客户: 我的订单什么时候到?")
        response2 = agent.invoke(
            {"messages": [{"role": "user", "content": "我的订单什么时候到?"}]},
            config=config,
        )
        print(f"Agent: {response2['messages'][-1].content}")

        print("\n关键点:")
        print("  - Agent 记住了订单 12345 的查询结果")
        print("  - 工具调用历史也被持久化")
        print("  - 无需重复调用工具")


# ============================================================================
# 实际应用 - 客服系统
# ============================================================================
def example_4_customer_service():
    """
    示例6:实际应用 - 持久化客服系统

    场景:客户可能分多次咨询,需要记住历史
    """
    print("\n" + "=" * 70)
    print("示例 6:实际应用 - 持久化客服系统")
    print("=" * 70)

    db_path = "customer_service.sqlite"

    with SqliteSaver.from_conn_string(db_path) as checkpointer:
        agent = create_agent(
            model=chat_model,
            tools=[get_order_status],
            system_prompt="""你是客服助手。
                    特点:
                    - 记住客户之前的咨询
                    - 友好、耐心
                    - 使用工具查询订单""",
            checkpointer=checkpointer,
        )

        customer_id = "customer_zhang"
        config = {"configurable": {"thread_id": customer_id}}

        print("\n第一次咨询(今天上午):")
        conversations_morning = ["你好,我想查询订单", "订单号是 12345"]

        for msg in conversations_morning:
            print(f"\n客户: {msg}")
            response = agent.invoke(
                {"messages": [{"role": "user", "content": msg}]}, config=config
            )
            print(f"客服: {response['messages'][-1].content}")

        print("\n" + "-" * 70)
        print("[几个小时后...]")
        print("-" * 70)

        print("\n第二次咨询(今天下午):")
        print("\n客户: 我的订单到哪了?")
        response = agent.invoke(
            {"messages": [{"role": "user", "content": "我的订单到哪了?"}]},
            config=config,
        )
        print(f"客服: {response['messages'][-1].content}")

        print("\n关键点:")
        print("  - 客户无需重复订单号")
        print("  - 系统记住了上午的咨询")
        print("  - 即使客服系统重启也不影响")
        print("  - 生产级应用的标准做法")


# ============================================================================
# SqliteSaver 参数说明
# ============================================================================
#     SqliteSaver 创建方式:

# 1. from_conn_string + with 语句(推荐)
# with SqliteSaver.from_conn_string("checkpoints.sqlite") as checkpointer:
#     agent = create_agent(model=model, checkpointer=checkpointer)
#     agent.invoke(...)

# - 自动管理连接和资源
# - 支持相对路径和绝对路径
# - 最简单安全的方式
# - 确保正确释放数据库连接
# - 注意:直接传文件路径,不要加 sqlite:/// 前缀

# 2. 使用 sqlite3.connect(高级)
# import sqlite3
# conn = sqlite3.connect("checkpoints.sqlite")
# checkpointer = SqliteSaver(conn)

# - 需要手动管理连接
# - 适合需要自定义连接参数的场景

# 数据库文件路径:
# - 相对路径:checkpoints.sqlite(当前目录)
# - 绝对路径:C:/Users/xxx/data/checkpoints.sqlite(Windows)
# - 内存数据库::memory:(测试用,程序退出即丢失)

# 最佳实践:
# ✅ 始终使用 with 语句管理 SqliteSaver
# ✅ 直接传文件路径,不要加 sqlite:/// 前缀
# ✅ 生产环境:使用绝对路径
# ✅ 开发测试:使用相对路径
# ✅ 单元测试:使用 :memory:
# ✅ 定期备份数据库文件
#     """


if __name__ == "__main__":
    example_sqlite_saver()
    example_sqlite_saver_2()
添加新评论