wechat_agent_ai - main.py

import re
import hashlib
import xml.etree.ElementTree as ET
from datetime import datetime
from fastapi import FastAPI, Request, Response

# RAG
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS

# LLM
from langchain_ollama import OllamaLLM

# LangGraph
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import StateGraph, END
from typing import TypedDict
from langchain_core.messages import HumanMessage, AIMessage

import requests
import time

app = FastAPI(title="微信公众号 RAG Agent")

# ===================== 配置 =====================
WECHAT_TOKEN = "wechat_rag_agent_2026"
WECHAT_ORIGIN_ID = "gh_cfbff1593cc5"

EMBEDDING_PATH = "./rag/models--Qwen--Qwen3-Embedding-0.6B/snapshots/c54f2e6e80b2d7b7de06f51cec4959f6b3e03418"
FAISS_INDEX_DIR = "./rag/faiss_index"
LLM_MODEL = "my-qwen"
OLLAMA_HOST = "http://localhost:11434"

CACHE_EXPIRE_SECONDS = 30
SESSION_TIMEOUT = 60
# ==================================================

# ===================== 优化版欢迎语 =====================
WELCOME_TEXT = """✨ 你好,欢迎使用智能励志助手
在这里,你可以随时获取正能量与成长感悟~

📌 使用指南
1️⃣ 回复【今天】或【每日一句】
   获取每日一句精美图文,中英双语励志短句

2️⃣ 发送关键词触发知识库
   支持:励志、奋斗、努力、学习、知识、坚持
   成长、名言、失败、成功、气馁、失望、沮丧、希望

3️⃣ 直接发送文字聊天
   AI 会为你简洁回复,陪你聊天解惑

愿你步履不停,始终向阳 🌱
"""

# 会话记录
USER_SESSION = {}
MESSAGE_CACHE = {}
LAST_MSG_ID = None


def clean_expired_cache():
    now = time.time()
    expired = [k for k, v in MESSAGE_CACHE.items() if v["expire_at"] < now]
    for key in expired:
        del MESSAGE_CACHE[key]


# ---------------------- RAG ----------------------
device = "cpu"
embeddings = HuggingFaceEmbeddings(
    model_name=EMBEDDING_PATH,
    model_kwargs={"device": device}
)
rag_db = FAISS.load_local(
    FAISS_INDEX_DIR,
    embeddings,
    allow_dangerous_deserialization=True
)
retriever = rag_db.as_retriever(k=2)

# ---------------------- LLM ----------------------
llm = OllamaLLM(model=LLM_MODEL, base_url=OLLAMA_HOST)


# ---------------------- 工具函数 ----------------------
def is_today_sentence(text):
    t = text.strip().lower()
    return t in ["今天", "每日一句", "今日一句", "句子"]


def get_iciba():
    try:
        url = "https://open.iciba.com/dsapi/"
        resp = requests.get(url, timeout=3)
        data = resp.json()
        return {
            "title": f"📅 每日一句 | {data['dateline']}",
            "desc": f"{data['note']}",
            "pic": data["fenxiang_img"]
        }
    except Exception as e:
        print(f"金山接口异常: {e}")
        return None


def is_motivation(text):
    keywords = ["励志", "奋斗", "努力", "学习", "知识", "坚持", "成长", "名言", "失败","成功","气馁","失望","沮丧","希望"]
    return any(key in text for key in keywords)


# ---------------------- LangGraph ----------------------
class AgentState(TypedDict):
    user_input: str
    intent: str
    reply: str
    messages: list


def classify_intent(state: AgentState):
    msg = state["user_input"]
    if is_today_sentence(msg):
        return {"intent": "sentence"}
    elif is_motivation(msg):
        return {"intent": "rag"}
    else:
        return {"intent": "chat"}


# ===================== 最终美化版 RAG 展示 =====================
# def rag_node(state: AgentState):
#     import re
#     docs = retriever.invoke(state["user_input"])
#     if not docs:
#         return {"reply": "💡 未找到相关内容,换个关键词试试~"}
#
#     lines = ["🌿 精选励志名言\n"]
#     for d in docs:
#         text = d.page_content.strip()
#         eng = re.sub(r'[\u4e00-\u9fff]', '', text).strip()
#         chn = re.sub(r'[a-zA-Z\',.;!?\"\- ]', '', text).strip()
#
#         lines.append(f"❝ {eng} ❞")
#         lines.append(f"   {chn}\n")
#
#     lines.append("📚 来源:本地 RAG 知识库")
#     return {"reply": "\n".join(lines)}

