web_search/main.py
2024-12-19 11:32:54 +08:00

168 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import datetime
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import aiomysql
import requests
from search_plus import search_and_save
# 数据库连接配置
db_config = {
'user': 'root',
'password': 'zaq12wsx@9Xin',
'host': '183.11.229.79',
'port': 3316,
'db': 'gptDB',
'auth_plugin': 'mysql_native_password'
}
SERP_API_KEY = "8af097ae8b587bb0569425058e03e5ef33b4c7b8b1a505053764b62e7e4ab9d6"
app = FastAPI()
class SearchRequest(BaseModel):
query: str
class SearchResult(BaseModel):
results: str
# 添加新的请求模型
class BatchSearchRequest(BaseModel):
queries: list[str]
class SearchResultItem(BaseModel):
query: str
engine: str
results: dict
class BatchSearchResponse(BaseModel):
status: str
message: str
total_processed: int
search_results: list[SearchResultItem]
async def fetch_all_content(query):
query_list = query.split('\n')
query_list = ["'"+item.split('.')[0]+"'" for item in query_list if item]
placeholders = ', '.join(query_list)
sql = f"SELECT url, gpt_output FROM saved_webpages WHERE hash_code IN ({placeholders});"
pool = await aiomysql.create_pool(**db_config)
contents = []
counts = 0
try:
async with pool.acquire() as conn:
async with conn.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute(sql)
results = await cursor.fetchall()
if not results: # 如果没有找到结果
return [], 0
urls = [result['url'] for result in results]
gpt_outputs = [result['gpt_output'] for result in results]
for url, gpt_output in zip(urls, gpt_outputs):
final_result = gpt_output + '\n相关链接:' + url
contents.append(final_result)
counts = len(results)
except Exception as e:
print(f"Error fetching summaries: {e}")
finally:
pool.close()
await pool.wait_closed()
return contents, counts
async def save_user_query(query: str, result_count: int):
"""保存用户查询到数据库"""
pool = await aiomysql.create_pool(**db_config)
try:
async with pool.acquire() as conn:
async with conn.cursor() as cursor:
sql = """
INSERT INTO user_queries (query, query_datetime, result_count)
VALUES (%s, %s, %s)
"""
await cursor.execute(sql, (query, datetime.datetime.now(), result_count))
await conn.commit()
except Exception as e:
print(f"Error saving user query: {e}")
finally:
pool.close()
await pool.wait_closed()
@app.post("/search/", response_model=list[SearchResult])
async def search(request: SearchRequest):
search_results = []
query = request.query
# 首先尝试从saved_webpages获取结果
results, counts = await fetch_all_content(query)
# 如果没有找到结果使用search_plus进行搜索
# if counts == 0:
# web_results = await search_and_save(query)
# results = web_results
# counts = len(web_results)
# 无论是否找到结果,都保存用户的查询记录
try:
await save_user_query(query, counts)
except Exception as e:
print(f"Error saving user query: {e}")
# 处理结果
for result in results:
search_results.append(SearchResult(results=result))
if not search_results:
# 如果没有找到任何结果,返回空列表但仍然记录这次查询
await save_user_query(query, 0)
return []
return search_results
@app.post("/batch_search/", response_model=BatchSearchResponse)
async def batch_search(request: BatchSearchRequest):
try:
processed_count = 0
search_results = []
try:
# 处理每个查询
for query in request.queries:
for engine in ["google", "bing", "baidu"]:
params = {
"api_key": SERP_API_KEY,
"engine": engine,
"q": query
}
response = requests.get('https://serpapi.com/search', params=params)
search = response.json()
search_metadata = search.get('search_metadata', {})
if search_metadata.get('status') == 'Success':
json_endpoint = search_metadata.get('json_endpoint')
response = requests.get(json_endpoint)
if response.status_code == 200:
data = response.json()
json_record_str = json.dumps(data)
# 添加到搜索结果列表
search_results.append(
SearchResultItem(
query=query,
engine=engine,
results=data
)
)
finally:
return BatchSearchResponse(
status="success",
message=f"成功处理了 {processed_count} 条搜索请求",
total_processed=processed_count,
search_results=search_results
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"批量搜索失败: {str(e)}")