第6周 国产大模型API调用学习指南

第6周 多国产大模型API调用 详细学习指南

一、知识点讲解

(一)文心一言API调用核心知识点

1. 知识点说明

文心一言API(ERNIE Bot API)是百度提供的大模型调用接口,可实现文本生成、摘要、多轮对话等功能,支持通过HTTP请求集成到各类应用中,适用于快速开发基于大模型的文本处理工具。

2. 核心参数详解

参数说明常用可选值及说明传参示例
prompt用户输入的提示词,用于引导模型生成内容无固定可选值,需根据业务场景设计(如摘要需求、对话需求)"请摘要以下济南政府工作报告内容:[报告文本]"
history多轮对话历史,用于维持上下文连贯性列表格式,每个元素为包含"role"和"content"的字典[{"role":"user","content":"济南2025年GDP目标是多少?"},{"role":"assistant","content":"XX亿元"}]
temperature控制生成内容的随机性,值越高越灵活,越低越严谨0.0~1.0,摘要场景建议0.2~0.4,创意场景0.6~0.80.3
top_p与temperature互补,控制生成内容的多样性,优先选择概率高的词汇0.0~1.0,一般不与temperature同时调整,摘要场景建议0.7~0.90.8
max_tokens控制生成内容的最大长度(含输入和输出 tokens)1~4096,根据需求调整,摘要场景建议512~10241024
stream是否开启流式输出,开启后逐字返回结果,提升交互体验true(开启)、false(关闭),命令行工具建议开启false

3. 典型用法示例


import requests
import json

# 文心一言API配置
API_KEY = "你的API_KEY"
SECRET_KEY = "你的SECRET_KEY"
TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token"
CHAT_URL = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot"

def get_access_token():
    """获取访问令牌(有效期30天,需缓存避免重复请求)"""
    params = {"grant_type": "client_credentials", "client_id": API_KEY, "client_secret": SECRET_KEY}
    response = requests.get(TOKEN_URL, params=params)
    return response.json().get("access_token")

def ernie_chat(prompt, history=None, temperature=0.3, max_tokens=1024):
    """
    文心一言对话函数
    :param prompt: 当前提示词
    :param history: 对话历史列表
    :param temperature: 随机性参数
    :param max_tokens: 最大长度
    :return: 模型生成结果
    """
    access_token = get_access_token()
    url = f"{CHAT_URL}?access_token={access_token}"
    headers = {"Content-Type": "application/json"}
    # 初始化对话历史
    history = history if history else []
    data = {
        "prompt": prompt,
        "history": history,
        "temperature": temperature,
        "max_tokens": max_tokens
    }
    response = requests.post(url, headers=headers, data=json.dumps(data))
    result = response.json()
    return result.get("result"), history + [{"role":"user","content":prompt}, {"role":"assistant","content":result.get("result")}]

# 调用示例
if __name__ == "__main__":
    prompt1 = "请介绍济南的支柱产业"
    res1, history1 = ernie_chat(prompt1)
    print("回答1:", res1)
    # 多轮对话,基于上一轮上下文
    prompt2 = "这些产业2025年的发展目标是什么?"
    res2, history2 = ernie_chat(prompt2, history1)
    print("回答2:", res2)

4. 常见陷阱提醒

  • 访问令牌未缓存:频繁请求令牌会触发接口限流,需将令牌缓存到文件或数据库,有效期内复用。
  • max_tokens设置过小:若输入文本过长,可能导致输出被截断,需预留足够 tokens 给输出内容。
  • 对话历史未清理:多轮对话历史累积过多会占用 tokens,需定期清理早期无关历史,仅保留关键上下文。
  • 未处理接口异常:网络波动、API额度不足会导致请求失败,需添加异常捕获和重试机制。

(二)通义千问API调用核心知识点

1. 知识点说明

通义千问API是阿里云提供的大模型接口,功能与文心一言类似,支持文本生成、摘要、多轮对话,接口设计更贴近开发者习惯,支持批量请求和成本控制相关参数。

2. 核心参数详解

