xhs_server/assistant.py
2024-12-16 10:31:07 +08:00

476 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import gradio as gr
import asyncio
from openai import OpenAI
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
import aiomysql
import logging
import os
from process_raw_notes import get_content_hash
import re
import tempfile
import graphviz
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# OpenAI配置
api_key = "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"
base_url = "http://52.90.243.11:8787/v1"
client = OpenAI(api_key=api_key, base_url=base_url)
# Embedding模型配置
embedding_model_name = '/datacenter/expert_models/bce-embedding-base_v1/'
embedding_model_kwargs = {'device': 'cuda'}
embedding_encode_kwargs = {'batch_size': 32, 'normalize_embeddings': True}
embed_model = HuggingFaceEmbeddings(
model_name=embedding_model_name,
model_kwargs=embedding_model_kwargs,
encode_kwargs=embedding_encode_kwargs
)
# 数据库配置
db_config = {
'host': '183.11.229.79',
'port': 3316,
'user': 'root',
'password': 'zaq12wsx@9Xin',
'db': '9xin',
'autocommit': True
}
async def get_db_connection():
"""创建数据库连接"""
return await aiomysql.connect(**db_config)
async def get_note_details(conn, note_id):
"""获取笔记详情"""
async with conn.cursor(aiomysql.DictCursor) as cursor:
try:
# 使用 SELECT * 获取所有字段
await cursor.execute("""
SELECT *
FROM xhs_notes
WHERE id = %s
""", (note_id,))
note = await cursor.fetchone()
if note:
# 确保所有需要的字段都有默认值
default_fields = {
'linked_count': note.get('linked_count', ''),
'collected_count': note.get('collected_count', ''),
'comment_count': note.get('comment_count', ''),
'share_count': note.get('share_count', ''),
'title': note.get('title', '无标题'),
'description': note.get('description', '无描述')
}
return note
except Exception as e:
logger.error(f"获取笔记详情时出错: {str(e)}")
# 返回带有默认值的空笔记
return {
'title': '获取失败',
'description': '无法获取笔记内容',
'linked_count': 0,
'collected_count': 0,
'comment_count': 0,
'share_count': 0,
'tag_list': []
}
async def get_clean_note_content(conn, note_id):
"""获取清洗后的笔记内容"""
async with conn.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute("""
SELECT content_type, content
FROM clean_note_store
WHERE note_id = %s
""", (note_id,))
results = await cursor.fetchall()
clean_content = {
'guide': '',
'mindmap': '',
'summary': ''
}
for row in results:
clean_content[row['content_type']] = row['content']
return clean_content
async def get_vector_note_id(conn, vector_id):
"""获取向量对应的笔记ID"""
async with conn.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute("""
SELECT note_id
FROM vector_store
WHERE vector_id = %s
""", (vector_id,))
result = await cursor.fetchone()
return result['note_id'] if result else None
async def chat_with_gpt(system_prompt, user_prompt):
"""与GPT交互"""
try:
# 添加错误检查
if not system_prompt or not user_prompt:
logger.error("系统提示或用户提示为空")
return "入参数错误"
response = client.chat.completions.create( # 移除 await
model="gpt-4o-mini", # 修改为正确的模型名称
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
max_tokens=1000,
temperature=0.7,
)
if not response or not response.choices:
logger.error("GPT响应为空")
return "获取回答"
return response.choices[0].message.content
except Exception as e:
logger.error(f"GPT API调用错误: {str(e)}")
return f"调用出错: {str(e)}"
async def process_query(query):
"""处理用户查询"""
conn = None
try:
logger.info(f"开始处理查询: {query}")
# 加载向量存储
vs_path = "./raw_vs"
logger.info(f"正在检查向量存储路径: {vs_path}")
if not os.path.exists(vs_path):
logger.error(f"向量存储路径不存在: {vs_path}")
return "错误:向量存储不存在", []
logger.info("正在加载向量存储...")
vs = FAISS.load_local(vs_path, embed_model, allow_dangerous_deserialization=True)
# 搜索相关向量
logger.info("开始向量搜索...")
results = vs.similarity_search_with_score(query, k=10)
logger.info(f"找到 {len(results)} 条相关结果")
try:
conn = await get_db_connection()
if not conn:
logger.error("数据库连接失败")
return "数据库连接失败", []
logger.info("数据库连接成功")
context_notes = []
notes_data = [] # 存储完整的笔记数据
# 获取不重复的前5条笔记
seen_note_ids = set()
for doc, score in results:
try:
content_hash = get_content_hash(doc.page_content)
logger.info(f"处理content_hash: {content_hash}")
async with conn.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute("""
SELECT note_id FROM vector_store
WHERE content_hash = %s
""", (content_hash,))
result = await cursor.fetchone()
if not result or result['note_id'] in seen_note_ids:
continue
note_id = result['note_id']
seen_note_ids.add(note_id)
# 获取笔记详情和清洗内容
note = await get_note_details(conn, note_id)
clean_content = await get_clean_note_content(conn, note_id)
if not note or not clean_content:
continue
# 构建完整的笔记数据
note_data = {
'id': note_id,
'title': note.get('title', '无标题'),
'description': note.get('description', '暂无描述'),
'collected_count': note.get('collected_count', 0),
'comment_count': note.get('comment_count', 0),
'share_count': note.get('share_count', 0),
'clean_content': clean_content
}
notes_data.append(note_data)
context_notes.append(f"标题:{note['title']}\n内容:{note['description']}")
if len(notes_data) >= 5: # 修改为5条
break
except Exception as e:
logger.error(f"处理笔记时出错: {str(e)}")
continue
if not notes_data:
logger.warning("未获取到任何有效的笔记内容")
return "未找到相关笔记内容", []
# 准备GPT提示
logger.info("准备调用GPT...")
system_prompt = """你是一位专业的化妆行业教师,专门帮助产品经理理解和分析小红书笔记。
请基于提供的相关笔记内容,对用户的问题给出专业、简洁的回答。
回答要突出重点,并结合化妆行业的专业知识。"""
user_prompt = f"问题:{query}\n\n相关笔记内容:\n" + "\n\n".join(context_notes)
logger.info("调用GPT获取回答...")
gpt_response = await chat_with_gpt(system_prompt, user_prompt)
return gpt_response, notes_data
except Exception as e:
logger.error(f"处理数据时出错: {str(e)}")
return f"处理数据时出错: {str(e)}", []
finally:
if conn is not None:
try:
await conn.close()
logger.info("数据库连接已关闭")
except Exception as e:
logger.error(f"关闭数据库连接时出错: {str(e)}")
def create_ui():
"""创建Gradio界面"""
with gr.Blocks(title="化妆品产品经理AI助教", theme=gr.themes.Soft()) as demo:
# 存储当前的笔记内容
state = gr.State([])
gr.Markdown("# 化妆品产品经理AI助教")
# 修改 HTML 和 JavaScript 部分
gr.HTML("""
<style>
.query-section { background: #f0f7ff; padding: 20px; border-radius: 10px; margin-bottom: 20px; }
.ai-response { background: #f5f5f5; padding: 20px; border-radius: 10px; margin-bottom: 20px; }
.notes-analysis { background: #fff5f5; padding: 20px; border-radius: 10px; }
</style>
<!-- 添加 Mermaid 相关脚本 -->
<script src="https://cdn.jsdelivr.net/npm/mermaid/dist/mermaid.min.js"></script>
<script>
// 初始化 Mermaid
mermaid.initialize({
startOnLoad: true,
theme: 'default',
securityLevel: 'loose',
flowchart: {
useMaxWidth: true,
htmlLabels: true,
curve: 'basis'
}
});
// 监听 DOM 变化并渲染新的图表
const observer = new MutationObserver((mutations) => {
mutations.forEach((mutation) => {
if (mutation.addedNodes.length) {
document.querySelectorAll('.mermaid:not(.mermaid-processed)').forEach(async (element) => {
try {
const graphDefinition = element.textContent;
const { svg } = await mermaid.render('mermaid-svg-' + Math.random(), graphDefinition);
element.innerHTML = svg;
element.classList.add('mermaid-processed');
} catch (error) {
console.error('Mermaid rendering error:', error);
}
});
}
});
});
observer.observe(document.body, {
childList: true,
subtree: true
});
</script>
""")
# 查询区域
with gr.Row():
with gr.Column(scale=2, elem_classes="query-section"):
gr.Markdown("### 💡 问题输入区")
query_input = gr.Textbox(
label="请输入你的问题",
placeholder="例如:最近的面霜趋势是什么?",
lines=3
)
submit_btn = gr.Button("提交问题", variant="primary")
# AI回答和笔记展示区域
with gr.Row():
# 左侧AI回答和笔记卡片
with gr.Column(scale=1, elem_classes="ai-response"):
gr.Markdown("### 🤖 AI助教回答")
gpt_output = gr.Markdown()
gr.Markdown("### 📑 相关笔记")
# 使用 Radio 组件替代自定义卡片
note_selector = gr.Radio(
choices=[],
label="点击笔记查看详细分析",
visible=False
)
# 右侧笔记AI分析
with gr.Column(scale=1, elem_classes="notes-analysis"):
gr.Markdown("### 🔍 笔记AI分析")
clean_content_output = gr.HTML()
def format_note_choice(note):
"""格式化笔记选项"""
title = note.get('title', '无标题')
desc = note.get('description', '暂无描述')[:100]
stats = f"❤️ {note.get('collected_count', 0)} 💬 {note.get('comment_count', 0)} ↗️ {note.get('share_count', 0)}"
return f"{title}\n{desc}...\n{stats}"
async def handle_submit(query, state):
"""处理用户提交的查询"""
try:
gpt_response, notes_data = await process_query(query)
if not notes_data:
return (
"未找到相关笔记", # GPT响应
"请点击左侧笔记查看分析内容", # 分析内容
[], # 状态
gr.update(choices=[], value=None, visible=False) # Radio组件更新
)
# 格式化笔记选项
note_choices = [format_note_choice(note) for note in notes_data]
return (
f"## AI助教回答\n\n{gpt_response}", # GPT响应
"请点击左侧笔记查看分析内容", # 初始分析内容
notes_data, # 状态
gr.update( # Radio组件更新
choices=note_choices,
value=None,
visible=True
)
)
except Exception as e:
logger.error(f"处理查询时出错: {str(e)}")
return (
f"处理查询时出错: {str(e)}",
"",
[],
gr.update(choices=[], value=None, visible=False)
)
def show_note_analysis(choice, state):
"""处理笔记选择"""
logger.info(f"笔记选择改变: {choice}")
if not choice or not state:
return "请选择笔记查看分析内容"
# 根据选择的文本找到对应的笔记索引
try:
idx = next(i for i, note in enumerate(state)
if format_note_choice(note) == choice)
except StopIteration:
return "无法找到对应的笔记"
note = state[idx]
clean_content = note.get('clean_content', {})
def clean_text(text):
if not text:
return ""
return re.sub(r'```(?:markdown|mermaid)?', '', text).strip()
html_parts = []
# 添加导读部分
if clean_content.get('guide'):
guide = clean_text(clean_content['guide'])
html_parts.append(f"<h2>导读</h2><div class='guide-content'>{guide}</div>")
# 处理思维导图
if clean_content.get('mindmap'):
mindmap = clean_text(clean_content['mindmap'])
iframe_content = f"""
<iframe srcdoc='
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8" />
<script src="https://cdn.jsdelivr.net/npm/mermaid/dist/mermaid.min.js"></script>
<script>
mermaid.initialize({{ startOnLoad: true }});
</script>
</head>
<body>
<div class="mermaid">
{mindmap}
</div>
</body>
</html>
' width="100%" height="600px" style="border:none;">
</iframe>
"""
html_parts.append(f"""
<h2>思维导图</h2>
{iframe_content}
""")
# 添加总结部分
if clean_content.get('summary'):
summary = clean_text(clean_content['summary'])
html_parts.append(f"<h2>总结</h2><div class='summary-content'>{summary}</div>")
# 组合所有内容
html_content = f"""
<div class='analysis-content' style='padding: 20px;'>
{' '.join(html_parts)}
</div>
"""
return html_content
def sync_handle_submit(query, state):
"""同步包装异步函数"""
return asyncio.run(handle_submit(query, state))
# 绑定事件
submit_btn.click(
fn=sync_handle_submit,
inputs=[query_input, state],
outputs=[gpt_output, clean_content_output, state, note_selector]
)
note_selector.change(
fn=show_note_analysis,
inputs=[note_selector, state],
outputs=[clean_content_output]
)
return demo
if __name__ == "__main__":
demo = create_ui()
demo.launch(server_name="0.0.0.0", server_port=7865)