1121 lines
44 KiB
Python
1121 lines
44 KiB
Python
import sys
|
||
import os
|
||
import json
|
||
import asyncio
|
||
import subprocess
|
||
from datetime import datetime
|
||
from PyQt6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
|
||
QHBoxLayout, QLineEdit, QPushButton, QProgressBar,
|
||
QScrollArea, QLabel, QFrame, QTabWidget, QComboBox)
|
||
from PyQt6.QtCore import Qt, QThread, pyqtSignal, QUrl
|
||
from PyQt6.QtGui import QDesktopServices
|
||
import mysql.connector
|
||
import aiohttp
|
||
import aiofiles
|
||
from urllib.parse import urlparse
|
||
|
||
# 数据库配置
|
||
db_config = {
|
||
'user': 'root',
|
||
'password': 'zaq12wsx@9Xin',
|
||
'host': '183.11.229.79',
|
||
'port': 3316,
|
||
'database': '9Xin',
|
||
'auth_plugin': 'mysql_native_password'
|
||
}
|
||
|
||
class NoteCard(QFrame):
|
||
"""笔记卡片组件"""
|
||
def __init__(self, note_data, is_downloaded=False, parent=None):
|
||
super().__init__(parent)
|
||
self.note_data = note_data
|
||
self.is_downloaded = is_downloaded
|
||
self.setup_ui()
|
||
|
||
def setup_ui(self):
|
||
self.setFrameStyle(QFrame.Shape.Box | QFrame.Shadow.Raised)
|
||
# 根据下载状态设置不同的背景色
|
||
bg_color = "#f0f9ff" if self.is_downloaded else "white"
|
||
hover_color = "#e1f3ff" if self.is_downloaded else "#f0f0f0"
|
||
|
||
self.setStyleSheet(f"""
|
||
NoteCard {{
|
||
background-color: {bg_color};
|
||
border-radius: 10px;
|
||
margin: 5px;
|
||
padding: 10px;
|
||
font-family: "Microsoft YaHei", Arial;
|
||
}}
|
||
NoteCard:hover {{
|
||
background-color: {hover_color};
|
||
}}
|
||
""")
|
||
|
||
layout = QVBoxLayout(self)
|
||
|
||
# 笔记类型和下载状态
|
||
header_layout = QHBoxLayout()
|
||
type_label = QLabel("视频笔记" if self.note_data['type'] == 'video' else "图文笔记")
|
||
type_label.setStyleSheet("color: #666; font-size: 12px;")
|
||
header_layout.addWidget(type_label)
|
||
|
||
# 添加下载状态标签
|
||
status_label = QLabel("✓ 已下载" if self.is_downloaded else "⋯ 下载中")
|
||
status_label.setStyleSheet(
|
||
"color: #4CAF50;" if self.is_downloaded else "color: #FFA500;"
|
||
)
|
||
header_layout.addWidget(status_label, alignment=Qt.AlignmentFlag.AlignRight)
|
||
layout.addLayout(header_layout)
|
||
|
||
# 标题
|
||
title = QLabel(self.note_data['title'])
|
||
title.setStyleSheet("font-size: 16px; font-weight: bold;")
|
||
title.setWordWrap(True)
|
||
layout.addWidget(title)
|
||
|
||
# 描述
|
||
desc = QLabel(self.note_data['description'])
|
||
desc.setWordWrap(True)
|
||
desc.setStyleSheet("color: #333;")
|
||
layout.addWidget(desc)
|
||
|
||
# 统计信息
|
||
stats_layout = QHBoxLayout()
|
||
stats = [
|
||
f"❤️ {self.note_data['liked_count']}",
|
||
f"⭐ {self.note_data['collected_count']}",
|
||
f"💬 {self.note_data['comment_count']}",
|
||
f"↗️ {self.note_data['share_count']}"
|
||
]
|
||
for stat in stats:
|
||
stat_label = QLabel(stat)
|
||
stat_label.setStyleSheet("color: #666; font-size: 12px;")
|
||
stats_layout.addWidget(stat_label)
|
||
layout.addLayout(stats_layout)
|
||
|
||
# 标签
|
||
if self.note_data['tag_list']:
|
||
tags = self.note_data['tag_list'].split(',')
|
||
tags_text = ' '.join([f'#{tag}' for tag in tags])
|
||
tags_label = QLabel(tags_text)
|
||
tags_label.setStyleSheet("color: #0066cc; font-size: 12px;")
|
||
tags_label.setWordWrap(True)
|
||
layout.addWidget(tags_label)
|
||
|
||
# 添加下载状态和进度区域
|
||
self.status_area = QWidget()
|
||
status_layout = QHBoxLayout(self.status_area)
|
||
|
||
# 下载进度标签
|
||
self.progress_label = QLabel()
|
||
self.progress_label.setStyleSheet("color: #666; font-size: 12px;")
|
||
self.progress_label.setVisible(False)
|
||
status_layout.addWidget(self.progress_label)
|
||
|
||
# 下载进度条
|
||
self.progress_bar = QProgressBar()
|
||
self.progress_bar.setStyleSheet("""
|
||
QProgressBar {
|
||
border: 1px solid #ddd;
|
||
border-radius: 3px;
|
||
text-align: center;
|
||
height: 10px;
|
||
}
|
||
QProgressBar::chunk {
|
||
background-color: #1a73e8;
|
||
border-radius: 2px;
|
||
}
|
||
""")
|
||
self.progress_bar.setVisible(False)
|
||
status_layout.addWidget(self.progress_bar)
|
||
|
||
layout.addWidget(self.status_area)
|
||
|
||
# 重新下载按钮
|
||
if not self.is_downloaded:
|
||
self.retry_button = QPushButton("重新下载")
|
||
self.retry_button.setStyleSheet("""
|
||
QPushButton {
|
||
padding: 5px 10px;
|
||
background-color: #ff4d4f;
|
||
color: white;
|
||
border: none;
|
||
border-radius: 3px;
|
||
font-size: 12px;
|
||
}
|
||
QPushButton:hover {
|
||
background-color: #ff7875;
|
||
}
|
||
QPushButton:disabled {
|
||
background-color: #ffccc7;
|
||
}
|
||
""")
|
||
self.retry_button.clicked.connect(self.retry_download)
|
||
layout.addWidget(self.retry_button)
|
||
|
||
def retry_download(self):
|
||
# 发送信号给主窗口处理下载
|
||
parent = self.parent()
|
||
while parent and not isinstance(parent, MainWindow):
|
||
parent = parent.parent()
|
||
if parent:
|
||
parent.retry_download(self.note_data['note_id'])
|
||
|
||
def update_progress(self, value, message):
|
||
"""更新下载进度"""
|
||
self.progress_bar.setValue(value)
|
||
self.progress_label.setText(message)
|
||
self.progress_bar.setVisible(True)
|
||
self.progress_label.setVisible(True)
|
||
|
||
def set_downloading(self, is_downloading):
|
||
"""设置下载状态"""
|
||
if hasattr(self, 'retry_button'):
|
||
self.retry_button.setEnabled(not is_downloading)
|
||
self.retry_button.setText("下载中..." if is_downloading else "重新下载")
|
||
|
||
def show_error(self, message):
|
||
"""显示错误信息"""
|
||
self.progress_label.setText(f"错误: {message}")
|
||
self.progress_label.setStyleSheet("color: red;")
|
||
self.progress_label.setVisible(True)
|
||
self.progress_bar.setVisible(False)
|
||
|
||
def download_complete(self):
|
||
"""下载完成"""
|
||
self.progress_label.setText("下载完成")
|
||
self.progress_label.setStyleSheet("color: #4CAF50;")
|
||
self.progress_bar.setVisible(False)
|
||
self.is_downloaded = True
|
||
if hasattr(self, 'retry_button'):
|
||
self.retry_button.setVisible(False)
|
||
|
||
def mousePressEvent(self, event):
|
||
"""处理鼠标点击事件"""
|
||
if event.button() == Qt.MouseButton.LeftButton:
|
||
url = QUrl(self.note_data['note_url'])
|
||
QDesktopServices.openUrl(url)
|
||
|
||
class WorkerThread(QThread):
|
||
"""工作线程"""
|
||
progress = pyqtSignal(int, str)
|
||
status_update = pyqtSignal(str) # 新增:用于显示详细状态
|
||
download_progress = pyqtSignal(str, bool) # 新增:用于更新下载状态
|
||
finished = pyqtSignal(list)
|
||
error = pyqtSignal(str)
|
||
data_imported = pyqtSignal()
|
||
|
||
# 添加类变量
|
||
MAX_RETRIES = 5 # 最大重试次数
|
||
DOWNLOAD_TIMEOUT = 300 # 下载超时时间(秒)
|
||
RETRY_DELAY = 2 # 重试延迟时间(秒)
|
||
CHUNK_SIZE = 8192 # 分块下载大小
|
||
|
||
def __init__(self, keywords):
|
||
super().__init__()
|
||
self.keywords = keywords
|
||
self.retry_note_id = None # 添加重试笔记ID属性
|
||
|
||
def run(self):
|
||
loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(loop)
|
||
|
||
try:
|
||
if self.retry_note_id:
|
||
# 重新下载单个笔记
|
||
records = loop.run_until_complete(self.download_single_note(self.retry_note_id))
|
||
self.progress.emit(100, "处理完成")
|
||
self.finished.emit(records)
|
||
else:
|
||
# 原有的完整搜索流程
|
||
if loop.run_until_complete(self.run_crawler()):
|
||
if loop.run_until_complete(self.import_to_db()):
|
||
self.data_imported.emit()
|
||
records = loop.run_until_complete(self.download_media())
|
||
self.progress.emit(100, "处理完成")
|
||
self.finished.emit(records)
|
||
finally:
|
||
loop.close()
|
||
|
||
async def download_single_note(self, note_id):
|
||
"""下载单个笔记的媒体文件"""
|
||
try:
|
||
self.progress.emit(0, "开始下载...")
|
||
conn = mysql.connector.connect(**db_config)
|
||
cursor = conn.cursor(dictionary=True)
|
||
|
||
cursor.execute("""
|
||
SELECT note_id, image_list, video_url
|
||
FROM xhs_notes
|
||
WHERE note_id = %s
|
||
""", (note_id,))
|
||
record = cursor.fetchone()
|
||
|
||
if record:
|
||
timeout = aiohttp.ClientTimeout(total=300)
|
||
conn_kwargs = {'timeout': timeout, 'ssl': False}
|
||
|
||
async with aiohttp.ClientSession(**conn_kwargs) as session:
|
||
base_dir = f'./data/xhs/json/media/{record["note_id"]}'
|
||
os.makedirs(base_dir, exist_ok=True)
|
||
|
||
total_files = 0
|
||
completed_files = 0
|
||
|
||
# 计算总文件数
|
||
if record['image_list']:
|
||
total_files += len([url for url in record['image_list'].split(',') if url.strip()])
|
||
if record['video_url'] and record['video_url'].strip():
|
||
total_files += 1
|
||
|
||
if total_files == 0:
|
||
self.progress.emit(100, "没有需要下载的文件")
|
||
return [record]
|
||
|
||
# 下载图片
|
||
if record['image_list']:
|
||
image_urls = record['image_list'].split(',')
|
||
for i, url in enumerate(image_urls):
|
||
if url.strip():
|
||
ext = os.path.splitext(urlparse(url).path)[1] or '.jpg'
|
||
save_path = os.path.join(base_dir, f'image_{i+1}{ext}')
|
||
|
||
self.status_update.emit(f"下载图片 {i+1}/{len(image_urls)}")
|
||
if await self.download_file_with_retry(session, url.strip(), save_path, note_id):
|
||
completed_files += 1
|
||
progress = int(completed_files / total_files * 100)
|
||
self.progress.emit(progress, f"已完成: {completed_files}/{total_files}")
|
||
|
||
# 下载视频
|
||
if record['video_url'] and record['video_url'].strip():
|
||
url = record['video_url'].strip()
|
||
ext = os.path.splitext(urlparse(url).path)[1] or '.mp4'
|
||
save_path = os.path.join(base_dir, f'video{ext}')
|
||
|
||
self.status_update.emit("下载视频...")
|
||
if await self.download_file_with_retry(session, url, save_path, note_id, True):
|
||
completed_files += 1
|
||
progress = int(completed_files / total_files * 100)
|
||
self.progress.emit(progress, f"已完成: {completed_files}/{total_files}")
|
||
|
||
# 检查下载完成状态并更新数据库
|
||
await self.check_download_complete(note_id)
|
||
|
||
self.progress.emit(100, "下载完成")
|
||
return [record]
|
||
else:
|
||
self.error.emit("找不到指定的笔记")
|
||
return []
|
||
|
||
except Exception as e:
|
||
self.error.emit(str(e))
|
||
return []
|
||
finally:
|
||
if 'conn' in locals():
|
||
conn.close()
|
||
|
||
async def run_crawler(self):
|
||
"""运行爬虫"""
|
||
try:
|
||
self.progress.emit(10, "启动爬虫...")
|
||
process = await asyncio.create_subprocess_exec(
|
||
'python', 'main.py', '--platform', 'xhs', '--lt', 'qrcode',
|
||
'--keywords', self.keywords,
|
||
stdout=asyncio.subprocess.PIPE,
|
||
stderr=asyncio.subprocess.PIPE,
|
||
limit=1024*1024 # 增加缓冲区大小到1MB
|
||
)
|
||
|
||
# 实时读取爬虫输出
|
||
async def read_stream(stream):
|
||
buffer = ""
|
||
while True:
|
||
try:
|
||
chunk = await stream.read(8192) # 每次读取8KB
|
||
if not chunk:
|
||
break
|
||
text = chunk.decode('utf-8', errors='ignore')
|
||
buffer += text
|
||
|
||
# 按行处理缓冲区
|
||
while '\n' in buffer:
|
||
line, buffer = buffer.split('\n', 1)
|
||
line = line.strip()
|
||
if line:
|
||
self.status_update.emit(f"爬虫进度: {line}")
|
||
|
||
except Exception as e:
|
||
print(f"读取爬虫输出错误: {str(e)}")
|
||
continue
|
||
|
||
# 处理最后可能剩余的内容
|
||
if buffer.strip():
|
||
self.status_update.emit(f"爬虫进度: {buffer.strip()}")
|
||
|
||
# 同时读取stdout和stderr
|
||
await asyncio.gather(
|
||
read_stream(process.stdout),
|
||
read_stream(process.stderr)
|
||
)
|
||
|
||
await process.wait()
|
||
self.progress.emit(30, "爬虫数据获取完成")
|
||
return True
|
||
except Exception as e:
|
||
self.error.emit(f"爬虫执行失败: {str(e)}")
|
||
return False
|
||
|
||
async def import_to_db(self):
|
||
"""导入数据到据库"""
|
||
try:
|
||
self.progress.emit(40, "导入数据到数据库...")
|
||
conn = mysql.connector.connect(**db_config)
|
||
cursor = conn.cursor(dictionary=True)
|
||
|
||
# 修改JSON文件路径
|
||
json_path = f'./data/xhs/json/search_contents_{datetime.now().strftime("%Y-%m-%d")}.json'
|
||
with open(json_path, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
|
||
# 插入数据
|
||
for item in data:
|
||
# 检查记录是否存在
|
||
cursor.execute("SELECT COUNT(*) FROM xhs_notes WHERE note_id = %s", (item['note_id'],))
|
||
if cursor.fetchone()['COUNT(*)'] == 0:
|
||
# 插入新记录
|
||
insert_query = """INSERT INTO xhs_notes (
|
||
note_id, type, title, description, video_url, time,
|
||
last_update_time, user_id, nickname, avatar,
|
||
liked_count, collected_count, comment_count, share_count,
|
||
ip_location, image_list, tag_list, last_modify_ts,
|
||
note_url, source_keyword
|
||
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s,
|
||
%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)"""
|
||
values = (
|
||
item.get('note_id'), item.get('type'), item.get('title'),
|
||
item.get('desc'), item.get('video_url'), item.get('time'),
|
||
item.get('last_update_time'), item.get('user_id'),
|
||
item.get('nickname'), item.get('avatar'),
|
||
item.get('liked_count'), item.get('collected_count'),
|
||
item.get('comment_count'), item.get('share_count'),
|
||
item.get('ip_location'), item.get('image_list'),
|
||
item.get('tag_list'), item.get('last_modify_ts'),
|
||
item.get('note_url'), self.keywords
|
||
)
|
||
cursor.execute(insert_query, values)
|
||
|
||
conn.commit()
|
||
self.progress.emit(60, "数导入完成")
|
||
return True
|
||
except Exception as e:
|
||
self.error.emit(f"数据导入失败: {str(e)}")
|
||
return False
|
||
finally:
|
||
if 'conn' in locals():
|
||
conn.close()
|
||
|
||
async def download_file_with_retry(self, session, url, save_path, note_id, is_video=False):
|
||
"""带重试的文件下载"""
|
||
headers = {
|
||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
|
||
'Referer': 'https://www.xiaohongshu.com/',
|
||
}
|
||
|
||
success = False
|
||
for attempt in range(self.MAX_RETRIES):
|
||
try:
|
||
async with session.get(url, headers=headers, timeout=self.DOWNLOAD_TIMEOUT) as response:
|
||
if response.status == 200:
|
||
# 检查内容类型
|
||
content_type = response.headers.get('content-type', '')
|
||
if 'video' not in content_type.lower() and url.endswith(('.mp4', '.m3u8')):
|
||
# 尝试处理重定向
|
||
location = response.headers.get('location')
|
||
if location:
|
||
url = location
|
||
continue
|
||
|
||
total_size = int(response.headers.get('content-length', 0))
|
||
if total_size == 0:
|
||
print(f"警告:{url} 的内容长度为0")
|
||
return False
|
||
|
||
# 分块下载
|
||
async with aiofiles.open(save_path, 'wb') as f:
|
||
downloaded = 0
|
||
async for chunk in response.content.iter_chunked(self.CHUNK_SIZE):
|
||
await f.write(chunk)
|
||
downloaded += len(chunk)
|
||
|
||
# 验证文件大小
|
||
if os.path.getsize(save_path) > 0:
|
||
success = True
|
||
break
|
||
else:
|
||
print(f"下载的文件大小为0: {url}")
|
||
if os.path.exists(save_path):
|
||
os.remove(save_path)
|
||
|
||
elif response.status in [301, 302, 303, 307, 308]:
|
||
# 处理重定向
|
||
url = str(response.url)
|
||
continue
|
||
else:
|
||
print(f"下载失败,状态码: {response.status}, URL: {url}")
|
||
|
||
except asyncio.TimeoutError:
|
||
print(f"下载超时 {url}, 尝试次数 {attempt + 1}/{self.MAX_RETRIES}")
|
||
if attempt < self.MAX_RETRIES - 1:
|
||
await asyncio.sleep(self.RETRY_DELAY)
|
||
continue
|
||
except Exception as e:
|
||
print(f"下载出错 {url}: {str(e)}, 尝试次数 {attempt + 1}/{self.MAX_RETRIES}")
|
||
if attempt < self.MAX_RETRIES - 1:
|
||
await asyncio.sleep(self.RETRY_DELAY)
|
||
continue
|
||
|
||
if success:
|
||
# 只有在成功下载后才检查是否所有文件都下载完成
|
||
await self.check_download_complete(note_id)
|
||
|
||
return success
|
||
|
||
async def check_download_complete(self, note_id):
|
||
"""检查笔记的所有媒体是否下载完成"""
|
||
try:
|
||
conn = mysql.connector.connect(**db_config)
|
||
cursor = conn.cursor(dictionary=True)
|
||
|
||
# 获取笔记信息
|
||
cursor.execute("""
|
||
SELECT image_list, video_url
|
||
FROM xhs_notes
|
||
WHERE note_id = %s
|
||
""", (note_id,))
|
||
record = cursor.fetchone()
|
||
|
||
if not record:
|
||
return
|
||
|
||
is_complete = True
|
||
media_dir = f'./data/xhs/json/media/{note_id}'
|
||
|
||
# 检查图片
|
||
if record['image_list']:
|
||
image_urls = record['image_list'].split(',')
|
||
for i, url in enumerate(image_urls):
|
||
if url.strip():
|
||
ext = os.path.splitext(urlparse(url).path)[1] or '.jpg'
|
||
image_path = os.path.join(media_dir, f'image_{i+1}{ext}')
|
||
if not os.path.exists(image_path) or os.path.getsize(image_path) == 0:
|
||
is_complete = False
|
||
break
|
||
|
||
# 检查视频
|
||
if record['video_url'] and record['video_url'].strip():
|
||
url = record['video_url'].strip()
|
||
ext = os.path.splitext(urlparse(url).path)[1] or '.mp4'
|
||
video_path = os.path.join(media_dir, f'video{ext}')
|
||
if not os.path.exists(video_path) or os.path.getsize(video_path) == 0:
|
||
is_complete = False
|
||
|
||
# 更新数据库
|
||
cursor.execute("""
|
||
UPDATE xhs_notes
|
||
SET download_flag = %s
|
||
WHERE note_id = %s
|
||
""", (is_complete, note_id))
|
||
|
||
conn.commit()
|
||
|
||
if is_complete:
|
||
self.download_progress.emit(note_id, True)
|
||
|
||
except Exception as e:
|
||
print(f"检查下载状态时出错: {e}")
|
||
finally:
|
||
if 'conn' in locals():
|
||
conn.close()
|
||
|
||
async def download_media(self):
|
||
"""下载媒体文件"""
|
||
try:
|
||
self.progress.emit(70, "开始下载媒体文件...")
|
||
conn = mysql.connector.connect(**db_config)
|
||
cursor = conn.cursor(dictionary=True)
|
||
|
||
cursor.execute("""
|
||
SELECT note_id, image_list, video_url
|
||
FROM xhs_notes
|
||
WHERE source_keyword = %s
|
||
""", (self.keywords,))
|
||
records = cursor.fetchall()
|
||
|
||
# 创建下载任务列表
|
||
download_tasks = []
|
||
timeout = aiohttp.ClientTimeout(total=300) # 增加超时时间到5分钟
|
||
|
||
conn_kwargs = {
|
||
'timeout': timeout,
|
||
'ssl': False, # 禁用SSL验证
|
||
}
|
||
|
||
async with aiohttp.ClientSession(**conn_kwargs) as session:
|
||
for record in records:
|
||
self.status_update.emit(f"处理笔记: {record['note_id']}")
|
||
base_dir = f'./data/xhs/json/media/{record["note_id"]}'
|
||
os.makedirs(base_dir, exist_ok=True)
|
||
|
||
# 处理图下载任务
|
||
if record['image_list']:
|
||
image_urls = record['image_list'].split(',')
|
||
for i, url in enumerate(image_urls):
|
||
if url.strip():
|
||
ext = os.path.splitext(urlparse(url).path)[1] or '.jpg'
|
||
save_path = os.path.join(base_dir, f'image_{i+1}{ext}')
|
||
if not os.path.exists(save_path): # 避免重复下载
|
||
task = self.download_file_with_retry(
|
||
session, url.strip(), save_path, record['note_id'], False
|
||
)
|
||
download_tasks.append(task)
|
||
|
||
# 处理视频下载任务
|
||
if record['video_url'] and record['video_url'].strip():
|
||
url = record['video_url'].strip()
|
||
ext = os.path.splitext(urlparse(url).path)[1] or '.mp4'
|
||
save_path = os.path.join(base_dir, f'video{ext}')
|
||
if not os.path.exists(save_path): # 避免重复下载
|
||
task = self.download_file_with_retry(
|
||
session, url, save_path, record['note_id'], True
|
||
)
|
||
download_tasks.append(task)
|
||
|
||
# 分批执行下载任务
|
||
batch_size = 3 # 少并发数
|
||
total_tasks = len(download_tasks)
|
||
completed = 0
|
||
|
||
for i in range(0, len(download_tasks), batch_size):
|
||
batch = download_tasks[i:i + batch_size]
|
||
results = await asyncio.gather(*batch, return_exceptions=True)
|
||
completed += len(batch)
|
||
success_count = sum(1 for r in results if r is True)
|
||
|
||
progress = int(70 + (completed / total_tasks) * 20)
|
||
self.progress.emit(progress, f"已下载: {completed}/{total_tasks}")
|
||
self.status_update.emit(
|
||
f"下载进度: {completed}/{total_tasks} (成功: {success_count})"
|
||
)
|
||
await asyncio.sleep(1)
|
||
|
||
self.progress.emit(90, "媒体文件下载完成")
|
||
return records
|
||
|
||
except Exception as e:
|
||
self.error.emit(f"媒体下载失败: {str(e)}")
|
||
return []
|
||
finally:
|
||
if 'conn' in locals():
|
||
conn.close()
|
||
|
||
class BatchDownloadWorker(QThread):
|
||
progress = pyqtSignal(int, int, int, list) # current, total, success_count, failed_ids
|
||
error = pyqtSignal(str)
|
||
finished = pyqtSignal()
|
||
|
||
def __init__(self, note_ids):
|
||
super().__init__()
|
||
self.note_ids = note_ids
|
||
|
||
def run(self):
|
||
loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(loop)
|
||
|
||
try:
|
||
failed_ids = []
|
||
success_count = 0
|
||
total = len(self.note_ids)
|
||
|
||
for i, note_id in enumerate(self.note_ids, 1):
|
||
try:
|
||
# 下载单个笔记的媒体文件
|
||
worker = WorkerThread(None)
|
||
worker.retry_note_id = note_id
|
||
success = loop.run_until_complete(worker.download_single_note(note_id))
|
||
|
||
if success:
|
||
success_count += 1
|
||
else:
|
||
failed_ids.append(note_id)
|
||
|
||
except Exception as e:
|
||
print(f"下载记录 {note_id} 失败: {str(e)}")
|
||
failed_ids.append(note_id)
|
||
|
||
self.progress.emit(i, total, success_count, failed_ids)
|
||
|
||
self.finished.emit()
|
||
|
||
except Exception as e:
|
||
self.error.emit(str(e))
|
||
finally:
|
||
loop.close()
|
||
|
||
class MainWindow(QMainWindow):
|
||
def __init__(self):
|
||
super().__init__()
|
||
self.setWindowTitle("小红书内容采集器")
|
||
self.setMinimumSize(1200, 800)
|
||
|
||
os.makedirs('./data/xhs/json/media', exist_ok=True)
|
||
self.setup_ui()
|
||
|
||
# 初始化时加载所有记录
|
||
self.filter_records() # 这里会加载所有记录,因为搜索框为空且类型过滤为"全部笔记"
|
||
|
||
def setup_ui(self):
|
||
central_widget = QWidget()
|
||
self.setCentralWidget(central_widget)
|
||
layout = QVBoxLayout(central_widget)
|
||
|
||
# 创建标签页
|
||
self.tab_widget = QTabWidget()
|
||
layout.addWidget(self.tab_widget)
|
||
|
||
# 据展示标签页
|
||
self.setup_data_tab()
|
||
|
||
# 搜索标签页
|
||
self.setup_search_tab()
|
||
|
||
def setup_data_tab(self):
|
||
data_tab = QWidget()
|
||
layout = QVBoxLayout(data_tab)
|
||
|
||
# 搜索和过滤区域
|
||
filter_layout = QHBoxLayout()
|
||
|
||
# 搜索输入框
|
||
self.db_search_input = QLineEdit()
|
||
self.db_search_input.setPlaceholderText("搜索标题或描述...")
|
||
self.db_search_input.setStyleSheet(self.get_input_style())
|
||
self.db_search_input.returnPressed.connect(self.filter_records)
|
||
filter_layout.addWidget(self.db_search_input)
|
||
|
||
# 类型过滤下拉框
|
||
self.type_filter = QComboBox()
|
||
self.type_filter.addItems(["全部笔记", "仅看<EFBFBD><EFBFBD>文", "仅看视频"])
|
||
self.type_filter.setStyleSheet("""
|
||
QComboBox {
|
||
padding: 8px;
|
||
border: 1px solid #ccc;
|
||
border-radius: 4px;
|
||
font-size: 14px;
|
||
font-family: "Microsoft YaHei", Arial;
|
||
}
|
||
""")
|
||
self.type_filter.currentIndexChanged.connect(self.filter_records)
|
||
filter_layout.addWidget(self.type_filter)
|
||
|
||
# 搜索按钮
|
||
db_search_button = QPushButton("搜索")
|
||
db_search_button.setStyleSheet(self.get_button_style())
|
||
db_search_button.clicked.connect(self.filter_records)
|
||
filter_layout.addWidget(db_search_button)
|
||
|
||
# AI处理按钮
|
||
ai_process_button = QPushButton("AI处理")
|
||
ai_process_button.setStyleSheet("""
|
||
QPushButton {
|
||
padding: 8px 16px;
|
||
background-color: #52c41a;
|
||
color: white;
|
||
border: none;
|
||
border-radius: 4px;
|
||
font-size: 14px;
|
||
font-family: "Microsoft YaHei", Arial;
|
||
}
|
||
QPushButton:hover {
|
||
background-color: #73d13d;
|
||
}
|
||
""")
|
||
ai_process_button.clicked.connect(self.start_ai_process) # 预留AI处理功能
|
||
filter_layout.addWidget(ai_process_button)
|
||
|
||
layout.addLayout(filter_layout)
|
||
|
||
# 记录统计
|
||
self.total_count_label = QLabel()
|
||
self.total_count_label.setStyleSheet("color: #666; margin: 5px 0;")
|
||
layout.addWidget(self.total_count_label)
|
||
|
||
# 记录列表
|
||
self.records_scroll = QScrollArea()
|
||
self.records_scroll.setWidgetResizable(True)
|
||
self.records_content = QWidget()
|
||
self.records_layout = QVBoxLayout(self.records_content)
|
||
self.records_scroll.setWidget(self.records_content)
|
||
layout.addWidget(self.records_scroll)
|
||
|
||
self.tab_widget.addTab(data_tab, "数据展示")
|
||
|
||
def setup_search_tab(self):
|
||
search_tab = QWidget()
|
||
layout = QVBoxLayout(search_tab)
|
||
|
||
# 搜索区域
|
||
search_layout = QHBoxLayout()
|
||
self.search_input = QLineEdit()
|
||
self.search_input.setPlaceholderText("输入搜索关键词")
|
||
self.search_input.setStyleSheet(self.get_input_style())
|
||
self.search_button = QPushButton("搜索")
|
||
self.search_button.setStyleSheet(self.get_button_style())
|
||
search_layout.addWidget(self.search_input)
|
||
search_layout.addWidget(self.search_button)
|
||
layout.addLayout(search_layout)
|
||
|
||
# 进度信息
|
||
self.progress_bar = QProgressBar()
|
||
self.progress_bar.setVisible(False)
|
||
layout.addWidget(self.progress_bar)
|
||
|
||
self.status_label = QLabel()
|
||
self.status_label.setVisible(False)
|
||
layout.addWidget(self.status_label)
|
||
|
||
self.detail_status_label = QLabel()
|
||
self.detail_status_label.setStyleSheet("color: #666; font-size: 12px;")
|
||
self.detail_status_label.setVisible(False)
|
||
layout.addWidget(self.detail_status_label)
|
||
|
||
# 搜索结果
|
||
self.result_count_label = QLabel()
|
||
self.result_count_label.setStyleSheet("color: #666; margin: 5px 0;")
|
||
self.result_count_label.setVisible(False)
|
||
layout.addWidget(self.result_count_label)
|
||
|
||
self.search_scroll = QScrollArea()
|
||
self.search_scroll.setWidgetResizable(True)
|
||
self.search_content = QWidget()
|
||
self.search_layout = QVBoxLayout(self.search_content)
|
||
self.search_scroll.setWidget(self.search_content)
|
||
layout.addWidget(self.search_scroll)
|
||
|
||
self.tab_widget.addTab(search_tab, "搜索笔记")
|
||
|
||
# 连接信号
|
||
self.search_button.clicked.connect(self.start_search)
|
||
|
||
def filter_records(self):
|
||
"""根据搜索条件过滤记录"""
|
||
try:
|
||
conn = mysql.connector.connect(**db_config)
|
||
cursor = conn.cursor(dictionary=True)
|
||
|
||
# 构建查询条件
|
||
search_text = self.db_search_input.text().strip()
|
||
type_filter = self.type_filter.currentText()
|
||
|
||
query = """
|
||
SELECT *,
|
||
COALESCE(download_flag, FALSE) as is_downloaded
|
||
FROM xhs_notes
|
||
WHERE download_flag = TRUE # 只显示已下载完成的记录
|
||
"""
|
||
params = []
|
||
|
||
if search_text:
|
||
query += """ AND (title LIKE %s OR description LIKE %s)"""
|
||
params.extend([f'%{search_text}%', f'%{search_text}%'])
|
||
|
||
if type_filter == "仅看图文":
|
||
query += """ AND type = 'normal'"""
|
||
elif type_filter == "仅看视频":
|
||
query += """ AND type = 'video'"""
|
||
|
||
query += " ORDER BY time DESC"
|
||
|
||
cursor.execute(query, params)
|
||
records = cursor.fetchall()
|
||
|
||
# 更新显示
|
||
self.total_count_label.setText(f'找到 {len(records)} 条已下载记录')
|
||
self.clear_records()
|
||
|
||
for record in records:
|
||
card = NoteCard(record, True) # 所有显示的记录都是已下载的
|
||
self.records_layout.addWidget(card)
|
||
|
||
except Exception as e:
|
||
self.show_error(f"过滤记录失败: {str(e)}")
|
||
finally:
|
||
if 'conn' in locals():
|
||
conn.close()
|
||
|
||
def retry_download(self, note_id):
|
||
"""重新下载指定笔记的媒体文件"""
|
||
# 到对应的卡片
|
||
card = self.find_card_by_note_id(note_id)
|
||
if not card:
|
||
return
|
||
|
||
# 更新卡片状态
|
||
card.set_downloading(True)
|
||
card.update_progress(0, "准备下载...")
|
||
|
||
# 创建工作线程
|
||
self.worker = WorkerThread(None)
|
||
self.worker.progress.connect(lambda v, m: card.update_progress(v, m))
|
||
self.worker.finished.connect(lambda: self.on_retry_complete(note_id))
|
||
self.worker.error.connect(lambda e: self.on_retry_error(note_id, e))
|
||
self.worker.status_update.connect(lambda m: card.update_progress(-1, m))
|
||
|
||
# 启动单个笔记的下载
|
||
self.worker.retry_note_id = note_id
|
||
self.worker.start()
|
||
|
||
def find_card_by_note_id(self, note_id):
|
||
"""在两个标签页中查找指定note_id的卡片"""
|
||
# 在数据展示页查找
|
||
for i in range(self.records_layout.count()):
|
||
card = self.records_layout.itemAt(i).widget()
|
||
if isinstance(card, NoteCard) and card.note_data['note_id'] == note_id:
|
||
return card
|
||
|
||
# 在搜索结果页查找
|
||
for i in range(self.search_layout.count()):
|
||
card = self.search_layout.itemAt(i).widget()
|
||
if isinstance(card, NoteCard) and card.note_data['note_id'] == note_id:
|
||
return card
|
||
|
||
return None
|
||
|
||
def on_retry_complete(self, note_id):
|
||
"""重新下载完成的处理"""
|
||
card = self.find_card_by_note_id(note_id)
|
||
if card:
|
||
card.download_complete()
|
||
# 刷新显示
|
||
self.filter_records()
|
||
current_keyword = self.search_input.text().strip()
|
||
if current_keyword:
|
||
self.show_search_results(current_keyword)
|
||
|
||
def on_retry_error(self, note_id, error_message):
|
||
"""重新下载出错的处理"""
|
||
card = self.find_card_by_note_id(note_id)
|
||
if card:
|
||
card.set_downloading(False)
|
||
card.show_error(error_message)
|
||
|
||
def get_button_style(self):
|
||
"""获取按钮样式"""
|
||
return """
|
||
QPushButton {
|
||
padding: 8px 16px;
|
||
background-color: #1a73e8;
|
||
color: white;
|
||
border: none;
|
||
border-radius: 4px;
|
||
font-size: 14px;
|
||
font-family: "Microsoft YaHei", Arial;
|
||
}
|
||
QPushButton:hover {
|
||
background-color: #1557b0;
|
||
}
|
||
"""
|
||
|
||
def get_input_style(self):
|
||
"""获取输入框样式"""
|
||
return """
|
||
QLineEdit {
|
||
padding: 8px;
|
||
border: 1px solid #ccc;
|
||
border-radius: 4px;
|
||
font-size: 14px;
|
||
font-family: "Microsoft YaHei", Arial;
|
||
}
|
||
"""
|
||
|
||
def get_label_style(self):
|
||
"""获取标签样式"""
|
||
return """
|
||
QLabel {
|
||
font-size: 14px;
|
||
color: #333;
|
||
font-family: "Microsoft YaHei", Arial;
|
||
}
|
||
"""
|
||
|
||
def start_search(self):
|
||
"""开始搜索"""
|
||
keywords = self.search_input.text().strip()
|
||
if not keywords:
|
||
return
|
||
|
||
# 清空搜索结果
|
||
self.clear_search_results()
|
||
|
||
# 显示进度条
|
||
self.progress_bar.setVisible(True)
|
||
self.status_label.setVisible(True)
|
||
self.search_button.setEnabled(False)
|
||
|
||
# 创建工作线程
|
||
self.worker = WorkerThread(keywords)
|
||
self.worker.progress.connect(self.update_progress)
|
||
self.worker.finished.connect(self.on_download_complete)
|
||
self.worker.error.connect(self.show_error)
|
||
self.worker.status_update.connect(self.update_status_detail)
|
||
self.worker.download_progress.connect(self.update_download_status)
|
||
self.worker.data_imported.connect(self.on_data_imported)
|
||
self.worker.start()
|
||
|
||
def clear_search_results(self):
|
||
"""清空搜索结果区域"""
|
||
while self.search_layout.count():
|
||
item = self.search_layout.takeAt(0)
|
||
if item.widget():
|
||
item.widget().deleteLater()
|
||
self.result_count_label.setVisible(False)
|
||
|
||
def clear_records(self):
|
||
"""清空记录区域"""
|
||
while self.records_layout.count():
|
||
item = self.records_layout.takeAt(0)
|
||
if item.widget():
|
||
item.widget().deleteLater()
|
||
|
||
def update_progress(self, value, message):
|
||
"""更新进度条和状态信息"""
|
||
self.progress_bar.setValue(value)
|
||
self.status_label.setText(message)
|
||
self.status_label.setVisible(True)
|
||
|
||
def update_status_detail(self, message):
|
||
"""更新详细状态信息"""
|
||
self.detail_status_label.setText(message)
|
||
self.detail_status_label.setVisible(True)
|
||
|
||
def update_download_status(self, note_id, is_complete):
|
||
"""更新下载状态"""
|
||
if is_complete:
|
||
# 刷新显示
|
||
self.filter_records() # 刷新数据展示页
|
||
current_keyword = self.search_input.text().strip()
|
||
if current_keyword: # 如果有搜索关键词,也刷新搜索结果
|
||
self.show_search_results(current_keyword)
|
||
|
||
def on_data_imported(self):
|
||
"""数据导入完成后的处理"""
|
||
current_keyword = self.search_input.text().strip()
|
||
self.show_search_results(current_keyword)
|
||
self.filter_records() # 同时刷新数据展示页
|
||
|
||
def on_download_complete(self, records):
|
||
"""下载完成后的处理"""
|
||
self.search_button.setEnabled(True)
|
||
self.progress_bar.setVisible(False)
|
||
self.status_label.setVisible(False)
|
||
self.filter_records() # 刷新数据展示页
|
||
|
||
def show_search_results(self, keyword):
|
||
"""显示搜索结果"""
|
||
try:
|
||
conn = mysql.connector.connect(**db_config)
|
||
cursor = conn.cursor(dictionary=True)
|
||
cursor.execute("""
|
||
SELECT *,
|
||
COALESCE(download_flag, FALSE) as is_downloaded
|
||
FROM xhs_notes
|
||
WHERE source_keyword = %s
|
||
ORDER BY time DESC
|
||
""", (keyword,))
|
||
records = cursor.fetchall()
|
||
|
||
# 显示搜索结果统计
|
||
self.result_count_label.setText(f'关键词 "{keyword}" 的搜索结果:{len(records)} 条')
|
||
self.result_count_label.setVisible(True)
|
||
|
||
# 清空搜索结果区域
|
||
self.clear_search_results()
|
||
|
||
# 创建笔记卡片
|
||
for record in records:
|
||
card = NoteCard(record, record['is_downloaded'])
|
||
self.search_layout.addWidget(card)
|
||
|
||
except Exception as e:
|
||
self.show_error(f"获取数据失败: {str(e)}")
|
||
finally:
|
||
if 'conn' in locals():
|
||
conn.close()
|
||
|
||
def show_error(self, message):
|
||
"""显示错误信息"""
|
||
self.status_label.setText(f"错误: {message}")
|
||
self.status_label.setStyleSheet("color: red;")
|
||
self.status_label.setVisible(True)
|
||
self.search_button.setEnabled(True)
|
||
|
||
def start_batch_download(self):
|
||
"""开始批量下载未完成的记录"""
|
||
try:
|
||
conn = mysql.connector.connect(**db_config)
|
||
cursor = conn.cursor(dictionary=True)
|
||
|
||
# 获取未下载完成的记录
|
||
cursor.execute("""
|
||
SELECT note_id
|
||
FROM xhs_notes
|
||
WHERE COALESCE(download_flag, FALSE) = FALSE
|
||
""")
|
||
unfinished_records = cursor.fetchall()
|
||
|
||
if not unfinished_records:
|
||
self.show_status_message("没有需要下载的记录")
|
||
return
|
||
|
||
# 显示开始下载的消息
|
||
self.show_status_message(f"开始下载 {len(unfinished_records)} 条未完成记录...")
|
||
|
||
# 创建批量下载工作线程
|
||
self.batch_worker = BatchDownloadWorker([r['note_id'] for r in unfinished_records])
|
||
self.batch_worker.progress.connect(self.update_batch_progress)
|
||
self.batch_worker.error.connect(self.show_error)
|
||
self.batch_worker.finished.connect(self.on_batch_download_complete)
|
||
self.batch_worker.start()
|
||
|
||
except Exception as e:
|
||
self.show_error(f"获取未完成记录失败: {str(e)}")
|
||
finally:
|
||
if 'conn' in locals():
|
||
conn.close()
|
||
|
||
def update_batch_progress(self, current, total, success_count, failed_ids):
|
||
"""更新批量下载进度"""
|
||
self.show_status_message(f"正在下载: {current}/{total} (成功: {success_count})")
|
||
if failed_ids:
|
||
failed_text = "下载失败的记录:\n" + "\n".join(failed_ids)
|
||
self.failed_records_label.setText(failed_text)
|
||
self.failed_records_label.setVisible(True)
|
||
|
||
def on_batch_download_complete(self):
|
||
"""批量下载完成的处理"""
|
||
self.show_status_message("批量下载完成")
|
||
self.filter_records() # 刷新显示
|
||
|
||
def show_status_message(self, message):
|
||
"""显示状态消息"""
|
||
self.total_count_label.setText(message)
|
||
|
||
def start_ai_process(self):
|
||
"""AI处理功能(预留)"""
|
||
# TODO: 实现AI处理功能
|
||
self.show_status_message("AI处理功能开发中...")
|
||
|
||
if __name__ == "__main__":
|
||
app = QApplication(sys.argv)
|
||
window = MainWindow()
|
||
window.show()
|
||
sys.exit(app.exec()) |