xhs_server/assistant.py

476 lines
18 KiB
Python
Raw Normal View History

2024-12-16 02:31:07 +00:00
import gradio as gr
import asyncio
from openai import OpenAI
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
import aiomysql
import logging
import os
from process_raw_notes import get_content_hash
import re
import tempfile
import graphviz
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# OpenAI配置
api_key = "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"
base_url = "http://52.90.243.11:8787/v1"
client = OpenAI(api_key=api_key, base_url=base_url)
# Embedding模型配置
embedding_model_name = '/datacenter/expert_models/bce-embedding-base_v1/'
embedding_model_kwargs = {'device': 'cuda'}
embedding_encode_kwargs = {'batch_size': 32, 'normalize_embeddings': True}
embed_model = HuggingFaceEmbeddings(
model_name=embedding_model_name,
model_kwargs=embedding_model_kwargs,
encode_kwargs=embedding_encode_kwargs
)
# 数据库配置
db_config = {
'host': '183.11.229.79',
'port': 3316,
'user': 'root',
'password': 'zaq12wsx@9Xin',
'db': '9xin',
'autocommit': True
}
async def get_db_connection():
"""创建数据库连接"""
return await aiomysql.connect(**db_config)
async def get_note_details(conn, note_id):
"""获取笔记详情"""
async with conn.cursor(aiomysql.DictCursor) as cursor:
try:
# 使用 SELECT * 获取所有字段
await cursor.execute("""
SELECT *
FROM xhs_notes
WHERE id = %s
""", (note_id,))
note = await cursor.fetchone()
if note:
# 确保所有需要的字段都有默认值
default_fields = {
'linked_count': note.get('linked_count', ''),
'collected_count': note.get('collected_count', ''),
'comment_count': note.get('comment_count', ''),
'share_count': note.get('share_count', ''),
'title': note.get('title', '无标题'),
'description': note.get('description', '无描述')
}
return note
except Exception as e:
logger.error(f"获取笔记详情时出错: {str(e)}")
# 返回带有默认值的空笔记
return {
'title': '获取失败',
'description': '无法获取笔记内容',
'linked_count': 0,
'collected_count': 0,
'comment_count': 0,
'share_count': 0,
'tag_list': []
}
async def get_clean_note_content(conn, note_id):
"""获取清洗后的笔记内容"""
async with conn.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute("""
SELECT content_type, content
FROM clean_note_store
WHERE note_id = %s
""", (note_id,))
results = await cursor.fetchall()
clean_content = {
'guide': '',
'mindmap': '',
'summary': ''
}
for row in results:
clean_content[row['content_type']] = row['content']
return clean_content
async def get_vector_note_id(conn, vector_id):
"""获取向量对应的笔记ID"""
async with conn.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute("""
SELECT note_id
FROM vector_store
WHERE vector_id = %s
""", (vector_id,))
result = await cursor.fetchone()
return result['note_id'] if result else None
async def chat_with_gpt(system_prompt, user_prompt):
"""与GPT交互"""
try:
# 添加错误检查
if not system_prompt or not user_prompt:
logger.error("系统提示或用户提示为空")
return "入参数错误"
response = client.chat.completions.create( # 移除 await
model="gpt-4o-mini", # 修改为正确的模型名称
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
max_tokens=1000,
temperature=0.7,
)
if not response or not response.choices:
logger.error("GPT响应为空")
return "获取回答"
return response.choices[0].message.content
except Exception as e:
logger.error(f"GPT API调用错误: {str(e)}")
return f"调用出错: {str(e)}"
async def process_query(query):
"""处理用户查询"""
conn = None
try:
logger.info(f"开始处理查询: {query}")
# 加载向量存储
vs_path = "./raw_vs"
logger.info(f"正在检查向量存储路径: {vs_path}")
if not os.path.exists(vs_path):
logger.error(f"向量存储路径不存在: {vs_path}")
return "错误:向量存储不存在", []
logger.info("正在加载向量存储...")
vs = FAISS.load_local(vs_path, embed_model, allow_dangerous_deserialization=True)
# 搜索相关向量
logger.info("开始向量搜索...")
results = vs.similarity_search_with_score(query, k=10)
logger.info(f"找到 {len(results)} 条相关结果")
try:
conn = await get_db_connection()
if not conn:
logger.error("数据库连接失败")
return "数据库连接失败", []
logger.info("数据库连接成功")
context_notes = []
notes_data = [] # 存储完整的笔记数据
# 获取不重复的前5条笔记
seen_note_ids = set()
for doc, score in results:
try:
content_hash = get_content_hash(doc.page_content)
logger.info(f"处理content_hash: {content_hash}")
async with conn.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute("""
SELECT note_id FROM vector_store
WHERE content_hash = %s
""", (content_hash,))
result = await cursor.fetchone()
if not result or result['note_id'] in seen_note_ids:
continue
note_id = result['note_id']
seen_note_ids.add(note_id)
# 获取笔记详情和清洗内容
note = await get_note_details(conn, note_id)
clean_content = await get_clean_note_content(conn, note_id)
if not note or not clean_content:
continue
# 构建完整的笔记数据
note_data = {
'id': note_id,
'title': note.get('title', '无标题'),
'description': note.get('description', '暂无描述'),
'collected_count': note.get('collected_count', 0),
'comment_count': note.get('comment_count', 0),
'share_count': note.get('share_count', 0),
'clean_content': clean_content
}
notes_data.append(note_data)
context_notes.append(f"标题:{note['title']}\n内容:{note['description']}")
if len(notes_data) >= 5: # 修改为5条
break
except Exception as e:
logger.error(f"处理笔记时出错: {str(e)}")
continue
if not notes_data:
logger.warning("未获取到任何有效的笔记内容")
return "未找到相关笔记内容", []
# 准备GPT提示
logger.info("准备调用GPT...")
system_prompt = """你是一位专业的化妆行业教师,专门帮助产品经理理解和分析小红书笔记。
请基于提供的相关笔记内容对用户的问题给出专业简洁的回答
回答要突出重点并结合化妆行业的专业知识"""
user_prompt = f"问题:{query}\n\n相关笔记内容:\n" + "\n\n".join(context_notes)
logger.info("调用GPT获取回答...")
gpt_response = await chat_with_gpt(system_prompt, user_prompt)
return gpt_response, notes_data
except Exception as e:
logger.error(f"处理数据时出错: {str(e)}")
return f"处理数据时出错: {str(e)}", []
finally:
if conn is not None:
try:
await conn.close()
logger.info("数据库连接已关闭")
except Exception as e:
logger.error(f"关闭数据库连接时出错: {str(e)}")
def create_ui():
"""创建Gradio界面"""
with gr.Blocks(title="化妆品产品经理AI助教", theme=gr.themes.Soft()) as demo:
# 存储当前的笔记内容
state = gr.State([])
gr.Markdown("# 化妆品产品经理AI助教")
# 修改 HTML 和 JavaScript 部分
gr.HTML("""
<style>
.query-section { background: #f0f7ff; padding: 20px; border-radius: 10px; margin-bottom: 20px; }
.ai-response { background: #f5f5f5; padding: 20px; border-radius: 10px; margin-bottom: 20px; }
.notes-analysis { background: #fff5f5; padding: 20px; border-radius: 10px; }
</style>
<!-- 添加 Mermaid 相关脚本 -->
<script src="https://cdn.jsdelivr.net/npm/mermaid/dist/mermaid.min.js"></script>
<script>
// 初始化 Mermaid
mermaid.initialize({
startOnLoad: true,
theme: 'default',
securityLevel: 'loose',
flowchart: {
useMaxWidth: true,
htmlLabels: true,
curve: 'basis'
}
});
// 监听 DOM 变化并渲染新的图表
const observer = new MutationObserver((mutations) => {
mutations.forEach((mutation) => {
if (mutation.addedNodes.length) {
document.querySelectorAll('.mermaid:not(.mermaid-processed)').forEach(async (element) => {
try {
const graphDefinition = element.textContent;
const { svg } = await mermaid.render('mermaid-svg-' + Math.random(), graphDefinition);
element.innerHTML = svg;
element.classList.add('mermaid-processed');
} catch (error) {
console.error('Mermaid rendering error:', error);
}
});
}
});
});
observer.observe(document.body, {
childList: true,
subtree: true
});
</script>
""")
# 查询区域
with gr.Row():
with gr.Column(scale=2, elem_classes="query-section"):
gr.Markdown("### 💡 问题输入区")
query_input = gr.Textbox(
label="请输入你的问题",
placeholder="例如:最近的面霜趋势是什么?",
lines=3
)
submit_btn = gr.Button("提交问题", variant="primary")
# AI回答和笔记展示区域
with gr.Row():
# 左侧AI回答和笔记卡片
with gr.Column(scale=1, elem_classes="ai-response"):
gr.Markdown("### 🤖 AI助教回答")
gpt_output = gr.Markdown()
gr.Markdown("### 📑 相关笔记")
# 使用 Radio 组件替代自定义卡片
note_selector = gr.Radio(
choices=[],
label="点击笔记查看详细分析",
visible=False
)
# 右侧笔记AI分析
with gr.Column(scale=1, elem_classes="notes-analysis"):
gr.Markdown("### 🔍 笔记AI分析")
clean_content_output = gr.HTML()
def format_note_choice(note):
"""格式化笔记选项"""
title = note.get('title', '无标题')
desc = note.get('description', '暂无描述')[:100]
stats = f"❤️ {note.get('collected_count', 0)} 💬 {note.get('comment_count', 0)} ↗️ {note.get('share_count', 0)}"
return f"{title}\n{desc}...\n{stats}"
async def handle_submit(query, state):
"""处理用户提交的查询"""
try:
gpt_response, notes_data = await process_query(query)
if not notes_data:
return (
"未找到相关笔记", # GPT响应
"请点击左侧笔记查看分析内容", # 分析内容
[], # 状态
gr.update(choices=[], value=None, visible=False) # Radio组件更新
)
# 格式化笔记选项
note_choices = [format_note_choice(note) for note in notes_data]
return (
f"## AI助教回答\n\n{gpt_response}", # GPT响应
"请点击左侧笔记查看分析内容", # 初始分析内容
notes_data, # 状态
gr.update( # Radio组件更新
choices=note_choices,
value=None,
visible=True
)
)
except Exception as e:
logger.error(f"处理查询时出错: {str(e)}")
return (
f"处理查询时出错: {str(e)}",
"",
[],
gr.update(choices=[], value=None, visible=False)
)
def show_note_analysis(choice, state):
"""处理笔记选择"""
logger.info(f"笔记选择改变: {choice}")
if not choice or not state:
return "请选择笔记查看分析内容"
# 根据选择的文本找到对应的笔记索引
try:
idx = next(i for i, note in enumerate(state)
if format_note_choice(note) == choice)
except StopIteration:
return "无法找到对应的笔记"
note = state[idx]
clean_content = note.get('clean_content', {})
def clean_text(text):
if not text:
return ""
return re.sub(r'```(?:markdown|mermaid)?', '', text).strip()
html_parts = []
# 添加导读部分
if clean_content.get('guide'):
guide = clean_text(clean_content['guide'])
html_parts.append(f"<h2>导读</h2><div class='guide-content'>{guide}</div>")
# 处理思维导图
if clean_content.get('mindmap'):
mindmap = clean_text(clean_content['mindmap'])
iframe_content = f"""
<iframe srcdoc='
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8" />
<script src="https://cdn.jsdelivr.net/npm/mermaid/dist/mermaid.min.js"></script>
<script>
mermaid.initialize({{ startOnLoad: true }});
</script>
</head>
<body>
<div class="mermaid">
{mindmap}
</div>
</body>
</html>
' width="100%" height="600px" style="border:none;">
</iframe>
"""
html_parts.append(f"""
<h2>思维导图</h2>
{iframe_content}
""")
# 添加总结部分
if clean_content.get('summary'):
summary = clean_text(clean_content['summary'])
html_parts.append(f"<h2>总结</h2><div class='summary-content'>{summary}</div>")
# 组合所有内容
html_content = f"""
<div class='analysis-content' style='padding: 20px;'>
{' '.join(html_parts)}
</div>
"""
return html_content
def sync_handle_submit(query, state):
"""同步包装异步函数"""
return asyncio.run(handle_submit(query, state))
# 绑定事件
submit_btn.click(
fn=sync_handle_submit,
inputs=[query_input, state],
outputs=[gpt_output, clean_content_output, state, note_selector]
)
note_selector.change(
fn=show_note_analysis,
inputs=[note_selector, state],
outputs=[clean_content_output]
)
return demo
if __name__ == "__main__":
demo = create_ui()
demo.launch(server_name="0.0.0.0", server_port=7865)