from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.vectorstores.utils import DistanceStrategy from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_community.vectorstores import FAISS import os import logging import hashlib import aiomysql from tqdm import tqdm import asyncio # 配置日志记录 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) def get_content_hash(content): """生成内容哈希值""" content_cleaned = content.strip() return hashlib.sha256(content_cleaned.encode('utf-8')).hexdigest() # Embedding模型配置 embedding_model_name = '/datacenter/expert_models/bce-embedding-base_v1/' embedding_model_kwargs = {'device': 'cuda'} embedding_encode_kwargs = {'batch_size': 32, 'normalize_embeddings': True} embed_model = HuggingFaceEmbeddings( model_name=embedding_model_name, model_kwargs=embedding_model_kwargs, encode_kwargs=embedding_encode_kwargs ) async def process_xhs_notes(): """处理小红书笔记数据""" conn = None try: logger.info("正在尝试连接数据库...") # 数据库连接配置 conn = await aiomysql.connect( host='183.11.229.79', port=3316, user='root', password='zaq12wsx@9Xin', db='9Xin', autocommit=True ) logger.info("数据库连接成功") if conn is None: raise Exception("数据库连接失败") async with conn.cursor(aiomysql.DictCursor) as cursor: # 获取所有type为'normal'的笔记 await cursor.execute(""" SELECT id, note_id, title, description FROM xhs_notes WHERE type = 'normal' """) notes = await cursor.fetchall() # 使用text_splitter处理每条笔记 text_splitter = RecursiveCharacterTextSplitter( chunk_size=120, chunk_overlap=20, length_function=len, is_separator_regex=False, ) vs_path = "./raw_vs" if not os.path.exists(vs_path): os.makedirs(vs_path) # 初始化向量存储 if os.path.isfile(os.path.join(vs_path, 'index.faiss')) and os.path.isfile(os.path.join(vs_path, 'index.pkl')): vs = FAISS.load_local(vs_path, embed_model, allow_dangerous_deserialization=True) else: vs = FAISS.from_texts(["初始化向量存储"], embed_model, distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT) # 使用tqdm显示处理进度 for note in tqdm(notes, desc="处理小红书笔记"): # 合标题和描述 content = f"{note['title']}\n{note['description']}" # 分割文本 texts = text_splitter.split_text(content) if texts: # 创建新的向量存储 new_vs = FAISS.from_texts(texts, embed_model, distance_strategy=DistanceStrategy.MAX_INNER_PRODUCT) # 将向量数据保存到数据库 async with conn.cursor() as insert_cursor: for vector_id, document in new_vs.docstore._dict.items(): content = document.page_content content_hash = get_content_hash(content) # 检查是否已存在相同的content_hash await insert_cursor.execute(""" SELECT id FROM vector_store WHERE content_hash = %s """, (content_hash,)) if not await insert_cursor.fetchone(): # 插入新的向量数据 await insert_cursor.execute(""" INSERT INTO vector_store (note_id, vector_id, content, content_hash) VALUES (%s, %s, %s, %s) """, (note['id'], vector_id, content, content_hash)) # 合并向量存储 vs.merge_from(new_vs) # 保存最终的向量存储 vs.save_local(vs_path) except aiomysql.Error as e: logger.error(f"MySQL错误: {e}") raise except Exception as e: logger.error(f"处理小红书笔记时出错: {e}") raise finally: # 安全关闭连接 try: if conn is not None: logger.info("正在关闭数据库连接...") await conn.close() logger.info("数据库连接已关闭") except Exception as e: logger.error(f"关闭数据库连接时出错: {e}") if __name__ == "__main__": asyncio.run(process_xhs_notes())