xhs_crawler/xhs_crawler_gui.py

1121 lines
44 KiB
Python
Raw Normal View History

2024-12-17 08:14:10 +00:00
import sys
import os
import json
import asyncio
import subprocess
from datetime import datetime
from PyQt6.QtWidgets import (QApplication, QMainWindow, QWidget, QVBoxLayout,
QHBoxLayout, QLineEdit, QPushButton, QProgressBar,
QScrollArea, QLabel, QFrame, QTabWidget, QComboBox)
from PyQt6.QtCore import Qt, QThread, pyqtSignal, QUrl
from PyQt6.QtGui import QDesktopServices
import mysql.connector
import aiohttp
import aiofiles
from urllib.parse import urlparse
# 数据库配置
db_config = {
'user': 'root',
'password': 'zaq12wsx@9Xin',
'host': '183.11.229.79',
'port': 3316,
'database': '9Xin',
'auth_plugin': 'mysql_native_password'
}
class NoteCard(QFrame):
"""笔记卡片组件"""
def __init__(self, note_data, is_downloaded=False, parent=None):
super().__init__(parent)
self.note_data = note_data
self.is_downloaded = is_downloaded
self.setup_ui()
def setup_ui(self):
self.setFrameStyle(QFrame.Shape.Box | QFrame.Shadow.Raised)
# 根据下载状态设置不同的背景色
bg_color = "#f0f9ff" if self.is_downloaded else "white"
hover_color = "#e1f3ff" if self.is_downloaded else "#f0f0f0"
self.setStyleSheet(f"""
NoteCard {{
background-color: {bg_color};
border-radius: 10px;
margin: 5px;
padding: 10px;
font-family: "Microsoft YaHei", Arial;
}}
NoteCard:hover {{
background-color: {hover_color};
}}
""")
layout = QVBoxLayout(self)
# 笔记类型和下载状态
header_layout = QHBoxLayout()
type_label = QLabel("视频笔记" if self.note_data['type'] == 'video' else "图文笔记")
type_label.setStyleSheet("color: #666; font-size: 12px;")
header_layout.addWidget(type_label)
# 添加下载状态标签
status_label = QLabel("✓ 已下载" if self.is_downloaded else "⋯ 下载中")
status_label.setStyleSheet(
"color: #4CAF50;" if self.is_downloaded else "color: #FFA500;"
)
header_layout.addWidget(status_label, alignment=Qt.AlignmentFlag.AlignRight)
layout.addLayout(header_layout)
# 标题
title = QLabel(self.note_data['title'])
title.setStyleSheet("font-size: 16px; font-weight: bold;")
title.setWordWrap(True)
layout.addWidget(title)
# 描述
desc = QLabel(self.note_data['description'])
desc.setWordWrap(True)
desc.setStyleSheet("color: #333;")
layout.addWidget(desc)
# 统计信息
stats_layout = QHBoxLayout()
stats = [
f"❤️ {self.note_data['liked_count']}",
f"{self.note_data['collected_count']}",
f"💬 {self.note_data['comment_count']}",
f"↗️ {self.note_data['share_count']}"
]
for stat in stats:
stat_label = QLabel(stat)
stat_label.setStyleSheet("color: #666; font-size: 12px;")
stats_layout.addWidget(stat_label)
layout.addLayout(stats_layout)
# 标签
if self.note_data['tag_list']:
tags = self.note_data['tag_list'].split(',')
tags_text = ' '.join([f'#{tag}' for tag in tags])
tags_label = QLabel(tags_text)
tags_label.setStyleSheet("color: #0066cc; font-size: 12px;")
tags_label.setWordWrap(True)
layout.addWidget(tags_label)
# 添加下载状态和进度区域
self.status_area = QWidget()
status_layout = QHBoxLayout(self.status_area)
# 下载进度标签
self.progress_label = QLabel()
self.progress_label.setStyleSheet("color: #666; font-size: 12px;")
self.progress_label.setVisible(False)
status_layout.addWidget(self.progress_label)
# 下载进度条
self.progress_bar = QProgressBar()
self.progress_bar.setStyleSheet("""
QProgressBar {
border: 1px solid #ddd;
border-radius: 3px;
text-align: center;
height: 10px;
}
QProgressBar::chunk {
background-color: #1a73e8;
border-radius: 2px;
}
""")
self.progress_bar.setVisible(False)
status_layout.addWidget(self.progress_bar)
layout.addWidget(self.status_area)
# 重新下载按钮
if not self.is_downloaded:
self.retry_button = QPushButton("重新下载")
self.retry_button.setStyleSheet("""
QPushButton {
padding: 5px 10px;
background-color: #ff4d4f;
color: white;
border: none;
border-radius: 3px;
font-size: 12px;
}
QPushButton:hover {
background-color: #ff7875;
}
QPushButton:disabled {
background-color: #ffccc7;
}
""")
self.retry_button.clicked.connect(self.retry_download)
layout.addWidget(self.retry_button)
def retry_download(self):
# 发送信号给主窗口处理下载
parent = self.parent()
while parent and not isinstance(parent, MainWindow):
parent = parent.parent()
if parent:
parent.retry_download(self.note_data['note_id'])
def update_progress(self, value, message):
"""更新下载进度"""
self.progress_bar.setValue(value)
self.progress_label.setText(message)
self.progress_bar.setVisible(True)
self.progress_label.setVisible(True)
def set_downloading(self, is_downloading):
"""设置下载状态"""
if hasattr(self, 'retry_button'):
self.retry_button.setEnabled(not is_downloading)
self.retry_button.setText("下载中..." if is_downloading else "重新下载")
def show_error(self, message):
"""显示错误信息"""
self.progress_label.setText(f"错误: {message}")
self.progress_label.setStyleSheet("color: red;")
self.progress_label.setVisible(True)
self.progress_bar.setVisible(False)
def download_complete(self):
"""下载完成"""
self.progress_label.setText("下载完成")
self.progress_label.setStyleSheet("color: #4CAF50;")
self.progress_bar.setVisible(False)
self.is_downloaded = True
if hasattr(self, 'retry_button'):
self.retry_button.setVisible(False)
def mousePressEvent(self, event):
"""处理鼠标点击事件"""
if event.button() == Qt.MouseButton.LeftButton:
url = QUrl(self.note_data['note_url'])
QDesktopServices.openUrl(url)
class WorkerThread(QThread):
"""工作线程"""
progress = pyqtSignal(int, str)
status_update = pyqtSignal(str) # 新增:用于显示详细状态
download_progress = pyqtSignal(str, bool) # 新增:用于更新下载状态
finished = pyqtSignal(list)
error = pyqtSignal(str)
data_imported = pyqtSignal()
# 添加类变量
MAX_RETRIES = 5 # 最大重试次数
DOWNLOAD_TIMEOUT = 300 # 下载超时时间(秒)
RETRY_DELAY = 2 # 重试延迟时间(秒)
CHUNK_SIZE = 8192 # 分块下载大小
def __init__(self, keywords):
super().__init__()
self.keywords = keywords
self.retry_note_id = None # 添加重试笔记ID属性
def run(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
if self.retry_note_id:
# 重新下载单个笔记
records = loop.run_until_complete(self.download_single_note(self.retry_note_id))
self.progress.emit(100, "处理完成")
self.finished.emit(records)
else:
# 原有的完整搜索流程
if loop.run_until_complete(self.run_crawler()):
if loop.run_until_complete(self.import_to_db()):
self.data_imported.emit()
records = loop.run_until_complete(self.download_media())
self.progress.emit(100, "处理完成")
self.finished.emit(records)
finally:
loop.close()
async def download_single_note(self, note_id):
"""下载单个笔记的媒体文件"""
try:
self.progress.emit(0, "开始下载...")
conn = mysql.connector.connect(**db_config)
cursor = conn.cursor(dictionary=True)
cursor.execute("""
SELECT note_id, image_list, video_url
FROM xhs_notes
WHERE note_id = %s
""", (note_id,))
record = cursor.fetchone()
if record:
timeout = aiohttp.ClientTimeout(total=300)
conn_kwargs = {'timeout': timeout, 'ssl': False}
async with aiohttp.ClientSession(**conn_kwargs) as session:
base_dir = f'./data/xhs/json/media/{record["note_id"]}'
os.makedirs(base_dir, exist_ok=True)
total_files = 0
completed_files = 0
# 计算总文件数
if record['image_list']:
total_files += len([url for url in record['image_list'].split(',') if url.strip()])
if record['video_url'] and record['video_url'].strip():
total_files += 1
if total_files == 0:
self.progress.emit(100, "没有需要下载的文件")
return [record]
# 下载图片
if record['image_list']:
image_urls = record['image_list'].split(',')
for i, url in enumerate(image_urls):
if url.strip():
ext = os.path.splitext(urlparse(url).path)[1] or '.jpg'
save_path = os.path.join(base_dir, f'image_{i+1}{ext}')
self.status_update.emit(f"下载图片 {i+1}/{len(image_urls)}")
if await self.download_file_with_retry(session, url.strip(), save_path, note_id):
completed_files += 1
progress = int(completed_files / total_files * 100)
self.progress.emit(progress, f"已完成: {completed_files}/{total_files}")
# 下载视频
if record['video_url'] and record['video_url'].strip():
url = record['video_url'].strip()
ext = os.path.splitext(urlparse(url).path)[1] or '.mp4'
save_path = os.path.join(base_dir, f'video{ext}')
self.status_update.emit("下载视频...")
if await self.download_file_with_retry(session, url, save_path, note_id, True):
completed_files += 1
progress = int(completed_files / total_files * 100)
self.progress.emit(progress, f"已完成: {completed_files}/{total_files}")
# 检查下载完成状态并更新数据库
await self.check_download_complete(note_id)
self.progress.emit(100, "下载完成")
return [record]
else:
self.error.emit("找不到指定的笔记")
return []
except Exception as e:
self.error.emit(str(e))
return []
finally:
if 'conn' in locals():
conn.close()
async def run_crawler(self):
"""运行爬虫"""
try:
self.progress.emit(10, "启动爬虫...")
process = await asyncio.create_subprocess_exec(
'python', 'main.py', '--platform', 'xhs', '--lt', 'qrcode',
'--keywords', self.keywords,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
limit=1024*1024 # 增加缓冲区大小到1MB
)
# 实时读取爬虫输出
async def read_stream(stream):
buffer = ""
while True:
try:
chunk = await stream.read(8192) # 每次读取8KB
if not chunk:
break
text = chunk.decode('utf-8', errors='ignore')
buffer += text
# 按行处理缓冲区
while '\n' in buffer:
line, buffer = buffer.split('\n', 1)
line = line.strip()
if line:
self.status_update.emit(f"爬虫进度: {line}")
except Exception as e:
print(f"读取爬虫输出错误: {str(e)}")
continue
# 处理最后可能剩余的内容
if buffer.strip():
self.status_update.emit(f"爬虫进度: {buffer.strip()}")
# 同时读取stdout和stderr
await asyncio.gather(
read_stream(process.stdout),
read_stream(process.stderr)
)
await process.wait()
self.progress.emit(30, "爬虫数据获取完成")
return True
except Exception as e:
self.error.emit(f"爬虫执行失败: {str(e)}")
return False
async def import_to_db(self):
"""导入数据到据库"""
try:
self.progress.emit(40, "导入数据到数据库...")
conn = mysql.connector.connect(**db_config)
cursor = conn.cursor(dictionary=True)
# 修改JSON文件路径
json_path = f'./data/xhs/json/search_contents_{datetime.now().strftime("%Y-%m-%d")}.json'
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# 插入数据
for item in data:
# 检查记录是否存在
cursor.execute("SELECT COUNT(*) FROM xhs_notes WHERE note_id = %s", (item['note_id'],))
if cursor.fetchone()['COUNT(*)'] == 0:
# 插入新记录
insert_query = """INSERT INTO xhs_notes (
note_id, type, title, description, video_url, time,
last_update_time, user_id, nickname, avatar,
liked_count, collected_count, comment_count, share_count,
ip_location, image_list, tag_list, last_modify_ts,
note_url, source_keyword
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s,
%s, %s, %s, %s, %s, %s, %s, %s, %s, %s)"""
values = (
item.get('note_id'), item.get('type'), item.get('title'),
item.get('desc'), item.get('video_url'), item.get('time'),
item.get('last_update_time'), item.get('user_id'),
item.get('nickname'), item.get('avatar'),
item.get('liked_count'), item.get('collected_count'),
item.get('comment_count'), item.get('share_count'),
item.get('ip_location'), item.get('image_list'),
item.get('tag_list'), item.get('last_modify_ts'),
item.get('note_url'), self.keywords
)
cursor.execute(insert_query, values)
conn.commit()
self.progress.emit(60, "数导入完成")
return True
except Exception as e:
self.error.emit(f"数据导入失败: {str(e)}")
return False
finally:
if 'conn' in locals():
conn.close()
async def download_file_with_retry(self, session, url, save_path, note_id, is_video=False):
"""带重试的文件下载"""
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
'Referer': 'https://www.xiaohongshu.com/',
}
success = False
for attempt in range(self.MAX_RETRIES):
try:
async with session.get(url, headers=headers, timeout=self.DOWNLOAD_TIMEOUT) as response:
if response.status == 200:
# 检查内容类型
content_type = response.headers.get('content-type', '')
if 'video' not in content_type.lower() and url.endswith(('.mp4', '.m3u8')):
# 尝试处理重定向
location = response.headers.get('location')
if location:
url = location
continue
total_size = int(response.headers.get('content-length', 0))
if total_size == 0:
print(f"警告:{url} 的内容长度为0")
return False
# 分块下载
async with aiofiles.open(save_path, 'wb') as f:
downloaded = 0
async for chunk in response.content.iter_chunked(self.CHUNK_SIZE):
await f.write(chunk)
downloaded += len(chunk)
# 验证文件大小
if os.path.getsize(save_path) > 0:
success = True
break
else:
print(f"下载的文件大小为0: {url}")
if os.path.exists(save_path):
os.remove(save_path)
elif response.status in [301, 302, 303, 307, 308]:
# 处理重定向
url = str(response.url)
continue
else:
print(f"下载失败,状态码: {response.status}, URL: {url}")
except asyncio.TimeoutError:
print(f"下载超时 {url}, 尝试次数 {attempt + 1}/{self.MAX_RETRIES}")
if attempt < self.MAX_RETRIES - 1:
await asyncio.sleep(self.RETRY_DELAY)
continue
except Exception as e:
print(f"下载出错 {url}: {str(e)}, 尝试次数 {attempt + 1}/{self.MAX_RETRIES}")
if attempt < self.MAX_RETRIES - 1:
await asyncio.sleep(self.RETRY_DELAY)
continue
if success:
# 只有在成功下载后才检查是否所有文件都下载完成
await self.check_download_complete(note_id)
return success
async def check_download_complete(self, note_id):
"""检查笔记的所有媒体是否下载完成"""
try:
conn = mysql.connector.connect(**db_config)
cursor = conn.cursor(dictionary=True)
# 获取笔记信息
cursor.execute("""
SELECT image_list, video_url
FROM xhs_notes
WHERE note_id = %s
""", (note_id,))
record = cursor.fetchone()
if not record:
return
is_complete = True
media_dir = f'./data/xhs/json/media/{note_id}'
# 检查图片
if record['image_list']:
image_urls = record['image_list'].split(',')
for i, url in enumerate(image_urls):
if url.strip():
ext = os.path.splitext(urlparse(url).path)[1] or '.jpg'
image_path = os.path.join(media_dir, f'image_{i+1}{ext}')
if not os.path.exists(image_path) or os.path.getsize(image_path) == 0:
is_complete = False
break
# 检查视频
if record['video_url'] and record['video_url'].strip():
url = record['video_url'].strip()
ext = os.path.splitext(urlparse(url).path)[1] or '.mp4'
video_path = os.path.join(media_dir, f'video{ext}')
if not os.path.exists(video_path) or os.path.getsize(video_path) == 0:
is_complete = False
# 更新数据库
cursor.execute("""
UPDATE xhs_notes
SET download_flag = %s
WHERE note_id = %s
""", (is_complete, note_id))
conn.commit()
if is_complete:
self.download_progress.emit(note_id, True)
except Exception as e:
print(f"检查下载状态时出错: {e}")
finally:
if 'conn' in locals():
conn.close()
async def download_media(self):
"""下载媒体文件"""
try:
self.progress.emit(70, "开始下载媒体文件...")
conn = mysql.connector.connect(**db_config)
cursor = conn.cursor(dictionary=True)
cursor.execute("""
SELECT note_id, image_list, video_url
FROM xhs_notes
WHERE source_keyword = %s
""", (self.keywords,))
records = cursor.fetchall()
# 创建下载任务列表
download_tasks = []
timeout = aiohttp.ClientTimeout(total=300) # 增加超时时间到5分钟
conn_kwargs = {
'timeout': timeout,
'ssl': False, # 禁用SSL验证
}
async with aiohttp.ClientSession(**conn_kwargs) as session:
for record in records:
self.status_update.emit(f"处理笔记: {record['note_id']}")
base_dir = f'./data/xhs/json/media/{record["note_id"]}'
os.makedirs(base_dir, exist_ok=True)
# 处理图下载任务
if record['image_list']:
image_urls = record['image_list'].split(',')
for i, url in enumerate(image_urls):
if url.strip():
ext = os.path.splitext(urlparse(url).path)[1] or '.jpg'
save_path = os.path.join(base_dir, f'image_{i+1}{ext}')
if not os.path.exists(save_path): # 避免重复下载
task = self.download_file_with_retry(
session, url.strip(), save_path, record['note_id'], False
)
download_tasks.append(task)
# 处理视频下载任务
if record['video_url'] and record['video_url'].strip():
url = record['video_url'].strip()
ext = os.path.splitext(urlparse(url).path)[1] or '.mp4'
save_path = os.path.join(base_dir, f'video{ext}')
if not os.path.exists(save_path): # 避免重复下载
task = self.download_file_with_retry(
session, url, save_path, record['note_id'], True
)
download_tasks.append(task)
# 分批执行下载任务
batch_size = 3 # 少并发数
total_tasks = len(download_tasks)
completed = 0
for i in range(0, len(download_tasks), batch_size):
batch = download_tasks[i:i + batch_size]
results = await asyncio.gather(*batch, return_exceptions=True)
completed += len(batch)
success_count = sum(1 for r in results if r is True)
progress = int(70 + (completed / total_tasks) * 20)
self.progress.emit(progress, f"已下载: {completed}/{total_tasks}")
self.status_update.emit(
f"下载进度: {completed}/{total_tasks} (成功: {success_count})"
)
await asyncio.sleep(1)
self.progress.emit(90, "媒体文件下载完成")
return records
except Exception as e:
self.error.emit(f"媒体下载失败: {str(e)}")
return []
finally:
if 'conn' in locals():
conn.close()
class BatchDownloadWorker(QThread):
progress = pyqtSignal(int, int, int, list) # current, total, success_count, failed_ids
error = pyqtSignal(str)
finished = pyqtSignal()
def __init__(self, note_ids):
super().__init__()
self.note_ids = note_ids
def run(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
failed_ids = []
success_count = 0
total = len(self.note_ids)
for i, note_id in enumerate(self.note_ids, 1):
try:
# 下载单个笔记的媒体文件
worker = WorkerThread(None)
worker.retry_note_id = note_id
success = loop.run_until_complete(worker.download_single_note(note_id))
if success:
success_count += 1
else:
failed_ids.append(note_id)
except Exception as e:
print(f"下载记录 {note_id} 失败: {str(e)}")
failed_ids.append(note_id)
self.progress.emit(i, total, success_count, failed_ids)
self.finished.emit()
except Exception as e:
self.error.emit(str(e))
finally:
loop.close()
class MainWindow(QMainWindow):
def __init__(self):
super().__init__()
self.setWindowTitle("小红书内容采集器")
self.setMinimumSize(1200, 800)
os.makedirs('./data/xhs/json/media', exist_ok=True)
self.setup_ui()
# 初始化时加载所有记录
self.filter_records() # 这里会加载所有记录,因为搜索框为空且类型过滤为"全部笔记"
def setup_ui(self):
central_widget = QWidget()
self.setCentralWidget(central_widget)
layout = QVBoxLayout(central_widget)
# 创建标签页
self.tab_widget = QTabWidget()
layout.addWidget(self.tab_widget)
# 据展示标签页
self.setup_data_tab()
# 搜索标签页
self.setup_search_tab()
def setup_data_tab(self):
data_tab = QWidget()
layout = QVBoxLayout(data_tab)
# 搜索和过滤区域
filter_layout = QHBoxLayout()
# 搜索输入框
self.db_search_input = QLineEdit()
self.db_search_input.setPlaceholderText("搜索标题或描述...")
self.db_search_input.setStyleSheet(self.get_input_style())
self.db_search_input.returnPressed.connect(self.filter_records)
filter_layout.addWidget(self.db_search_input)
# 类型过滤下拉框
self.type_filter = QComboBox()
self.type_filter.addItems(["全部笔记", "仅看<EFBFBD><EFBFBD>", "仅看视频"])
self.type_filter.setStyleSheet("""
QComboBox {
padding: 8px;
border: 1px solid #ccc;
border-radius: 4px;
font-size: 14px;
font-family: "Microsoft YaHei", Arial;
}
""")
self.type_filter.currentIndexChanged.connect(self.filter_records)
filter_layout.addWidget(self.type_filter)
# 搜索按钮
db_search_button = QPushButton("搜索")
db_search_button.setStyleSheet(self.get_button_style())
db_search_button.clicked.connect(self.filter_records)
filter_layout.addWidget(db_search_button)
# AI处理按钮
ai_process_button = QPushButton("AI处理")
ai_process_button.setStyleSheet("""
QPushButton {
padding: 8px 16px;
background-color: #52c41a;
color: white;
border: none;
border-radius: 4px;
font-size: 14px;
font-family: "Microsoft YaHei", Arial;
}
QPushButton:hover {
background-color: #73d13d;
}
""")
ai_process_button.clicked.connect(self.start_ai_process) # 预留AI处理功能
filter_layout.addWidget(ai_process_button)
layout.addLayout(filter_layout)
# 记录统计
self.total_count_label = QLabel()
self.total_count_label.setStyleSheet("color: #666; margin: 5px 0;")
layout.addWidget(self.total_count_label)
# 记录列表
self.records_scroll = QScrollArea()
self.records_scroll.setWidgetResizable(True)
self.records_content = QWidget()
self.records_layout = QVBoxLayout(self.records_content)
self.records_scroll.setWidget(self.records_content)
layout.addWidget(self.records_scroll)
self.tab_widget.addTab(data_tab, "数据展示")
def setup_search_tab(self):
search_tab = QWidget()
layout = QVBoxLayout(search_tab)
# 搜索区域
search_layout = QHBoxLayout()
self.search_input = QLineEdit()
self.search_input.setPlaceholderText("输入搜索关键词")
self.search_input.setStyleSheet(self.get_input_style())
self.search_button = QPushButton("搜索")
self.search_button.setStyleSheet(self.get_button_style())
search_layout.addWidget(self.search_input)
search_layout.addWidget(self.search_button)
layout.addLayout(search_layout)
# 进度信息
self.progress_bar = QProgressBar()
self.progress_bar.setVisible(False)
layout.addWidget(self.progress_bar)
self.status_label = QLabel()
self.status_label.setVisible(False)
layout.addWidget(self.status_label)
self.detail_status_label = QLabel()
self.detail_status_label.setStyleSheet("color: #666; font-size: 12px;")
self.detail_status_label.setVisible(False)
layout.addWidget(self.detail_status_label)
# 搜索结果
self.result_count_label = QLabel()
self.result_count_label.setStyleSheet("color: #666; margin: 5px 0;")
self.result_count_label.setVisible(False)
layout.addWidget(self.result_count_label)
self.search_scroll = QScrollArea()
self.search_scroll.setWidgetResizable(True)
self.search_content = QWidget()
self.search_layout = QVBoxLayout(self.search_content)
self.search_scroll.setWidget(self.search_content)
layout.addWidget(self.search_scroll)
self.tab_widget.addTab(search_tab, "搜索笔记")
# 连接信号
self.search_button.clicked.connect(self.start_search)
def filter_records(self):
"""根据搜索条件过滤记录"""
try:
conn = mysql.connector.connect(**db_config)
cursor = conn.cursor(dictionary=True)
# 构建查询条件
search_text = self.db_search_input.text().strip()
type_filter = self.type_filter.currentText()
query = """
SELECT *,
COALESCE(download_flag, FALSE) as is_downloaded
FROM xhs_notes
WHERE download_flag = TRUE # 只显示已下载完成的记录
"""
params = []
if search_text:
query += """ AND (title LIKE %s OR description LIKE %s)"""
params.extend([f'%{search_text}%', f'%{search_text}%'])
if type_filter == "仅看图文":
query += """ AND type = 'normal'"""
elif type_filter == "仅看视频":
query += """ AND type = 'video'"""
query += " ORDER BY time DESC"
cursor.execute(query, params)
records = cursor.fetchall()
# 更新显示
self.total_count_label.setText(f'找到 {len(records)} 条已下载记录')
self.clear_records()
for record in records:
card = NoteCard(record, True) # 所有显示的记录都是已下载的
self.records_layout.addWidget(card)
except Exception as e:
self.show_error(f"过滤记录失败: {str(e)}")
finally:
if 'conn' in locals():
conn.close()
def retry_download(self, note_id):
"""重新下载指定笔记的媒体文件"""
# 到对应的卡片
card = self.find_card_by_note_id(note_id)
if not card:
return
# 更新卡片状态
card.set_downloading(True)
card.update_progress(0, "准备下载...")
# 创建工作线程
self.worker = WorkerThread(None)
self.worker.progress.connect(lambda v, m: card.update_progress(v, m))
self.worker.finished.connect(lambda: self.on_retry_complete(note_id))
self.worker.error.connect(lambda e: self.on_retry_error(note_id, e))
self.worker.status_update.connect(lambda m: card.update_progress(-1, m))
# 启动单个笔记的下载
self.worker.retry_note_id = note_id
self.worker.start()
def find_card_by_note_id(self, note_id):
"""在两个标签页中查找指定note_id的卡片"""
# 在数据展示页查找
for i in range(self.records_layout.count()):
card = self.records_layout.itemAt(i).widget()
if isinstance(card, NoteCard) and card.note_data['note_id'] == note_id:
return card
# 在搜索结果页查找
for i in range(self.search_layout.count()):
card = self.search_layout.itemAt(i).widget()
if isinstance(card, NoteCard) and card.note_data['note_id'] == note_id:
return card
return None
def on_retry_complete(self, note_id):
"""重新下载完成的处理"""
card = self.find_card_by_note_id(note_id)
if card:
card.download_complete()
# 刷新显示
self.filter_records()
current_keyword = self.search_input.text().strip()
if current_keyword:
self.show_search_results(current_keyword)
def on_retry_error(self, note_id, error_message):
"""重新下载出错的处理"""
card = self.find_card_by_note_id(note_id)
if card:
card.set_downloading(False)
card.show_error(error_message)
def get_button_style(self):
"""获取按钮样式"""
return """
QPushButton {
padding: 8px 16px;
background-color: #1a73e8;
color: white;
border: none;
border-radius: 4px;
font-size: 14px;
font-family: "Microsoft YaHei", Arial;
}
QPushButton:hover {
background-color: #1557b0;
}
"""
def get_input_style(self):
"""获取输入框样式"""
return """
QLineEdit {
padding: 8px;
border: 1px solid #ccc;
border-radius: 4px;
font-size: 14px;
font-family: "Microsoft YaHei", Arial;
}
"""
def get_label_style(self):
"""获取标签样式"""
return """
QLabel {
font-size: 14px;
color: #333;
font-family: "Microsoft YaHei", Arial;
}
"""
def start_search(self):
"""开始搜索"""
keywords = self.search_input.text().strip()
if not keywords:
return
# 清空搜索结果
self.clear_search_results()
# 显示进度条
self.progress_bar.setVisible(True)
self.status_label.setVisible(True)
self.search_button.setEnabled(False)
# 创建工作线程
self.worker = WorkerThread(keywords)
self.worker.progress.connect(self.update_progress)
self.worker.finished.connect(self.on_download_complete)
self.worker.error.connect(self.show_error)
self.worker.status_update.connect(self.update_status_detail)
self.worker.download_progress.connect(self.update_download_status)
self.worker.data_imported.connect(self.on_data_imported)
self.worker.start()
def clear_search_results(self):
"""清空搜索结果区域"""
while self.search_layout.count():
item = self.search_layout.takeAt(0)
if item.widget():
item.widget().deleteLater()
self.result_count_label.setVisible(False)
def clear_records(self):
"""清空记录区域"""
while self.records_layout.count():
item = self.records_layout.takeAt(0)
if item.widget():
item.widget().deleteLater()
def update_progress(self, value, message):
"""更新进度条和状态信息"""
self.progress_bar.setValue(value)
self.status_label.setText(message)
self.status_label.setVisible(True)
def update_status_detail(self, message):
"""更新详细状态信息"""
self.detail_status_label.setText(message)
self.detail_status_label.setVisible(True)
def update_download_status(self, note_id, is_complete):
"""更新下载状态"""
if is_complete:
# 刷新显示
self.filter_records() # 刷新数据展示页
current_keyword = self.search_input.text().strip()
if current_keyword: # 如果有搜索关键词,也刷新搜索结果
self.show_search_results(current_keyword)
def on_data_imported(self):
"""数据导入完成后的处理"""
current_keyword = self.search_input.text().strip()
self.show_search_results(current_keyword)
self.filter_records() # 同时刷新数据展示页
def on_download_complete(self, records):
"""下载完成后的处理"""
self.search_button.setEnabled(True)
self.progress_bar.setVisible(False)
self.status_label.setVisible(False)
self.filter_records() # 刷新数据展示页
def show_search_results(self, keyword):
"""显示搜索结果"""
try:
conn = mysql.connector.connect(**db_config)
cursor = conn.cursor(dictionary=True)
cursor.execute("""
SELECT *,
COALESCE(download_flag, FALSE) as is_downloaded
FROM xhs_notes
WHERE source_keyword = %s
ORDER BY time DESC
""", (keyword,))
records = cursor.fetchall()
# 显示搜索结果统计
self.result_count_label.setText(f'关键词 "{keyword}" 的搜索结果:{len(records)}')
self.result_count_label.setVisible(True)
# 清空搜索结果区域
self.clear_search_results()
# 创建笔记卡片
for record in records:
card = NoteCard(record, record['is_downloaded'])
self.search_layout.addWidget(card)
except Exception as e:
self.show_error(f"获取数据失败: {str(e)}")
finally:
if 'conn' in locals():
conn.close()
def show_error(self, message):
"""显示错误信息"""
self.status_label.setText(f"错误: {message}")
self.status_label.setStyleSheet("color: red;")
self.status_label.setVisible(True)
self.search_button.setEnabled(True)
def start_batch_download(self):
"""开始批量下载未完成的记录"""
try:
conn = mysql.connector.connect(**db_config)
cursor = conn.cursor(dictionary=True)
# 获取未下载完成的记录
cursor.execute("""
SELECT note_id
FROM xhs_notes
WHERE COALESCE(download_flag, FALSE) = FALSE
""")
unfinished_records = cursor.fetchall()
if not unfinished_records:
self.show_status_message("没有需要下载的记录")
return
# 显示开始下载的消息
self.show_status_message(f"开始下载 {len(unfinished_records)} 条未完成记录...")
# 创建批量下载工作线程
self.batch_worker = BatchDownloadWorker([r['note_id'] for r in unfinished_records])
self.batch_worker.progress.connect(self.update_batch_progress)
self.batch_worker.error.connect(self.show_error)
self.batch_worker.finished.connect(self.on_batch_download_complete)
self.batch_worker.start()
except Exception as e:
self.show_error(f"获取未完成记录失败: {str(e)}")
finally:
if 'conn' in locals():
conn.close()
def update_batch_progress(self, current, total, success_count, failed_ids):
"""更新批量下载进度"""
self.show_status_message(f"正在下载: {current}/{total} (成功: {success_count})")
if failed_ids:
failed_text = "下载失败的记录:\n" + "\n".join(failed_ids)
self.failed_records_label.setText(failed_text)
self.failed_records_label.setVisible(True)
def on_batch_download_complete(self):
"""批量下载完成的处理"""
self.show_status_message("批量下载完成")
self.filter_records() # 刷新显示
def show_status_message(self, message):
"""显示状态消息"""
self.total_count_label.setText(message)
def start_ai_process(self):
"""AI处理功能预留"""
# TODO: 实现AI处理功能
self.show_status_message("AI处理功能开发中...")
if __name__ == "__main__":
app = QApplication(sys.argv)
window = MainWindow()
window.show()
sys.exit(app.exec())