import gradio as gr import asyncio from openai import OpenAI from langchain_community.vectorstores import FAISS from langchain_community.embeddings import HuggingFaceEmbeddings import aiomysql import logging import os from process_raw_notes import get_content_hash import re # 配置日志 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # OpenAI配置 # api_key = "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" # base_url = "http://52.90.243.11:8787/v1" # client = OpenAI(api_key=api_key, base_url=base_url) client = OpenAI( # 若没有配置环境变量,请用百炼API Key将下行替换为:api_key="sk-xxx", api_key="sk-9583aa36267540da97a70b0385809f2c", base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", ) # Embedding模型配置 embedding_model_name = '/datacenter/expert_models/bce-embedding-base_v1/' embedding_model_kwargs = {'device': 'cuda'} embedding_encode_kwargs = {'batch_size': 32, 'normalize_embeddings': True} embed_model = HuggingFaceEmbeddings( model_name=embedding_model_name, model_kwargs=embedding_model_kwargs, encode_kwargs=embedding_encode_kwargs ) # 数据库配置 db_config = { 'host': '183.11.229.79', 'port': 3316, 'user': 'root', 'password': 'zaq12wsx@9Xin', 'db': '9Xin', 'autocommit': True } async def get_db_connection(): """创建数据库连接""" return await aiomysql.connect(**db_config) async def get_note_details(conn, note_id): """获取笔记详情""" async with conn.cursor(aiomysql.DictCursor) as cursor: try: # 使用 SELECT * 获取所有字段 await cursor.execute(""" SELECT * FROM xhs_notes WHERE id = %s """, (note_id,)) note = await cursor.fetchone() if note: # 确保所有需要的字段都有默认值 default_fields = { 'linked_count': note.get('linked_count', ''), 'collected_count': note.get('collected_count', ''), 'comment_count': note.get('comment_count', ''), 'share_count': note.get('share_count', ''), 'title': note.get('title', '无标题'), 'description': note.get('description', '无描述') } return note except Exception as e: logger.error(f"获取笔记详情时出错: {str(e)}") # 返回带有默认值的空笔记 return { 'title': '获取失败', 'description': '无法获取笔记内容', 'linked_count': 0, 'collected_count': 0, 'comment_count': 0, 'share_count': 0, 'tag_list': [] } async def get_clean_note_content(conn, note_id): """获取清洗后的笔记内容""" async with conn.cursor(aiomysql.DictCursor) as cursor: await cursor.execute(""" SELECT content_type, content FROM clean_note_store WHERE note_id = %s """, (note_id,)) results = await cursor.fetchall() clean_content = { 'guide': '', 'mindmap': '', 'summary': '' } for row in results: clean_content[row['content_type']] = row['content'] return clean_content async def get_vector_note_id(conn, vector_id): """获取向量对应的笔记ID""" async with conn.cursor(aiomysql.DictCursor) as cursor: await cursor.execute(""" SELECT note_id FROM vector_store WHERE vector_id = %s """, (vector_id,)) result = await cursor.fetchone() return result['note_id'] if result else None async def chat_with_gpt(system_prompt, user_prompt): """与GPT交互""" try: # 添加错误检查 if not system_prompt or not user_prompt: logger.error("系统提示或用户提示为空") return "入参数错误" response = client.chat.completions.create( model="qwen-plus",#"gpt-4o-mini", # 修改为正确的模型名称 messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ], max_tokens=1000, temperature=0.7, ) if not response or not response.choices: logger.error("GPT响应为空") return "获取回答" return response.choices[0].message.content except Exception as e: logger.error(f"GPT API调用错误: {str(e)}") return f"调用出错: {str(e)}" async def process_query(query): """处理用户查询""" conn = None try: logger.info(f"开始处理查询: {query}") # 加载向量存储 vs_path = "./raw_vs" logger.info(f"正在检查向量存储路径: {vs_path}") if not os.path.exists(vs_path): logger.error(f"向量存储路径不存在: {vs_path}") return "错误:向量存储不存在", [] logger.info("正在加载向量存储...") vs = FAISS.load_local(vs_path, embed_model, allow_dangerous_deserialization=True) # 搜索相关向量 logger.info("开始向量搜索...") results = vs.similarity_search_with_score(query, k=10) logger.info(f"找到 {len(results)} 条相关结果") try: conn = await get_db_connection() if not conn: logger.error("数据库连接失败") return "数据库连接失败", [] logger.info("数据库连接成功") context_notes = [] notes_data = [] # 存储完整的笔记数据 # 获取相似度分数大于0.5的前5条笔记 seen_note_ids = set() for doc, similarity_score in results: # MAX_INNER_PRODUCT 策略下,分数越大表示越相似 # 如果相似度分数小于0.5,跳过该结果 if similarity_score < 0.5: continue try: content_hash = get_content_hash(doc.page_content) logger.info(f"处理content_hash: {content_hash}, 相似度分数: {similarity_score}") async with conn.cursor(aiomysql.DictCursor) as cursor: await cursor.execute(""" SELECT note_id FROM vector_store WHERE content_hash = %s """, (content_hash,)) result = await cursor.fetchone() if not result or result['note_id'] in seen_note_ids: continue note_id = result['note_id'] seen_note_ids.add(note_id) # 获取笔记详情和清洗内容 note = await get_note_details(conn, note_id) clean_content = await get_clean_note_content(conn, note_id) if not note or not clean_content: continue # 构建完整的笔记数据 note_data = { 'id': note_id, 'title': note.get('title', '无标题'), 'description': note.get('description', '暂无描述'), 'collected_count': note.get('collected_count', 0), 'comment_count': note.get('comment_count', 0), 'share_count': note.get('share_count', 0), 'clean_content': clean_content } notes_data.append(note_data) context_notes.append(f"标题:{note['title']}\n内容:{note['description']}") if len(notes_data) >= 5: # 仍然限制最多5条 break except Exception as e: logger.error(f"处理笔记时出错: {str(e)}") continue # 即使没有找到符合条件的笔记也继续执行 logger.info(f"找到 {len(notes_data)} 条符合相似度要求的笔记") if not notes_data: logger.warning("未获取到任何有效的笔记内容") # return "未找到相关笔记内容", [] # 准备GPT提示 logger.info("准备调用GPT...") system_prompt = """你是一位专业的化妆行业教师,专门帮助产品经理理解和分析小红书笔记。你的回答分下面3种情况: 1、如果用户的输入与化妆品无关,不要参考相关笔记内容,正常与客户交流。 2、如果用户的输入与化妆品相关,而且找到相关笔记内容,参考相关笔记内容,给出回答。 3、如果用户的输入与化妆品相关,但是没有找到相关笔记内容,请结合上下文和历史记录,给出回答。 回答要突出重点,并结合化妆品行业的专业知识。\n""" + f"相关笔记内容:\n" + "\n\n".join(context_notes) user_prompt = f"{query}" logger.info("调用GPT获取回答...") gpt_response = await chat_with_gpt(system_prompt, user_prompt) return gpt_response, notes_data except Exception as e: logger.error(f"处理数据时出错: {str(e)}") return f"处理数据时出错: {str(e)}", [] finally: if conn is not None: try: await conn.close() logger.info("数据库连接已关闭") except Exception as e: logger.error(f"关闭数据库连接时出错: {str(e)}") def create_ui(): """创建Gradio界面""" with gr.Blocks(title="化妆品产品经理AI助教", theme=gr.themes.Soft()) as demo: # 存储当前的笔记内容 state = gr.State([]) gr.Markdown("# 化妆品产品经理AI助教") # 使用CSS设置三个区域的样式 gr.HTML(""" """) # 查询区域 with gr.Row(): with gr.Column(scale=2, elem_classes="query-section"): gr.Markdown("### 💡 问题输入区") query_input = gr.Textbox( label="请输入你的问题", placeholder="例如:最近的面霜趋势是什么?", lines=3 ) submit_btn = gr.Button("提交问题", variant="primary") # AI回答和笔记展示区域 with gr.Row(): # 左侧:AI回答和笔记卡片 with gr.Column(scale=1, elem_classes="ai-response"): gr.Markdown("### 🤖 AI助教回答") gpt_output = gr.Markdown() gr.Markdown("### 📑 相关笔记") # 使用 Radio 组件替代自定义卡片 note_selector = gr.Radio( choices=[], label="点击笔记查看详细分析", visible=False ) # 右侧:笔记AI分析 with gr.Column(scale=1, elem_classes="notes-analysis"): gr.Markdown("### 🔍 笔记AI分析") clean_content_output = gr.Markdown() def format_note_choice(note): """格式化笔记选项""" title = note.get('title', '无标题') desc = note.get('description', '暂无描述')[:100] stats = f"❤️ {note.get('collected_count', 0)} 💬 {note.get('comment_count', 0)} ↗️ {note.get('share_count', 0)}" return f"{title}\n{desc}...\n{stats}" async def handle_submit(query, state): """处理用户提交的查询""" try: gpt_response, notes_data = await process_query(query) logger.info(f"GPT响应: {gpt_response}") if not notes_data: return ( f"{gpt_response}", # GPT响应 "请点击左侧笔记查看分析内容", # 分析内容 [], # 状态 gr.update(choices=[], value=None, visible=False) # Radio组件更新 ) # 格式化笔记选项 note_choices = [format_note_choice(note) for note in notes_data] return ( f"{gpt_response}", # GPT响应 "请点击左侧笔记查看分析内容", # 初始分析内容 notes_data, # 状态 gr.update( # Radio组件更新 choices=note_choices, value=None, visible=True ) ) except Exception as e: logger.error(f"处理查询时出错: {str(e)}") return ( f"处理查询时出错: {str(e)}", "", [], gr.update(choices=[], value=None, visible=False) ) def show_note_analysis(choice, state): """处理笔记选择""" logger.info(f"笔记选择改变: {choice}") if not choice or not state: return "请选择笔记查看分析内容" # 根据选择的文本找到对应的笔记索引 try: idx = next(i for i, note in enumerate(state) if format_note_choice(note) == choice) except StopIteration: return "无法找到对应的笔记" note = state[idx] clean_content = note.get('clean_content', {}) # 定义一个函数来清理内容 def clean_text(text): # 移除 ``` 和 ```markdown, ```mermaid return re.sub(r'```(?:markdown|mermaid)?', '', text) markdown = "" if clean_content.get('guide'): guide = clean_text(clean_content['guide']) markdown += "## 导读\n\n" + guide + "\n\n" if clean_content.get('mindmap'): mindmap = clean_text(clean_content['mindmap']) markdown += "## 思维导图\n\n" + mindmap + "\n\n" if clean_content.get('summary'): summary = clean_text(clean_content['summary']) markdown += "## 总结\n\n" + summary return markdown or "该笔记暂无分析内容" def sync_handle_submit(query, state): """同步包装异步函数""" return asyncio.run(handle_submit(query, state)) # 绑定事件 submit_btn.click( fn=sync_handle_submit, inputs=[query_input, state], outputs=[gpt_output, clean_content_output, state, note_selector] ) note_selector.change( fn=show_note_analysis, inputs=[note_selector, state], outputs=[clean_content_output] ) return demo def format_ai_analysis(note_data): """格式化笔记的AI分析为Markdown格式""" try: markdown = f"# {note_data.get('title', '无标题')}\n\n" content = note_data.get('content', {}) if not isinstance(content, dict): content = {'guide': '', 'mindmap': '', 'summary': ''} # 添加导读部分 if content.get('guide'): markdown += "## 导读\n\n" markdown += f"{content['guide']}\n\n" # 添加思维导图部分 if content.get('mindmap'): markdown += "## 思维导图\n\n" markdown += f"```\n{content['mindmap']}\n```\n\n" # 添加总结部分 if content.get('summary'): markdown += "## 总结\n\n" markdown += f"{content['summary']}\n\n" # 添加统计信 markdown += "## 互动数据\n\n" markdown += f"- 点赞数:{note_data.get('collected_count', 0)}\n" markdown += f"- 评论数:{note_data.get('comment_count', 0)}\n" markdown += f"- ���享数:{note_data.get('share_count', 0)}\n" markdown += f"- 收藏数:{note_data.get('linked_count', 0)}\n\n" # 添加标签 if note_data.get('tag_list'): markdown += "## 标签\n\n" tags = [f"#{tag}" for tag in note_data['tag_list']] markdown += ", ".join(tags) return markdown except Exception as e: logger.error(f"格式化AI分析时出错: {str(e)}") return "无法加载笔记分析内容" if __name__ == "__main__": demo = create_ui() demo.launch(server_name="0.0.0.0", server_port=7864)