import logging from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import List, Optional from openai import OpenAI import json import os from rpa_client import rpa_prompt, api_key, base_url # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s' ) logger = logging.getLogger(__name__) # 配置对话日志 conversation_logger = logging.getLogger('conversation_logger') conversation_logger.setLevel(logging.INFO) file_handler = logging.FileHandler('rpa_conversation.log') file_handler.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') file_handler.setFormatter(formatter) conversation_logger.addHandler(file_handler) # 初始化 FastAPI 应用 app = FastAPI(title="RPA Chat API") # 配置 CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # 初始化 OpenAI 客户端 client = OpenAI(api_key=api_key, base_url=base_url) # 请求模型 class ChatRequest(BaseModel): message: str allowed_paths: Optional[str] = None # 响应模型 class ChatResponse(BaseModel): response: str commands: List[dict] # 存储对话历史的字典 conversations = {} def get_conversation_history(conversation_id: str) -> list: """获取或创建对话历史""" if conversation_id not in conversations: # 初始化新的对话,包含系统提示 conversations[conversation_id] = [ { "role": "system", "content": rpa_prompt } ] return conversations[conversation_id] async def interact_with_chatgpt(messages: list) -> str: """与 ChatGPT 交互""" try: response = client.chat.completions.create( model="gpt-4o-mini", messages=messages, max_tokens=500, temperature=0.1, ) return response.choices[0].message.content except Exception as e: logger.error(f"OpenAI API 调用错误: {e}") raise def parse_chatgpt_response(response_text: str) -> tuple[str, list]: """解析 ChatGPT 的响应""" try: # 清理响应文本 response_text = response_text.strip() if response_text.startswith("```json"): response_text = response_text.replace("```json", "", 1) if response_text.endswith("```"): response_text = response_text.replace("```", "", 1) # 解析 JSON data = json.loads(response_text.strip()) return data["response"], data.get("commands", []) except Exception as e: logger.error(f"解析响应失败: {e}") return "解析响应时出错", [] @app.post("/chat", response_model=ChatResponse) async def chat(request: ChatRequest): """处理聊天请求""" try: # 生成会话ID (这里简单使用固定ID,您可以根据需要修改) conversation_id = "default" # 获取对话历史 messages = get_conversation_history(conversation_id) # # 如果提供了新的允许路径,更新系统提示 # if request.allowed_paths: # messages[0]["content"] = rpa_prompt.replace( # "{allowed_paths}", # request.allowed_paths # ) # 添加用户消息 messages.append({ "role": "user", "content": request.message }) # 保持对话历史在合理长度 if len(messages) > 10: # 保留系统消息和最近的对话 messages = [messages[0]] + messages[-9:] # 记录对话 conversation_logger.info(f"User: {request.message}") # 调用 ChatGPT chatgpt_response = await interact_with_chatgpt(messages) # 解析响应 response_text, commands = parse_chatgpt_response(chatgpt_response) # 记录响应 conversation_logger.info(f"Assistant: {response_text}") conversation_logger.info(f"Commands: {commands}") # 添加助手响应到对话历史 messages.append({ "role": "assistant", "content": chatgpt_response }) return ChatResponse( response=response_text, commands=commands ) except Exception as e: logger.error(f"处理请求时出错: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.get("/conversations/{conversation_id}") async def get_conversation(conversation_id: str): """获取对话历史""" if conversation_id not in conversations: raise HTTPException(status_code=404, detail="Conversation not found") return conversations[conversation_id] @app.delete("/conversations/{conversation_id}") async def clear_conversation(conversation_id: str): """清除对话历史""" if conversation_id in conversations: # 保留系统提示 system_prompt = conversations[conversation_id][0] conversations[conversation_id] = [system_prompt] return {"status": "success"} if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=11089)