参数说明常用可选值及说明传参示例
messages包含对话历史和当前请求的消息列表,替代文心一言的prompt和history列表格式,每个元素含"role"(user/assistant/system)和"content"[{"role":"user","content":"摘要济南政府工作报告"},{"role":"assistant","content":"好的,正在为您摘要"}]
temperature控制生成内容随机性,逻辑与文心一言一致0.0~1.0,摘要场景0.2~0.4,对话场景0.5~0.60.3
max_tokens生成内容的最大长度,仅计算输出tokens(与文心一言不同)1~8192,摘要场景建议512~10241024
top_k控制生成时选择词汇的范围,k值越小,生成内容越集中1~100,默认40,摘要场景建议20~3025
stop终止符,当模型生成到该内容时停止输出字符串或列表,可自定义终止内容["。", "?"]

3. 典型用法示例


import requests
import json

# 通义千问API配置
API_KEY = "你的API_KEY"
API_SECRET = "你的API_SECRET"
CHAT_URL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"

def qwen_chat(messages, temperature=0.3, max_tokens=1024, top_k=25):
    """
    通义千问对话函数
    :param messages: 消息列表(含历史)
    :param temperature: 随机性参数
    :param max_tokens: 最大输出长度
    :param top_k: 词汇选择范围
    :return: 模型生成结果
    """
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {API_KEY}"
    }
    data = {
        "model": "qwen-turbo",  # 通义千问轻量版模型
        "messages": messages,
        "temperature": temperature,
        "max_tokens": max_tokens,
        "top_k": top_k
    }
    response = requests.post(CHAT_URL, headers=headers, data=json.dumps(data))
    result = response.json()
    # 提取生成内容
    content = result["output"]["choices"][0]["message"]["content"]
    # 更新对话历史
    messages.append({"role":"assistant","content":content})
    return content, messages

# 调用示例
if __name__ == "__main__":
    # 初始化消息列表(含当前请求)
    messages1 = [{"role":"user","content":"请摘要济南政府工作报告核心内容"}]
    res1, messages2 = qwen_chat(messages1)
    print("回答1:", res1)
    # 多轮对话
    messages2.append({"role":"user","content":"报告中关于科技创新的内容有哪些?"})
    res2, messages3 = qwen_chat(messages2)
    print("回答2:", res2)

4. 常见陷阱提醒

  • 模型参数混淆:通义千问的max_tokens仅计算输出长度,与文心一言(含输入)不同,需注意额度控制。
  • Authorization格式错误:需严格按照"Bearer + 空格 + API_KEY"格式设置请求头,否则会返回权限错误。
  • 未指定模型版本:不同模型版本(如qwen-turbo、qwen-plus)的性能和成本不同,需根据需求选择,避免超额。
  • 批量请求未限速:通义千问支持批量请求,但频繁批量调用会触发限流,需控制请求频率。

(三)多模型适配与成本控制知识点

1. 知识点说明

多模型适配核心是通过封装统一接口,实现模型切换无代码改动;成本控制则通过日志记录API调用详情、限制tokens使用量、选择合适模型版本等方式降低开销。

2. 核心方法与参数

方法/参数说明常用可选值及说明使用示例
模型封装接口定义统一的调用方法,接收模型类型参数,内部分发到对应模型逻辑model_type:ernie(文心一言)、qwen(通义千问)def chat(model_type, messages, **kwargs): ...
tokens计算方法估算输入输出tokens数量,控制单条请求成本中文按1字≈1token,英文按1词≈1.3token估算def count_tokens(text): return len(text)
日志记录字段记录API调用关键信息,用于成本统计和问题排查模型类型、调用时间、tokens数量、耗时、成本、请求内容{"model":"ernie","time":"2025-10-01","tokens":512,"cost":0.01}

3. 典型用法示例(多模型适配+日志)


import json
import time
from datetime import datetime
import logging
# 导入前文定义的文心一言和通义千问函数
from ernie_api import ernie_chat
from qwen_api import qwen_chat

# 配置日志(写入文件,支持成本排序)
logging.basicConfig(
    filename="api_call_logs.log",
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)

def count_tokens(text):
    """估算tokens数量(中文简化版)"""
    return len(text.strip())

def calculate_cost(model_type, input_tokens, output_tokens):
    """计算单条请求成本(参考官方定价,实际以官网为准)"""
    cost_map = {
        "ernie": {"input": 0.00001, "output": 0.00002},  # 文心一言:输入0.01元/千token,输出0.02元/千token
        "qwen": {"input": 0.000008, "output": 0.000018}  # 通义千问:输入0.008元/千token,输出0.018元/千token
    }
    cost = (input_tokens * cost_map[model_type]["input"] + output_tokens * cost_map[model_type]["output"]) / 1000
    return round(cost, 4)

