228 lines
7.3 KiB
Python
228 lines
7.3 KiB
Python
import argparse
|
|
import json
|
|
from typing import Any
|
|
|
|
import tiktoken
|
|
from beartype import beartype
|
|
|
|
from agent.prompts import *
|
|
from browser_env import Trajectory
|
|
from browser_env.actions import (
|
|
Action,
|
|
ActionParsingError,
|
|
create_id_based_action,
|
|
create_none_action,
|
|
create_playwright_action,
|
|
)
|
|
from browser_env.utils import Observation, StateInfo
|
|
from llms import (
|
|
call_llm,
|
|
generate_from_huggingface_completion,
|
|
generate_from_openai_chat_completion,
|
|
generate_from_openai_completion,
|
|
lm_config,
|
|
)
|
|
from llms.tokenizers import Tokenizer
|
|
|
|
|
|
class Agent:
|
|
"""Base class for the agent"""
|
|
|
|
def __init__(self, *args: Any) -> None:
|
|
pass
|
|
|
|
def next_action(
|
|
self, trajectory: Trajectory, intent: str, meta_data: Any
|
|
) -> Action:
|
|
"""Predict the next action given the observation"""
|
|
raise NotImplementedError
|
|
|
|
def check_action(
|
|
self, trajectory: Trajectory, intent: str, meta_data: dict[str, Any], target_action: str
|
|
) -> Action:
|
|
"""Predict the next action given the observation"""
|
|
raise NotImplementedError
|
|
|
|
def reset(
|
|
self,
|
|
test_config_file: str,
|
|
) -> None:
|
|
raise NotImplementedError
|
|
|
|
|
|
class TeacherForcingAgent(Agent):
|
|
"""Agent that follows a pre-defined action sequence"""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def set_action_set_tag(self, tag: str) -> None:
|
|
self.action_set_tag = tag
|
|
|
|
def set_actions(self, action_seq: str | list[str]) -> None:
|
|
if isinstance(action_seq, str):
|
|
action_strs = action_seq.strip().split("\n")
|
|
else:
|
|
action_strs = action_seq
|
|
action_strs = [a.strip() for a in action_strs]
|
|
|
|
actions = []
|
|
for a_str in action_strs:
|
|
try:
|
|
if self.action_set_tag == "playwright":
|
|
cur_action = create_playwright_action(a_str)
|
|
elif self.action_set_tag == "id_accessibility_tree":
|
|
cur_action = create_id_based_action(a_str)
|
|
else:
|
|
raise ValueError(
|
|
f"Unknown action type {self.action_set_tag}"
|
|
)
|
|
except ActionParsingError as e:
|
|
cur_action = create_none_action()
|
|
|
|
cur_action["raw_prediction"] = a_str
|
|
actions.append(cur_action)
|
|
|
|
self.actions: list[Action] = actions
|
|
|
|
def next_action(
|
|
self, trajectory: Trajectory, intent: str, meta_data: Any
|
|
) -> Action:
|
|
"""Predict the next action given the observation"""
|
|
return self.actions.pop(0)
|
|
|
|
def reset(
|
|
self,
|
|
test_config_file: str,
|
|
) -> None:
|
|
with open(test_config_file) as f:
|
|
ref_actions = json.load(f)["reference_action_sequence"]
|
|
tag = ref_actions["action_set_tag"]
|
|
action_seq = ref_actions["action_sequence"]
|
|
self.set_action_set_tag(tag)
|
|
self.set_actions(action_seq)
|
|
|
|
|
|
class PromptAgent(Agent):
|
|
"""prompt-based agent that emits action given the history"""
|
|
|
|
@beartype
|
|
def __init__(
|
|
self,
|
|
action_set_tag: str,
|
|
lm_config: lm_config.LMConfig,
|
|
prompt_constructor: PromptConstructor,
|
|
) -> None:
|
|
super().__init__()
|
|
self.lm_config = lm_config
|
|
self.prompt_constructor = prompt_constructor
|
|
self.action_set_tag = action_set_tag
|
|
|
|
def set_action_set_tag(self, tag: str) -> None:
|
|
self.action_set_tag = tag
|
|
|
|
@beartype
|
|
def next_action(
|
|
self, trajectory: Trajectory, intent: str, meta_data: dict[str, Any]
|
|
) -> Action:
|
|
prompt = self.prompt_constructor.construct(
|
|
trajectory, intent, meta_data
|
|
)
|
|
lm_config = self.lm_config
|
|
n = 0
|
|
while True:
|
|
response = call_llm(lm_config, prompt)
|
|
force_prefix = self.prompt_constructor.instruction[
|
|
"meta_data"
|
|
].get("force_prefix", "")
|
|
response = f"{force_prefix}{response}"
|
|
n += 1
|
|
try:
|
|
parsed_response = self.prompt_constructor.extract_action(
|
|
response
|
|
)
|
|
if self.action_set_tag in ["id_html_tree", "id_html_nasc_tree", "id_accessibility_tree"]:
|
|
action = create_id_based_action(parsed_response)
|
|
elif self.action_set_tag == "playwright":
|
|
action = create_playwright_action(parsed_response)
|
|
else:
|
|
raise ValueError(
|
|
f"Unknown action type {self.action_set_tag}"
|
|
)
|
|
action["raw_prediction"] = response
|
|
break
|
|
except ActionParsingError as e:
|
|
if n >= lm_config.gen_config["max_retry"]:
|
|
action = create_none_action()
|
|
action["raw_prediction"] = response
|
|
break
|
|
|
|
return action
|
|
|
|
def check_action(
|
|
self, trajectory: Trajectory, intent: str, meta_data: dict[str, Any], target_action: str
|
|
) -> Action:
|
|
prompt = self.prompt_constructor.construct(
|
|
trajectory, intent, meta_data
|
|
)
|
|
lm_config = self.lm_config
|
|
n = 0
|
|
|
|
# agent will retry if the action is not parsed correctly
|
|
while True:
|
|
response = target_action
|
|
force_prefix = self.prompt_constructor.instruction[
|
|
"meta_data"
|
|
].get("force_prefix", "")
|
|
response = f"{force_prefix}{response}"
|
|
n += 1
|
|
try:
|
|
parsed_response = self.prompt_constructor.extract_action(
|
|
response
|
|
)
|
|
if self.action_set_tag in ["id_accessibility_tree", "id_html_tree", "id_html_nasc_tree"]:
|
|
action = create_id_based_action(parsed_response)
|
|
elif self.action_set_tag == "playwright":
|
|
action = create_playwright_action(parsed_response)
|
|
else:
|
|
raise ValueError(
|
|
f"Unknown action type {self.action_set_tag}"
|
|
)
|
|
action["raw_prediction"] = response
|
|
break
|
|
except ActionParsingError as e:
|
|
if n >= lm_config.gen_config["max_retry"]:
|
|
action = create_none_action()
|
|
action["raw_prediction"] = response
|
|
break
|
|
|
|
return prompt, action
|
|
|
|
def reset(self, test_config_file: str) -> None:
|
|
pass
|
|
|
|
|
|
def construct_agent(args: argparse.Namespace) -> Agent:
|
|
llm_config = lm_config.construct_llm_config(args)
|
|
|
|
agent: Agent
|
|
if args.agent_type == "teacher_forcing":
|
|
agent = TeacherForcingAgent()
|
|
elif args.agent_type == "prompt":
|
|
with open(args.instruction_path) as f:
|
|
constructor_type = json.load(f)["meta_data"]["prompt_constructor"]
|
|
tokenizer = Tokenizer(args.provider, args.model)
|
|
prompt_constructor = eval(constructor_type)(
|
|
args.instruction_path, lm_config=llm_config, tokenizer=tokenizer
|
|
)
|
|
agent = PromptAgent(
|
|
action_set_tag=args.action_set_tag,
|
|
lm_config=llm_config,
|
|
prompt_constructor=prompt_constructor,
|
|
)
|
|
else:
|
|
raise NotImplementedError(
|
|
f"agent type {args.agent_type} not implemented"
|
|
)
|
|
return agent
|