199 lines
7.5 KiB
Python
199 lines
7.5 KiB
Python
|
from openai import OpenAI
|
|||
|
import asyncio
|
|||
|
import aiomysql
|
|||
|
from langchain.embeddings import HuggingFaceEmbeddings
|
|||
|
from langchain_community.vectorstores.utils import DistanceStrategy
|
|||
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|||
|
from langchain.vectorstores import FAISS
|
|||
|
import os
|
|||
|
import logging
|
|||
|
from tqdm import tqdm
|
|||
|
import hashlib
|
|||
|
|
|||
|
# 配置日志记录
|
|||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|||
|
logger = logging.getLogger(__name__)
|
|||
|
|
|||
|
# 数据库配置
|
|||
|
db_config = {
|
|||
|
'user': 'root',
|
|||
|
'password': 'zaq12wsx@9Xin',
|
|||
|
'host': '183.11.229.79',
|
|||
|
'port': 3316,
|
|||
|
'database': '9Xin',
|
|||
|
'auth_plugin': 'mysql_native_password'
|
|||
|
}
|
|||
|
|
|||
|
# OpenAI配置
|
|||
|
api_key = "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" # 替换为您的 OpenAI API 密钥
|
|||
|
base_url = "http://52.90.243.11:8787/v1"
|
|||
|
client = OpenAI(api_key=api_key, base_url=base_url)
|
|||
|
|
|||
|
# Embedding模型配置
|
|||
|
embedding_model_name = '/datacenter/expert_models/bce-embedding-base_v1/'
|
|||
|
embedding_model_kwargs = {'device': 'cuda'}
|
|||
|
embedding_encode_kwargs = {'batch_size': 32, 'normalize_embeddings': True}
|
|||
|
|
|||
|
embed_model = HuggingFaceEmbeddings(
|
|||
|
model_name=embedding_model_name,
|
|||
|
model_kwargs=embedding_model_kwargs,
|
|||
|
encode_kwargs=embedding_encode_kwargs
|
|||
|
)
|
|||
|
|
|||
|
def get_content_hash(content):
|
|||
|
"""生成内容哈希值"""
|
|||
|
content_cleaned = content.strip()
|
|||
|
return hashlib.sha256(content_cleaned.encode('utf-8')).hexdigest()
|
|||
|
|
|||
|
async def interact_with_chatgpt(content) -> dict:
|
|||
|
"""与ChatGPT交互获取分析结果"""
|
|||
|
system_prompt = """你是一位专业的化妆品行业专家。请以专业的视角分析提供的内容,并提供以下三个部分的分析:
|
|||
|
1. 面向化妆品产品经理的笔记导读
|
|||
|
2. 内容的思维导图(使用mermaid格式)
|
|||
|
3. 内容的总结
|
|||
|
|
|||
|
请严格按照以下格式返回:
|
|||
|
# 笔记导读
|
|||
|
[导读内容]
|
|||
|
|
|||
|
# 思维导图
|
|||
|
[思维导图内容]
|
|||
|
|
|||
|
# 内容总结
|
|||
|
[总结内容]"""
|
|||
|
|
|||
|
try:
|
|||
|
response = client.chat.completions.create(
|
|||
|
model="gpt-4o-mini",
|
|||
|
messages=[
|
|||
|
{"role": "system", "content": system_prompt},
|
|||
|
{"role": "user", "content": content}
|
|||
|
],
|
|||
|
max_tokens=1000,
|
|||
|
temperature=0.2,
|
|||
|
)
|
|||
|
content = response.choices[0].message.content
|
|||
|
|
|||
|
# 解析返回内容
|
|||
|
parts = {}
|
|||
|
current_section = None
|
|||
|
current_content = []
|
|||
|
|
|||
|
for line in content.split('\n'):
|
|||
|
if line.startswith('# '):
|
|||
|
if current_section and current_content:
|
|||
|
parts[current_section] = '\n'.join(current_content).strip()
|
|||
|
current_content = []
|
|||
|
current_section = line[2:].strip().lower()
|
|||
|
if '导读' in current_section:
|
|||
|
current_section = 'guide'
|
|||
|
elif '思维导图' in current_section:
|
|||
|
current_section = 'mindmap'
|
|||
|
elif '总结' in current_section:
|
|||
|
current_section = 'summary'
|
|||
|
elif current_section:
|
|||
|
current_content.append(line)
|
|||
|
|
|||
|
if current_section and current_content:
|
|||
|
parts[current_section] = '\n'.join(current_content).strip()
|
|||
|
|
|||
|
return parts
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"OpenAI API 调用错误: {e}")
|
|||
|
return None
|
|||
|
|
|||
|
async def process_clean_notes():
|
|||
|
"""处理笔记并生成清洗后的向量存储"""
|
|||
|
conn = None
|
|||
|
try:
|
|||
|
logger.info("正在尝试连接数据库...")
|
|||
|
# 数据库连接配置
|
|||
|
conn = await aiomysql.connect(
|
|||
|
host='183.11.229.79',
|
|||
|
port=3316,
|
|||
|
user='root',
|
|||
|
password='zaq12wsx@9Xin',
|
|||
|
db='9Xin',
|
|||
|
autocommit=True
|
|||
|
)
|
|||
|
logger.info("数据库连接成功")
|
|||
|
|
|||
|
if conn is None:
|
|||
|
raise Exception("数据库连接失败")
|
|||
|
|
|||
|
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
|||
|
# 获取所有normal类型的笔记
|
|||
|
await cursor.execute("""
|
|||
|
SELECT id, note_id, title, description
|
|||
|
FROM xhs_notes
|
|||
|
WHERE type = 'normal'
|
|||
|
""")
|
|||
|
notes = await cursor.fetchall()
|
|||
|
|
|||
|
# 初始化向量存储
|
|||
|
vs_path = "./clean_faiss"
|
|||
|
if not os.path.exists(vs_path):
|
|||
|
os.makedirs(vs_path)
|
|||
|
|
|||
|
if os.path.isfile(os.path.join(vs_path, 'index.faiss')) and os.path.isfile(os.path.join(vs_path, 'index.pkl')):
|
|||
|
vs = FAISS.load_local(vs_path, embed_model, allow_dangerous_deserialization=True)
|
|||
|
else:
|
|||
|
vs = FAISS.from_texts(["初始化向量存储"], embed_model, distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT)
|
|||
|
|
|||
|
# 处理每条笔记
|
|||
|
for note in tqdm(notes, desc="处理笔记"):
|
|||
|
content = f"{note['title']}\n{note['description']}"
|
|||
|
|
|||
|
# 获取GPT分析结果
|
|||
|
gpt_response = await interact_with_chatgpt(content)
|
|||
|
if not gpt_response:
|
|||
|
continue
|
|||
|
|
|||
|
# 处理每种类型的内容
|
|||
|
for content_type, content in gpt_response.items():
|
|||
|
# 创建新的向量存储
|
|||
|
new_vs = FAISS.from_texts([content], embed_model, distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT)
|
|||
|
|
|||
|
# 保存到数据库
|
|||
|
async with conn.cursor() as insert_cursor:
|
|||
|
for vector_id, document in new_vs.docstore._dict.items():
|
|||
|
content = document.page_content
|
|||
|
content_hash = get_content_hash(content)
|
|||
|
|
|||
|
# 检查是否已存在
|
|||
|
await insert_cursor.execute("""
|
|||
|
SELECT id FROM clean_note_store
|
|||
|
WHERE content_hash = %s AND content_type = %s
|
|||
|
""", (content_hash, content_type))
|
|||
|
|
|||
|
if not await insert_cursor.fetchone():
|
|||
|
await insert_cursor.execute("""
|
|||
|
INSERT INTO clean_note_store
|
|||
|
(note_id, vector_id, content_type, content, content_hash)
|
|||
|
VALUES (%s, %s, %s, %s, %s)
|
|||
|
""", (note['id'], vector_id, content_type, content, content_hash))
|
|||
|
|
|||
|
# 合并向量存储
|
|||
|
vs.merge_from(new_vs)
|
|||
|
|
|||
|
# 保存最终的向量存储
|
|||
|
vs.save_local(vs_path)
|
|||
|
|
|||
|
except aiomysql.Error as e:
|
|||
|
logger.error(f"MySQL错误: {e}")
|
|||
|
raise
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"处理笔记时出错: {e}")
|
|||
|
raise
|
|||
|
finally:
|
|||
|
# 安全关闭连接
|
|||
|
try:
|
|||
|
if conn is not None:
|
|||
|
logger.info("正在关闭数据库连接...")
|
|||
|
await conn.close()
|
|||
|
logger.info("数据库连接已关闭")
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"关闭数据库连接时出错: {e}")
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
asyncio.run(process_clean_notes())
|