webrl/VAB-WebArena-Lite/agent/agent.py
2025-04-23 17:01:18 +08:00

237 lines
8.2 KiB
Python

import argparse
import json
from typing import Any, Optional
import tiktoken
from beartype import beartype
from PIL import Image
from agent.prompts import *
from browser_env import Trajectory
from browser_env.actions import (
Action,
ActionParsingError,
create_id_based_action,
create_none_action,
create_playwright_action,
create_webrl_id_based_action
)
from browser_env.utils import Observation, StateInfo
from llms import (
call_llm,
generate_from_huggingface_completion,
generate_from_openai_chat_completion,
generate_from_openai_completion,
lm_config,
)
from llms.tokenizers import Tokenizer
class Agent:
"""Base class for the agent"""
def __init__(self, *args: Any) -> None:
pass
def next_action(
self, trajectory: Trajectory, intent: str, meta_data: Any
) -> Action:
"""Predict the next action given the observation"""
raise NotImplementedError
def reset(
self,
test_config_file: str,
) -> None:
raise NotImplementedError
class TeacherForcingAgent(Agent):
"""Agent that follows a pre-defined action sequence"""
def __init__(self) -> None:
super().__init__()
def set_action_set_tag(self, tag: str) -> None:
self.action_set_tag = tag
def set_actions(self, action_seq: str | list[str]) -> None:
if isinstance(action_seq, str):
action_strs = action_seq.strip().split("\n")
else:
action_strs = action_seq
action_strs = [a.strip() for a in action_strs]
actions = []
for a_str in action_strs:
try:
if self.action_set_tag == "playwright":
cur_action = create_playwright_action(a_str)
elif self.action_set_tag == "id_accessibility_tree":
cur_action = create_id_based_action(a_str)
else:
raise ValueError(
f"Unknown action type {self.action_set_tag}"
)
except ActionParsingError as e:
cur_action = create_none_action()
cur_action["raw_prediction"] = a_str
actions.append(cur_action)
self.actions: list[Action] = actions
def next_action(
self, trajectory: Trajectory, intent: str, meta_data: Any
) -> Action:
"""Predict the next action given the observation"""
return self.actions.pop(0)
def reset(
self,
test_config_file: str,
) -> None:
with open(test_config_file) as f:
ref_actions = json.load(f)["reference_action_sequence"]
tag = ref_actions["action_set_tag"]
action_seq = ref_actions["action_sequence"]
self.set_action_set_tag(tag)
self.set_actions(action_seq)
class PromptAgent(Agent):
"""prompt-based agent that emits action given the history"""
@beartype
def __init__(
self,
action_set_tag: str,
lm_config: lm_config.LMConfig,
prompt_constructor: PromptConstructor,
captioning_fn = None,
planner_ip = None
) -> None:
super().__init__()
self.lm_config = lm_config
self.prompt_constructor = prompt_constructor
self.action_set_tag = action_set_tag
self.captioning_fn = captioning_fn
self.planner_ip = planner_ip
# Check if the model is multimodal.
if ("gemini" in lm_config.model or "gpt-4" in lm_config.model and "vision" in lm_config.model or lm_config.provider in ["api", "finetune"]) and type(prompt_constructor) == MultimodalCoTPromptConstructor:
self.multimodal_inputs = True
else:
self.multimodal_inputs = False
def set_action_set_tag(self, tag: str) -> None:
self.action_set_tag = tag
@beartype
def next_action(
self, trajectory: Trajectory, intent: str, meta_data: dict[str, Any], images: Optional[list[Image.Image]] = None,
output_response: bool = False
) -> Action:
# Create page screenshot image for multimodal models.
if self.multimodal_inputs:
page_screenshot_arr = trajectory[-1]["observation"]["image"]
page_screenshot_img = Image.fromarray(
page_screenshot_arr
) # size = (viewport_width, viewport_width)
# Caption the input image, if provided.
if images is not None and len(images) > 0:
if self.captioning_fn is not None:
image_input_caption = ""
for image_i, image in enumerate(images):
if image_i == 0:
image_input_caption += f'Input image {image_i+1}: "{self.captioning_fn([image])[0]}"'
else:
image_input_caption += f'input image {image_i+1}: "{self.captioning_fn([image])[0]}"'
if len(images) > 1:
image_input_caption += ", "
# Update intent to include captions of input images.
intent = f"{image_input_caption}\nIntent: {intent}"
elif not self.multimodal_inputs:
print(
"WARNING: Input image provided but no image captioner available."
)
if self.multimodal_inputs:
prompt = self.prompt_constructor.construct(
trajectory, intent, page_screenshot_img, images, meta_data
)
else:
prompt = self.prompt_constructor.construct(
trajectory, intent, meta_data
)
lm_config = self.lm_config
n = 0
while True:
if self.planner_ip is not None and self.planner_ip != "":
response = call_llm(lm_config, prompt, 'EMPTY', self.planner_ip)
else:
response = call_llm(lm_config, prompt)
force_prefix = self.prompt_constructor.instruction[
"meta_data"
].get("force_prefix", "")
response = f"{force_prefix}{response}"
if output_response:
print(f'Agent: {response}', flush=True)
n += 1
try:
parsed_response = self.prompt_constructor.extract_action(
response
)
if self.action_set_tag == "id_accessibility_tree":
action = create_id_based_action(parsed_response)
elif self.action_set_tag == "playwright":
action = create_playwright_action(parsed_response)
elif self.action_set_tag == "som":
action = create_id_based_action(parsed_response)
elif self.action_set_tag == 'webrl_id':
action = create_webrl_id_based_action(parsed_response)
else:
raise ValueError(
f"Unknown action type {self.action_set_tag}"
)
action["raw_prediction"] = response
break
except ActionParsingError as e:
if n >= lm_config.gen_config["max_retry"]:
action = create_none_action()
action["raw_prediction"] = response
break
return action
def reset(self, test_config_file: str) -> None:
pass
def construct_agent(args: argparse.Namespace, captioning_fn=None) -> Agent:
llm_config = lm_config.construct_llm_config(args)
agent: Agent
if args.agent_type == "teacher_forcing":
agent = TeacherForcingAgent()
elif args.agent_type == "prompt":
with open(args.instruction_path) as f:
constructor_type = json.load(f)["meta_data"]["prompt_constructor"]
tokenizer = Tokenizer(args.provider, args.model)
prompt_constructor = eval(constructor_type)(
args.instruction_path, lm_config=llm_config, tokenizer=tokenizer
)
agent = PromptAgent(
action_set_tag=args.action_set_tag,
lm_config=llm_config,
prompt_constructor=prompt_constructor,
captioning_fn=captioning_fn,
planner_ip=args.planner_ip
)
else:
raise NotImplementedError(
f"agent type {args.agent_type} not implemented"
)
return agent