def unified_chat(model_type, prompt, history=None, **kwargs):
    """
    统一对话接口,支持多模型切换
    :param model_type: 模型类型(ernie/qwen)
    :param prompt: 当前提示词
    :param history: 对话历史
    :param kwargs: 其他模型参数(temperature、max_tokens等)
    :return: 生成结果、更新后历史、调用日志
    """
    start_time = time.time()
    history = history if history else []
    input_text = prompt + "".join([item["content"] for item in history])
    input_tokens = count_tokens(input_text)
    
    # 模型分发
    if model_type == "ernie":
        result, new_history = ernie_chat(prompt, history, **kwargs)
    elif model_type == "qwen":
        # 转换历史格式为通义千问的messages
        messages = [{"role":"user" if item["role"]=="user" else "assistant","content":item["content"]} for item in history]
        messages.append({"role":"user","content":prompt})
        result, messages = qwen_chat(messages, **kwargs)
        # 转换回统一历史格式
        new_history = [{"role":item["role"],"content":item["content"]} for item in messages]
    else:
        raise ValueError("不支持的模型类型,可选:ernie、qwen")
    
    # 计算输出tokens和成本
    output_tokens = count_tokens(result)
    cost = calculate_cost(model_type, input_tokens, output_tokens)
   耗时 = round(time.time() - start_time, 2)
    
    # 生成日志
    log_info = {
        "model_type": model_type,
        "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
        "input_tokens": input_tokens,
        "output_tokens": output_tokens,
        "total_tokens": input_tokens + output_tokens,
        "cost": cost,
        "duration": 耗时,
        "prompt": prompt[:50] + "..." if len(prompt) > 50 else prompt  # 截取提示词,避免日志过长
    }
    # 记录日志
    logging.info(json.dumps(log_info, ensure_ascii=False))
    return result, new_history, log_info

# 调用示例(模型切换无代码改动)
if __name__ == "__main__":
    # 使用文心一言
    res1, hist1, log1 = unified_chat("ernie", "摘要济南政府工作报告")
    print("文心一言结果:", res1)
    # 切换到通义千问,历史和参数可复用
    res2, hist2, log2 = unified_chat("qwen", "补充报告中的民生保障内容", hist1, temperature=0.3)
    print("通义千问结果:", res2)
    # 查看日志
    with open("api_call_logs.log", "r", encoding="utf-8") as f:
        print("日志内容:", f.readlines())

4. 常见陷阱提醒

  • 历史格式不统一:不同模型的对话历史格式不同,封装时需做好格式转换,否则会导致上下文丢失。
  • 成本计算偏差:简化版tokens估算与官方实际计算有差异,正式环境需调用官方tokens计算接口,避免成本超支。
  • 日志未异步写入:高频调用时同步写入日志会影响接口响应速度,需使用异步日志或队列批量写入。
  • 模型参数未适配:不同模型的参数默认值不同(如max_tokens范围),封装时需设置统一默认值,避免参数错误。

二、实战场景(FastApi实现多模型对话工具)

1. 实战需求

基于FastApi开发多模型对话接口,支持文心一言/通义千问切换、多轮对话、API调用日志记录、济南政府工作报告摘要功能,满足上下文连贯、模型无改动切换、日志完整的需求。

2. 完整示例代码


from fastapi import FastAPI, HTTPException, Query
from pydantic import BaseModel
import json
import logging
import time
from datetime import datetime
# 导入前文封装的模型调用函数
from model_unified import unified_chat, count_tokens, calculate_cost

# 初始化FastApi应用
app = FastAPI(title="多国产大模型对话接口", version="1.0")

# 配置日志(支持成本排序,后续可导入数据库)
logging.basicConfig(
    filename="api_call_logs.log",
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)

# 数据模型定义(请求体)
class ChatRequest(BaseModel):
    model_type: str = Query(default="ernie", description="模型类型,可选ernie(文心一言)、qwen(通义千问)")
    prompt: str = Query(description="用户提示词")
    history: list = Query(default=[], description="对话历史,格式:[{\"role\":\"user/assistant\",\"content\":\"内容\"}]")
    temperature: float = Query(default=0.3, ge=0.0, le=1.0, description="随机性参数")
    max_tokens: int = Query(default=1024, ge=1, le=4096, description="最大输出长度")

