import re
import hashlib
import xml.etree.ElementTree as ET
from datetime import datetime
from fastapi import FastAPI, Request, Response
# RAG
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
# LLM
from langchain_ollama import OllamaLLM
# LangGraph
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import StateGraph, END
from typing import TypedDict
from langchain_core.messages import HumanMessage, AIMessage
import requests
import time
app = FastAPI(title="微信公众号 RAG Agent")
# ===================== 配置 =====================
WECHAT_TOKEN = "wechat_rag_agent_2026"
WECHAT_ORIGIN_ID = "gh_cfbff1593cc5"
EMBEDDING_PATH = "./rag/models--Qwen--Qwen3-Embedding-0.6B/snapshots/c54f2e6e80b2d7b7de06f51cec4959f6b3e03418"
FAISS_INDEX_DIR = "./rag/faiss_index"
LLM_MODEL = "my-qwen"
OLLAMA_HOST = "http://localhost:11434"
CACHE_EXPIRE_SECONDS = 30
SESSION_TIMEOUT = 60
# ==================================================
# ===================== 优化版欢迎语 =====================
WELCOME_TEXT = """✨ 你好,欢迎使用智能励志助手
在这里,你可以随时获取正能量与成长感悟~
📌 使用指南
1️⃣ 回复【今天】或【每日一句】
获取每日一句精美图文,中英双语励志短句
2️⃣ 发送关键词触发知识库
支持:励志、奋斗、努力、学习、知识、坚持
成长、名言、失败、成功、气馁、失望、沮丧、希望
3️⃣ 直接发送文字聊天
AI 会为你简洁回复,陪你聊天解惑
愿你步履不停,始终向阳 🌱
"""
# 会话记录
USER_SESSION = {}
MESSAGE_CACHE = {}
LAST_MSG_ID = None
def clean_expired_cache():
now = time.time()
expired = [k for k, v in MESSAGE_CACHE.items() if v["expire_at"] < now]
for key in expired:
del MESSAGE_CACHE[key]
# ---------------------- RAG ----------------------
device = "cpu"
embeddings = HuggingFaceEmbeddings(
model_name=EMBEDDING_PATH,
model_kwargs={"device": device}
)
rag_db = FAISS.load_local(
FAISS_INDEX_DIR,
embeddings,
allow_dangerous_deserialization=True
)
retriever = rag_db.as_retriever(k=2)
# ---------------------- LLM ----------------------
llm = OllamaLLM(model=LLM_MODEL, base_url=OLLAMA_HOST)
# ---------------------- 工具函数 ----------------------
def is_today_sentence(text):
t = text.strip().lower()
return t in ["今天", "每日一句", "今日一句", "句子"]
def get_iciba():
try:
url = "https://open.iciba.com/dsapi/"
resp = requests.get(url, timeout=3)
data = resp.json()
return {
"title": f"📅 每日一句 | {data['dateline']}",
"desc": f"{data['note']}",
"pic": data["fenxiang_img"]
}
except Exception as e:
print(f"金山接口异常: {e}")
return None
def is_motivation(text):
keywords = ["励志", "奋斗", "努力", "学习", "知识", "坚持", "成长", "名言", "失败","成功","气馁","失望","沮丧","希望"]
return any(key in text for key in keywords)
# ---------------------- LangGraph ----------------------
class AgentState(TypedDict):
user_input: str
intent: str
reply: str
messages: list
def classify_intent(state: AgentState):
msg = state["user_input"]
if is_today_sentence(msg):
return {"intent": "sentence"}
elif is_motivation(msg):
return {"intent": "rag"}
else:
return {"intent": "chat"}
# ===================== 最终美化版 RAG 展示 =====================
# def rag_node(state: AgentState):
# import re
# docs = retriever.invoke(state["user_input"])
# if not docs:
# return {"reply": "💡 未找到相关内容,换个关键词试试~"}
#
# lines = ["🌿 精选励志名言\n"]
# for d in docs:
# text = d.page_content.strip()
# eng = re.sub(r'[\u4e00-\u9fff]', '', text).strip()
# chn = re.sub(r'[a-zA-Z\',.;!?\"\- ]', '', text).strip()
#
# lines.append(f"❝ {eng} ❞")
# lines.append(f" {chn}\n")
#
# lines.append("📚 来源:本地 RAG 知识库")
# return {"reply": "\n".join(lines)}
def rag_node(state: AgentState):
import re
docs = retriever.invoke(state["user_input"])
if not docs:
return {"reply": "未找到相关内容,换个关键词试试~"}
lines = ["🌿 精选励志名言"]
lines.append("")
for idx, d in enumerate(docs, 1):
text = d.page_content.strip()
# 清理英文:删除中文+中文符号,只保留标准英文
eng = re.sub(r'[\u4e00-\u9fff]', '', text)
eng = re.sub(r'[,。、;:!?…""''()]', '', eng)
eng = re.sub(r'\s+', ' ', eng).strip()
eng = eng.rstrip('.!?') + '.'
# 清理中文:删除英文+符号,补全句号
chn = re.sub(r'[a-zA-Z\'",;:!.?\-()\[\]]', '', text)
chn = re.sub(r'\s+', ' ', chn).strip()
chn = chn.rstrip(',。;!?') + '。'
# 拼接内容(无多余空行)
lines.append(f"{idx}. {eng}")
lines.append(f"\t\t{chn}")
lines.append("---")
# 删除最后一条分隔线
if lines and lines[-1] == "---":
lines.pop()
lines.append("")
lines.append("📍 来源:本地 RAG 知识库")
return {"reply": "\n".join(lines)}
def chat_node(state: AgentState):
reply = llm.invoke(state["user_input"])
return {"reply": reply, "messages": [AIMessage(content=reply)]}
memory = InMemorySaver()
wf = StateGraph(AgentState)
wf.add_node("cls", classify_intent)
wf.add_node("rag", rag_node)
wf.add_node("chat", chat_node)
wf.set_entry_point("cls")
wf.add_conditional_edges("cls", lambda s: s["intent"], {
"sentence": "chat",
"rag": "rag",
"chat": "chat"
})
wf.add_edge("rag", END)
wf.add_edge("chat", END)
agent = wf.compile(checkpointer=memory)
# ---------------------- 微信接口 ----------------------
@app.get("/wechat")
def wechat_verify(signature: str, timestamp: str, nonce: str, echostr: str):
print(f"[GET] 微信接口验证")
tmp = sorted([WECHAT_TOKEN, timestamp, nonce])
sha1 = hashlib.sha1("".join(tmp).encode()).hexdigest()
return Response(content=echostr if sha1 == signature else "err")
@app.post("/wechat")
async def wechat_msg(request: Request):
global LAST_MSG_ID
try:
xml_body = await request.body()
print(f"\n========== 收到微信请求 ==========")
root = ET.fromstring(xml_body)
msg_type = root.find("MsgType").text
openid = root.find("FromUserName").text
to_user = root.find("ToUserName").text
print(f"消息类型: {msg_type}, 用户OPENID: {openid}")
if msg_type != "text":
print("非文本消息,直接返回")
return Response("")
content = root.find("Content").text.strip()
msg_id = root.find("MsgId").text.strip()
now = time.time()
clean_expired_cache()
print(f"用户消息: {content}, 消息ID: {msg_id}")
first_message = False
if openid not in USER_SESSION or now - USER_SESSION[openid] > SESSION_TIMEOUT:
first_message = True
print("✅ 新会话,标记为首次消息")
USER_SESSION[openid] = now
# 每日一句
if is_today_sentence(content):
data = get_iciba()
if data:
print(f"✅ 调用金山接口成功,返回图文")
reply_xml = f"""<xml>
<ToUserName><![CDATA[{openid}]]></ToUserName>
<FromUserName><![CDATA[{to_user}]]></FromUserName>
<CreateTime>{int(now)}</CreateTime>
<MsgType><![CDATA[news]]></MsgType>
<ArticleCount>1</ArticleCount>
<Articles>
<item>
<Title><![CDATA[{data['title']}]]></Title>
<Description><![CDATA[{data['desc']}]]></Description>
<PicUrl><![CDATA[{data['pic']}]]></PicUrl>
<Url><![CDATA[]]></Url>
</item>
</Articles>
</xml>"""
return Response(content=reply_xml, media_type="application/xml")
# 正常 AI 逻辑
if msg_id != LAST_MSG_ID:
MESSAGE_CACHE.clear()
LAST_MSG_ID = msg_id
print(f"✅ 新消息ID,清空缓存")
if msg_id in MESSAGE_CACHE and MESSAGE_CACHE[msg_id]["expire_at"] > now:
reply = MESSAGE_CACHE[msg_id]["reply"]
print(f"✅ 命中缓存,直接回复: {reply}")
else:
config = {"configurable": {"thread_id": openid}}
result = agent.invoke({
"user_input": content,
"messages": [HumanMessage(content=content)]
}, config=config)
reply = result["reply"]
MESSAGE_CACHE[msg_id] = {
"reply": reply,
"expire_at": now + CACHE_EXPIRE_SECONDS
}
print(f"✅ AI生成回复: {reply}")
if first_message:
print(f"✅ 首次消息,拼接欢迎语")
reply = WELCOME_TEXT + reply
reply_xml = f"""<xml>
<ToUserName><![CDATA[{openid}]]></ToUserName>
<FromUserName><![CDATA[{to_user}]]></FromUserName>
<CreateTime>{int(now)}</CreateTime>
<MsgType><![CDATA[text]]></MsgType>
<Content><![CDATA[{reply}]]></Content>
</xml>"""
print(f"✅ 最终回复XML: {reply_xml}")
return Response(content=reply_xml, media_type="application/xml")
except Exception as e:
print(f"❌ 全局异常: {str(e)}")
return Response("")
@app.get("/")
def index():
return {"status": "running"}