xhs_server/process_clean_notes.py
2024-12-16 10:31:07 +08:00

199 lines
7.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from openai import OpenAI
import asyncio
import aiomysql
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores.utils import DistanceStrategy
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
import os
import logging
from tqdm import tqdm
import hashlib
# 配置日志记录
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# 数据库配置
db_config = {
'user': 'root',
'password': 'zaq12wsx@9Xin',
'host': '183.11.229.79',
'port': 3316,
'database': '9Xin',
'auth_plugin': 'mysql_native_password'
}
# OpenAI配置
api_key = "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" # 替换为您的 OpenAI API 密钥
base_url = "http://52.90.243.11:8787/v1"
client = OpenAI(api_key=api_key, base_url=base_url)
# Embedding模型配置
embedding_model_name = '/datacenter/expert_models/bce-embedding-base_v1/'
embedding_model_kwargs = {'device': 'cuda'}
embedding_encode_kwargs = {'batch_size': 32, 'normalize_embeddings': True}
embed_model = HuggingFaceEmbeddings(
model_name=embedding_model_name,
model_kwargs=embedding_model_kwargs,
encode_kwargs=embedding_encode_kwargs
)
def get_content_hash(content):
"""生成内容哈希值"""
content_cleaned = content.strip()
return hashlib.sha256(content_cleaned.encode('utf-8')).hexdigest()
async def interact_with_chatgpt(content) -> dict:
"""与ChatGPT交互获取分析结果"""
system_prompt = """你是一位专业的化妆品行业专家。请以专业的视角分析提供的内容,并提供以下三个部分的分析:
1. 面向化妆品产品经理的笔记导读
2. 内容的思维导图使用mermaid格式
3. 内容的总结
请严格按照以下格式返回:
# 笔记导读
[导读内容]
# 思维导图
[思维导图内容]
# 内容总结
[总结内容]"""
try:
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": content}
],
max_tokens=1000,
temperature=0.2,
)
content = response.choices[0].message.content
# 解析返回内容
parts = {}
current_section = None
current_content = []
for line in content.split('\n'):
if line.startswith('# '):
if current_section and current_content:
parts[current_section] = '\n'.join(current_content).strip()
current_content = []
current_section = line[2:].strip().lower()
if '导读' in current_section:
current_section = 'guide'
elif '思维导图' in current_section:
current_section = 'mindmap'
elif '总结' in current_section:
current_section = 'summary'
elif current_section:
current_content.append(line)
if current_section and current_content:
parts[current_section] = '\n'.join(current_content).strip()
return parts
except Exception as e:
logger.error(f"OpenAI API 调用错误: {e}")
return None
async def process_clean_notes():
"""处理笔记并生成清洗后的向量存储"""
conn = None
try:
logger.info("正在尝试连接数据库...")
# 数据库连接配置
conn = await aiomysql.connect(
host='183.11.229.79',
port=3316,
user='root',
password='zaq12wsx@9Xin',
db='9Xin',
autocommit=True
)
logger.info("数据库连接成功")
if conn is None:
raise Exception("数据库连接失败")
async with conn.cursor(aiomysql.DictCursor) as cursor:
# 获取所有normal类型的笔记
await cursor.execute("""
SELECT id, note_id, title, description
FROM xhs_notes
WHERE type = 'normal'
""")
notes = await cursor.fetchall()
# 初始化向量存储
vs_path = "./clean_faiss"
if not os.path.exists(vs_path):
os.makedirs(vs_path)
if os.path.isfile(os.path.join(vs_path, 'index.faiss')) and os.path.isfile(os.path.join(vs_path, 'index.pkl')):
vs = FAISS.load_local(vs_path, embed_model, allow_dangerous_deserialization=True)
else:
vs = FAISS.from_texts(["初始化向量存储"], embed_model, distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT)
# 处理每条笔记
for note in tqdm(notes, desc="处理笔记"):
content = f"{note['title']}\n{note['description']}"
# 获取GPT分析结果
gpt_response = await interact_with_chatgpt(content)
if not gpt_response:
continue
# 处理每种类型的内容
for content_type, content in gpt_response.items():
# 创建新的向量存储
new_vs = FAISS.from_texts([content], embed_model, distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT)
# 保存到数据库
async with conn.cursor() as insert_cursor:
for vector_id, document in new_vs.docstore._dict.items():
content = document.page_content
content_hash = get_content_hash(content)
# 检查是否已存在
await insert_cursor.execute("""
SELECT id FROM clean_note_store
WHERE content_hash = %s AND content_type = %s
""", (content_hash, content_type))
if not await insert_cursor.fetchone():
await insert_cursor.execute("""
INSERT INTO clean_note_store
(note_id, vector_id, content_type, content, content_hash)
VALUES (%s, %s, %s, %s, %s)
""", (note['id'], vector_id, content_type, content, content_hash))
# 合并向量存储
vs.merge_from(new_vs)
# 保存最终的向量存储
vs.save_local(vs_path)
except aiomysql.Error as e:
logger.error(f"MySQL错误: {e}")
raise
except Exception as e:
logger.error(f"处理笔记时出错: {e}")
raise
finally:
# 安全关闭连接
try:
if conn is not None:
logger.info("正在关闭数据库连接...")
await conn.close()
logger.info("数据库连接已关闭")
except Exception as e:
logger.error(f"关闭数据库连接时出错: {e}")
if __name__ == "__main__":
asyncio.run(process_clean_notes())