455 lines
18 KiB
Python
455 lines
18 KiB
Python
|
import gradio as gr
|
|||
|
import asyncio
|
|||
|
from openai import OpenAI
|
|||
|
from langchain_community.vectorstores import FAISS
|
|||
|
from langchain_community.embeddings import HuggingFaceEmbeddings
|
|||
|
import aiomysql
|
|||
|
import logging
|
|||
|
import os
|
|||
|
from process_raw_notes import get_content_hash
|
|||
|
import re
|
|||
|
|
|||
|
# 配置日志
|
|||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|||
|
logger = logging.getLogger(__name__)
|
|||
|
|
|||
|
# OpenAI配置
|
|||
|
# api_key = "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"
|
|||
|
# base_url = "http://52.90.243.11:8787/v1"
|
|||
|
# client = OpenAI(api_key=api_key, base_url=base_url)
|
|||
|
client = OpenAI(
|
|||
|
# 若没有配置环境变量,请用百炼API Key将下行替换为:api_key="sk-xxx",
|
|||
|
api_key="sk-9583aa36267540da97a70b0385809f2c",
|
|||
|
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
|||
|
)
|
|||
|
# Embedding模型配置
|
|||
|
embedding_model_name = '/datacenter/expert_models/bce-embedding-base_v1/'
|
|||
|
embedding_model_kwargs = {'device': 'cuda'}
|
|||
|
embedding_encode_kwargs = {'batch_size': 32, 'normalize_embeddings': True}
|
|||
|
|
|||
|
embed_model = HuggingFaceEmbeddings(
|
|||
|
model_name=embedding_model_name,
|
|||
|
model_kwargs=embedding_model_kwargs,
|
|||
|
encode_kwargs=embedding_encode_kwargs
|
|||
|
)
|
|||
|
|
|||
|
# 数据库配置
|
|||
|
db_config = {
|
|||
|
'host': '183.11.229.79',
|
|||
|
'port': 3316,
|
|||
|
'user': 'root',
|
|||
|
'password': 'zaq12wsx@9Xin',
|
|||
|
'db': '9Xin',
|
|||
|
'autocommit': True
|
|||
|
}
|
|||
|
|
|||
|
async def get_db_connection():
|
|||
|
"""创建数据库连接"""
|
|||
|
return await aiomysql.connect(**db_config)
|
|||
|
|
|||
|
async def get_note_details(conn, note_id):
|
|||
|
"""获取笔记详情"""
|
|||
|
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
|||
|
try:
|
|||
|
# 使用 SELECT * 获取所有字段
|
|||
|
await cursor.execute("""
|
|||
|
SELECT *
|
|||
|
FROM xhs_notes
|
|||
|
WHERE id = %s
|
|||
|
""", (note_id,))
|
|||
|
note = await cursor.fetchone()
|
|||
|
|
|||
|
if note:
|
|||
|
|
|||
|
# 确保所有需要的字段都有默认值
|
|||
|
default_fields = {
|
|||
|
'linked_count': note.get('linked_count', ''),
|
|||
|
'collected_count': note.get('collected_count', ''),
|
|||
|
'comment_count': note.get('comment_count', ''),
|
|||
|
'share_count': note.get('share_count', ''),
|
|||
|
'title': note.get('title', '无标题'),
|
|||
|
'description': note.get('description', '无描述')
|
|||
|
}
|
|||
|
|
|||
|
return note
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"获取笔记详情时出错: {str(e)}")
|
|||
|
# 返回带有默认值的空笔记
|
|||
|
return {
|
|||
|
'title': '获取失败',
|
|||
|
'description': '无法获取笔记内容',
|
|||
|
'linked_count': 0,
|
|||
|
'collected_count': 0,
|
|||
|
'comment_count': 0,
|
|||
|
'share_count': 0,
|
|||
|
'tag_list': []
|
|||
|
}
|
|||
|
|
|||
|
async def get_clean_note_content(conn, note_id):
|
|||
|
"""获取清洗后的笔记内容"""
|
|||
|
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
|||
|
await cursor.execute("""
|
|||
|
SELECT content_type, content
|
|||
|
FROM clean_note_store
|
|||
|
WHERE note_id = %s
|
|||
|
""", (note_id,))
|
|||
|
results = await cursor.fetchall()
|
|||
|
|
|||
|
clean_content = {
|
|||
|
'guide': '',
|
|||
|
'mindmap': '',
|
|||
|
'summary': ''
|
|||
|
}
|
|||
|
|
|||
|
for row in results:
|
|||
|
clean_content[row['content_type']] = row['content']
|
|||
|
|
|||
|
return clean_content
|
|||
|
|
|||
|
async def get_vector_note_id(conn, vector_id):
|
|||
|
"""获取向量对应的笔记ID"""
|
|||
|
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
|||
|
await cursor.execute("""
|
|||
|
SELECT note_id
|
|||
|
FROM vector_store
|
|||
|
WHERE vector_id = %s
|
|||
|
""", (vector_id,))
|
|||
|
result = await cursor.fetchone()
|
|||
|
return result['note_id'] if result else None
|
|||
|
|
|||
|
async def chat_with_gpt(system_prompt, user_prompt):
|
|||
|
"""与GPT交互"""
|
|||
|
try:
|
|||
|
# 添加错误检查
|
|||
|
if not system_prompt or not user_prompt:
|
|||
|
logger.error("系统提示或用户提示为空")
|
|||
|
return "入参数错误"
|
|||
|
|
|||
|
response = client.chat.completions.create(
|
|||
|
model="qwen-plus",#"gpt-4o-mini", # 修改为正确的模型名称
|
|||
|
messages=[
|
|||
|
{"role": "system", "content": system_prompt},
|
|||
|
{"role": "user", "content": user_prompt}
|
|||
|
],
|
|||
|
max_tokens=1000,
|
|||
|
temperature=0.7,
|
|||
|
)
|
|||
|
|
|||
|
if not response or not response.choices:
|
|||
|
logger.error("GPT响应为空")
|
|||
|
return "获取回答"
|
|||
|
|
|||
|
return response.choices[0].message.content
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"GPT API调用错误: {str(e)}")
|
|||
|
return f"调用出错: {str(e)}"
|
|||
|
|
|||
|
async def process_query(query):
|
|||
|
"""处理用户查询"""
|
|||
|
conn = None
|
|||
|
try:
|
|||
|
logger.info(f"开始处理查询: {query}")
|
|||
|
|
|||
|
# 加载向量存储
|
|||
|
vs_path = "./raw_vs"
|
|||
|
logger.info(f"正在检查向量存储路径: {vs_path}")
|
|||
|
if not os.path.exists(vs_path):
|
|||
|
logger.error(f"向量存储路径不存在: {vs_path}")
|
|||
|
return "错误:向量存储不存在", []
|
|||
|
|
|||
|
logger.info("正在加载向量存储...")
|
|||
|
vs = FAISS.load_local(vs_path, embed_model, allow_dangerous_deserialization=True)
|
|||
|
|
|||
|
# 搜索相关向量
|
|||
|
logger.info("开始向量搜索...")
|
|||
|
results = vs.similarity_search_with_score(query, k=10)
|
|||
|
logger.info(f"找到 {len(results)} 条相关结果")
|
|||
|
|
|||
|
try:
|
|||
|
conn = await get_db_connection()
|
|||
|
if not conn:
|
|||
|
logger.error("数据库连接失败")
|
|||
|
return "数据库连接失败", []
|
|||
|
|
|||
|
logger.info("数据库连接成功")
|
|||
|
context_notes = []
|
|||
|
notes_data = [] # 存储完整的笔记数据
|
|||
|
|
|||
|
# 获取相似度分数大于0.5的前5条笔记
|
|||
|
seen_note_ids = set()
|
|||
|
for doc, similarity_score in results:
|
|||
|
# MAX_INNER_PRODUCT 策略下,分数越大表示越相似
|
|||
|
# 如果相似度分数小于0.5,跳过该结果
|
|||
|
if similarity_score < 0.5:
|
|||
|
continue
|
|||
|
|
|||
|
try:
|
|||
|
content_hash = get_content_hash(doc.page_content)
|
|||
|
logger.info(f"处理content_hash: {content_hash}, 相似度分数: {similarity_score}")
|
|||
|
|
|||
|
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
|||
|
await cursor.execute("""
|
|||
|
SELECT note_id FROM vector_store
|
|||
|
WHERE content_hash = %s
|
|||
|
""", (content_hash,))
|
|||
|
result = await cursor.fetchone()
|
|||
|
|
|||
|
if not result or result['note_id'] in seen_note_ids:
|
|||
|
continue
|
|||
|
|
|||
|
note_id = result['note_id']
|
|||
|
seen_note_ids.add(note_id)
|
|||
|
|
|||
|
# 获取笔记详情和清洗内容
|
|||
|
note = await get_note_details(conn, note_id)
|
|||
|
clean_content = await get_clean_note_content(conn, note_id)
|
|||
|
|
|||
|
if not note or not clean_content:
|
|||
|
continue
|
|||
|
|
|||
|
# 构建完整的笔记数据
|
|||
|
note_data = {
|
|||
|
'id': note_id,
|
|||
|
'title': note.get('title', '无标题'),
|
|||
|
'description': note.get('description', '暂无描述'),
|
|||
|
'collected_count': note.get('collected_count', 0),
|
|||
|
'comment_count': note.get('comment_count', 0),
|
|||
|
'share_count': note.get('share_count', 0),
|
|||
|
'clean_content': clean_content
|
|||
|
}
|
|||
|
|
|||
|
notes_data.append(note_data)
|
|||
|
context_notes.append(f"标题:{note['title']}\n内容:{note['description']}")
|
|||
|
|
|||
|
if len(notes_data) >= 5: # 仍然限制最多5条
|
|||
|
break
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"处理笔记时出错: {str(e)}")
|
|||
|
continue
|
|||
|
|
|||
|
# 即使没有找到符合条件的笔记也继续执行
|
|||
|
logger.info(f"找到 {len(notes_data)} 条符合相似度要求的笔记")
|
|||
|
|
|||
|
if not notes_data:
|
|||
|
logger.warning("未获取到任何有效的笔记内容")
|
|||
|
# return "未找到相关笔记内容", []
|
|||
|
|
|||
|
# 准备GPT提示
|
|||
|
logger.info("准备调用GPT...")
|
|||
|
system_prompt = """你是一位专业的化妆行业教师,专门帮助产品经理理解和分析小红书笔记。你的回答分下面3种情况:
|
|||
|
1、如果用户的输入与化妆品无关,不要参考相关笔记内容,正常与客户交流。
|
|||
|
2、如果用户的输入与化妆品相关,而且找到相关笔记内容,参考相关笔记内容,给出回答。
|
|||
|
3、如果用户的输入与化妆品相关,但是没有找到相关笔记内容,请结合上下文和历史记录,给出回答。
|
|||
|
回答要突出重点,并结合化妆品行业的专业知识。\n""" + f"相关笔记内容:\n" + "\n\n".join(context_notes)
|
|||
|
user_prompt = f"{query}"
|
|||
|
logger.info("调用GPT获取回答...")
|
|||
|
gpt_response = await chat_with_gpt(system_prompt, user_prompt)
|
|||
|
|
|||
|
return gpt_response, notes_data
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"处理数据时出错: {str(e)}")
|
|||
|
return f"处理数据时出错: {str(e)}", []
|
|||
|
|
|||
|
finally:
|
|||
|
if conn is not None:
|
|||
|
try:
|
|||
|
await conn.close()
|
|||
|
logger.info("数据库连接已关闭")
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"关闭数据库连接时出错: {str(e)}")
|
|||
|
|
|||
|
def create_ui():
|
|||
|
"""创建Gradio界面"""
|
|||
|
with gr.Blocks(title="化妆品产品经理AI助教", theme=gr.themes.Soft()) as demo:
|
|||
|
# 存储当前的笔记内容
|
|||
|
state = gr.State([])
|
|||
|
|
|||
|
gr.Markdown("# 化妆品产品经理AI助教")
|
|||
|
|
|||
|
# 使用CSS设置三个区域的样式
|
|||
|
gr.HTML("""
|
|||
|
<style>
|
|||
|
.query-section { background: #f0f7ff; padding: 20px; border-radius: 10px; margin-bottom: 20px; }
|
|||
|
.ai-response { background: #f5f5f5; padding: 20px; border-radius: 10px; margin-bottom: 20px; }
|
|||
|
.notes-analysis { background: #fff5f5; padding: 20px; border-radius: 10px; }
|
|||
|
</style>
|
|||
|
""")
|
|||
|
|
|||
|
# 查询区域
|
|||
|
with gr.Row():
|
|||
|
with gr.Column(scale=2, elem_classes="query-section"):
|
|||
|
gr.Markdown("### 💡 问题输入区")
|
|||
|
query_input = gr.Textbox(
|
|||
|
label="请输入你的问题",
|
|||
|
placeholder="例如:最近的面霜趋势是什么?",
|
|||
|
lines=3
|
|||
|
)
|
|||
|
submit_btn = gr.Button("提交问题", variant="primary")
|
|||
|
|
|||
|
# AI回答和笔记展示区域
|
|||
|
with gr.Row():
|
|||
|
# 左侧:AI回答和笔记卡片
|
|||
|
with gr.Column(scale=1, elem_classes="ai-response"):
|
|||
|
gr.Markdown("### 🤖 AI助教回答")
|
|||
|
gpt_output = gr.Markdown()
|
|||
|
gr.Markdown("### 📑 相关笔记")
|
|||
|
# 使用 Radio 组件替代自定义卡片
|
|||
|
note_selector = gr.Radio(
|
|||
|
choices=[],
|
|||
|
label="点击笔记查看详细分析",
|
|||
|
visible=False
|
|||
|
)
|
|||
|
|
|||
|
# 右侧:笔记AI分析
|
|||
|
with gr.Column(scale=1, elem_classes="notes-analysis"):
|
|||
|
gr.Markdown("### 🔍 笔记AI分析")
|
|||
|
clean_content_output = gr.Markdown()
|
|||
|
|
|||
|
def format_note_choice(note):
|
|||
|
"""格式化笔记选项"""
|
|||
|
title = note.get('title', '无标题')
|
|||
|
desc = note.get('description', '暂无描述')[:100]
|
|||
|
stats = f"❤️ {note.get('collected_count', 0)} 💬 {note.get('comment_count', 0)} ↗️ {note.get('share_count', 0)}"
|
|||
|
return f"{title}\n{desc}...\n{stats}"
|
|||
|
|
|||
|
async def handle_submit(query, state):
|
|||
|
"""处理用户提交的查询"""
|
|||
|
try:
|
|||
|
gpt_response, notes_data = await process_query(query)
|
|||
|
logger.info(f"GPT响应: {gpt_response}")
|
|||
|
if not notes_data:
|
|||
|
return (
|
|||
|
f"{gpt_response}", # GPT响应
|
|||
|
"请点击左侧笔记查看分析内容", # 分析内容
|
|||
|
[], # 状态
|
|||
|
gr.update(choices=[], value=None, visible=False) # Radio组件更新
|
|||
|
)
|
|||
|
|
|||
|
# 格式化笔记选项
|
|||
|
note_choices = [format_note_choice(note) for note in notes_data]
|
|||
|
|
|||
|
return (
|
|||
|
f"{gpt_response}", # GPT响应
|
|||
|
"请点击左侧笔记查看分析内容", # 初始分析内容
|
|||
|
notes_data, # 状态
|
|||
|
gr.update( # Radio组件更新
|
|||
|
choices=note_choices,
|
|||
|
value=None,
|
|||
|
visible=True
|
|||
|
)
|
|||
|
)
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"处理查询时出错: {str(e)}")
|
|||
|
return (
|
|||
|
f"处理查询时出错: {str(e)}",
|
|||
|
"",
|
|||
|
[],
|
|||
|
gr.update(choices=[], value=None, visible=False)
|
|||
|
)
|
|||
|
|
|||
|
def show_note_analysis(choice, state):
|
|||
|
"""处理笔记选择"""
|
|||
|
logger.info(f"笔记选择改变: {choice}")
|
|||
|
if not choice or not state:
|
|||
|
return "请选择笔记查看分析内容"
|
|||
|
|
|||
|
# 根据选择的文本找到对应的笔记索引
|
|||
|
try:
|
|||
|
idx = next(i for i, note in enumerate(state)
|
|||
|
if format_note_choice(note) == choice)
|
|||
|
except StopIteration:
|
|||
|
return "无法找到对应的笔记"
|
|||
|
|
|||
|
note = state[idx]
|
|||
|
clean_content = note.get('clean_content', {})
|
|||
|
|
|||
|
# 定义一个函数来清理内容
|
|||
|
def clean_text(text):
|
|||
|
# 移除 ``` 和 ```markdown, ```mermaid
|
|||
|
return re.sub(r'```(?:markdown|mermaid)?', '', text)
|
|||
|
|
|||
|
markdown = ""
|
|||
|
if clean_content.get('guide'):
|
|||
|
guide = clean_text(clean_content['guide'])
|
|||
|
markdown += "## 导读\n\n" + guide + "\n\n"
|
|||
|
|
|||
|
if clean_content.get('mindmap'):
|
|||
|
mindmap = clean_text(clean_content['mindmap'])
|
|||
|
markdown += "## 思维导图\n\n" + mindmap + "\n\n"
|
|||
|
|
|||
|
if clean_content.get('summary'):
|
|||
|
summary = clean_text(clean_content['summary'])
|
|||
|
markdown += "## 总结\n\n" + summary
|
|||
|
|
|||
|
return markdown or "该笔记暂无分析内容"
|
|||
|
|
|||
|
def sync_handle_submit(query, state):
|
|||
|
"""同步包装异步函数"""
|
|||
|
return asyncio.run(handle_submit(query, state))
|
|||
|
|
|||
|
# 绑定事件
|
|||
|
submit_btn.click(
|
|||
|
fn=sync_handle_submit,
|
|||
|
inputs=[query_input, state],
|
|||
|
outputs=[gpt_output, clean_content_output, state, note_selector]
|
|||
|
)
|
|||
|
|
|||
|
note_selector.change(
|
|||
|
fn=show_note_analysis,
|
|||
|
inputs=[note_selector, state],
|
|||
|
outputs=[clean_content_output]
|
|||
|
)
|
|||
|
|
|||
|
return demo
|
|||
|
|
|||
|
def format_ai_analysis(note_data):
|
|||
|
"""格式化笔记的AI分析为Markdown格式"""
|
|||
|
try:
|
|||
|
markdown = f"# {note_data.get('title', '无标题')}\n\n"
|
|||
|
|
|||
|
content = note_data.get('content', {})
|
|||
|
if not isinstance(content, dict):
|
|||
|
content = {'guide': '', 'mindmap': '', 'summary': ''}
|
|||
|
|
|||
|
# 添加导读部分
|
|||
|
if content.get('guide'):
|
|||
|
markdown += "## 导读\n\n"
|
|||
|
markdown += f"{content['guide']}\n\n"
|
|||
|
|
|||
|
# 添加思维导图部分
|
|||
|
if content.get('mindmap'):
|
|||
|
markdown += "## 思维导图\n\n"
|
|||
|
markdown += f"```\n{content['mindmap']}\n```\n\n"
|
|||
|
|
|||
|
# 添加总结部分
|
|||
|
if content.get('summary'):
|
|||
|
markdown += "## 总结\n\n"
|
|||
|
markdown += f"{content['summary']}\n\n"
|
|||
|
|
|||
|
# 添加统计信
|
|||
|
markdown += "## 互动数据\n\n"
|
|||
|
markdown += f"- 点赞数:{note_data.get('collected_count', 0)}\n"
|
|||
|
markdown += f"- 评论数:{note_data.get('comment_count', 0)}\n"
|
|||
|
markdown += f"- <20><><EFBFBD>享数:{note_data.get('share_count', 0)}\n"
|
|||
|
markdown += f"- 收藏数:{note_data.get('linked_count', 0)}\n\n"
|
|||
|
|
|||
|
# 添加标签
|
|||
|
if note_data.get('tag_list'):
|
|||
|
markdown += "## 标签\n\n"
|
|||
|
tags = [f"#{tag}" for tag in note_data['tag_list']]
|
|||
|
markdown += ", ".join(tags)
|
|||
|
|
|||
|
return markdown
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"格式化AI分析时出错: {str(e)}")
|
|||
|
return "无法加载笔记分析内容"
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
demo = create_ui()
|
|||
|
demo.launch(server_name="0.0.0.0", server_port=7864)
|