import argparse import json from typing import Any, Optional import tiktoken from beartype import beartype from PIL import Image from agent.prompts import * from browser_env import Trajectory from browser_env.actions import ( Action, ActionParsingError, create_id_based_action, create_none_action, create_playwright_action, ) from browser_env.utils import Observation, StateInfo from llms import ( call_llm, generate_from_huggingface_completion, generate_from_openai_chat_completion, generate_from_openai_completion, lm_config, ) from llms.tokenizers import Tokenizer class Agent: """Base class for the agent""" def __init__(self, *args: Any) -> None: pass def next_action( self, trajectory: Trajectory, intent: str, meta_data: Any ) -> Action: """Predict the next action given the observation""" raise NotImplementedError def reset( self, test_config_file: str, ) -> None: raise NotImplementedError class TeacherForcingAgent(Agent): """Agent that follows a pre-defined action sequence""" def __init__(self) -> None: super().__init__() def set_action_set_tag(self, tag: str) -> None: self.action_set_tag = tag def set_actions(self, action_seq: str | list[str]) -> None: if isinstance(action_seq, str): action_strs = action_seq.strip().split("\n") else: action_strs = action_seq action_strs = [a.strip() for a in action_strs] actions = [] for a_str in action_strs: try: if self.action_set_tag == "playwright": cur_action = create_playwright_action(a_str) elif self.action_set_tag == "id_accessibility_tree": cur_action = create_id_based_action(a_str) else: raise ValueError( f"Unknown action type {self.action_set_tag}" ) except ActionParsingError as e: cur_action = create_none_action() cur_action["raw_prediction"] = a_str actions.append(cur_action) self.actions: list[Action] = actions def next_action( self, trajectory: Trajectory, intent: str, meta_data: Any ) -> Action: """Predict the next action given the observation""" return self.actions.pop(0) def reset( self, test_config_file: str, ) -> None: with open(test_config_file) as f: ref_actions = json.load(f)["reference_action_sequence"] tag = ref_actions["action_set_tag"] action_seq = ref_actions["action_sequence"] self.set_action_set_tag(tag) self.set_actions(action_seq) class PromptAgent(Agent): """prompt-based agent that emits action given the history""" @beartype def __init__( self, action_set_tag: str, lm_config: lm_config.LMConfig, prompt_constructor: PromptConstructor, captioning_fn = None, ) -> None: super().__init__() self.lm_config = lm_config self.prompt_constructor = prompt_constructor self.action_set_tag = action_set_tag self.captioning_fn = captioning_fn # Check if the model is multimodal. if ("gemini" in lm_config.model or "gpt-4" in lm_config.model and "vision" in lm_config.model or lm_config.provider in ["api", "finetune"]) and type(prompt_constructor) == MultimodalCoTPromptConstructor: self.multimodal_inputs = True else: self.multimodal_inputs = False def set_action_set_tag(self, tag: str) -> None: self.action_set_tag = tag @beartype def next_action( self, trajectory: Trajectory, intent: str, meta_data: dict[str, Any], images: Optional[list[Image.Image]] = None, output_response: bool = False ) -> Action: # Create page screenshot image for multimodal models. if self.multimodal_inputs: page_screenshot_arr = trajectory[-1]["observation"]["image"] page_screenshot_img = Image.fromarray( page_screenshot_arr ) # size = (viewport_width, viewport_width) # Caption the input image, if provided. if images is not None and len(images) > 0: if self.captioning_fn is not None: image_input_caption = "" for image_i, image in enumerate(images): if image_i == 0: image_input_caption += f'Input image {image_i+1}: "{self.captioning_fn([image])[0]}"' else: image_input_caption += f'input image {image_i+1}: "{self.captioning_fn([image])[0]}"' if len(images) > 1: image_input_caption += ", " # Update intent to include captions of input images. intent = f"{image_input_caption}\nIntent: {intent}" elif not self.multimodal_inputs: print( "WARNING: Input image provided but no image captioner available." ) if self.multimodal_inputs: prompt = self.prompt_constructor.construct( trajectory, intent, page_screenshot_img, images, meta_data ) else: prompt = self.prompt_constructor.construct( trajectory, intent, meta_data ) lm_config = self.lm_config n = 0 while True: response = call_llm(lm_config, prompt) force_prefix = self.prompt_constructor.instruction[ "meta_data" ].get("force_prefix", "") response = f"{force_prefix}{response}" if output_response: print(f'Agent: {response}', flush=True) n += 1 try: parsed_response = self.prompt_constructor.extract_action( response ) if self.action_set_tag == "id_accessibility_tree": action = create_id_based_action(parsed_response) elif self.action_set_tag == "playwright": action = create_playwright_action(parsed_response) elif self.action_set_tag == "som": action = create_id_based_action(parsed_response) else: raise ValueError( f"Unknown action type {self.action_set_tag}" ) action["raw_prediction"] = response break except ActionParsingError as e: if n >= lm_config.gen_config["max_retry"]: action = create_none_action() action["raw_prediction"] = response break return action def reset(self, test_config_file: str) -> None: pass def construct_agent(args: argparse.Namespace, captioning_fn=None) -> Agent: llm_config = lm_config.construct_llm_config(args) agent: Agent if args.agent_type == "teacher_forcing": agent = TeacherForcingAgent() elif args.agent_type == "prompt": with open(args.instruction_path) as f: constructor_type = json.load(f)["meta_data"]["prompt_constructor"] tokenizer = Tokenizer(args.provider, args.model) prompt_constructor = eval(constructor_type)( args.instruction_path, lm_config=llm_config, tokenizer=tokenizer ) agent = PromptAgent( action_set_tag=args.action_set_tag, lm_config=llm_config, prompt_constructor=prompt_constructor, captioning_fn=captioning_fn ) else: raise NotImplementedError( f"agent type {args.agent_type} not implemented" ) return agent