map网站使用socks5连接到官网openstreetmap; 模型改成gpt-4o; llm fuzzy math也使用4o; 增加generate_test_data.py脚本
143 lines
5.8 KiB
Python
143 lines
5.8 KiB
Python
import os
|
|
import time
|
|
import re
|
|
import argparse
|
|
import os
|
|
import shutil
|
|
|
|
from AgentOccam.env import WebArenaEnvironmentWrapper
|
|
|
|
from AgentOccam.AgentOccam import AgentOccam
|
|
from webagents_step.utils.data_prep import *
|
|
from webagents_step.agents.step_agent import StepAgent
|
|
|
|
from AgentOccam.prompts import AgentOccam_prompt
|
|
from webagents_step.prompts.webarena import step_fewshot_template_adapted, step_fewshot_template
|
|
|
|
from AgentOccam.utils import EVALUATOR_DIR
|
|
|
|
def run():
|
|
parser = argparse.ArgumentParser(
|
|
description="Only the config file argument should be passed"
|
|
)
|
|
parser.add_argument(
|
|
"--config", type=str, required=True, help="yaml config file location"
|
|
)
|
|
args = parser.parse_args()
|
|
with open(args.config, "r") as file:
|
|
config = DotDict(yaml.safe_load(file))
|
|
|
|
if config.logging:
|
|
if config.logname:
|
|
dstdir = f"{config.logdir}/{config.logname}"
|
|
else:
|
|
dstdir = f"{config.logdir}/{time.strftime('%Y%m%d-%H%M%S')}"
|
|
os.makedirs(dstdir, exist_ok=True)
|
|
shutil.copyfile(args.config, os.path.join(dstdir, args.config.split("/")[-1]))
|
|
random.seed(42)
|
|
|
|
config_file_list = []
|
|
|
|
task_ids = config.env.task_ids
|
|
if hasattr(config.env, "relative_task_dir"):
|
|
relative_task_dir = config.env.relative_task_dir
|
|
else:
|
|
relative_task_dir = "tasks"
|
|
if task_ids == "all" or task_ids == ["all"]:
|
|
task_ids = [filename[:-len(".json")] for filename in os.listdir(f"config_files/{relative_task_dir}") if filename.endswith(".json")]
|
|
for task_id in task_ids:
|
|
config_file_list.append(f"config_files/{relative_task_dir}/{task_id}.json")
|
|
|
|
fullpage = config.env.fullpage if hasattr(config.env, "fullpage") else True
|
|
current_viewport_only = not fullpage
|
|
|
|
if config.agent.type == "AgentOccam":
|
|
agent_init = lambda: AgentOccam(
|
|
prompt_dict = {k: v for k, v in AgentOccam_prompt.__dict__.items() if isinstance(v, dict)},
|
|
config = config.agent,
|
|
)
|
|
elif config.agent.type == "AgentOccam-SteP":
|
|
agent_init = lambda: StepAgent(
|
|
root_action = config.agent.root_action,
|
|
action_to_prompt_dict = {k: v for k, v in step_fewshot_template_adapted.__dict__.items() if isinstance(v, dict)},
|
|
low_level_action_list = config.agent.low_level_action_list,
|
|
max_actions=config.env.max_env_steps,
|
|
verbose=config.verbose,
|
|
logging=config.logging,
|
|
debug=config.debug,
|
|
model=config.agent.model_name,
|
|
prompt_mode=config.agent.prompt_mode,
|
|
)
|
|
elif config.agent.type == "SteP-replication":
|
|
agent_init = lambda: StepAgent(
|
|
root_action = config.agent.root_action,
|
|
action_to_prompt_dict = {k: v for k, v in step_fewshot_template.__dict__.items() if isinstance(v, dict)},
|
|
low_level_action_list = config.agent.low_level_action_list,
|
|
max_actions=config.env.max_env_steps,
|
|
verbose=config.verbose,
|
|
logging=config.logging,
|
|
debug=config.debug,
|
|
model=config.agent.model_name,
|
|
prompt_mode=config.agent.prompt_mode,
|
|
)
|
|
else:
|
|
raise NotImplementedError(f"{config.agent.type} not implemented")
|
|
|
|
|
|
for config_file in config_file_list:
|
|
with open(config_file, "r") as f:
|
|
task_config = json.load(f)
|
|
print(f"Task {task_config['task_id']}.")
|
|
if os.path.exists(os.path.join(dstdir, f"{task_config['task_id']}.json")):
|
|
print(f"Skip {task_config['task_id']}.")
|
|
continue
|
|
if task_config['task_id'] in list(range(600, 650))+list(range(681, 689)):
|
|
print("Reddit post task. Sleep 30 mins.")
|
|
time.sleep(1800)
|
|
env = WebArenaEnvironmentWrapper(config_file=config_file,
|
|
max_browser_rows=config.env.max_browser_rows,
|
|
max_steps=config.max_steps,
|
|
slow_mo=1,
|
|
observation_type="accessibility_tree",
|
|
current_viewport_only=current_viewport_only,
|
|
viewport_size={"width": 1920, "height": 1080},
|
|
headless=config.env.headless,
|
|
global_config=config,
|
|
proxy_url=config.env.proxy_url)
|
|
|
|
agent = agent_init()
|
|
objective = env.get_objective()
|
|
status = agent.act(objective=objective, env=env)
|
|
env.close()
|
|
|
|
if config.logging:
|
|
with open(config_file, "r") as f:
|
|
task_config = json.load(f)
|
|
log_file = os.path.join(dstdir, f"{task_config['task_id']}.json")
|
|
log_data = {
|
|
"task": config_file,
|
|
"id": task_config['task_id'],
|
|
"model": config.agent.actor.model if hasattr(config.agent, "actor") else config.agent.model_name,
|
|
"type": config.agent.type,
|
|
"trajectory": agent.get_trajectory(),
|
|
}
|
|
summary_file = os.path.join(dstdir, "summary.csv")
|
|
summary_data = {
|
|
"task": config_file,
|
|
"task_id": task_config['task_id'],
|
|
"model": config.agent.actor.model if hasattr(config.agent, "actor") else config.agent.model_name,
|
|
"type": config.agent.type,
|
|
"logfile": re.search(r"/([^/]+/[^/]+\.json)$", log_file).group(1),
|
|
}
|
|
if status:
|
|
summary_data.update(status)
|
|
log_run(
|
|
log_file=log_file,
|
|
log_data=log_data,
|
|
summary_file=summary_file,
|
|
summary_data=summary_data,
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
run()
|