476 lines
18 KiB
Python
476 lines
18 KiB
Python
import gradio as gr
|
||
import asyncio
|
||
from openai import OpenAI
|
||
from langchain_community.vectorstores import FAISS
|
||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||
import aiomysql
|
||
import logging
|
||
import os
|
||
from process_raw_notes import get_content_hash
|
||
import re
|
||
import tempfile
|
||
import graphviz
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# OpenAI配置
|
||
api_key = "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"
|
||
base_url = "http://52.90.243.11:8787/v1"
|
||
client = OpenAI(api_key=api_key, base_url=base_url)
|
||
|
||
# Embedding模型配置
|
||
embedding_model_name = '/datacenter/expert_models/bce-embedding-base_v1/'
|
||
embedding_model_kwargs = {'device': 'cuda'}
|
||
embedding_encode_kwargs = {'batch_size': 32, 'normalize_embeddings': True}
|
||
|
||
embed_model = HuggingFaceEmbeddings(
|
||
model_name=embedding_model_name,
|
||
model_kwargs=embedding_model_kwargs,
|
||
encode_kwargs=embedding_encode_kwargs
|
||
)
|
||
|
||
# 数据库配置
|
||
db_config = {
|
||
'host': '183.11.229.79',
|
||
'port': 3316,
|
||
'user': 'root',
|
||
'password': 'zaq12wsx@9Xin',
|
||
'db': '9xin',
|
||
'autocommit': True
|
||
}
|
||
|
||
async def get_db_connection():
|
||
"""创建数据库连接"""
|
||
return await aiomysql.connect(**db_config)
|
||
|
||
async def get_note_details(conn, note_id):
|
||
"""获取笔记详情"""
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
try:
|
||
# 使用 SELECT * 获取所有字段
|
||
await cursor.execute("""
|
||
SELECT *
|
||
FROM xhs_notes
|
||
WHERE id = %s
|
||
""", (note_id,))
|
||
note = await cursor.fetchone()
|
||
|
||
if note:
|
||
|
||
# 确保所有需要的字段都有默认值
|
||
default_fields = {
|
||
'linked_count': note.get('linked_count', ''),
|
||
'collected_count': note.get('collected_count', ''),
|
||
'comment_count': note.get('comment_count', ''),
|
||
'share_count': note.get('share_count', ''),
|
||
'title': note.get('title', '无标题'),
|
||
'description': note.get('description', '无描述')
|
||
}
|
||
|
||
return note
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取笔记详情时出错: {str(e)}")
|
||
# 返回带有默认值的空笔记
|
||
return {
|
||
'title': '获取失败',
|
||
'description': '无法获取笔记内容',
|
||
'linked_count': 0,
|
||
'collected_count': 0,
|
||
'comment_count': 0,
|
||
'share_count': 0,
|
||
'tag_list': []
|
||
}
|
||
|
||
async def get_clean_note_content(conn, note_id):
|
||
"""获取清洗后的笔记内容"""
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
await cursor.execute("""
|
||
SELECT content_type, content
|
||
FROM clean_note_store
|
||
WHERE note_id = %s
|
||
""", (note_id,))
|
||
results = await cursor.fetchall()
|
||
|
||
clean_content = {
|
||
'guide': '',
|
||
'mindmap': '',
|
||
'summary': ''
|
||
}
|
||
|
||
for row in results:
|
||
clean_content[row['content_type']] = row['content']
|
||
|
||
return clean_content
|
||
|
||
async def get_vector_note_id(conn, vector_id):
|
||
"""获取向量对应的笔记ID"""
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
await cursor.execute("""
|
||
SELECT note_id
|
||
FROM vector_store
|
||
WHERE vector_id = %s
|
||
""", (vector_id,))
|
||
result = await cursor.fetchone()
|
||
return result['note_id'] if result else None
|
||
|
||
async def chat_with_gpt(system_prompt, user_prompt):
|
||
"""与GPT交互"""
|
||
try:
|
||
# 添加错误检查
|
||
if not system_prompt or not user_prompt:
|
||
logger.error("系统提示或用户提示为空")
|
||
return "入参数错误"
|
||
|
||
response = client.chat.completions.create( # 移除 await
|
||
model="gpt-4o-mini", # 修改为正确的模型名称
|
||
messages=[
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt}
|
||
],
|
||
max_tokens=1000,
|
||
temperature=0.7,
|
||
)
|
||
|
||
if not response or not response.choices:
|
||
logger.error("GPT响应为空")
|
||
return "获取回答"
|
||
|
||
return response.choices[0].message.content
|
||
|
||
except Exception as e:
|
||
logger.error(f"GPT API调用错误: {str(e)}")
|
||
return f"调用出错: {str(e)}"
|
||
|
||
async def process_query(query):
|
||
"""处理用户查询"""
|
||
conn = None
|
||
try:
|
||
logger.info(f"开始处理查询: {query}")
|
||
|
||
# 加载向量存储
|
||
vs_path = "./raw_vs"
|
||
logger.info(f"正在检查向量存储路径: {vs_path}")
|
||
if not os.path.exists(vs_path):
|
||
logger.error(f"向量存储路径不存在: {vs_path}")
|
||
return "错误:向量存储不存在", []
|
||
|
||
logger.info("正在加载向量存储...")
|
||
vs = FAISS.load_local(vs_path, embed_model, allow_dangerous_deserialization=True)
|
||
|
||
# 搜索相关向量
|
||
logger.info("开始向量搜索...")
|
||
results = vs.similarity_search_with_score(query, k=10)
|
||
logger.info(f"找到 {len(results)} 条相关结果")
|
||
|
||
try:
|
||
conn = await get_db_connection()
|
||
if not conn:
|
||
logger.error("数据库连接失败")
|
||
return "数据库连接失败", []
|
||
|
||
logger.info("数据库连接成功")
|
||
context_notes = []
|
||
notes_data = [] # 存储完整的笔记数据
|
||
|
||
# 获取不重复的前5条笔记
|
||
seen_note_ids = set()
|
||
for doc, score in results:
|
||
try:
|
||
content_hash = get_content_hash(doc.page_content)
|
||
logger.info(f"处理content_hash: {content_hash}")
|
||
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
await cursor.execute("""
|
||
SELECT note_id FROM vector_store
|
||
WHERE content_hash = %s
|
||
""", (content_hash,))
|
||
result = await cursor.fetchone()
|
||
|
||
if not result or result['note_id'] in seen_note_ids:
|
||
continue
|
||
|
||
note_id = result['note_id']
|
||
seen_note_ids.add(note_id)
|
||
|
||
# 获取笔记详情和清洗内容
|
||
note = await get_note_details(conn, note_id)
|
||
clean_content = await get_clean_note_content(conn, note_id)
|
||
|
||
if not note or not clean_content:
|
||
continue
|
||
|
||
# 构建完整的笔记数据
|
||
note_data = {
|
||
'id': note_id,
|
||
'title': note.get('title', '无标题'),
|
||
'description': note.get('description', '暂无描述'),
|
||
'collected_count': note.get('collected_count', 0),
|
||
'comment_count': note.get('comment_count', 0),
|
||
'share_count': note.get('share_count', 0),
|
||
'clean_content': clean_content
|
||
}
|
||
|
||
notes_data.append(note_data)
|
||
context_notes.append(f"标题:{note['title']}\n内容:{note['description']}")
|
||
|
||
if len(notes_data) >= 5: # 修改为5条
|
||
break
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理笔记时出错: {str(e)}")
|
||
continue
|
||
|
||
if not notes_data:
|
||
logger.warning("未获取到任何有效的笔记内容")
|
||
return "未找到相关笔记内容", []
|
||
|
||
# 准备GPT提示
|
||
logger.info("准备调用GPT...")
|
||
system_prompt = """你是一位专业的化妆行业教师,专门帮助产品经理理解和分析小红书笔记。
|
||
请基于提供的相关笔记内容,对用户的问题给出专业、简洁的回答。
|
||
回答要突出重点,并结合化妆行业的专业知识。"""
|
||
|
||
user_prompt = f"问题:{query}\n\n相关笔记内容:\n" + "\n\n".join(context_notes)
|
||
|
||
logger.info("调用GPT获取回答...")
|
||
gpt_response = await chat_with_gpt(system_prompt, user_prompt)
|
||
|
||
return gpt_response, notes_data
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理数据时出错: {str(e)}")
|
||
return f"处理数据时出错: {str(e)}", []
|
||
|
||
finally:
|
||
if conn is not None:
|
||
try:
|
||
await conn.close()
|
||
logger.info("数据库连接已关闭")
|
||
except Exception as e:
|
||
logger.error(f"关闭数据库连接时出错: {str(e)}")
|
||
|
||
def create_ui():
|
||
"""创建Gradio界面"""
|
||
with gr.Blocks(title="化妆品产品经理AI助教", theme=gr.themes.Soft()) as demo:
|
||
# 存储当前的笔记内容
|
||
state = gr.State([])
|
||
|
||
gr.Markdown("# 化妆品产品经理AI助教")
|
||
|
||
# 修改 HTML 和 JavaScript 部分
|
||
gr.HTML("""
|
||
<style>
|
||
.query-section { background: #f0f7ff; padding: 20px; border-radius: 10px; margin-bottom: 20px; }
|
||
.ai-response { background: #f5f5f5; padding: 20px; border-radius: 10px; margin-bottom: 20px; }
|
||
.notes-analysis { background: #fff5f5; padding: 20px; border-radius: 10px; }
|
||
</style>
|
||
|
||
<!-- 添加 Mermaid 相关脚本 -->
|
||
<script src="https://cdn.jsdelivr.net/npm/mermaid/dist/mermaid.min.js"></script>
|
||
<script>
|
||
// 初始化 Mermaid
|
||
mermaid.initialize({
|
||
startOnLoad: true,
|
||
theme: 'default',
|
||
securityLevel: 'loose',
|
||
flowchart: {
|
||
useMaxWidth: true,
|
||
htmlLabels: true,
|
||
curve: 'basis'
|
||
}
|
||
});
|
||
|
||
// 监听 DOM 变化并渲染新的图表
|
||
const observer = new MutationObserver((mutations) => {
|
||
mutations.forEach((mutation) => {
|
||
if (mutation.addedNodes.length) {
|
||
document.querySelectorAll('.mermaid:not(.mermaid-processed)').forEach(async (element) => {
|
||
try {
|
||
const graphDefinition = element.textContent;
|
||
const { svg } = await mermaid.render('mermaid-svg-' + Math.random(), graphDefinition);
|
||
element.innerHTML = svg;
|
||
element.classList.add('mermaid-processed');
|
||
} catch (error) {
|
||
console.error('Mermaid rendering error:', error);
|
||
}
|
||
});
|
||
}
|
||
});
|
||
});
|
||
|
||
observer.observe(document.body, {
|
||
childList: true,
|
||
subtree: true
|
||
});
|
||
</script>
|
||
""")
|
||
|
||
# 查询区域
|
||
with gr.Row():
|
||
with gr.Column(scale=2, elem_classes="query-section"):
|
||
gr.Markdown("### 💡 问题输入区")
|
||
query_input = gr.Textbox(
|
||
label="请输入你的问题",
|
||
placeholder="例如:最近的面霜趋势是什么?",
|
||
lines=3
|
||
)
|
||
submit_btn = gr.Button("提交问题", variant="primary")
|
||
|
||
# AI回答和笔记展示区域
|
||
with gr.Row():
|
||
# 左侧:AI回答和笔记卡片
|
||
with gr.Column(scale=1, elem_classes="ai-response"):
|
||
gr.Markdown("### 🤖 AI助教回答")
|
||
gpt_output = gr.Markdown()
|
||
gr.Markdown("### 📑 相关笔记")
|
||
# 使用 Radio 组件替代自定义卡片
|
||
note_selector = gr.Radio(
|
||
choices=[],
|
||
label="点击笔记查看详细分析",
|
||
visible=False
|
||
)
|
||
|
||
# 右侧:笔记AI分析
|
||
with gr.Column(scale=1, elem_classes="notes-analysis"):
|
||
gr.Markdown("### 🔍 笔记AI分析")
|
||
clean_content_output = gr.HTML()
|
||
|
||
def format_note_choice(note):
|
||
"""格式化笔记选项"""
|
||
title = note.get('title', '无标题')
|
||
desc = note.get('description', '暂无描述')[:100]
|
||
stats = f"❤️ {note.get('collected_count', 0)} 💬 {note.get('comment_count', 0)} ↗️ {note.get('share_count', 0)}"
|
||
return f"{title}\n{desc}...\n{stats}"
|
||
|
||
async def handle_submit(query, state):
|
||
"""处理用户提交的查询"""
|
||
try:
|
||
gpt_response, notes_data = await process_query(query)
|
||
|
||
if not notes_data:
|
||
return (
|
||
"未找到相关笔记", # GPT响应
|
||
"请点击左侧笔记查看分析内容", # 分析内容
|
||
[], # 状态
|
||
gr.update(choices=[], value=None, visible=False) # Radio组件更新
|
||
)
|
||
|
||
# 格式化笔记选项
|
||
note_choices = [format_note_choice(note) for note in notes_data]
|
||
|
||
return (
|
||
f"## AI助教回答\n\n{gpt_response}", # GPT响应
|
||
"请点击左侧笔记查看分析内容", # 初始分析内容
|
||
notes_data, # 状态
|
||
gr.update( # Radio组件更新
|
||
choices=note_choices,
|
||
value=None,
|
||
visible=True
|
||
)
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理查询时出错: {str(e)}")
|
||
return (
|
||
f"处理查询时出错: {str(e)}",
|
||
"",
|
||
[],
|
||
gr.update(choices=[], value=None, visible=False)
|
||
)
|
||
|
||
def show_note_analysis(choice, state):
|
||
"""处理笔记选择"""
|
||
logger.info(f"笔记选择改变: {choice}")
|
||
if not choice or not state:
|
||
return "请选择笔记查看分析内容"
|
||
|
||
# 根据选择的文本找到对应的笔记索引
|
||
try:
|
||
idx = next(i for i, note in enumerate(state)
|
||
if format_note_choice(note) == choice)
|
||
except StopIteration:
|
||
return "无法找到对应的笔记"
|
||
|
||
note = state[idx]
|
||
clean_content = note.get('clean_content', {})
|
||
|
||
def clean_text(text):
|
||
if not text:
|
||
return ""
|
||
return re.sub(r'```(?:markdown|mermaid)?', '', text).strip()
|
||
|
||
html_parts = []
|
||
|
||
# 添加导读部分
|
||
if clean_content.get('guide'):
|
||
guide = clean_text(clean_content['guide'])
|
||
html_parts.append(f"<h2>导读</h2><div class='guide-content'>{guide}</div>")
|
||
|
||
# 处理思维导图
|
||
if clean_content.get('mindmap'):
|
||
mindmap = clean_text(clean_content['mindmap'])
|
||
iframe_content = f"""
|
||
<iframe srcdoc='
|
||
<!DOCTYPE html>
|
||
<html>
|
||
<head>
|
||
<meta charset="utf-8" />
|
||
<script src="https://cdn.jsdelivr.net/npm/mermaid/dist/mermaid.min.js"></script>
|
||
<script>
|
||
mermaid.initialize({{ startOnLoad: true }});
|
||
</script>
|
||
</head>
|
||
<body>
|
||
<div class="mermaid">
|
||
{mindmap}
|
||
</div>
|
||
</body>
|
||
</html>
|
||
' width="100%" height="600px" style="border:none;">
|
||
</iframe>
|
||
"""
|
||
|
||
html_parts.append(f"""
|
||
<h2>思维导图</h2>
|
||
{iframe_content}
|
||
""")
|
||
|
||
# 添加总结部分
|
||
if clean_content.get('summary'):
|
||
summary = clean_text(clean_content['summary'])
|
||
html_parts.append(f"<h2>总结</h2><div class='summary-content'>{summary}</div>")
|
||
|
||
# 组合所有内容
|
||
html_content = f"""
|
||
<div class='analysis-content' style='padding: 20px;'>
|
||
{' '.join(html_parts)}
|
||
</div>
|
||
"""
|
||
|
||
return html_content
|
||
|
||
def sync_handle_submit(query, state):
|
||
"""同步包装异步函数"""
|
||
return asyncio.run(handle_submit(query, state))
|
||
|
||
# 绑定事件
|
||
submit_btn.click(
|
||
fn=sync_handle_submit,
|
||
inputs=[query_input, state],
|
||
outputs=[gpt_output, clean_content_output, state, note_selector]
|
||
)
|
||
|
||
note_selector.change(
|
||
fn=show_note_analysis,
|
||
inputs=[note_selector, state],
|
||
outputs=[clean_content_output]
|
||
)
|
||
|
||
return demo
|
||
|
||
if __name__ == "__main__":
|
||
demo = create_ui()
|
||
demo.launch(server_name="0.0.0.0", server_port=7865) |