def rag_node(state: AgentState):
    import re
    docs = retriever.invoke(state["user_input"])
    if not docs:
        return {"reply": "未找到相关内容,换个关键词试试~"}

    lines = ["🌿 精选励志名言"]
    lines.append("")
    for idx, d in enumerate(docs, 1):
        text = d.page_content.strip()

        # 清理英文:删除中文+中文符号,只保留标准英文
        eng = re.sub(r'[\u4e00-\u9fff]', '', text)
        eng = re.sub(r'[,。、;:!?…""''()]', '', eng)
        eng = re.sub(r'\s+', ' ', eng).strip()
        eng = eng.rstrip('.!?') + '.'

        # 清理中文:删除英文+符号,补全句号
        chn = re.sub(r'[a-zA-Z\'",;:!.?\-()\[\]]', '', text)
        chn = re.sub(r'\s+', ' ', chn).strip()
        chn = chn.rstrip(',。;!?') + '。'

        # 拼接内容(无多余空行)
        lines.append(f"{idx}. {eng}")
        lines.append(f"\t\t{chn}")
        lines.append("---")

    # 删除最后一条分隔线
    if lines and lines[-1] == "---":
        lines.pop()

    lines.append("")
    lines.append("📍 来源:本地 RAG 知识库")

    return {"reply": "\n".join(lines)}


def chat_node(state: AgentState):
    reply = llm.invoke(state["user_input"])
    return {"reply": reply, "messages": [AIMessage(content=reply)]}


memory = InMemorySaver()
wf = StateGraph(AgentState)
wf.add_node("cls", classify_intent)
wf.add_node("rag", rag_node)
wf.add_node("chat", chat_node)

wf.set_entry_point("cls")
wf.add_conditional_edges("cls", lambda s: s["intent"], {
    "sentence": "chat",
    "rag": "rag",
    "chat": "chat"
})
wf.add_edge("rag", END)
wf.add_edge("chat", END)

agent = wf.compile(checkpointer=memory)


# ---------------------- 微信接口 ----------------------
@app.get("/wechat")
def wechat_verify(signature: str, timestamp: str, nonce: str, echostr: str):
    print(f"[GET] 微信接口验证")
    tmp = sorted([WECHAT_TOKEN, timestamp, nonce])
    sha1 = hashlib.sha1("".join(tmp).encode()).hexdigest()
    return Response(content=echostr if sha1 == signature else "err")


@app.post("/wechat")
async def wechat_msg(request: Request):
    global LAST_MSG_ID
    try:
        xml_body = await request.body()
        print(f"\n========== 收到微信请求 ==========")
        root = ET.fromstring(xml_body)
        msg_type = root.find("MsgType").text
        openid = root.find("FromUserName").text
        to_user = root.find("ToUserName").text

        print(f"消息类型: {msg_type}, 用户OPENID: {openid}")

        if msg_type != "text":
            print("非文本消息,直接返回")
            return Response("")

        content = root.find("Content").text.strip()
        msg_id = root.find("MsgId").text.strip()
        now = time.time()
        clean_expired_cache()
        print(f"用户消息: {content}, 消息ID: {msg_id}")

        first_message = False
        if openid not in USER_SESSION or now - USER_SESSION[openid] > SESSION_TIMEOUT:
            first_message = True
            print("✅ 新会话,标记为首次消息")
        USER_SESSION[openid] = now

        # 每日一句
        if is_today_sentence(content):
            data = get_iciba()
            if data:
                print(f"✅ 调用金山接口成功,返回图文")
                reply_xml = f"""<xml>
<ToUserName><![CDATA[{openid}]]></ToUserName>
<FromUserName><![CDATA[{to_user}]]></FromUserName>
<CreateTime>{int(now)}</CreateTime>
<MsgType><![CDATA[news]]></MsgType>
<ArticleCount>1</ArticleCount>
<Articles>
<item>
<Title><![CDATA[{data['title']}]]></Title>
<Description><![CDATA[{data['desc']}]]></Description>
<PicUrl><![CDATA[{data['pic']}]]></PicUrl>
<Url><![CDATA[]]></Url>
</item>
</Articles>
</xml>"""
                return Response(content=reply_xml, media_type="application/xml")

        # 正常 AI 逻辑
        if msg_id != LAST_MSG_ID:
            MESSAGE_CACHE.clear()
            LAST_MSG_ID = msg_id
            print(f"✅ 新消息ID,清空缓存")

        if msg_id in MESSAGE_CACHE and MESSAGE_CACHE[msg_id]["expire_at"] > now:
            reply = MESSAGE_CACHE[msg_id]["reply"]
            print(f"✅ 命中缓存,直接回复: {reply}")
        else:
            config = {"configurable": {"thread_id": openid}}
            result = agent.invoke({
                "user_input": content,
                "messages": [HumanMessage(content=content)]
            }, config=config)
            reply = result["reply"]
            MESSAGE_CACHE[msg_id] = {
                "reply": reply,
                "expire_at": now + CACHE_EXPIRE_SECONDS
            }
            print(f"✅ AI生成回复: {reply}")

        if first_message:
            print(f"✅ 首次消息,拼接欢迎语")
            reply = WELCOME_TEXT + reply

        reply_xml = f"""<xml>
<ToUserName><![CDATA[{openid}]]></ToUserName>
<FromUserName><![CDATA[{to_user}]]></FromUserName>
<CreateTime>{int(now)}</CreateTime>
<MsgType><![CDATA[text]]></MsgType>
<Content><![CDATA[{reply}]]></Content>
</xml>"""
        print(f"✅ 最终回复XML: {reply_xml}")
        return Response(content=reply_xml, media_type="application/xml")

    except Exception as e:
        print(f"❌ 全局异常: {str(e)}")
        return Response("")


@app.get("/")
def index():
    return {"status": "running"}