# 数据模型定义(响应体)
class ChatResponse(BaseModel):
    result: str
    history: list
    log: dict

# 读取济南政府工作报告(模拟文件存储,实际可从数据库读取)
def read_jinan_report():
    """读取济南政府工作报告文本"""
    with open("jinan_report.txt", "r", encoding="utf-8") as f:
        return f.read()

# 摘要接口(专门用于济南政府工作报告摘要)
@app.post("/jinan-report/summary", response_model=ChatResponse, description="济南政府工作报告摘要接口")
async def jinan_report_summary(
    model_type: str = Query(default="ernie", description="模型类型"),
    temperature: float = Query(default=0.2, ge=0.0, le=0.4, description="摘要场景建议0.2~0.4")
):
    try:
        # 读取报告内容
        report_content = read_jinan_report()
        # 构造摘要提示词(提升准确率)
        prompt = f"""请严格摘要以下济南政府工作报告内容,要求:
1. 涵盖核心数据、重点工作、发展目标三大板块;
2. 语言简洁,逻辑清晰,无冗余信息;
3. 字数控制在500字以内;
4. 准确还原原文含义,不得添加主观评价。
报告内容:{report_content}"""
        # 调用统一模型接口
        result, history, log = unified_chat(
            model_type=model_type,
            prompt=prompt,
            temperature=temperature,
            max_tokens=1024
        )
        # 验证摘要准确率(简化版:关键词匹配,实际可人工校验或用第三方工具)
        keywords = ["GDP", "科技创新", "民生保障", "产业升级", "生态保护"]
        match_count = sum(1 for keyword in keywords if keyword in result)
        if match_count < 3:
            logging.warning(f"摘要准确率不足,关键词匹配数:{match_count}")
        return ChatResponse(result=result, history=history, log=log)
    except Exception as e:
        logging.error(f"摘要接口异常:{str(e)}")
        raise HTTPException(status_code=500, detail=f"摘要失败:{str(e)}")

# 通用对话接口
@app.post("/chat", response_model=ChatResponse, description="多模型通用对话接口")
async def chat(request: ChatRequest):
    try:
        # 调用统一模型接口
        result, history, log = unified_chat(
            model_type=request.model_type,
            prompt=request.prompt,
            history=request.history,
            temperature=request.temperature,
            max_tokens=request.max_tokens
        )
        return ChatResponse(result=result, history=history, log=log)
    except ValueError as e:
        raise HTTPException(status_code=400, detail=str(e))
    except Exception as e:
        logging.error(f"对话接口异常:{str(e)}")
        raise HTTPException(status_code=500, detail=f"对话失败:{str(e)}")

# 日志查询接口(支持按成本排序)
@app.get("/logs", description="API调用日志查询接口")
async def get_logs(
    sort_by_cost: bool = Query(default=False, description="是否按成本排序(降序)"),
    limit: int = Query(default=10, ge=1, le=100, description="返回日志条数")
):
    try:
        # 读取日志文件
        with open("api_call_logs.log", "r", encoding="utf-8") as f:
            logs = f.readlines()
        # 解析日志(过滤无效日志)
        parsed_logs = []
        for log in logs:
            try:
                # 提取日志中的JSON内容
                log_json = json.loads(log.split(" - INFO - ")[-1].strip())
                parsed_logs.append(log_json)
            except:
                continue
        # 按成本排序
        if sort_by_cost:
            parsed_logs.sort(key=lambda x: x["cost"], reverse=True)
        # 限制返回条数
        return {"logs": parsed_logs[:limit], "total": len(parsed_logs)}
    except Exception as e:
        logging.error(f"日志查询异常:{str(e)}")
        raise HTTPException(status_code=500, detail=f"日志查询失败:{str(e)}")

# 启动应用(命令:uvicorn main:app --reload)
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

