xhs_server/process_clean_notes.py

199 lines
7.5 KiB
Python
Raw Normal View History

2024-12-16 02:31:07 +00:00
from openai import OpenAI
import asyncio
import aiomysql
from langchain.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores.utils import DistanceStrategy
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
import os
import logging
from tqdm import tqdm
import hashlib
# 配置日志记录
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# 数据库配置
db_config = {
'user': 'root',
'password': 'zaq12wsx@9Xin',
'host': '183.11.229.79',
'port': 3316,
'database': '9Xin',
'auth_plugin': 'mysql_native_password'
}
# OpenAI配置
api_key = "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" # 替换为您的 OpenAI API 密钥
base_url = "http://52.90.243.11:8787/v1"
client = OpenAI(api_key=api_key, base_url=base_url)
# Embedding模型配置
embedding_model_name = '/datacenter/expert_models/bce-embedding-base_v1/'
embedding_model_kwargs = {'device': 'cuda'}
embedding_encode_kwargs = {'batch_size': 32, 'normalize_embeddings': True}
embed_model = HuggingFaceEmbeddings(
model_name=embedding_model_name,
model_kwargs=embedding_model_kwargs,
encode_kwargs=embedding_encode_kwargs
)
def get_content_hash(content):
"""生成内容哈希值"""
content_cleaned = content.strip()
return hashlib.sha256(content_cleaned.encode('utf-8')).hexdigest()
async def interact_with_chatgpt(content) -> dict:
"""与ChatGPT交互获取分析结果"""
system_prompt = """你是一位专业的化妆品行业专家。请以专业的视角分析提供的内容,并提供以下三个部分的分析:
1. 面向化妆品产品经理的笔记导读
2. 内容的思维导图使用mermaid格式
3. 内容的总结
请严格按照以下格式返回
# 笔记导读
[导读内容]
# 思维导图
[思维导图内容]
# 内容总结
[总结内容]"""
try:
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": content}
],
max_tokens=1000,
temperature=0.2,
)
content = response.choices[0].message.content
# 解析返回内容
parts = {}
current_section = None
current_content = []
for line in content.split('\n'):
if line.startswith('# '):
if current_section and current_content:
parts[current_section] = '\n'.join(current_content).strip()
current_content = []
current_section = line[2:].strip().lower()
if '导读' in current_section:
current_section = 'guide'
elif '思维导图' in current_section:
current_section = 'mindmap'
elif '总结' in current_section:
current_section = 'summary'
elif current_section:
current_content.append(line)
if current_section and current_content:
parts[current_section] = '\n'.join(current_content).strip()
return parts
except Exception as e:
logger.error(f"OpenAI API 调用错误: {e}")
return None
async def process_clean_notes():
"""处理笔记并生成清洗后的向量存储"""
conn = None
try:
logger.info("正在尝试连接数据库...")
# 数据库连接配置
conn = await aiomysql.connect(
host='183.11.229.79',
port=3316,
user='root',
password='zaq12wsx@9Xin',
db='9Xin',
autocommit=True
)
logger.info("数据库连接成功")
if conn is None:
raise Exception("数据库连接失败")
async with conn.cursor(aiomysql.DictCursor) as cursor:
# 获取所有normal类型的笔记
await cursor.execute("""
SELECT id, note_id, title, description
FROM xhs_notes
WHERE type = 'normal'
""")
notes = await cursor.fetchall()
# 初始化向量存储
vs_path = "./clean_faiss"
if not os.path.exists(vs_path):
os.makedirs(vs_path)
if os.path.isfile(os.path.join(vs_path, 'index.faiss')) and os.path.isfile(os.path.join(vs_path, 'index.pkl')):
vs = FAISS.load_local(vs_path, embed_model, allow_dangerous_deserialization=True)
else:
vs = FAISS.from_texts(["初始化向量存储"], embed_model, distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT)
# 处理每条笔记
for note in tqdm(notes, desc="处理笔记"):
content = f"{note['title']}\n{note['description']}"
# 获取GPT分析结果
gpt_response = await interact_with_chatgpt(content)
if not gpt_response:
continue
# 处理每种类型的内容
for content_type, content in gpt_response.items():
# 创建新的向量存储
new_vs = FAISS.from_texts([content], embed_model, distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT)
# 保存到数据库
async with conn.cursor() as insert_cursor:
for vector_id, document in new_vs.docstore._dict.items():
content = document.page_content
content_hash = get_content_hash(content)
# 检查是否已存在
await insert_cursor.execute("""
SELECT id FROM clean_note_store
WHERE content_hash = %s AND content_type = %s
""", (content_hash, content_type))
if not await insert_cursor.fetchone():
await insert_cursor.execute("""
INSERT INTO clean_note_store
(note_id, vector_id, content_type, content, content_hash)
VALUES (%s, %s, %s, %s, %s)
""", (note['id'], vector_id, content_type, content, content_hash))
# 合并向量存储
vs.merge_from(new_vs)
# 保存最终的向量存储
vs.save_local(vs_path)
except aiomysql.Error as e:
logger.error(f"MySQL错误: {e}")
raise
except Exception as e:
logger.error(f"处理笔记时出错: {e}")
raise
finally:
# 安全关闭连接
try:
if conn is not None:
logger.info("正在关闭数据库连接...")
await conn.close()
logger.info("数据库连接已关闭")
except Exception as e:
logger.error(f"关闭数据库连接时出错: {e}")
if __name__ == "__main__":
asyncio.run(process_clean_notes())