168 lines
5.6 KiB
Python
168 lines
5.6 KiB
Python
import json
|
||
import datetime
|
||
from fastapi import FastAPI, HTTPException
|
||
from pydantic import BaseModel
|
||
import aiomysql
|
||
import requests
|
||
from search_plus import search_and_save
|
||
|
||
# 数据库连接配置
|
||
db_config = {
|
||
'user': 'root',
|
||
'password': 'zaq12wsx@9Xin',
|
||
'host': '183.11.229.79',
|
||
'port': 3316,
|
||
'db': 'gptDB',
|
||
'auth_plugin': 'mysql_native_password'
|
||
}
|
||
SERP_API_KEY = "8af097ae8b587bb0569425058e03e5ef33b4c7b8b1a505053764b62e7e4ab9d6"
|
||
app = FastAPI()
|
||
|
||
class SearchRequest(BaseModel):
|
||
query: str
|
||
|
||
class SearchResult(BaseModel):
|
||
results: str
|
||
|
||
# 添加新的请求模型
|
||
class BatchSearchRequest(BaseModel):
|
||
queries: list[str]
|
||
|
||
class SearchResultItem(BaseModel):
|
||
query: str
|
||
engine: str
|
||
results: dict
|
||
|
||
class BatchSearchResponse(BaseModel):
|
||
status: str
|
||
message: str
|
||
total_processed: int
|
||
search_results: list[SearchResultItem]
|
||
|
||
async def fetch_all_content(query):
|
||
query_list = query.split('\n')
|
||
query_list = ["'"+item.split('.')[0]+"'" for item in query_list if item]
|
||
placeholders = ', '.join(query_list)
|
||
|
||
sql = f"SELECT url, gpt_output FROM saved_webpages WHERE hash_code IN ({placeholders});"
|
||
pool = await aiomysql.create_pool(**db_config)
|
||
contents = []
|
||
counts = 0
|
||
try:
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor(aiomysql.DictCursor) as cursor:
|
||
await cursor.execute(sql)
|
||
results = await cursor.fetchall()
|
||
if not results: # 如果没有找到结果
|
||
return [], 0
|
||
|
||
urls = [result['url'] for result in results]
|
||
gpt_outputs = [result['gpt_output'] for result in results]
|
||
for url, gpt_output in zip(urls, gpt_outputs):
|
||
final_result = gpt_output + '\n相关链接:' + url
|
||
contents.append(final_result)
|
||
counts = len(results)
|
||
except Exception as e:
|
||
print(f"Error fetching summaries: {e}")
|
||
finally:
|
||
pool.close()
|
||
await pool.wait_closed()
|
||
return contents, counts
|
||
|
||
async def save_user_query(query: str, result_count: int):
|
||
"""保存用户查询到数据库"""
|
||
pool = await aiomysql.create_pool(**db_config)
|
||
try:
|
||
async with pool.acquire() as conn:
|
||
async with conn.cursor() as cursor:
|
||
sql = """
|
||
INSERT INTO user_queries (query, query_datetime, result_count)
|
||
VALUES (%s, %s, %s)
|
||
"""
|
||
await cursor.execute(sql, (query, datetime.datetime.now(), result_count))
|
||
await conn.commit()
|
||
except Exception as e:
|
||
print(f"Error saving user query: {e}")
|
||
finally:
|
||
pool.close()
|
||
await pool.wait_closed()
|
||
|
||
@app.post("/search/", response_model=list[SearchResult])
|
||
async def search(request: SearchRequest):
|
||
search_results = []
|
||
query = request.query
|
||
|
||
# 首先尝试从saved_webpages获取结果
|
||
results, counts = await fetch_all_content(query)
|
||
|
||
# 如果没有找到结果,使用search_plus进行搜索
|
||
# if counts == 0:
|
||
# web_results = await search_and_save(query)
|
||
# results = web_results
|
||
# counts = len(web_results)
|
||
|
||
# 无论是否找到结果,都保存用户的查询记录
|
||
try:
|
||
await save_user_query(query, counts)
|
||
except Exception as e:
|
||
print(f"Error saving user query: {e}")
|
||
|
||
# 处理结果
|
||
for result in results:
|
||
search_results.append(SearchResult(results=result))
|
||
|
||
if not search_results:
|
||
# 如果没有找到任何结果,返回空列表但仍然记录这次查询
|
||
await save_user_query(query, 0)
|
||
return []
|
||
|
||
return search_results
|
||
|
||
@app.post("/batch_search/", response_model=BatchSearchResponse)
|
||
async def batch_search(request: BatchSearchRequest):
|
||
try:
|
||
processed_count = 0
|
||
search_results = []
|
||
try:
|
||
# 处理每个查询
|
||
for query in request.queries:
|
||
for engine in ["google", "bing", "baidu"]:
|
||
params = {
|
||
"api_key": SERP_API_KEY,
|
||
"engine": engine,
|
||
"q": query
|
||
}
|
||
response = requests.get('https://serpapi.com/search', params=params)
|
||
search = response.json()
|
||
search_metadata = search.get('search_metadata', {})
|
||
|
||
if search_metadata.get('status') == 'Success':
|
||
json_endpoint = search_metadata.get('json_endpoint')
|
||
response = requests.get(json_endpoint)
|
||
|
||
if response.status_code == 200:
|
||
data = response.json()
|
||
json_record_str = json.dumps(data)
|
||
# 添加到搜索结果列表
|
||
search_results.append(
|
||
SearchResultItem(
|
||
query=query,
|
||
engine=engine,
|
||
results=data
|
||
)
|
||
)
|
||
|
||
|
||
finally:
|
||
|
||
return BatchSearchResponse(
|
||
status="success",
|
||
message=f"成功处理了 {processed_count} 条搜索请求",
|
||
total_processed=processed_count,
|
||
search_results=search_results
|
||
)
|
||
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"批量搜索失败: {str(e)}")
|
||
|