3. 实战技巧与参数总结

  • 请求参数校验:使用FastApi的Query和Pydantic模型做参数校验,避免非法参数传入(如temperature超出范围),提升接口稳定性。
  • 摘要提示词优化:通过明确指令(分点要求、字数限制、核心板块)提升摘要准确率,达到≥90%的目标,这是摘要功能的关键技巧。
  • 异常捕获分层:针对参数错误(400)和服务器异常(500)分别处理,返回清晰的错误信息,便于问题排查。
  • 日志设计技巧:日志包含tokens、成本、耗时等关键信息,支持按成本排序,便于后续成本分析和优化,同时截取长提示词避免日志冗余。
  • 模型参数适配:摘要场景将temperature限制在0.2~0.4,降低随机性,提升摘要准确性;对话场景可适当提高至0.5~0.6,增强交互性。

三、练习题(FastApi实现)

练习题1:多模型API调用限流接口

题目描述

基于FastApi实现多模型对话接口的限流功能,限制每个用户(按请求头中的User-ID区分)每分钟最多调用10次接口,超过限流返回429状态码,同时记录限流日志。

实现思路

  1. 使用字典缓存用户调用记录,键为User-ID,值为调用时间列表。
  2. 定义限流依赖项,每次请求前检查用户近1分钟内的调用次数,超过10次则触发限流。
  3. 添加限流日志记录,包含用户ID、限流时间、当前调用次数。
  4. 在通用对话接口中引入限流依赖项,实现限流功能。

实现代码


from fastapi import FastAPI, HTTPException, Depends, Header
from pydantic import BaseModel
import time
import logging

app = FastAPI(title="限流版多模型对话接口")

# 配置限流日志
logging.basicConfig(
    filename="rate_limit_logs.log",
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)

# 缓存用户调用记录:{user_id: [timestamp1, timestamp2, ...]}
user_call_cache = {}
# 限流配置
RATE_LIMIT = 10  # 每分钟最大调用次数
TIME_WINDOW = 60  # 时间窗口(秒)

def rate_limit_check(user_id: str = Header(..., description="用户唯一标识")):
    """限流依赖项"""
    current_time = time.time()
    # 初始化用户调用记录
    if user_id not in user_call_cache:
        user_call_cache[user_id] = []
    # 清理时间窗口外的记录
    user_call_cache[user_id] = [t for t in user_call_cache[user_id] if current_time - t <= TIME_WINDOW]
    # 检查是否超过限流
    if len(user_call_cache[user_id]) >= RATE_LIMIT:
        logging.warning(f"用户{user_id}触发限流,当前调用次数:{len(user_call_cache[user_id])}")
        raise HTTPException(
            status_code=429,
            detail=f"请求过于频繁,请{int(TIME_WINDOW - (current_time - user_call_cache[user_id][0]))}秒后再试"
        )
    # 记录本次调用时间
    user_call_cache[user_id].append(current_time)
    return user_id

# 复用之前的ChatRequest和ChatResponse模型
class ChatRequest(BaseModel):
    model_type: str = "ernie"
    prompt: str
    history: list = []
    temperature: float = 0.3
    max_tokens: int = 1024

class ChatResponse(BaseModel):
    result: str
    history: list
    log: dict

# 复用统一模型调用函数
from model_unified import unified_chat

# 带限流的对话接口
@app.post("/chat/rate-limit", response_model=ChatResponse, dependencies=[Depends(rate_limit_check)])
async def chat_with_rate_limit(request: ChatRequest, user_id: str = Depends(rate_limit_check)):
    try:
        result, history, log = unified_chat(
            model_type=request.model_type,
            prompt=request.prompt,
            history=request.history,
            temperature=request.temperature,
            max_tokens=request.max_tokens
        )
        # 日志添加用户ID
        log["user_id"] = user_id
        logging.info(f"用户{user_id}调用接口成功,日志:{log}")
        return ChatResponse(result=result, history=history, log=log)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

练习题2:API调用成本统计接口

题目描述

基于FastApi实现API调用成本统计接口,支持按模型类型、时间范围(今日/本周/本月)统计总调用次数、总tokens用量、总成本,返回统计结果和详细明细。

实现思路

  1. 读取之前记录的API调用日志,解析日志中的关键信息(模型类型、时间、tokens、成本)。
  2. 定义时间范围参数(today/week/month),根据参数筛选对应时间范围内的日志。
  3. 按模型类型分组统计,计算每组的总调用次数、总tokens、总成本,同时保留明细数据。
  4. 实现统计接口,返回分组统计结果和明细,支持按成本降序排序。

实现代码


from fastapi import FastAPI, Query, HTTPException
import json
import logging
from datetime import datetime, timedelta

