AgentOccam/Agent_E/ae/server/api_routes.py
2025-01-22 11:32:35 -08:00

192 lines
8.8 KiB
Python

import asyncio
import json
import logging
import os
import uuid
from queue import Empty
from queue import Queue
from typing import Any
import uvicorn
from fastapi import FastAPI
from fastapi import Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from pydantic import Field
import Agent_E.ae.core.playwright_manager as browserManager
from Agent_E.ae.config import SOURCE_LOG_FOLDER_PATH
from Agent_E.ae.core.agents_llm_config import AgentsLLMConfig
from Agent_E.ae.core.autogen_wrapper import AutogenWrapper
from Agent_E.ae.utils.formatting_helper import is_terminating_message
from Agent_E.ae.utils.ui_messagetype import MessageType
browser_manager = browserManager.PlaywrightManager(headless=False)
APP_VERSION = "1.0.0"
APP_NAME = "Agent-E Web API"
API_PREFIX = "/api"
IS_DEBUG = False
HOST = os.getenv("HOST", "0.0.0.0")
PORT = int(os.getenv("PORT", 8080))
WORKERS = 1
container_id = os.getenv("CONTAINER_ID", "")
# Configure logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger("uvicorn")
class CommandQueryModel(BaseModel):
command: str = Field(..., description="The command related to web navigation to execute.") # Required field with description
llm_config: dict[str,Any] | None = Field(None, description="The LLM configuration string to use for the agents.")
planner_max_chat_round: int = Field(50, description="The maximum number of chat rounds for the planner.")
browser_nav_max_chat_round: int = Field(10, description="The maximum number of chat rounds for the browser navigation agent.")
clientid: str | None = Field(None, description="Client identifier, optional")
request_originator: str | None = Field(None, description="Optional id of the request originator")
def get_app() -> FastAPI:
"""Starts the Application"""
fast_app = FastAPI(title=APP_NAME, version=APP_VERSION, debug=IS_DEBUG)
fast_app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
return fast_app
app = get_app()
@app.on_event("startup") # type: ignore
async def startup_event():
"""
Startup event handler to initialize browser manager asynchronously.
"""
global container_id
if container_id.strip() == "":
container_id = str(uuid.uuid4())
os.environ["CONTAINER_ID"] = container_id
await browser_manager.async_initialize()
@app.post("/execute_task", description="Execute a given command related to web navigation and return the result.")
async def execute_task(request: Request, query_model: CommandQueryModel):
notification_queue = Queue() # type: ignore
transaction_id = str(uuid.uuid4()) if query_model.clientid is None else query_model.clientid
register_notification_listener(notification_queue)
return StreamingResponse(run_task(request, transaction_id, query_model.command, browser_manager, notification_queue, query_model.request_originator,query_model.llm_config,
planner_max_chat_round=query_model.planner_max_chat_round,
browser_nav_max_chat_round=query_model.browser_nav_max_chat_round), media_type="text/event-stream")
def run_task(request: Request, transaction_id: str, command: str, playwright_manager: browserManager.PlaywrightManager, notification_queue: Queue, request_originator: str|None = None, llm_config: dict[str,Any]|None = None, # type: ignore
planner_max_chat_round: int = 50, browser_nav_max_chat_round: int = 10):
"""
Run the task to process the command and generate events.
Args:
request (Request): The request object to detect client disconnect.
transaction_id (str): The transaction ID to identify the request.
command (str): The command to execute.
playwright_manager (PlaywrightManager): The manager handling browser interactions and notifications.
notification_queue (Queue): The queue to hold notifications for this request.
request_originator (str|None): The originator of the request.
llm_config (dict[str,Any]|None): The LLM configuration to use for the agents.
planner_max_chat_rounds (int, optional): The maximum number of chat rounds for the planner. Defaults to 50.
browser_nav_max_chat_round (int, optional): The maximum number of chat rounds for the browser navigation agent. Defaults to 10.
Yields:
str: JSON-encoded string representing a notification.
"""
async def event_generator():
task = asyncio.create_task(process_command(command, playwright_manager, planner_max_chat_round, browser_nav_max_chat_round, llm_config))
task_detail = f"transaction_id={transaction_id}, request_originator={request_originator}, command={command}"
try:
while not task.done() or not notification_queue.empty():
if await request.is_disconnected():
logger.info(f"Client disconnected. Cancelling the task: {task_detail}")
task.cancel()
break
try:
notification = notification_queue.get_nowait() # type: ignore
notification["transaction_id"] = transaction_id # Include the transaction ID in the notification
notification["request_originator"] = request_originator # Include the request originator in the notification
yield f"data: {json.dumps(notification)}\n\n" # Using 'data: ' to follow the SSE format
except Empty:
await asyncio.sleep(0.1)
except asyncio.CancelledError:
logger.info(f"Task was cancelled due to client disconnection. {task_detail}")
except Exception as e:
logger.error(f"An error occurred while processing task: {task_detail}. Error: {e}")
await task
except asyncio.CancelledError:
logger.info(f"Task was cancelled due to client disconnection. {task_detail}")
await task
return event_generator()
async def process_command(command: str, playwright_manager: browserManager.PlaywrightManager, planner_max_chat_round: int, browser_nav_max_chat_round: int, llm_config:dict[str,Any]|None = None):
"""
Process the command and send notifications.
Args:
command (str): The command to process.
playwright_manager (PlaywrightManager): The manager handling browser interactions and notifications.
"""
await playwright_manager.go_to_homepage() # Go to the homepage before processing the command
current_url = await playwright_manager.get_current_url()
await playwright_manager.notify_user("Processing command", MessageType.INFO)
# Load the configuration using AgentsLLMConfig
normalized_llm_config = None
if llm_config is None:
normalized_llm_config = AgentsLLMConfig()
else:
normalized_llm_config = AgentsLLMConfig(llm_config=llm_config)
logger.info("Applied LLM config received via API.")
# Retrieve planner agent and browser nav agent configurations
planner_agent_config = normalized_llm_config.get_planner_agent_config()
browser_nav_agent_config = normalized_llm_config.get_browser_nav_agent_config()
ag = await AutogenWrapper.create(planner_agent_config, browser_nav_agent_config, planner_max_chat_round=planner_max_chat_round,
browser_nav_max_chat_round=browser_nav_max_chat_round)
command_exec_result = await ag.process_command(command, current_url) # type: ignore
messages=ag.agents_map["planner_agent"].chat_messages
messages_str_keys = {str(key): value for key, value in messages.items()} # type: ignore
with open(os.path.join(SOURCE_LOG_FOLDER_PATH, 'chat_messages.json'), 'w', encoding='utf-8') as f:
json.dump(messages_str_keys, f, ensure_ascii=False, indent=4)
logger.debug("Chat messages saved")
if is_terminating_message(command_exec_result.summary):
await playwright_manager.notify_user("DONE", MessageType.DONE)
else:
await playwright_manager.notify_user("Max turns reached", MessageType.MAX_TURNS_REACHED)
def register_notification_listener(notification_queue: Queue): # type: ignore
"""
Register the event generator as a listener in the NotificationManager.
"""
def listener(notification: dict[str, str]) -> None:
notification["container_id"] = container_id # Include the container ID (or UUID) in the notification
notification_queue.put(notification) # type: ignore
browser_manager.notification_manager.register_listener(listener)
if __name__ == "__main__":
logger.info("**********Application Started**********")
uvicorn.run("main:app", host=HOST, port=PORT, workers=WORKERS, reload=IS_DEBUG, log_level="info")