172 lines
5.3 KiB
Python
172 lines
5.3 KiB
Python
|
import logging
|
|||
|
from fastapi import FastAPI, HTTPException
|
|||
|
from fastapi.middleware.cors import CORSMiddleware
|
|||
|
from pydantic import BaseModel
|
|||
|
from typing import List, Optional
|
|||
|
from openai import OpenAI
|
|||
|
import json
|
|||
|
import os
|
|||
|
from rpa_client import rpa_prompt, api_key, base_url
|
|||
|
|
|||
|
# 配置日志
|
|||
|
logging.basicConfig(
|
|||
|
level=logging.INFO,
|
|||
|
format='%(asctime)s [%(levelname)s] %(message)s'
|
|||
|
)
|
|||
|
logger = logging.getLogger(__name__)
|
|||
|
|
|||
|
# 配置对话日志
|
|||
|
conversation_logger = logging.getLogger('conversation_logger')
|
|||
|
conversation_logger.setLevel(logging.INFO)
|
|||
|
file_handler = logging.FileHandler('rpa_conversation.log')
|
|||
|
file_handler.setLevel(logging.INFO)
|
|||
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
|||
|
file_handler.setFormatter(formatter)
|
|||
|
conversation_logger.addHandler(file_handler)
|
|||
|
|
|||
|
# 初始化 FastAPI 应用
|
|||
|
app = FastAPI(title="RPA Chat API")
|
|||
|
|
|||
|
# 配置 CORS
|
|||
|
app.add_middleware(
|
|||
|
CORSMiddleware,
|
|||
|
allow_origins=["*"],
|
|||
|
allow_credentials=True,
|
|||
|
allow_methods=["*"],
|
|||
|
allow_headers=["*"],
|
|||
|
)
|
|||
|
|
|||
|
# 初始化 OpenAI 客户端
|
|||
|
client = OpenAI(api_key=api_key, base_url=base_url)
|
|||
|
|
|||
|
# 请求模型
|
|||
|
class ChatRequest(BaseModel):
|
|||
|
message: str
|
|||
|
allowed_paths: Optional[str] = None
|
|||
|
|
|||
|
# 响应模型
|
|||
|
class ChatResponse(BaseModel):
|
|||
|
response: str
|
|||
|
commands: List[dict]
|
|||
|
|
|||
|
# 存储对话历史的字典
|
|||
|
conversations = {}
|
|||
|
|
|||
|
def get_conversation_history(conversation_id: str) -> list:
|
|||
|
"""获取或创建对话历史"""
|
|||
|
if conversation_id not in conversations:
|
|||
|
# 初始化新的对话,包含系统提示
|
|||
|
conversations[conversation_id] = [
|
|||
|
{
|
|||
|
"role": "system",
|
|||
|
"content": rpa_prompt
|
|||
|
}
|
|||
|
]
|
|||
|
return conversations[conversation_id]
|
|||
|
|
|||
|
async def interact_with_chatgpt(messages: list) -> str:
|
|||
|
"""与 ChatGPT 交互"""
|
|||
|
try:
|
|||
|
response = client.chat.completions.create(
|
|||
|
model="gpt-4o-mini",
|
|||
|
messages=messages,
|
|||
|
max_tokens=500,
|
|||
|
temperature=0.1,
|
|||
|
)
|
|||
|
return response.choices[0].message.content
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"OpenAI API 调用错误: {e}")
|
|||
|
raise
|
|||
|
|
|||
|
def parse_chatgpt_response(response_text: str) -> tuple[str, list]:
|
|||
|
"""解析 ChatGPT 的响应"""
|
|||
|
try:
|
|||
|
# 清理响应文本
|
|||
|
response_text = response_text.strip()
|
|||
|
if response_text.startswith("```json"):
|
|||
|
response_text = response_text.replace("```json", "", 1)
|
|||
|
if response_text.endswith("```"):
|
|||
|
response_text = response_text.replace("```", "", 1)
|
|||
|
|
|||
|
# 解析 JSON
|
|||
|
data = json.loads(response_text.strip())
|
|||
|
return data["response"], data.get("commands", [])
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"解析响应失败: {e}")
|
|||
|
return "解析响应时出错", []
|
|||
|
|
|||
|
@app.post("/chat", response_model=ChatResponse)
|
|||
|
async def chat(request: ChatRequest):
|
|||
|
"""处理聊天请求"""
|
|||
|
try:
|
|||
|
# 生成会话ID (这里简单使用固定ID,您可以根据需要修改)
|
|||
|
conversation_id = "default"
|
|||
|
|
|||
|
# 获取对话历史
|
|||
|
messages = get_conversation_history(conversation_id)
|
|||
|
|
|||
|
# # 如果提供了新的允许路径,更新系统提示
|
|||
|
# if request.allowed_paths:
|
|||
|
# messages[0]["content"] = rpa_prompt.replace(
|
|||
|
# "{allowed_paths}",
|
|||
|
# request.allowed_paths
|
|||
|
# )
|
|||
|
|
|||
|
# 添加用户消息
|
|||
|
messages.append({
|
|||
|
"role": "user",
|
|||
|
"content": request.message
|
|||
|
})
|
|||
|
|
|||
|
# 保持对话历史在合理长度
|
|||
|
if len(messages) > 10:
|
|||
|
# 保留系统消息和最近的对话
|
|||
|
messages = [messages[0]] + messages[-9:]
|
|||
|
|
|||
|
# 记录对话
|
|||
|
conversation_logger.info(f"User: {request.message}")
|
|||
|
|
|||
|
# 调用 ChatGPT
|
|||
|
chatgpt_response = await interact_with_chatgpt(messages)
|
|||
|
|
|||
|
# 解析响应
|
|||
|
response_text, commands = parse_chatgpt_response(chatgpt_response)
|
|||
|
|
|||
|
# 记录响应
|
|||
|
conversation_logger.info(f"Assistant: {response_text}")
|
|||
|
conversation_logger.info(f"Commands: {commands}")
|
|||
|
|
|||
|
# 添加助手响应到对话历史
|
|||
|
messages.append({
|
|||
|
"role": "assistant",
|
|||
|
"content": chatgpt_response
|
|||
|
})
|
|||
|
|
|||
|
return ChatResponse(
|
|||
|
response=response_text,
|
|||
|
commands=commands
|
|||
|
)
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
logger.error(f"处理请求时出错: {e}")
|
|||
|
raise HTTPException(status_code=500, detail=str(e))
|
|||
|
|
|||
|
@app.get("/conversations/{conversation_id}")
|
|||
|
async def get_conversation(conversation_id: str):
|
|||
|
"""获取对话历史"""
|
|||
|
if conversation_id not in conversations:
|
|||
|
raise HTTPException(status_code=404, detail="Conversation not found")
|
|||
|
return conversations[conversation_id]
|
|||
|
|
|||
|
@app.delete("/conversations/{conversation_id}")
|
|||
|
async def clear_conversation(conversation_id: str):
|
|||
|
"""清除对话历史"""
|
|||
|
if conversation_id in conversations:
|
|||
|
# 保留系统提示
|
|||
|
system_prompt = conversations[conversation_id][0]
|
|||
|
conversations[conversation_id] = [system_prompt]
|
|||
|
return {"status": "success"}
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
import uvicorn
|
|||
|
uvicorn.run(app, host="0.0.0.0", port=11089)
|