app = FastAPI(title="API调用成本统计接口")

# 读取日志并解析
def parse_logs():
    """解析API调用日志"""
    parsed_logs = []
    try:
        with open("api_call_logs.log", "r", encoding="utf-8") as f:
            logs = f.readlines()
        for log in logs:
            try:
                log_json = json.loads(log.split(" - INFO - ")[-1].strip())
                # 转换时间格式为datetime
                log_json["timestamp"] = datetime.strptime(log_json["timestamp"], "%Y-%m-%d %H:%M:%S")
                parsed_logs.append(log_json)
            except Exception as e:
                logging.warning(f"解析日志失败:{str(e)},日志内容:{log}")
                continue
        return parsed_logs
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"读取日志失败:{str(e)}")

# 按时间范围筛选日志
def filter_logs_by_time(logs, time_range: str):
    """
    按时间范围筛选日志
    :param logs: 解析后的日志列表
    :param time_range: 时间范围(today/week/month)
    :return: 筛选后的日志
    """
    now = datetime.now()
    if time_range == "today":
        start_time = datetime(now.year, now.month, now.day)
    elif time_range == "week":
        # 本周一0点
        start_time = now - timedelta(days=now.weekday())
        start_time = datetime(start_time.year, start_time.month, start_time.day)
    elif time_range == "month":
        # 本月1号0点
        start_time = datetime(now.year, now.month, 1)
    else:
        raise ValueError("无效的时间范围,可选:today/week/month")
    return [log for log in logs if log["timestamp"] >= start_time]

# 成本统计接口
@app.get("/cost/stats", description="API调用成本统计接口")
async def cost_stats(
    time_range: str = Query(default="today", description="时间范围:today/week/month"),
    sort_by_cost: bool = Query(default=True, description="是否按成本降序排序")
):
    # 解析日志
    logs = parse_logs()
    # 按时间范围筛选
    filtered_logs = filter_logs_by_time(logs, time_range)
    if not filtered_logs:
        return {"message": "该时间范围内无API调用记录", "stats": {}, "details": []}
    
    # 按模型类型分组统计
    stats = {}
    for log in filtered_logs:
        model = log["model_type"]
        if model not in stats:
            stats[model] = {
                "call_count": 0,
                "total_tokens": 0,
                "total_cost": 0.0,
                "avg_cost_per_call": 0.0
            }
        # 累加统计数据
        stats[model]["call_count"] += 1
        stats[model]["total_tokens"] += log["total_tokens"]
        stats[model]["total_cost"] += log["cost"]
        # 计算平均成本
        stats[model]["avg_cost_per_call"] = round(stats[model]["total_cost"] / stats[model]["call_count"], 4)
    
    # 转换时间格式为字符串(便于返回)
    details = []
    for log in filtered_logs:
        log["timestamp"] = log["timestamp"].strftime("%Y-%m-%d %H:%M:%S")
        details.append(log)
    
    # 按成本排序
    if sort_by_cost:
        # 统计结果排序
        sorted_stats = dict(sorted(stats.items(), key=lambda x: x[1]["total_cost"], reverse=True))
        # 明细排序
        details.sort(key=lambda x: x["cost"], reverse=True)
    else:
        sorted_stats = stats
    
    # 计算总体统计
    total_stats = {
        "total_call_count": sum([v["call_count"] for v in sorted_stats.values()]),
        "total_tokens": sum([v["total_tokens"] for v in sorted_stats.values()]),
        "total_cost": round(sum([v["total_cost"] for v in sorted_stats.values()]), 4)
    }
    
    return {
        "time_range": time_range,
        "total_stats": total_stats,
        "model_stats": sorted_stats,
        "details": details
    }

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

练习题3:多轮对话上下文管理接口

题目描述

基于FastApi实现多轮对话上下文管理接口,支持保存对话历史(按对话ID区分)、加载对话历史、清理过期对话历史(超过24小时无新消息的对话),确保多轮对话上下文连贯。

实现思路

  1. 使用字典缓存对话历史,键为对话ID,值为包含“history”(对话内容)和“last_update_time”(最后更新时间)的字典。
  2. 实现创建对话、加载对话、更新对话、清理过期对话四个接口。
  3. 每次更新对话时刷新最后更新时间,清理接口定期(或手动触发)删除过期对话,释放内存。
  4. 集成多模型对话功能,确保加载历史后上下文连贯。

