import gradio as gr import asyncio from openai import OpenAI from langchain_community.vectorstores import FAISS from langchain_community.embeddings import HuggingFaceEmbeddings import aiomysql import logging import os from process_raw_notes import get_content_hash import re import tempfile import graphviz # 配置日志 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # OpenAI配置 api_key = "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" base_url = "http://52.90.243.11:8787/v1" client = OpenAI(api_key=api_key, base_url=base_url) # Embedding模型配置 embedding_model_name = '/datacenter/expert_models/bce-embedding-base_v1/' embedding_model_kwargs = {'device': 'cuda'} embedding_encode_kwargs = {'batch_size': 32, 'normalize_embeddings': True} embed_model = HuggingFaceEmbeddings( model_name=embedding_model_name, model_kwargs=embedding_model_kwargs, encode_kwargs=embedding_encode_kwargs ) # 数据库配置 db_config = { 'host': '183.11.229.79', 'port': 3316, 'user': 'root', 'password': 'zaq12wsx@9Xin', 'db': '9xin', 'autocommit': True } async def get_db_connection(): """创建数据库连接""" return await aiomysql.connect(**db_config) async def get_note_details(conn, note_id): """获取笔记详情""" async with conn.cursor(aiomysql.DictCursor) as cursor: try: # 使用 SELECT * 获取所有字段 await cursor.execute(""" SELECT * FROM xhs_notes WHERE id = %s """, (note_id,)) note = await cursor.fetchone() if note: # 确保所有需要的字段都有默认值 default_fields = { 'linked_count': note.get('linked_count', ''), 'collected_count': note.get('collected_count', ''), 'comment_count': note.get('comment_count', ''), 'share_count': note.get('share_count', ''), 'title': note.get('title', '无标题'), 'description': note.get('description', '无描述') } return note except Exception as e: logger.error(f"获取笔记详情时出错: {str(e)}") # 返回带有默认值的空笔记 return { 'title': '获取失败', 'description': '无法获取笔记内容', 'linked_count': 0, 'collected_count': 0, 'comment_count': 0, 'share_count': 0, 'tag_list': [] } async def get_clean_note_content(conn, note_id): """获取清洗后的笔记内容""" async with conn.cursor(aiomysql.DictCursor) as cursor: await cursor.execute(""" SELECT content_type, content FROM clean_note_store WHERE note_id = %s """, (note_id,)) results = await cursor.fetchall() clean_content = { 'guide': '', 'mindmap': '', 'summary': '' } for row in results: clean_content[row['content_type']] = row['content'] return clean_content async def get_vector_note_id(conn, vector_id): """获取向量对应的笔记ID""" async with conn.cursor(aiomysql.DictCursor) as cursor: await cursor.execute(""" SELECT note_id FROM vector_store WHERE vector_id = %s """, (vector_id,)) result = await cursor.fetchone() return result['note_id'] if result else None async def chat_with_gpt(system_prompt, user_prompt): """与GPT交互""" try: # 添加错误检查 if not system_prompt or not user_prompt: logger.error("系统提示或用户提示为空") return "入参数错误" response = client.chat.completions.create( # 移除 await model="gpt-4o-mini", # 修改为正确的模型名称 messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ], max_tokens=1000, temperature=0.7, ) if not response or not response.choices: logger.error("GPT响应为空") return "获取回答" return response.choices[0].message.content except Exception as e: logger.error(f"GPT API调用错误: {str(e)}") return f"调用出错: {str(e)}" async def process_query(query): """处理用户查询""" conn = None try: logger.info(f"开始处理查询: {query}") # 加载向量存储 vs_path = "./raw_vs" logger.info(f"正在检查向量存储路径: {vs_path}") if not os.path.exists(vs_path): logger.error(f"向量存储路径不存在: {vs_path}") return "错误:向量存储不存在", [] logger.info("正在加载向量存储...") vs = FAISS.load_local(vs_path, embed_model, allow_dangerous_deserialization=True) # 搜索相关向量 logger.info("开始向量搜索...") results = vs.similarity_search_with_score(query, k=10) logger.info(f"找到 {len(results)} 条相关结果") try: conn = await get_db_connection() if not conn: logger.error("数据库连接失败") return "数据库连接失败", [] logger.info("数据库连接成功") context_notes = [] notes_data = [] # 存储完整的笔记数据 # 获取不重复的前5条笔记 seen_note_ids = set() for doc, score in results: try: content_hash = get_content_hash(doc.page_content) logger.info(f"处理content_hash: {content_hash}") async with conn.cursor(aiomysql.DictCursor) as cursor: await cursor.execute(""" SELECT note_id FROM vector_store WHERE content_hash = %s """, (content_hash,)) result = await cursor.fetchone() if not result or result['note_id'] in seen_note_ids: continue note_id = result['note_id'] seen_note_ids.add(note_id) # 获取笔记详情和清洗内容 note = await get_note_details(conn, note_id) clean_content = await get_clean_note_content(conn, note_id) if not note or not clean_content: continue # 构建完整的笔记数据 note_data = { 'id': note_id, 'title': note.get('title', '无标题'), 'description': note.get('description', '暂无描述'), 'collected_count': note.get('collected_count', 0), 'comment_count': note.get('comment_count', 0), 'share_count': note.get('share_count', 0), 'clean_content': clean_content } notes_data.append(note_data) context_notes.append(f"标题:{note['title']}\n内容:{note['description']}") if len(notes_data) >= 5: # 修改为5条 break except Exception as e: logger.error(f"处理笔记时出错: {str(e)}") continue if not notes_data: logger.warning("未获取到任何有效的笔记内容") return "未找到相关笔记内容", [] # 准备GPT提示 logger.info("准备调用GPT...") system_prompt = """你是一位专业的化妆行业教师,专门帮助产品经理理解和分析小红书笔记。 请基于提供的相关笔记内容,对用户的问题给出专业、简洁的回答。 回答要突出重点,并结合化妆行业的专业知识。""" user_prompt = f"问题:{query}\n\n相关笔记内容:\n" + "\n\n".join(context_notes) logger.info("调用GPT获取回答...") gpt_response = await chat_with_gpt(system_prompt, user_prompt) return gpt_response, notes_data except Exception as e: logger.error(f"处理数据时出错: {str(e)}") return f"处理数据时出错: {str(e)}", [] finally: if conn is not None: try: await conn.close() logger.info("数据库连接已关闭") except Exception as e: logger.error(f"关闭数据库连接时出错: {str(e)}") def create_ui(): """创建Gradio界面""" with gr.Blocks(title="化妆品产品经理AI助教", theme=gr.themes.Soft()) as demo: # 存储当前的笔记内容 state = gr.State([]) gr.Markdown("# 化妆品产品经理AI助教") # 修改 HTML 和 JavaScript 部分 gr.HTML(""" """) # 查询区域 with gr.Row(): with gr.Column(scale=2, elem_classes="query-section"): gr.Markdown("### 💡 问题输入区") query_input = gr.Textbox( label="请输入你的问题", placeholder="例如:最近的面霜趋势是什么?", lines=3 ) submit_btn = gr.Button("提交问题", variant="primary") # AI回答和笔记展示区域 with gr.Row(): # 左侧:AI回答和笔记卡片 with gr.Column(scale=1, elem_classes="ai-response"): gr.Markdown("### 🤖 AI助教回答") gpt_output = gr.Markdown() gr.Markdown("### 📑 相关笔记") # 使用 Radio 组件替代自定义卡片 note_selector = gr.Radio( choices=[], label="点击笔记查看详细分析", visible=False ) # 右侧:笔记AI分析 with gr.Column(scale=1, elem_classes="notes-analysis"): gr.Markdown("### 🔍 笔记AI分析") clean_content_output = gr.HTML() def format_note_choice(note): """格式化笔记选项""" title = note.get('title', '无标题') desc = note.get('description', '暂无描述')[:100] stats = f"❤️ {note.get('collected_count', 0)} 💬 {note.get('comment_count', 0)} ↗️ {note.get('share_count', 0)}" return f"{title}\n{desc}...\n{stats}" async def handle_submit(query, state): """处理用户提交的查询""" try: gpt_response, notes_data = await process_query(query) if not notes_data: return ( "未找到相关笔记", # GPT响应 "请点击左侧笔记查看分析内容", # 分析内容 [], # 状态 gr.update(choices=[], value=None, visible=False) # Radio组件更新 ) # 格式化笔记选项 note_choices = [format_note_choice(note) for note in notes_data] return ( f"## AI助教回答\n\n{gpt_response}", # GPT响应 "请点击左侧笔记查看分析内容", # 初始分析内容 notes_data, # 状态 gr.update( # Radio组件更新 choices=note_choices, value=None, visible=True ) ) except Exception as e: logger.error(f"处理查询时出错: {str(e)}") return ( f"处理查询时出错: {str(e)}", "", [], gr.update(choices=[], value=None, visible=False) ) def show_note_analysis(choice, state): """处理笔记选择""" logger.info(f"笔记选择改变: {choice}") if not choice or not state: return "请选择笔记查看分析内容" # 根据选择的文本找到对应的笔记索引 try: idx = next(i for i, note in enumerate(state) if format_note_choice(note) == choice) except StopIteration: return "无法找到对应的笔记" note = state[idx] clean_content = note.get('clean_content', {}) def clean_text(text): if not text: return "" return re.sub(r'```(?:markdown|mermaid)?', '', text).strip() html_parts = [] # 添加导读部分 if clean_content.get('guide'): guide = clean_text(clean_content['guide']) html_parts.append(f"