xhs_server/ai_assistant.py
2024-12-16 10:31:07 +08:00

455 lines
18 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import gradio as gr
import asyncio
from openai import OpenAI
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
import aiomysql
import logging
import os
from process_raw_notes import get_content_hash
import re
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# OpenAI配置
# api_key = "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"
# base_url = "http://52.90.243.11:8787/v1"
# client = OpenAI(api_key=api_key, base_url=base_url)
client = OpenAI(
# 若没有配置环境变量请用百炼API Key将下行替换为api_key="sk-xxx",
api_key="sk-9583aa36267540da97a70b0385809f2c",
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
)
# Embedding模型配置
embedding_model_name = '/datacenter/expert_models/bce-embedding-base_v1/'
embedding_model_kwargs = {'device': 'cuda'}
embedding_encode_kwargs = {'batch_size': 32, 'normalize_embeddings': True}
embed_model = HuggingFaceEmbeddings(
model_name=embedding_model_name,
model_kwargs=embedding_model_kwargs,
encode_kwargs=embedding_encode_kwargs
)
# 数据库配置
db_config = {
'host': '183.11.229.79',
'port': 3316,
'user': 'root',
'password': 'zaq12wsx@9Xin',
'db': '9Xin',
'autocommit': True
}
async def get_db_connection():
"""创建数据库连接"""
return await aiomysql.connect(**db_config)
async def get_note_details(conn, note_id):
"""获取笔记详情"""
async with conn.cursor(aiomysql.DictCursor) as cursor:
try:
# 使用 SELECT * 获取所有字段
await cursor.execute("""
SELECT *
FROM xhs_notes
WHERE id = %s
""", (note_id,))
note = await cursor.fetchone()
if note:
# 确保所有需要的字段都有默认值
default_fields = {
'linked_count': note.get('linked_count', ''),
'collected_count': note.get('collected_count', ''),
'comment_count': note.get('comment_count', ''),
'share_count': note.get('share_count', ''),
'title': note.get('title', '无标题'),
'description': note.get('description', '无描述')
}
return note
except Exception as e:
logger.error(f"获取笔记详情时出错: {str(e)}")
# 返回带有默认值的空笔记
return {
'title': '获取失败',
'description': '无法获取笔记内容',
'linked_count': 0,
'collected_count': 0,
'comment_count': 0,
'share_count': 0,
'tag_list': []
}
async def get_clean_note_content(conn, note_id):
"""获取清洗后的笔记内容"""
async with conn.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute("""
SELECT content_type, content
FROM clean_note_store
WHERE note_id = %s
""", (note_id,))
results = await cursor.fetchall()
clean_content = {
'guide': '',
'mindmap': '',
'summary': ''
}
for row in results:
clean_content[row['content_type']] = row['content']
return clean_content
async def get_vector_note_id(conn, vector_id):
"""获取向量对应的笔记ID"""
async with conn.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute("""
SELECT note_id
FROM vector_store
WHERE vector_id = %s
""", (vector_id,))
result = await cursor.fetchone()
return result['note_id'] if result else None
async def chat_with_gpt(system_prompt, user_prompt):
"""与GPT交互"""
try:
# 添加错误检查
if not system_prompt or not user_prompt:
logger.error("系统提示或用户提示为空")
return "入参数错误"
response = client.chat.completions.create(
model="qwen-plus",#"gpt-4o-mini", # 修改为正确的模型名称
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
max_tokens=1000,
temperature=0.7,
)
if not response or not response.choices:
logger.error("GPT响应为空")
return "获取回答"
return response.choices[0].message.content
except Exception as e:
logger.error(f"GPT API调用错误: {str(e)}")
return f"调用出错: {str(e)}"
async def process_query(query):
"""处理用户查询"""
conn = None
try:
logger.info(f"开始处理查询: {query}")
# 加载向量存储
vs_path = "./raw_vs"
logger.info(f"正在检查向量存储路径: {vs_path}")
if not os.path.exists(vs_path):
logger.error(f"向量存储路径不存在: {vs_path}")
return "错误:向量存储不存在", []
logger.info("正在加载向量存储...")
vs = FAISS.load_local(vs_path, embed_model, allow_dangerous_deserialization=True)
# 搜索相关向量
logger.info("开始向量搜索...")
results = vs.similarity_search_with_score(query, k=10)
logger.info(f"找到 {len(results)} 条相关结果")
try:
conn = await get_db_connection()
if not conn:
logger.error("数据库连接失败")
return "数据库连接失败", []
logger.info("数据库连接成功")
context_notes = []
notes_data = [] # 存储完整的笔记数据
# 获取相似度分数大于0.5的前5条笔记
seen_note_ids = set()
for doc, similarity_score in results:
# MAX_INNER_PRODUCT 策略下,分数越大表示越相似
# 如果相似度分数小于0.5,跳过该结果
if similarity_score < 0.5:
continue
try:
content_hash = get_content_hash(doc.page_content)
logger.info(f"处理content_hash: {content_hash}, 相似度分数: {similarity_score}")
async with conn.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute("""
SELECT note_id FROM vector_store
WHERE content_hash = %s
""", (content_hash,))
result = await cursor.fetchone()
if not result or result['note_id'] in seen_note_ids:
continue
note_id = result['note_id']
seen_note_ids.add(note_id)
# 获取笔记详情和清洗内容
note = await get_note_details(conn, note_id)
clean_content = await get_clean_note_content(conn, note_id)
if not note or not clean_content:
continue
# 构建完整的笔记数据
note_data = {
'id': note_id,
'title': note.get('title', '无标题'),
'description': note.get('description', '暂无描述'),
'collected_count': note.get('collected_count', 0),
'comment_count': note.get('comment_count', 0),
'share_count': note.get('share_count', 0),
'clean_content': clean_content
}
notes_data.append(note_data)
context_notes.append(f"标题:{note['title']}\n内容:{note['description']}")
if len(notes_data) >= 5: # 仍然限制最多5条
break
except Exception as e:
logger.error(f"处理笔记时出错: {str(e)}")
continue
# 即使没有找到符合条件的笔记也继续执行
logger.info(f"找到 {len(notes_data)} 条符合相似度要求的笔记")
if not notes_data:
logger.warning("未获取到任何有效的笔记内容")
# return "未找到相关笔记内容", []
# 准备GPT提示
logger.info("准备调用GPT...")
system_prompt = """你是一位专业的化妆行业教师专门帮助产品经理理解和分析小红书笔记。你的回答分下面3种情况
1、如果用户的输入与化妆品无关不要参考相关笔记内容正常与客户交流。
2、如果用户的输入与化妆品相关而且找到相关笔记内容参考相关笔记内容给出回答。
3、如果用户的输入与化妆品相关但是没有找到相关笔记内容请结合上下文和历史记录给出回答。
回答要突出重点,并结合化妆品行业的专业知识。\n""" + f"相关笔记内容:\n" + "\n\n".join(context_notes)
user_prompt = f"{query}"
logger.info("调用GPT获取回答...")
gpt_response = await chat_with_gpt(system_prompt, user_prompt)
return gpt_response, notes_data
except Exception as e:
logger.error(f"处理数据时出错: {str(e)}")
return f"处理数据时出错: {str(e)}", []
finally:
if conn is not None:
try:
await conn.close()
logger.info("数据库连接已关闭")
except Exception as e:
logger.error(f"关闭数据库连接时出错: {str(e)}")
def create_ui():
"""创建Gradio界面"""
with gr.Blocks(title="化妆品产品经理AI助教", theme=gr.themes.Soft()) as demo:
# 存储当前的笔记内容
state = gr.State([])
gr.Markdown("# 化妆品产品经理AI助教")
# 使用CSS设置三个区域的样式
gr.HTML("""
<style>
.query-section { background: #f0f7ff; padding: 20px; border-radius: 10px; margin-bottom: 20px; }
.ai-response { background: #f5f5f5; padding: 20px; border-radius: 10px; margin-bottom: 20px; }
.notes-analysis { background: #fff5f5; padding: 20px; border-radius: 10px; }
</style>
""")
# 查询区域
with gr.Row():
with gr.Column(scale=2, elem_classes="query-section"):
gr.Markdown("### 💡 问题输入区")
query_input = gr.Textbox(
label="请输入你的问题",
placeholder="例如:最近的面霜趋势是什么?",
lines=3
)
submit_btn = gr.Button("提交问题", variant="primary")
# AI回答和笔记展示区域
with gr.Row():
# 左侧AI回答和笔记卡片
with gr.Column(scale=1, elem_classes="ai-response"):
gr.Markdown("### 🤖 AI助教回答")
gpt_output = gr.Markdown()
gr.Markdown("### 📑 相关笔记")
# 使用 Radio 组件替代自定义卡片
note_selector = gr.Radio(
choices=[],
label="点击笔记查看详细分析",
visible=False
)
# 右侧笔记AI分析
with gr.Column(scale=1, elem_classes="notes-analysis"):
gr.Markdown("### 🔍 笔记AI分析")
clean_content_output = gr.Markdown()
def format_note_choice(note):
"""格式化笔记选项"""
title = note.get('title', '无标题')
desc = note.get('description', '暂无描述')[:100]
stats = f"❤️ {note.get('collected_count', 0)} 💬 {note.get('comment_count', 0)} ↗️ {note.get('share_count', 0)}"
return f"{title}\n{desc}...\n{stats}"
async def handle_submit(query, state):
"""处理用户提交的查询"""
try:
gpt_response, notes_data = await process_query(query)
logger.info(f"GPT响应: {gpt_response}")
if not notes_data:
return (
f"{gpt_response}", # GPT响应
"请点击左侧笔记查看分析内容", # 分析内容
[], # 状态
gr.update(choices=[], value=None, visible=False) # Radio组件更新
)
# 格式化笔记选项
note_choices = [format_note_choice(note) for note in notes_data]
return (
f"{gpt_response}", # GPT响应
"请点击左侧笔记查看分析内容", # 初始分析内容
notes_data, # 状态
gr.update( # Radio组件更新
choices=note_choices,
value=None,
visible=True
)
)
except Exception as e:
logger.error(f"处理查询时出错: {str(e)}")
return (
f"处理查询时出错: {str(e)}",
"",
[],
gr.update(choices=[], value=None, visible=False)
)
def show_note_analysis(choice, state):
"""处理笔记选择"""
logger.info(f"笔记选择改变: {choice}")
if not choice or not state:
return "请选择笔记查看分析内容"
# 根据选择的文本找到对应的笔记索引
try:
idx = next(i for i, note in enumerate(state)
if format_note_choice(note) == choice)
except StopIteration:
return "无法找到对应的笔记"
note = state[idx]
clean_content = note.get('clean_content', {})
# 定义一个函数来清理内容
def clean_text(text):
# 移除 ``` 和 ```markdown, ```mermaid
return re.sub(r'```(?:markdown|mermaid)?', '', text)
markdown = ""
if clean_content.get('guide'):
guide = clean_text(clean_content['guide'])
markdown += "## 导读\n\n" + guide + "\n\n"
if clean_content.get('mindmap'):
mindmap = clean_text(clean_content['mindmap'])
markdown += "## 思维导图\n\n" + mindmap + "\n\n"
if clean_content.get('summary'):
summary = clean_text(clean_content['summary'])
markdown += "## 总结\n\n" + summary
return markdown or "该笔记暂无分析内容"
def sync_handle_submit(query, state):
"""同步包装异步函数"""
return asyncio.run(handle_submit(query, state))
# 绑定事件
submit_btn.click(
fn=sync_handle_submit,
inputs=[query_input, state],
outputs=[gpt_output, clean_content_output, state, note_selector]
)
note_selector.change(
fn=show_note_analysis,
inputs=[note_selector, state],
outputs=[clean_content_output]
)
return demo
def format_ai_analysis(note_data):
"""格式化笔记的AI分析为Markdown格式"""
try:
markdown = f"# {note_data.get('title', '无标题')}\n\n"
content = note_data.get('content', {})
if not isinstance(content, dict):
content = {'guide': '', 'mindmap': '', 'summary': ''}
# 添加导读部分
if content.get('guide'):
markdown += "## 导读\n\n"
markdown += f"{content['guide']}\n\n"
# 添加思维导图部分
if content.get('mindmap'):
markdown += "## 思维导图\n\n"
markdown += f"```\n{content['mindmap']}\n```\n\n"
# 添加总结部分
if content.get('summary'):
markdown += "## 总结\n\n"
markdown += f"{content['summary']}\n\n"
# 添加统计信
markdown += "## 互动数据\n\n"
markdown += f"- 点赞数:{note_data.get('collected_count', 0)}\n"
markdown += f"- 评论数:{note_data.get('comment_count', 0)}\n"
markdown += f"- <20><><EFBFBD>享数{note_data.get('share_count', 0)}\n"
markdown += f"- 收藏数:{note_data.get('linked_count', 0)}\n\n"
# 添加标签
if note_data.get('tag_list'):
markdown += "## 标签\n\n"
tags = [f"#{tag}" for tag in note_data['tag_list']]
markdown += ", ".join(tags)
return markdown
except Exception as e:
logger.error(f"格式化AI分析时出错: {str(e)}")
return "无法加载笔记分析内容"
if __name__ == "__main__":
demo = create_ui()
demo.launch(server_name="0.0.0.0", server_port=7864)