455 lines
18 KiB
Python
455 lines
18 KiB
Python
import gradio as gr
|
||
import asyncio
|
||
from openai import OpenAI
|
||
from langchain_community.vectorstores import FAISS
|
||
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||
import aiomysql
|
||
import logging
|
||
import os
|
||
from process_raw_notes import get_content_hash
|
||
import re
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# OpenAI配置
|
||
# api_key = "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"
|
||
# base_url = "http://52.90.243.11:8787/v1"
|
||
# client = OpenAI(api_key=api_key, base_url=base_url)
|
||
client = OpenAI(
|
||
# 若没有配置环境变量,请用百炼API Key将下行替换为:api_key="sk-xxx",
|
||
api_key="sk-9583aa36267540da97a70b0385809f2c",
|
||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||
)
|
||
# Embedding模型配置
|
||
embedding_model_name = '/datacenter/expert_models/bce-embedding-base_v1/'
|
||
embedding_model_kwargs = {'device': 'cuda'}
|
||
embedding_encode_kwargs = {'batch_size': 32, 'normalize_embeddings': True}
|
||
|
||
embed_model = HuggingFaceEmbeddings(
|
||
model_name=embedding_model_name,
|
||
model_kwargs=embedding_model_kwargs,
|
||
encode_kwargs=embedding_encode_kwargs
|
||
)
|
||
|
||
# 数据库配置
|
||
db_config = {
|
||
'host': '183.11.229.79',
|
||
'port': 3316,
|
||
'user': 'root',
|
||
'password': 'zaq12wsx@9Xin',
|
||
'db': '9Xin',
|
||
'autocommit': True
|
||
}
|
||
|
||
async def get_db_connection():
|
||
"""创建数据库连接"""
|
||
return await aiomysql.connect(**db_config)
|
||
|
||
async def get_note_details(conn, note_id):
|
||
"""获取笔记详情"""
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
try:
|
||
# 使用 SELECT * 获取所有字段
|
||
await cursor.execute("""
|
||
SELECT *
|
||
FROM xhs_notes
|
||
WHERE id = %s
|
||
""", (note_id,))
|
||
note = await cursor.fetchone()
|
||
|
||
if note:
|
||
|
||
# 确保所有需要的字段都有默认值
|
||
default_fields = {
|
||
'linked_count': note.get('linked_count', ''),
|
||
'collected_count': note.get('collected_count', ''),
|
||
'comment_count': note.get('comment_count', ''),
|
||
'share_count': note.get('share_count', ''),
|
||
'title': note.get('title', '无标题'),
|
||
'description': note.get('description', '无描述')
|
||
}
|
||
|
||
return note
|
||
|
||
except Exception as e:
|
||
logger.error(f"获取笔记详情时出错: {str(e)}")
|
||
# 返回带有默认值的空笔记
|
||
return {
|
||
'title': '获取失败',
|
||
'description': '无法获取笔记内容',
|
||
'linked_count': 0,
|
||
'collected_count': 0,
|
||
'comment_count': 0,
|
||
'share_count': 0,
|
||
'tag_list': []
|
||
}
|
||
|
||
async def get_clean_note_content(conn, note_id):
|
||
"""获取清洗后的笔记内容"""
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
await cursor.execute("""
|
||
SELECT content_type, content
|
||
FROM clean_note_store
|
||
WHERE note_id = %s
|
||
""", (note_id,))
|
||
results = await cursor.fetchall()
|
||
|
||
clean_content = {
|
||
'guide': '',
|
||
'mindmap': '',
|
||
'summary': ''
|
||
}
|
||
|
||
for row in results:
|
||
clean_content[row['content_type']] = row['content']
|
||
|
||
return clean_content
|
||
|
||
async def get_vector_note_id(conn, vector_id):
|
||
"""获取向量对应的笔记ID"""
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
await cursor.execute("""
|
||
SELECT note_id
|
||
FROM vector_store
|
||
WHERE vector_id = %s
|
||
""", (vector_id,))
|
||
result = await cursor.fetchone()
|
||
return result['note_id'] if result else None
|
||
|
||
async def chat_with_gpt(system_prompt, user_prompt):
|
||
"""与GPT交互"""
|
||
try:
|
||
# 添加错误检查
|
||
if not system_prompt or not user_prompt:
|
||
logger.error("系统提示或用户提示为空")
|
||
return "入参数错误"
|
||
|
||
response = client.chat.completions.create(
|
||
model="qwen-plus",#"gpt-4o-mini", # 修改为正确的模型名称
|
||
messages=[
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt}
|
||
],
|
||
max_tokens=1000,
|
||
temperature=0.7,
|
||
)
|
||
|
||
if not response or not response.choices:
|
||
logger.error("GPT响应为空")
|
||
return "获取回答"
|
||
|
||
return response.choices[0].message.content
|
||
|
||
except Exception as e:
|
||
logger.error(f"GPT API调用错误: {str(e)}")
|
||
return f"调用出错: {str(e)}"
|
||
|
||
async def process_query(query):
|
||
"""处理用户查询"""
|
||
conn = None
|
||
try:
|
||
logger.info(f"开始处理查询: {query}")
|
||
|
||
# 加载向量存储
|
||
vs_path = "./raw_vs"
|
||
logger.info(f"正在检查向量存储路径: {vs_path}")
|
||
if not os.path.exists(vs_path):
|
||
logger.error(f"向量存储路径不存在: {vs_path}")
|
||
return "错误:向量存储不存在", []
|
||
|
||
logger.info("正在加载向量存储...")
|
||
vs = FAISS.load_local(vs_path, embed_model, allow_dangerous_deserialization=True)
|
||
|
||
# 搜索相关向量
|
||
logger.info("开始向量搜索...")
|
||
results = vs.similarity_search_with_score(query, k=10)
|
||
logger.info(f"找到 {len(results)} 条相关结果")
|
||
|
||
try:
|
||
conn = await get_db_connection()
|
||
if not conn:
|
||
logger.error("数据库连接失败")
|
||
return "数据库连接失败", []
|
||
|
||
logger.info("数据库连接成功")
|
||
context_notes = []
|
||
notes_data = [] # 存储完整的笔记数据
|
||
|
||
# 获取相似度分数大于0.5的前5条笔记
|
||
seen_note_ids = set()
|
||
for doc, similarity_score in results:
|
||
# MAX_INNER_PRODUCT 策略下,分数越大表示越相似
|
||
# 如果相似度分数小于0.5,跳过该结果
|
||
if similarity_score < 0.5:
|
||
continue
|
||
|
||
try:
|
||
content_hash = get_content_hash(doc.page_content)
|
||
logger.info(f"处理content_hash: {content_hash}, 相似度分数: {similarity_score}")
|
||
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
await cursor.execute("""
|
||
SELECT note_id FROM vector_store
|
||
WHERE content_hash = %s
|
||
""", (content_hash,))
|
||
result = await cursor.fetchone()
|
||
|
||
if not result or result['note_id'] in seen_note_ids:
|
||
continue
|
||
|
||
note_id = result['note_id']
|
||
seen_note_ids.add(note_id)
|
||
|
||
# 获取笔记详情和清洗内容
|
||
note = await get_note_details(conn, note_id)
|
||
clean_content = await get_clean_note_content(conn, note_id)
|
||
|
||
if not note or not clean_content:
|
||
continue
|
||
|
||
# 构建完整的笔记数据
|
||
note_data = {
|
||
'id': note_id,
|
||
'title': note.get('title', '无标题'),
|
||
'description': note.get('description', '暂无描述'),
|
||
'collected_count': note.get('collected_count', 0),
|
||
'comment_count': note.get('comment_count', 0),
|
||
'share_count': note.get('share_count', 0),
|
||
'clean_content': clean_content
|
||
}
|
||
|
||
notes_data.append(note_data)
|
||
context_notes.append(f"标题:{note['title']}\n内容:{note['description']}")
|
||
|
||
if len(notes_data) >= 5: # 仍然限制最多5条
|
||
break
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理笔记时出错: {str(e)}")
|
||
continue
|
||
|
||
# 即使没有找到符合条件的笔记也继续执行
|
||
logger.info(f"找到 {len(notes_data)} 条符合相似度要求的笔记")
|
||
|
||
if not notes_data:
|
||
logger.warning("未获取到任何有效的笔记内容")
|
||
# return "未找到相关笔记内容", []
|
||
|
||
# 准备GPT提示
|
||
logger.info("准备调用GPT...")
|
||
system_prompt = """你是一位专业的化妆行业教师,专门帮助产品经理理解和分析小红书笔记。你的回答分下面3种情况:
|
||
1、如果用户的输入与化妆品无关,不要参考相关笔记内容,正常与客户交流。
|
||
2、如果用户的输入与化妆品相关,而且找到相关笔记内容,参考相关笔记内容,给出回答。
|
||
3、如果用户的输入与化妆品相关,但是没有找到相关笔记内容,请结合上下文和历史记录,给出回答。
|
||
回答要突出重点,并结合化妆品行业的专业知识。\n""" + f"相关笔记内容:\n" + "\n\n".join(context_notes)
|
||
user_prompt = f"{query}"
|
||
logger.info("调用GPT获取回答...")
|
||
gpt_response = await chat_with_gpt(system_prompt, user_prompt)
|
||
|
||
return gpt_response, notes_data
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理数据时出错: {str(e)}")
|
||
return f"处理数据时出错: {str(e)}", []
|
||
|
||
finally:
|
||
if conn is not None:
|
||
try:
|
||
await conn.close()
|
||
logger.info("数据库连接已关闭")
|
||
except Exception as e:
|
||
logger.error(f"关闭数据库连接时出错: {str(e)}")
|
||
|
||
def create_ui():
|
||
"""创建Gradio界面"""
|
||
with gr.Blocks(title="化妆品产品经理AI助教", theme=gr.themes.Soft()) as demo:
|
||
# 存储当前的笔记内容
|
||
state = gr.State([])
|
||
|
||
gr.Markdown("# 化妆品产品经理AI助教")
|
||
|
||
# 使用CSS设置三个区域的样式
|
||
gr.HTML("""
|
||
<style>
|
||
.query-section { background: #f0f7ff; padding: 20px; border-radius: 10px; margin-bottom: 20px; }
|
||
.ai-response { background: #f5f5f5; padding: 20px; border-radius: 10px; margin-bottom: 20px; }
|
||
.notes-analysis { background: #fff5f5; padding: 20px; border-radius: 10px; }
|
||
</style>
|
||
""")
|
||
|
||
# 查询区域
|
||
with gr.Row():
|
||
with gr.Column(scale=2, elem_classes="query-section"):
|
||
gr.Markdown("### 💡 问题输入区")
|
||
query_input = gr.Textbox(
|
||
label="请输入你的问题",
|
||
placeholder="例如:最近的面霜趋势是什么?",
|
||
lines=3
|
||
)
|
||
submit_btn = gr.Button("提交问题", variant="primary")
|
||
|
||
# AI回答和笔记展示区域
|
||
with gr.Row():
|
||
# 左侧:AI回答和笔记卡片
|
||
with gr.Column(scale=1, elem_classes="ai-response"):
|
||
gr.Markdown("### 🤖 AI助教回答")
|
||
gpt_output = gr.Markdown()
|
||
gr.Markdown("### 📑 相关笔记")
|
||
# 使用 Radio 组件替代自定义卡片
|
||
note_selector = gr.Radio(
|
||
choices=[],
|
||
label="点击笔记查看详细分析",
|
||
visible=False
|
||
)
|
||
|
||
# 右侧:笔记AI分析
|
||
with gr.Column(scale=1, elem_classes="notes-analysis"):
|
||
gr.Markdown("### 🔍 笔记AI分析")
|
||
clean_content_output = gr.Markdown()
|
||
|
||
def format_note_choice(note):
|
||
"""格式化笔记选项"""
|
||
title = note.get('title', '无标题')
|
||
desc = note.get('description', '暂无描述')[:100]
|
||
stats = f"❤️ {note.get('collected_count', 0)} 💬 {note.get('comment_count', 0)} ↗️ {note.get('share_count', 0)}"
|
||
return f"{title}\n{desc}...\n{stats}"
|
||
|
||
async def handle_submit(query, state):
|
||
"""处理用户提交的查询"""
|
||
try:
|
||
gpt_response, notes_data = await process_query(query)
|
||
logger.info(f"GPT响应: {gpt_response}")
|
||
if not notes_data:
|
||
return (
|
||
f"{gpt_response}", # GPT响应
|
||
"请点击左侧笔记查看分析内容", # 分析内容
|
||
[], # 状态
|
||
gr.update(choices=[], value=None, visible=False) # Radio组件更新
|
||
)
|
||
|
||
# 格式化笔记选项
|
||
note_choices = [format_note_choice(note) for note in notes_data]
|
||
|
||
return (
|
||
f"{gpt_response}", # GPT响应
|
||
"请点击左侧笔记查看分析内容", # 初始分析内容
|
||
notes_data, # 状态
|
||
gr.update( # Radio组件更新
|
||
choices=note_choices,
|
||
value=None,
|
||
visible=True
|
||
)
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理查询时出错: {str(e)}")
|
||
return (
|
||
f"处理查询时出错: {str(e)}",
|
||
"",
|
||
[],
|
||
gr.update(choices=[], value=None, visible=False)
|
||
)
|
||
|
||
def show_note_analysis(choice, state):
|
||
"""处理笔记选择"""
|
||
logger.info(f"笔记选择改变: {choice}")
|
||
if not choice or not state:
|
||
return "请选择笔记查看分析内容"
|
||
|
||
# 根据选择的文本找到对应的笔记索引
|
||
try:
|
||
idx = next(i for i, note in enumerate(state)
|
||
if format_note_choice(note) == choice)
|
||
except StopIteration:
|
||
return "无法找到对应的笔记"
|
||
|
||
note = state[idx]
|
||
clean_content = note.get('clean_content', {})
|
||
|
||
# 定义一个函数来清理内容
|
||
def clean_text(text):
|
||
# 移除 ``` 和 ```markdown, ```mermaid
|
||
return re.sub(r'```(?:markdown|mermaid)?', '', text)
|
||
|
||
markdown = ""
|
||
if clean_content.get('guide'):
|
||
guide = clean_text(clean_content['guide'])
|
||
markdown += "## 导读\n\n" + guide + "\n\n"
|
||
|
||
if clean_content.get('mindmap'):
|
||
mindmap = clean_text(clean_content['mindmap'])
|
||
markdown += "## 思维导图\n\n" + mindmap + "\n\n"
|
||
|
||
if clean_content.get('summary'):
|
||
summary = clean_text(clean_content['summary'])
|
||
markdown += "## 总结\n\n" + summary
|
||
|
||
return markdown or "该笔记暂无分析内容"
|
||
|
||
def sync_handle_submit(query, state):
|
||
"""同步包装异步函数"""
|
||
return asyncio.run(handle_submit(query, state))
|
||
|
||
# 绑定事件
|
||
submit_btn.click(
|
||
fn=sync_handle_submit,
|
||
inputs=[query_input, state],
|
||
outputs=[gpt_output, clean_content_output, state, note_selector]
|
||
)
|
||
|
||
note_selector.change(
|
||
fn=show_note_analysis,
|
||
inputs=[note_selector, state],
|
||
outputs=[clean_content_output]
|
||
)
|
||
|
||
return demo
|
||
|
||
def format_ai_analysis(note_data):
|
||
"""格式化笔记的AI分析为Markdown格式"""
|
||
try:
|
||
markdown = f"# {note_data.get('title', '无标题')}\n\n"
|
||
|
||
content = note_data.get('content', {})
|
||
if not isinstance(content, dict):
|
||
content = {'guide': '', 'mindmap': '', 'summary': ''}
|
||
|
||
# 添加导读部分
|
||
if content.get('guide'):
|
||
markdown += "## 导读\n\n"
|
||
markdown += f"{content['guide']}\n\n"
|
||
|
||
# 添加思维导图部分
|
||
if content.get('mindmap'):
|
||
markdown += "## 思维导图\n\n"
|
||
markdown += f"```\n{content['mindmap']}\n```\n\n"
|
||
|
||
# 添加总结部分
|
||
if content.get('summary'):
|
||
markdown += "## 总结\n\n"
|
||
markdown += f"{content['summary']}\n\n"
|
||
|
||
# 添加统计信
|
||
markdown += "## 互动数据\n\n"
|
||
markdown += f"- 点赞数:{note_data.get('collected_count', 0)}\n"
|
||
markdown += f"- 评论数:{note_data.get('comment_count', 0)}\n"
|
||
markdown += f"- <20><><EFBFBD>享数:{note_data.get('share_count', 0)}\n"
|
||
markdown += f"- 收藏数:{note_data.get('linked_count', 0)}\n\n"
|
||
|
||
# 添加标签
|
||
if note_data.get('tag_list'):
|
||
markdown += "## 标签\n\n"
|
||
tags = [f"#{tag}" for tag in note_data['tag_list']]
|
||
markdown += ", ".join(tags)
|
||
|
||
return markdown
|
||
|
||
except Exception as e:
|
||
logger.error(f"格式化AI分析时出错: {str(e)}")
|
||
return "无法加载笔记分析内容"
|
||
|
||
if __name__ == "__main__":
|
||
demo = create_ui()
|
||
demo.launch(server_name="0.0.0.0", server_port=7864) |