xhs_server/process_raw_notes.py

135 lines
5.2 KiB
Python
Raw Normal View History

2024-12-16 02:31:07 +00:00
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores.utils import DistanceStrategy
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
import os
import logging
import hashlib
import aiomysql
from tqdm import tqdm
import asyncio
# 配置日志记录
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def get_content_hash(content):
"""生成内容哈希值"""
content_cleaned = content.strip()
return hashlib.sha256(content_cleaned.encode('utf-8')).hexdigest()
# Embedding模型配置
embedding_model_name = '/datacenter/expert_models/bce-embedding-base_v1/'
embedding_model_kwargs = {'device': 'cuda'}
embedding_encode_kwargs = {'batch_size': 32, 'normalize_embeddings': True}
embed_model = HuggingFaceEmbeddings(
model_name=embedding_model_name,
model_kwargs=embedding_model_kwargs,
encode_kwargs=embedding_encode_kwargs
)
async def process_xhs_notes():
"""处理小红书笔记数据"""
conn = None
try:
logger.info("正在尝试连接数据库...")
# 数据库连接配置
conn = await aiomysql.connect(
host='183.11.229.79',
port=3316,
user='root',
password='zaq12wsx@9Xin',
db='9Xin',
autocommit=True
)
logger.info("数据库连接成功")
if conn is None:
raise Exception("数据库连接失败")
async with conn.cursor(aiomysql.DictCursor) as cursor:
# 获取所有type为'normal'的笔记
await cursor.execute("""
SELECT id, note_id, title, description
FROM xhs_notes
WHERE type = 'normal'
""")
notes = await cursor.fetchall()
# 使用text_splitter处理每条笔记
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=120,
chunk_overlap=20,
length_function=len,
is_separator_regex=False,
)
vs_path = "./raw_vs"
if not os.path.exists(vs_path):
os.makedirs(vs_path)
# 初始化向量存储
if os.path.isfile(os.path.join(vs_path, 'index.faiss')) and os.path.isfile(os.path.join(vs_path, 'index.pkl')):
vs = FAISS.load_local(vs_path, embed_model, allow_dangerous_deserialization=True)
else:
vs = FAISS.from_texts(["初始化向量存储"], embed_model, distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT)
# 使用tqdm显示处理进度
for note in tqdm(notes, desc="处理小红书笔记"):
# 合标题和描述
content = f"{note['title']}\n{note['description']}"
# 分割文本
texts = text_splitter.split_text(content)
if texts:
# 创建新的向量存储
new_vs = FAISS.from_texts(texts, embed_model, distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT)
# 将向量数据保存到数据库
async with conn.cursor() as insert_cursor:
for vector_id, document in new_vs.docstore._dict.items():
content = document.page_content
content_hash = get_content_hash(content)
# 检查是否已存在相同的content_hash
await insert_cursor.execute("""
SELECT id FROM vector_store
WHERE content_hash = %s
""", (content_hash,))
if not await insert_cursor.fetchone():
# 插入新的向量数据
await insert_cursor.execute("""
INSERT INTO vector_store
(note_id, vector_id, content, content_hash)
VALUES (%s, %s, %s, %s)
""", (note['id'], vector_id, content, content_hash))
# 合并向量存储
vs.merge_from(new_vs)
# 保存最终的向量存储
vs.save_local(vs_path)
except aiomysql.Error as e:
logger.error(f"MySQL错误: {e}")
raise
except Exception as e:
logger.error(f"处理小红书笔记时出错: {e}")
raise
finally:
# 安全关闭连接
try:
if conn is not None:
logger.info("正在关闭数据库连接...")
await conn.close()
logger.info("数据库连接已关闭")
except Exception as e:
logger.error(f"关闭数据库连接时出错: {e}")
if __name__ == "__main__":
asyncio.run(process_xhs_notes())