135 lines
5.2 KiB
Python
135 lines
5.2 KiB
Python
|
from langchain_community.embeddings import HuggingFaceEmbeddings
|
||
|
from langchain_community.vectorstores.utils import DistanceStrategy
|
||
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||
|
from langchain_community.vectorstores import FAISS
|
||
|
import os
|
||
|
import logging
|
||
|
import hashlib
|
||
|
import aiomysql
|
||
|
from tqdm import tqdm
|
||
|
import asyncio
|
||
|
# 配置日志记录
|
||
|
logging.basicConfig(
|
||
|
level=logging.INFO,
|
||
|
format='%(asctime)s - %(levelname)s - %(message)s'
|
||
|
)
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
def get_content_hash(content):
|
||
|
"""生成内容哈希值"""
|
||
|
content_cleaned = content.strip()
|
||
|
return hashlib.sha256(content_cleaned.encode('utf-8')).hexdigest()
|
||
|
|
||
|
# Embedding模型配置
|
||
|
embedding_model_name = '/datacenter/expert_models/bce-embedding-base_v1/'
|
||
|
embedding_model_kwargs = {'device': 'cuda'}
|
||
|
embedding_encode_kwargs = {'batch_size': 32, 'normalize_embeddings': True}
|
||
|
|
||
|
embed_model = HuggingFaceEmbeddings(
|
||
|
model_name=embedding_model_name,
|
||
|
model_kwargs=embedding_model_kwargs,
|
||
|
encode_kwargs=embedding_encode_kwargs
|
||
|
)
|
||
|
|
||
|
async def process_xhs_notes():
|
||
|
"""处理小红书笔记数据"""
|
||
|
conn = None
|
||
|
try:
|
||
|
logger.info("正在尝试连接数据库...")
|
||
|
# 数据库连接配置
|
||
|
conn = await aiomysql.connect(
|
||
|
host='183.11.229.79',
|
||
|
port=3316,
|
||
|
user='root',
|
||
|
password='zaq12wsx@9Xin',
|
||
|
db='9Xin',
|
||
|
autocommit=True
|
||
|
)
|
||
|
logger.info("数据库连接成功")
|
||
|
|
||
|
if conn is None:
|
||
|
raise Exception("数据库连接失败")
|
||
|
|
||
|
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
|
# 获取所有type为'normal'的笔记
|
||
|
await cursor.execute("""
|
||
|
SELECT id, note_id, title, description
|
||
|
FROM xhs_notes
|
||
|
WHERE type = 'normal'
|
||
|
""")
|
||
|
notes = await cursor.fetchall()
|
||
|
|
||
|
# 使用text_splitter处理每条笔记
|
||
|
text_splitter = RecursiveCharacterTextSplitter(
|
||
|
chunk_size=120,
|
||
|
chunk_overlap=20,
|
||
|
length_function=len,
|
||
|
is_separator_regex=False,
|
||
|
)
|
||
|
|
||
|
vs_path = "./raw_vs"
|
||
|
if not os.path.exists(vs_path):
|
||
|
os.makedirs(vs_path)
|
||
|
|
||
|
# 初始化向量存储
|
||
|
if os.path.isfile(os.path.join(vs_path, 'index.faiss')) and os.path.isfile(os.path.join(vs_path, 'index.pkl')):
|
||
|
vs = FAISS.load_local(vs_path, embed_model, allow_dangerous_deserialization=True)
|
||
|
else:
|
||
|
vs = FAISS.from_texts(["初始化向量存储"], embed_model, distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT)
|
||
|
|
||
|
# 使用tqdm显示处理进度
|
||
|
for note in tqdm(notes, desc="处理小红书笔记"):
|
||
|
# 合标题和描述
|
||
|
content = f"{note['title']}\n{note['description']}"
|
||
|
|
||
|
# 分割文本
|
||
|
texts = text_splitter.split_text(content)
|
||
|
|
||
|
if texts:
|
||
|
# 创建新的向量存储
|
||
|
new_vs = FAISS.from_texts(texts, embed_model, distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT)
|
||
|
|
||
|
# 将向量数据保存到数据库
|
||
|
async with conn.cursor() as insert_cursor:
|
||
|
for vector_id, document in new_vs.docstore._dict.items():
|
||
|
content = document.page_content
|
||
|
content_hash = get_content_hash(content)
|
||
|
|
||
|
# 检查是否已存在相同的content_hash
|
||
|
await insert_cursor.execute("""
|
||
|
SELECT id FROM vector_store
|
||
|
WHERE content_hash = %s
|
||
|
""", (content_hash,))
|
||
|
|
||
|
if not await insert_cursor.fetchone():
|
||
|
# 插入新的向量数据
|
||
|
await insert_cursor.execute("""
|
||
|
INSERT INTO vector_store
|
||
|
(note_id, vector_id, content, content_hash)
|
||
|
VALUES (%s, %s, %s, %s)
|
||
|
""", (note['id'], vector_id, content, content_hash))
|
||
|
|
||
|
# 合并向量存储
|
||
|
vs.merge_from(new_vs)
|
||
|
|
||
|
# 保存最终的向量存储
|
||
|
vs.save_local(vs_path)
|
||
|
|
||
|
except aiomysql.Error as e:
|
||
|
logger.error(f"MySQL错误: {e}")
|
||
|
raise
|
||
|
except Exception as e:
|
||
|
logger.error(f"处理小红书笔记时出错: {e}")
|
||
|
raise
|
||
|
finally:
|
||
|
# 安全关闭连接
|
||
|
try:
|
||
|
if conn is not None:
|
||
|
logger.info("正在关闭数据库连接...")
|
||
|
await conn.close()
|
||
|
logger.info("数据库连接已关闭")
|
||
|
except Exception as e:
|
||
|
logger.error(f"关闭数据库连接时出错: {e}")
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
asyncio.run(process_xhs_notes())
|