JIAP/rpa_api.py
2024-12-19 10:15:23 +08:00

172 lines
5.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional
from openai import OpenAI
import json
import os
from rpa_client import rpa_prompt, api_key, base_url
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(message)s'
)
logger = logging.getLogger(__name__)
# 配置对话日志
conversation_logger = logging.getLogger('conversation_logger')
conversation_logger.setLevel(logging.INFO)
file_handler = logging.FileHandler('rpa_conversation.log')
file_handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
conversation_logger.addHandler(file_handler)
# 初始化 FastAPI 应用
app = FastAPI(title="RPA Chat API")
# 配置 CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 初始化 OpenAI 客户端
client = OpenAI(api_key=api_key, base_url=base_url)
# 请求模型
class ChatRequest(BaseModel):
message: str
allowed_paths: Optional[str] = None
# 响应模型
class ChatResponse(BaseModel):
response: str
commands: List[dict]
# 存储对话历史的字典
conversations = {}
def get_conversation_history(conversation_id: str) -> list:
"""获取或创建对话历史"""
if conversation_id not in conversations:
# 初始化新的对话,包含系统提示
conversations[conversation_id] = [
{
"role": "system",
"content": rpa_prompt
}
]
return conversations[conversation_id]
async def interact_with_chatgpt(messages: list) -> str:
"""与 ChatGPT 交互"""
try:
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=messages,
max_tokens=500,
temperature=0.1,
)
return response.choices[0].message.content
except Exception as e:
logger.error(f"OpenAI API 调用错误: {e}")
raise
def parse_chatgpt_response(response_text: str) -> tuple[str, list]:
"""解析 ChatGPT 的响应"""
try:
# 清理响应文本
response_text = response_text.strip()
if response_text.startswith("```json"):
response_text = response_text.replace("```json", "", 1)
if response_text.endswith("```"):
response_text = response_text.replace("```", "", 1)
# 解析 JSON
data = json.loads(response_text.strip())
return data["response"], data.get("commands", [])
except Exception as e:
logger.error(f"解析响应失败: {e}")
return "解析响应时出错", []
@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
"""处理聊天请求"""
try:
# 生成会话ID (这里简单使用固定ID您可以根据需要修改)
conversation_id = "default"
# 获取对话历史
messages = get_conversation_history(conversation_id)
# # 如果提供了新的允许路径,更新系统提示
# if request.allowed_paths:
# messages[0]["content"] = rpa_prompt.replace(
# "{allowed_paths}",
# request.allowed_paths
# )
# 添加用户消息
messages.append({
"role": "user",
"content": request.message
})
# 保持对话历史在合理长度
if len(messages) > 10:
# 保留系统消息和最近的对话
messages = [messages[0]] + messages[-9:]
# 记录对话
conversation_logger.info(f"User: {request.message}")
# 调用 ChatGPT
chatgpt_response = await interact_with_chatgpt(messages)
# 解析响应
response_text, commands = parse_chatgpt_response(chatgpt_response)
# 记录响应
conversation_logger.info(f"Assistant: {response_text}")
conversation_logger.info(f"Commands: {commands}")
# 添加助手响应到对话历史
messages.append({
"role": "assistant",
"content": chatgpt_response
})
return ChatResponse(
response=response_text,
commands=commands
)
except Exception as e:
logger.error(f"处理请求时出错: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/conversations/{conversation_id}")
async def get_conversation(conversation_id: str):
"""获取对话历史"""
if conversation_id not in conversations:
raise HTTPException(status_code=404, detail="Conversation not found")
return conversations[conversation_id]
@app.delete("/conversations/{conversation_id}")
async def clear_conversation(conversation_id: str):
"""清除对话历史"""
if conversation_id in conversations:
# 保留系统提示
system_prompt = conversations[conversation_id][0]
conversations[conversation_id] = [system_prompt]
return {"status": "success"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=11089)