web_search/main.py

168 lines
5.6 KiB
Python
Raw Normal View History

2024-12-19 03:32:54 +00:00
import json
import datetime
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import aiomysql
import requests
from search_plus import search_and_save
# 数据库连接配置
db_config = {
'user': 'root',
'password': 'zaq12wsx@9Xin',
'host': '183.11.229.79',
'port': 3316,
'db': 'gptDB',
'auth_plugin': 'mysql_native_password'
}
SERP_API_KEY = "8af097ae8b587bb0569425058e03e5ef33b4c7b8b1a505053764b62e7e4ab9d6"
app = FastAPI()
class SearchRequest(BaseModel):
query: str
class SearchResult(BaseModel):
results: str
# 添加新的请求模型
class BatchSearchRequest(BaseModel):
queries: list[str]
class SearchResultItem(BaseModel):
query: str
engine: str
results: dict
class BatchSearchResponse(BaseModel):
status: str
message: str
total_processed: int
search_results: list[SearchResultItem]
async def fetch_all_content(query):
query_list = query.split('\n')
query_list = ["'"+item.split('.')[0]+"'" for item in query_list if item]
placeholders = ', '.join(query_list)
sql = f"SELECT url, gpt_output FROM saved_webpages WHERE hash_code IN ({placeholders});"
pool = await aiomysql.create_pool(**db_config)
contents = []
counts = 0
try:
async with pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute(sql)
results = await cursor.fetchall()
if not results: # 如果没有找到结果
return [], 0
urls = [result['url'] for result in results]
gpt_outputs = [result['gpt_output'] for result in results]
for url, gpt_output in zip(urls, gpt_outputs):
final_result = gpt_output + '\n相关链接:' + url
contents.append(final_result)
counts = len(results)
except Exception as e:
print(f"Error fetching summaries: {e}")
finally:
pool.close()
await pool.wait_closed()
return contents, counts
async def save_user_query(query: str, result_count: int):
"""保存用户查询到数据库"""
pool = await aiomysql.create_pool(**db_config)
try:
async with pool.acquire() as conn:
async with conn.cursor() as cursor:
sql = """
INSERT INTO user_queries (query, query_datetime, result_count)
VALUES (%s, %s, %s)
"""
await cursor.execute(sql, (query, datetime.datetime.now(), result_count))
await conn.commit()
except Exception as e:
print(f"Error saving user query: {e}")
finally:
pool.close()
await pool.wait_closed()
@app.post("/search/", response_model=list[SearchResult])
async def search(request: SearchRequest):
search_results = []
query = request.query
# 首先尝试从saved_webpages获取结果
results, counts = await fetch_all_content(query)
# 如果没有找到结果使用search_plus进行搜索
# if counts == 0:
# web_results = await search_and_save(query)
# results = web_results
# counts = len(web_results)
# 无论是否找到结果,都保存用户的查询记录
try:
await save_user_query(query, counts)
except Exception as e:
print(f"Error saving user query: {e}")
# 处理结果
for result in results:
search_results.append(SearchResult(results=result))
if not search_results:
# 如果没有找到任何结果,返回空列表但仍然记录这次查询
await save_user_query(query, 0)
return []
return search_results
@app.post("/batch_search/", response_model=BatchSearchResponse)
async def batch_search(request: BatchSearchRequest):
try:
processed_count = 0
search_results = []
try:
# 处理每个查询
for query in request.queries:
for engine in ["google", "bing", "baidu"]:
params = {
"api_key": SERP_API_KEY,
"engine": engine,
"q": query
}
response = requests.get('https://serpapi.com/search', params=params)
search = response.json()
search_metadata = search.get('search_metadata', {})
if search_metadata.get('status') == 'Success':
json_endpoint = search_metadata.get('json_endpoint')
response = requests.get(json_endpoint)
if response.status_code == 200:
data = response.json()
json_record_str = json.dumps(data)
# 添加到搜索结果列表
search_results.append(
SearchResultItem(
query=query,
engine=engine,
results=data
)
)
finally:
return BatchSearchResponse(
status="success",
message=f"成功处理了 {processed_count} 条搜索请求",
total_processed=processed_count,
search_results=search_results
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"批量搜索失败: {str(e)}")