实现代码


from fastapi import FastAPI, HTTPException, Query
from pydantic import BaseModel
import uuid
import time
import logging

app = FastAPI(title="多轮对话上下文管理接口")

# 配置日志
logging.basicConfig(
    filename="conversation_logs.log",
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S"
)

# 对话缓存:{conv_id: {"history": [], "last_update_time": float}}
conversation_cache = {}
# 过期时间(24小时,单位:秒)
EXPIRE_TIME = 24 * 60 * 60

# 数据模型
class ConversationCreateRequest(BaseModel):
    model_type: str = "ernie"
    initial_prompt: str = Query(description="初始提示词")

class ConversationUpdateRequest(BaseModel):
    conv_id: str = Query(description="对话ID")
    prompt: str = Query(description="当前提示词")
    model_type: str = "ernie"

class ConversationResponse(BaseModel):
    conv_id: str
    history: list
    result: str
    last_update_time: str

# 创建对话
@app.post("/conversation/create", response_model=ConversationResponse, description="创建新对话并发送初始提示词")
async def create_conversation(request: ConversationCreateRequest):
    # 生成唯一对话ID
    conv_id = str(uuid.uuid4())
    current_time = time.time()
    # 调用模型获取初始响应
    from model_unified import unified_chat
    result, history, log = unified_chat(
        model_type=request.model_type,
        prompt=request.initial_prompt
    )
    # 缓存对话
    conversation_cache[conv_id] = {
        "history": history,
        "last_update_time": current_time
    }
    logging.info(f"创建新对话,ID:{conv_id},模型:{request.model_type}")
    return ConversationResponse(
        conv_id=conv_id,
        history=history,
        result=result,
        last_update_time=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(current_time))
    )

# 加载对话
@app.get("/conversation/load", description="加载对话历史")
async def load_conversation(conv_id: str = Query(description="对话ID")):
    if conv_id not in conversation_cache:
        raise HTTPException(status_code=404, detail="对话ID不存在")
    # 检查是否过期
    current_time = time.time()
    if current_time - conversation_cache[conv_id]["last_update_time"] > EXPIRE_TIME:
        # 移除过期对话
        del conversation_cache[conv_id]
        raise HTTPException(status_code=410, detail="对话已过期")
    # 返回对话信息
    conv_info = conversation_cache[conv_id]
    return {
        "conv_id": conv_id,
        "history": conv_info["history"],
        "last_update_time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(conv_info["last_update_time"]))
    }

# 更新对话(发送新消息)
@app.post("/conversation/update", response_model=ConversationResponse, description="更新对话并获取响应")
async def update_conversation(request: ConversationUpdateRequest):
    if request.conv_id not in conversation_cache:
        raise HTTPException(status_code=404, detail="对话ID不存在")
    # 检查过期
    current_time = time.time()
    conv_info = conversation_cache[request.conv_id]
    if current_time - conv_info["last_update_time"] > EXPIRE_TIME:
        del conversation_cache[request.conv_id]
        raise HTTPException(status_code=410, detail="对话已过期")
    # 调用模型(携带历史上下文)
    from model_unified import unified_chat
    result, new_history, log = unified_chat(
        model_type=request.model_type,
        prompt=request.prompt,
        history=conv_info["history"]
    )
    # 更新缓存
    conversation_cache[request.conv_id] = {
        "history": new_history,
        "last_update_time": current_time
    }
    logging.info(f"更新对话,ID:{request.conv_id}")
    return ConversationResponse(
        conv_id=request.conv_id,
        history=new_history,
        result=result,
        last_update_time=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(current_time))
    )

# 清理过期对话
@app.delete("/conversation/clean", description="清理过期对话(手动触发)")
async def clean_expired_conversations():
    current_time = time.time()
    expired_count = 0
    # 遍历缓存,删除过期对话
    for conv_id in list(conversation_cache.keys()):
        if current_time - conversation_cache[conv_id]["last_update_time"] > EXPIRE_TIME:
            del conversation_cache[conv_id]
            expired_count += 1
    logging.info(f"清理过期对话完成,共删除{expired_count}条")
    return {"message": f"清理完成,删除过期对话{expired_count}条", "remaining_conversations": len(conversation_cache)}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

(注:文档部分内容可能由 AI 生成)

添加新评论