168 lines
5.6 KiB
Python
168 lines
5.6 KiB
Python
|
import json
|
|||
|
import datetime
|
|||
|
from fastapi import FastAPI, HTTPException
|
|||
|
from pydantic import BaseModel
|
|||
|
import aiomysql
|
|||
|
import requests
|
|||
|
from search_plus import search_and_save
|
|||
|
|
|||
|
# 数据库连接配置
|
|||
|
db_config = {
|
|||
|
'user': 'root',
|
|||
|
'password': 'zaq12wsx@9Xin',
|
|||
|
'host': '183.11.229.79',
|
|||
|
'port': 3316,
|
|||
|
'db': 'gptDB',
|
|||
|
'auth_plugin': 'mysql_native_password'
|
|||
|
}
|
|||
|
SERP_API_KEY = "8af097ae8b587bb0569425058e03e5ef33b4c7b8b1a505053764b62e7e4ab9d6"
|
|||
|
app = FastAPI()
|
|||
|
|
|||
|
class SearchRequest(BaseModel):
|
|||
|
query: str
|
|||
|
|
|||
|
class SearchResult(BaseModel):
|
|||
|
results: str
|
|||
|
|
|||
|
# 添加新的请求模型
|
|||
|
class BatchSearchRequest(BaseModel):
|
|||
|
queries: list[str]
|
|||
|
|
|||
|
class SearchResultItem(BaseModel):
|
|||
|
query: str
|
|||
|
engine: str
|
|||
|
results: dict
|
|||
|
|
|||
|
class BatchSearchResponse(BaseModel):
|
|||
|
status: str
|
|||
|
message: str
|
|||
|
total_processed: int
|
|||
|
search_results: list[SearchResultItem]
|
|||
|
|
|||
|
async def fetch_all_content(query):
|
|||
|
query_list = query.split('\n')
|
|||
|
query_list = ["'"+item.split('.')[0]+"'" for item in query_list if item]
|
|||
|
placeholders = ', '.join(query_list)
|
|||
|
|
|||
|
sql = f"SELECT url, gpt_output FROM saved_webpages WHERE hash_code IN ({placeholders});"
|
|||
|
pool = await aiomysql.create_pool(**db_config)
|
|||
|
contents = []
|
|||
|
counts = 0
|
|||
|
try:
|
|||
|
async with pool.acquire() as conn:
|
|||
|
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
|||
|
await cursor.execute(sql)
|
|||
|
results = await cursor.fetchall()
|
|||
|
if not results: # 如果没有找到结果
|
|||
|
return [], 0
|
|||
|
|
|||
|
urls = [result['url'] for result in results]
|
|||
|
gpt_outputs = [result['gpt_output'] for result in results]
|
|||
|
for url, gpt_output in zip(urls, gpt_outputs):
|
|||
|
final_result = gpt_output + '\n相关链接:' + url
|
|||
|
contents.append(final_result)
|
|||
|
counts = len(results)
|
|||
|
except Exception as e:
|
|||
|
print(f"Error fetching summaries: {e}")
|
|||
|
finally:
|
|||
|
pool.close()
|
|||
|
await pool.wait_closed()
|
|||
|
return contents, counts
|
|||
|
|
|||
|
async def save_user_query(query: str, result_count: int):
|
|||
|
"""保存用户查询到数据库"""
|
|||
|
pool = await aiomysql.create_pool(**db_config)
|
|||
|
try:
|
|||
|
async with pool.acquire() as conn:
|
|||
|
async with conn.cursor() as cursor:
|
|||
|
sql = """
|
|||
|
INSERT INTO user_queries (query, query_datetime, result_count)
|
|||
|
VALUES (%s, %s, %s)
|
|||
|
"""
|
|||
|
await cursor.execute(sql, (query, datetime.datetime.now(), result_count))
|
|||
|
await conn.commit()
|
|||
|
except Exception as e:
|
|||
|
print(f"Error saving user query: {e}")
|
|||
|
finally:
|
|||
|
pool.close()
|
|||
|
await pool.wait_closed()
|
|||
|
|
|||
|
@app.post("/search/", response_model=list[SearchResult])
|
|||
|
async def search(request: SearchRequest):
|
|||
|
search_results = []
|
|||
|
query = request.query
|
|||
|
|
|||
|
# 首先尝试从saved_webpages获取结果
|
|||
|
results, counts = await fetch_all_content(query)
|
|||
|
|
|||
|
# 如果没有找到结果,使用search_plus进行搜索
|
|||
|
# if counts == 0:
|
|||
|
# web_results = await search_and_save(query)
|
|||
|
# results = web_results
|
|||
|
# counts = len(web_results)
|
|||
|
|
|||
|
# 无论是否找到结果,都保存用户的查询记录
|
|||
|
try:
|
|||
|
await save_user_query(query, counts)
|
|||
|
except Exception as e:
|
|||
|
print(f"Error saving user query: {e}")
|
|||
|
|
|||
|
# 处理结果
|
|||
|
for result in results:
|
|||
|
search_results.append(SearchResult(results=result))
|
|||
|
|
|||
|
if not search_results:
|
|||
|
# 如果没有找到任何结果,返回空列表但仍然记录这次查询
|
|||
|
await save_user_query(query, 0)
|
|||
|
return []
|
|||
|
|
|||
|
return search_results
|
|||
|
|
|||
|
@app.post("/batch_search/", response_model=BatchSearchResponse)
|
|||
|
async def batch_search(request: BatchSearchRequest):
|
|||
|
try:
|
|||
|
processed_count = 0
|
|||
|
search_results = []
|
|||
|
try:
|
|||
|
# 处理每个查询
|
|||
|
for query in request.queries:
|
|||
|
for engine in ["google", "bing", "baidu"]:
|
|||
|
params = {
|
|||
|
"api_key": SERP_API_KEY,
|
|||
|
"engine": engine,
|
|||
|
"q": query
|
|||
|
}
|
|||
|
response = requests.get('https://serpapi.com/search', params=params)
|
|||
|
search = response.json()
|
|||
|
search_metadata = search.get('search_metadata', {})
|
|||
|
|
|||
|
if search_metadata.get('status') == 'Success':
|
|||
|
json_endpoint = search_metadata.get('json_endpoint')
|
|||
|
response = requests.get(json_endpoint)
|
|||
|
|
|||
|
if response.status_code == 200:
|
|||
|
data = response.json()
|
|||
|
json_record_str = json.dumps(data)
|
|||
|
# 添加到搜索结果列表
|
|||
|
search_results.append(
|
|||
|
SearchResultItem(
|
|||
|
query=query,
|
|||
|
engine=engine,
|
|||
|
results=data
|
|||
|
)
|
|||
|
)
|
|||
|
|
|||
|
|
|||
|
finally:
|
|||
|
|
|||
|
return BatchSearchResponse(
|
|||
|
status="success",
|
|||
|
message=f"成功处理了 {processed_count} 条搜索请求",
|
|||
|
total_processed=processed_count,
|
|||
|
search_results=search_results
|
|||
|
)
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
raise HTTPException(status_code=500, detail=f"批量搜索失败: {str(e)}")
|
|||
|
|