add webrl mode

This commit is contained in:
QZH-777 2024-11-14 15:51:41 +08:00
parent 8d86a00e85
commit 521d7e999a
32 changed files with 5688 additions and 48 deletions

File diff suppressed because it is too large Load Diff

View File

@ -14,6 +14,7 @@ from browser_env.actions import (
create_id_based_action, create_id_based_action,
create_none_action, create_none_action,
create_playwright_action, create_playwright_action,
create_webrl_id_based_action
) )
from browser_env.utils import Observation, StateInfo from browser_env.utils import Observation, StateInfo
from llms import ( from llms import (
@ -108,12 +109,14 @@ class PromptAgent(Agent):
lm_config: lm_config.LMConfig, lm_config: lm_config.LMConfig,
prompt_constructor: PromptConstructor, prompt_constructor: PromptConstructor,
captioning_fn = None, captioning_fn = None,
planner_ip = None
) -> None: ) -> None:
super().__init__() super().__init__()
self.lm_config = lm_config self.lm_config = lm_config
self.prompt_constructor = prompt_constructor self.prompt_constructor = prompt_constructor
self.action_set_tag = action_set_tag self.action_set_tag = action_set_tag
self.captioning_fn = captioning_fn self.captioning_fn = captioning_fn
self.planner_ip = planner_ip
# Check if the model is multimodal. # Check if the model is multimodal.
if ("gemini" in lm_config.model or "gpt-4" in lm_config.model and "vision" in lm_config.model or lm_config.provider in ["api", "finetune"]) and type(prompt_constructor) == MultimodalCoTPromptConstructor: if ("gemini" in lm_config.model or "gpt-4" in lm_config.model and "vision" in lm_config.model or lm_config.provider in ["api", "finetune"]) and type(prompt_constructor) == MultimodalCoTPromptConstructor:
@ -165,7 +168,10 @@ class PromptAgent(Agent):
lm_config = self.lm_config lm_config = self.lm_config
n = 0 n = 0
while True: while True:
response = call_llm(lm_config, prompt) if self.planner_ip is not None and self.planner_ip != "":
response = call_llm(lm_config, prompt, 'EMPTY', self.planner_ip)
else:
response = call_llm(lm_config, prompt)
force_prefix = self.prompt_constructor.instruction[ force_prefix = self.prompt_constructor.instruction[
"meta_data" "meta_data"
].get("force_prefix", "") ].get("force_prefix", "")
@ -183,6 +189,8 @@ class PromptAgent(Agent):
action = create_playwright_action(parsed_response) action = create_playwright_action(parsed_response)
elif self.action_set_tag == "som": elif self.action_set_tag == "som":
action = create_id_based_action(parsed_response) action = create_id_based_action(parsed_response)
elif self.action_set_tag == 'webrl_id':
action = create_webrl_id_based_action(parsed_response)
else: else:
raise ValueError( raise ValueError(
f"Unknown action type {self.action_set_tag}" f"Unknown action type {self.action_set_tag}"
@ -218,7 +226,8 @@ def construct_agent(args: argparse.Namespace, captioning_fn=None) -> Agent:
action_set_tag=args.action_set_tag, action_set_tag=args.action_set_tag,
lm_config=llm_config, lm_config=llm_config,
prompt_constructor=prompt_constructor, prompt_constructor=prompt_constructor,
captioning_fn=captioning_fn captioning_fn=captioning_fn,
planner_ip=args.planner_ip
) )
else: else:
raise NotImplementedError( raise NotImplementedError(

View File

@ -0,0 +1,319 @@
import json
import os
import re
import subprocess
import time
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Union
import numpy as np
import numpy.typing as npt
import requests
from beartype import beartype
from gymnasium import Env
from gymnasium.spaces import Box, Text
from playwright.sync_api import (
CDPSession,
Page,
Playwright,
ViewportSize,
expect,
sync_playwright,
)
DATASET = os.environ["DATASET"]
if DATASET == "visualwebarena":
from browser_env.env_config import (
CLASSIFIEDS,
CLASSIFIEDS_RESET_TOKEN,
)
from .actions import Action, execute_action, get_action_space, execute_action_webrl
from .processors import ObservationHandler, ObservationMetadata
from .utils import (
AccessibilityTree,
DetachedPage,
Observation,
png_bytes_to_numpy,
)
@dataclass
class PlaywrightScript:
function: str # goto, get_by_role
destination: str # https://www.google.com/, combobox
name: str | None = None # Search, Avatar 2009
operation: str | None = None # click, fill, press
value: str | None = None # avatar movie, Enter
def parse_action(action: str) -> PlaywrightScript:
splitted = action.strip().split(" ")
assert len(splitted) >= 2
match splitted[:2]:
case ["goto", url]:
assert len(splitted) == 2
return PlaywrightScript("goto", url)
case ["get_by_role", destination]:
assert len(splitted) >= 4
match splitted[2:]:
case [name, operation]:
return PlaywrightScript(
"get_by_role", destination, name, operation
)
case [name, operation, value]:
return PlaywrightScript(
"get_by_role", destination, name, operation, value
)
case _:
raise ValueError("Invalid action")
case _:
raise ValueError(f"Invalid action {action}")
class ScriptBrowserEnv(Env[dict[str, Observation], Action]):
"""
The goal of this environment is to produce a prototype of a browser environment.
In the end, we want to support a fully configurable browser environment with wide
range of action spaces and observation spaces, both structured and unstructured.
But in this prototype, we just support action space specified by Playwright script,
and observation space is the html content of the page.
"""
@beartype
def __init__(
self,
max_page_length: int = 8192,
headless: bool = True,
slow_mo: int = 0,
observation_type: str = "html",
current_viewport_only: bool = False,
viewport_size: ViewportSize = {"width": 1280, "height": 720},
save_trace_enabled: bool = False,
sleep_after_execution: float = 0.0,
captioning_fn=None,
):
# TODO: make Space[Action] = ActionSpace
self.action_space = get_action_space() # type: ignore[assignment]
self.headless = headless
self.slow_mo = slow_mo
self.current_viewport_only = current_viewport_only
self.reset_finished = False
self.viewport_size = viewport_size
self.save_trace_enabled = save_trace_enabled
self.sleep_after_execution = sleep_after_execution
match observation_type:
case "html" | "accessibility_tree" | "accessibility_tree_with_captioner" | "webrl":
self.text_observation_type = observation_type
self.image_observation_type = ""
self.main_observation_type = "text"
case "image":
self.image_observation_type = observation_type
self.text_observation_type = "" # type: ignore[assignment]
self.main_observation_type = "image"
case "image_som":
self.image_observation_type = observation_type
self.text_observation_type = observation_type # type: ignore[assignment]
self.main_observation_type = "image"
case _:
raise ValueError(
f"Unsupported observation type: {observation_type}"
)
self.observation_handler = ObservationHandler(
self.main_observation_type,
self.text_observation_type,
self.image_observation_type,
self.current_viewport_only,
self.viewport_size,
captioning_fn,
)
self.observation_space = (
self.observation_handler.get_observation_space()
)
@beartype
def setup(self, config_file: Path | None = None) -> None:
self.context_manager = sync_playwright()
self.playwright = self.context_manager.__enter__()
self.browser = self.playwright.chromium.launch(
headless=self.headless, slow_mo=self.slow_mo
)
if config_file:
with open(config_file, "r") as f:
instance_config = json.load(f)
else:
instance_config = {}
# Reset site if needed. Currently only supported for Classifieds.
# TODO(jykoh): Add reset functionality for Shopping/Reddit.
if instance_config.get("require_reset", False):
if "classifieds" in instance_config["sites"]:
# Send POST request to __CLASSIFIEDS__/index.php?page=reset with token=CLASSIFIEDS_TOKEN
response = requests.post(
f"{CLASSIFIEDS}/index.php?page=reset",
data={"token": CLASSIFIEDS_RESET_TOKEN},
)
# Check if the request was successful
if response.status_code == 200:
print("Reset Classifieds site.")
else:
print(
"Failed to reset Classifieds site:",
response.status_code,
)
else:
print(
"WARNING: Reset is not supported for this site. Please manually reset the site."
)
storage_state = instance_config.get("storage_state", None)
start_url = instance_config.get("start_url", None)
geolocation = instance_config.get("geolocation", None)
# Use custom viewport size if specified in the config, otherwise use the default.
viewport_size = self.viewport_size.copy()
viewport_size.update(instance_config.get("viewport_size", {}))
self.observation_handler.viewport_size = viewport_size
self.context = self.browser.new_context(
viewport=viewport_size,
storage_state=storage_state,
geolocation=geolocation,
device_scale_factor=1,
)
if self.save_trace_enabled:
self.context.tracing.start(screenshots=True, snapshots=True)
if start_url:
start_urls = start_url.split(" |AND| ")
for url in start_urls:
page = self.context.new_page()
if self.text_observation_type in [
"accessibility_tree",
"accessibility_tree_with_captioner",
]:
client = page.context.new_cdp_session(page)
client.send("Accessibility.enable")
client.detach()
page.goto(url)
# set the first page as the current page
self.page = self.context.pages[0]
self.page.bring_to_front()
else:
self.page = self.context.new_page()
if self.text_observation_type in [
"accessibility_tree",
"accessibility_tree_with_captioner",
]:
client = self.page.context.new_cdp_session(self.page)
client.send("Accessibility.enable")
client.detach()
def _get_obs(self) -> dict[str, Observation]:
obs = self.observation_handler.get_observation(self.page)
return obs
def _get_obs_metadata(self) -> dict[str, ObservationMetadata]:
metadata = self.observation_handler.get_observation_metadata()
return metadata
@beartype
def reset(
self,
*,
seed: int | None = None,
options: dict[str, str] | None = None,
) -> tuple[dict[str, Observation], dict[str, Any]]:
"""
Reset the environment.
:param options: options for the environment. The current supported options are:
- "storage_state": the storage state of the browser. It is a file path to a json file.
"""
super().reset(seed=seed, options=options)
if self.reset_finished:
self.context_manager.__exit__()
if options is not None and "config_file" in options:
config_file = Path(options["config_file"])
if config_file.exists():
self.setup(config_file=config_file)
else:
raise ValueError(f"Config file {config_file} does not exist.")
else:
self.setup()
self.reset_finished = True
timeout_in_ms = 120000
self.page.set_default_timeout(timeout_in_ms)
self.page.set_default_navigation_timeout(timeout_in_ms)
self.page.wait_for_timeout(int(self.sleep_after_execution * 1000))
observation = self._get_obs()
observation_metadata = self._get_obs_metadata()
info = {
"page": DetachedPage(self.page.url, ""),
"fail_error": "",
"observation_metadata": observation_metadata,
}
return (observation, info)
def save_trace(self, trace_path: str | Path) -> None:
if self.save_trace_enabled:
self.context.tracing.stop(path=trace_path)
def close(self) -> None:
if self.reset_finished:
self.context_manager.__exit__()
def step(
self, action: Action
) -> tuple[dict[str, Observation], float, bool, bool, dict[str, Any]]:
if not self.reset_finished:
raise RuntimeError("Call reset first before calling step.")
success = False
fail_error = ""
try:
if self.text_observation_type == 'webrl':
self.page = execute_action_webrl(
action,
self.page,
self.context,
self.observation_handler.action_processor,
self.sleep_after_execution,
)
else:
self.page = execute_action(
action,
self.page,
self.context,
self.observation_handler.action_processor,
self.sleep_after_execution,
)
success = True
except Exception as e:
fail_error = str(e)
observation = self._get_obs()
observation_metadata = self._get_obs_metadata()
info = {
"page": DetachedPage(self.page.url, self.page.content()),
"fail_error": fail_error,
"observation_metadata": observation_metadata,
}
msg = (
observation,
float(success), # reward
False, # terminated
False, # truncated
info,
)
return msg

View File

@ -172,6 +172,10 @@ class StringEvaluator(Evaluator):
# prevent false positive (e.g, 0) # prevent false positive (e.g, 0)
if len(word_tokenize(clean_ref)) == 1: if len(word_tokenize(clean_ref)) == 1:
tok_pred = word_tokenize(clean_pred) tok_pred = word_tokenize(clean_pred)
for token in tok_pred:
if '/' in token:
sub_tokens = token.split('/')
tok_pred.extend(sub_tokens)
return float(clean_ref in tok_pred) return float(clean_ref in tok_pred)
else: else:
return float(clean_ref in clean_pred) return float(clean_ref in clean_pred)

View File

@ -0,0 +1,239 @@
import base64
import io
import json
import re
from pathlib import Path
from typing import Any
from PIL import Image
from agent.prompts import *
from browser_env import (
Action,
ActionTypes,
ObservationMetadata,
StateInfo,
action2str,
)
HTML_TEMPLATE = """
<!DOCTYPE html>
<head>
<style>
pre {{
white-space: pre-wrap;
word-wrap: break-word;
}}
</style>
</head>
<html>
<body>
{body}
</body>
</html>
"""
def get_render_action(
action: Action,
observation_metadata: dict[str, ObservationMetadata],
action_set_tag: str,
) -> str:
"""Parse the predicted actions for rendering purpose. More comprehensive information"""
match action_set_tag:
case "id_accessibility_tree":
text_meta_data = observation_metadata["text"]
if action["element_id"] in text_meta_data["obs_nodes_info"]:
node_content = text_meta_data["obs_nodes_info"][
action["element_id"]
]["text"]
else:
node_content = "No match found"
action_str = f"<div class='raw_parsed_prediction' style='background-color:grey'><pre>{action['raw_prediction']}</pre></div>"
action_str += f"<div class='action_object' style='background-color:grey'><pre>{repr(action)}</pre></div>"
action_str += f"<div class='parsed_action' style='background-color:yellow'><pre>{action2str(action, action_set_tag, node_content)}</pre></div>"
case "som":
text_meta_data = observation_metadata["text"]
if action["element_id"] in text_meta_data["obs_nodes_info"]:
node_content = text_meta_data["obs_nodes_info"][
action["element_id"]
]
else:
node_content = "No match found"
action_str = f"<div class='raw_parsed_prediction' style='background-color:grey'><pre>{action['raw_prediction']}</pre></div>"
action_str += f"<div class='action_object' style='background-color:grey'><pre>{repr(action)}</pre></div>"
action_str += f"<div class='parsed_action' style='background-color:yellow'><pre>{action2str(action, action_set_tag, node_content)}</pre></div>"
case "playwright":
action_str = action["pw_code"]
case "webrl_id":
action_str = action["raw_prediction"]
case _:
raise ValueError(
f"Unknown action type {action['action_type'], action_set_tag}"
)
return action_str
def get_action_description(
action: Action,
observation_metadata: dict[str, ObservationMetadata],
action_set_tag: str,
prompt_constructor: PromptConstructor | None,
) -> str:
"""Generate the text version of the predicted actions to store in action history for prompt use.
May contain hint information to recover from the failures"""
match action_set_tag:
case "id_accessibility_tree":
text_meta_data = observation_metadata["text"]
if action["action_type"] in [
ActionTypes.CLICK,
ActionTypes.HOVER,
ActionTypes.TYPE,
]:
action_name = str(action["action_type"]).split(".")[1].lower()
if action["element_id"] in text_meta_data["obs_nodes_info"]:
node_content = text_meta_data["obs_nodes_info"][
action["element_id"]
]["text"]
node_content = " ".join(node_content.split()[1:])
action_str = action2str(
action, action_set_tag, node_content
)
else:
action_str = f"Attempt to perfom \"{action_name}\" on element \"[{action['element_id']}]\" but no matching element found. Please check the observation more carefully."
else:
if (
action["action_type"] == ActionTypes.NONE
and prompt_constructor is not None
):
action_splitter = prompt_constructor.instruction[
"meta_data"
]["action_splitter"]
action_str = f'The previous prediction you issued was "{action["raw_prediction"]}". However, the format was incorrect. Ensure that the action is wrapped inside a pair of {action_splitter} and enclose arguments within [] as follows: {action_splitter}action [arg] ...{action_splitter}.'
else:
action_str = action2str(action, action_set_tag, "")
case "som":
text_meta_data = observation_metadata["image"]
if action["action_type"] in [
ActionTypes.CLICK,
ActionTypes.HOVER,
ActionTypes.TYPE,
]:
action_name = str(action["action_type"]).split(".")[1].lower()
if action["element_id"] in text_meta_data["obs_nodes_info"]:
action_str = action2str(action, action_set_tag, "")
else:
print(
'action["element_id"], text_meta_data["obs_nodes_info"]',
action["element_id"],
text_meta_data["obs_nodes_info"],
)
action_str = f"Attempt to perfom \"{action_name}\" on element \"[{action['element_id']}]\" but no matching element found. Please check the observation more carefully."
else:
if (
action["action_type"] == ActionTypes.NONE
and prompt_constructor is not None
):
action_splitter = prompt_constructor.instruction[
"meta_data"
]["action_splitter"]
action_str = f'The previous prediction you issued was "{action["raw_prediction"]}". However, the format was incorrect. Ensure that the action is wrapped inside a pair of {action_splitter} and enclose arguments within [] as follows: {action_splitter}action [arg] ...{action_splitter}.'
else:
action_str = action2str(action, action_set_tag, "")
case "playwright":
action_str = action["pw_code"]
case "webrl_id":
action_str = action["raw_prediction"]
case _:
raise ValueError(f"Unknown action type {action['action_type']}")
return action_str
class RenderHelper(object):
"""Helper class to render text and image observations and meta data in the trajectory"""
def __init__(
self, config_file: str, result_dir: str, action_set_tag: str
) -> None:
with open(config_file, "r") as f:
_config = json.load(f)
_config_str = ""
for k, v in _config.items():
_config_str += f"{k}: {v}\n"
_config_str = f"<pre>{_config_str}</pre>\n"
task_id = _config["task_id"]
self.action_set_tag = action_set_tag
self.render_file = open(
Path(result_dir) / f"render_{task_id}.html", "a+", encoding="utf-8"
)
self.render_file.truncate(0)
# write init template
self.render_file.write(HTML_TEMPLATE.format(body=f"{_config_str}"))
self.render_file.read()
self.render_file.flush()
def render(
self,
action: Action,
state_info: StateInfo,
meta_data: dict[str, Any],
render_screenshot: bool = False,
) -> None:
"""Render the trajectory"""
# text observation
observation = state_info["observation"]
text_obs = observation["text"]
info = state_info["info"]
new_content = f"<h2>New Page</h2>\n"
new_content += f"<h3 class='url'><a href={state_info['info']['page'].url}>URL: {state_info['info']['page'].url}</a></h3>\n"
new_content += f"<div class='state_obv'><pre>{text_obs}</pre><div>\n"
if render_screenshot:
# image observation
img_obs = observation["image"]
image = Image.fromarray(img_obs)
byte_io = io.BytesIO()
image.save(byte_io, format="PNG")
byte_io.seek(0)
image_bytes = base64.b64encode(byte_io.read())
image_str = image_bytes.decode("utf-8")
new_content += f"<img src='data:image/png;base64,{image_str}' style='width:50vw; height:auto;'/>\n"
# meta data
new_content += f"<div class='prev_action' style='background-color:pink'>{meta_data['action_history'][-1]}</div>\n"
# action
action_str = get_render_action(
action,
info["observation_metadata"],
action_set_tag=self.action_set_tag,
)
# with yellow background
action_str = f"<div class='predict_action'>{action_str}</div>"
new_content += f"{action_str}\n"
# add new content
self.render_file.seek(0)
html = self.render_file.read()
html_body = re.findall(r"<body>(.*?)</body>", html, re.DOTALL)[0]
html_body += new_content
html = HTML_TEMPLATE.format(body=html_body)
self.render_file.seek(0)
self.render_file.truncate()
self.render_file.write(html)
self.render_file.flush()
def close(self) -> None:
self.render_file.close()

View File

@ -0,0 +1,7 @@
from .identifier import IdentifierTool
from .prompt import HtmlPrompt
from .html_parser import HtmlParser
from .utils import print_html_object
from .configs import basic_attrs, mind2web_keep_attrs
from .fetch import get_parsed_html

View File

@ -0,0 +1,3 @@
from .html_prompt import prompts
from .config import basic_attrs, mind2web_keep_attrs, miniwob_attrs
from .config import config_meta

View File

@ -0,0 +1,56 @@
basic_attrs = [
'title',
'value',
'type',
'placeholder',
'selected',
'data-value',
'data-text',
'data-testid',
'data-label',
'data-bbox',
'data-status'
]
mind2web_keep_attrs = [
'alt',
'aria_description',
'aria_label',
'aria_role',
'input_checked',
'input_value',
'label',
'name',
'option_selected',
'placeholder',
'role',
'text_value',
'title',
'type',
'value',
]
miniwob_attrs = [
'id',
'type',
'value',
]
config_meta = """
======= Configs =======
Columns:
- id: {id_attr}
- label: {label_attr}
Position: {use_position}
- window: {window_size}
- rect_dict: {rect}
Keep:
- parents: {parent_chain}
- attrs: {keep_attrs}
- elems: {keep_elem}
- obs_elem: {obs_elem}
Generator:
- prompt: {prompt_name}
- label: {identifier_name}
========================
"""

View File

@ -0,0 +1,22 @@
refine_prompt = {
'dom': '<{tag}{label}|{attr}{content}{subtree} >',
'label': '[{label}]',
'attr': '{attr}',
'attr_splitter': '; ',
'subtree_splitter': ' ',
}
xml_prompt = {
'dom': '<{tag}{label}{attr}>{content}{subtree} </{tag}>',
'label': ' id="{label}"',
'attr': '{key}="{attr}"',
'attr_splitter': ' ',
'subtree_splitter': ' ',
}
prompts = {
'refine': refine_prompt,
'xml': xml_prompt,
'new_data': refine_prompt,
}

View File

@ -0,0 +1,108 @@
import os
import json
import base64
from .html_parser import HtmlParser
from .configs import basic_attrs
from .scripts import *
def get_window(page):
x = page.evaluate("window.scrollX")
y = page.evaluate("window.scrollY")
w = page.evaluate("window.innerWidth")
h = page.evaluate("window.innerHeight")
return (x, y, w, h)
def modify_page(page):
page.wait_for_timeout(500)
try:
page.evaluate(remove_id_script)
except:
pass
packet = {
"raw_html": page.evaluate("document.documentElement.outerHTML"),
"window": get_window(page)
}
page.evaluate(prepare_script)
page.wait_for_timeout(100)
img_bytes = page.screenshot(path="debug_info/screenshot_raw.png")
raw_image = base64.b64encode(img_bytes).decode()
page.evaluate(clickable_checker_script)
page.wait_for_timeout(50)
# get all clickable elements
start_id = 0
items, start_id = page.evaluate(label_script, {
"selector": ".possible-clickable-element",
"startIndex": start_id
})
page.wait_for_timeout(50)
# mark our own labels and get the images
items = page.evaluate(label_marker_script, items)
page.wait_for_timeout(100)
img_bytes = page.screenshot(path="debug_info/marked.png")
marked_image = base64.b64encode(img_bytes).decode()
# remove markers on the page
page.evaluate(remove_label_mark_script)
packet.update({
"raw_image": raw_image,
"marked_image": marked_image,
"modified_html": page.evaluate("document.documentElement.outerHTML")
})
# element_info, include "all_elements" and "clickable_elements"
element_info = page.evaluate(element_info_script)
page.wait_for_timeout(100)
packet.update(element_info)
return packet
def save_debug_info(packet):
with open("debug_info/raw.html", "w") as f:
f.write(packet["modified_html"])
with open("debug_info/parsed.html", "w") as f:
f.write(packet["html"])
with open("debug_info/all_element.json", "w") as f:
f.write(json.dumps(packet["all_elements"]))
def get_parsed_html(page):
if not os.path.exists("debug_info"):
os.makedirs("debug_info")
print("parsing html...")
packet = modify_page(page)
raw_html = packet["modified_html"]
args = {
"use_position": True,
"rect_dict": {},
"window_size": packet["window"],
"id-attr": "data-backend-node-id",
"label_attr": "data-label-id",
"label_generator": "order",
"regenerate_label": False,
"attr_list": basic_attrs,
"prompt": "xml",
"dataset": "pipeline"
}
hp = HtmlParser(raw_html, args)
res = hp.parse_tree()
page_html = res.get("html", "")
packet["html"] = page_html
# for debug
save_debug_info(packet)
print("parsing finished.")
return packet

View File

@ -0,0 +1,447 @@
from lxml import html
import time, copy, random
import json, re, os
from .identifier import IdentifierTool
from .prompt import HtmlPrompt
from .configs import config_meta
from .utils import get_xpath_top_down, rect2tuple
class HtmlParser():
def __init__(self, ctx: str, args: dict[str]={}) -> None:
stt = time.time()
self.dom_tree = self.ctx2tree(ctx)
# tool related
self.bids2label = {}
self.bids2xpath = {}
self.used_labels = {}
# parse args
self.parse_args(args)
self.init_time = time.time() - stt
def parse_args(self, args: dict[str]={}) -> None:
def attr_check(attr, type_model='str'):
if attr is None:
return False
attr_type = type(attr)
if attr_type != type(type_model):
return False
if attr_type == type('str') and len(attr) == 0:
return False
return True
args = {} if args is None else args
# [Position] use_pos: False -> use full page, otherwise use window_size
dataset = args.get('dataset', '')
use_position = args.get('use_position', False)
window_size = args.get('window_size', None)
rect = args.get('rect_dict', None)
if use_position:
if not attr_check(window_size, ()):
raise ValueError('window_size must be set when use_position is True')
if not attr_check(rect, {}):
raise ValueError('rect_dict must be set when use_position is True')
if not attr_check(rect, {}):
rect = {}
# [Label] for vimium is temp_clickable_label, otherwise keep all of it
label_attr = args.get('label_attr', '')
get_new_label = args.get('regenerate_label', False)
label_method = args.get('label_generator', None)
regen_label = not attr_check(label_method)
# [id] for mind2web is backend_node_id, for normal website use our method
id_attr = args.get('id_attr', '')
regen_id = not attr_check(id_attr)
if regen_id:
id_attr = 'temp_id'
# [attributes]
keep_attrs = args.get('attr_list', [])
if not attr_check(keep_attrs, []):
keep_attrs = []
# [Tags] for clickable elem, keep: must keep, obs: keep if follow specific rule
parent_chain = args.get('parent_chain', False)
keep_elem = args.get('keep_elem', [])
obs_elem = args.get('obs_elem', [])
# sanity check
self.set_args(use_position, window_size, rect, label_attr, id_attr, keep_attrs, keep_elem, obs_elem, parent_chain, get_new_label, dataset)
# [Prompt]
prompt = args.get('prompt', None)
self.prompt = HtmlPrompt(prompt)
# traverse and get special data
if regen_id or regen_label:
self.mark_id()
if get_new_label:
self.used_labels = {}
self.identifier = IdentifierTool(label_method, self.used_labels)
def set_args(self, use_position: bool=False, window_size: tuple=(), rect_dict: dict[str]={}, label_attr: str='',
id_attr: str='', keep_attrs: list[str]=[], keep_elem: list[str]=[], obs_elem: list[str]=[],
parent_chain: bool=False, get_new_label: bool=False, dataset: str='') -> None:
self.use_position = use_position
self.window_size = window_size
self.rect = rect_dict
self.label_attr = label_attr
self.id_attr = id_attr
self.keep_attrs = keep_attrs
self.keep = keep_elem
self.obs = obs_elem
self.parent_chain = parent_chain
self.get_new_label = get_new_label
self.dataset = dataset
def get_config(self):
config = {
'id_attr': self.id_attr,
'keep_attrs': self.keep_attrs[:5],
'label_attr': self.label_attr,
'use_position': self.use_position,
'window_size': self.window_size,
'rect': dict(list(self.rect.items())[:3]),
'keep_elem': self.keep[:5],
'obs_elem': self.obs[:5],
'parent_chain': self.parent_chain,
'prompt_name': self.prompt.name,
'identifier_name': self.identifier.name
}
return config, config_meta.format(**config)
def update_rect_dict(self, rect_dict: dict[str]={}) -> None:
self.rect = rect_dict
@staticmethod
def ctx2tree(ctx: str) -> html.HtmlElement:
# remove useless tags, eg. style and script
ctx = re.sub('<!--[\W\w]*?-->', '', ctx)
ctx = re.sub('<style[\W\w]*?>[\W\w]*?</style>', '', ctx)
ctx = re.sub('<script[\W\w]*?>[\W\w]*?</script>', '', ctx)
ctx = '' if ctx is None else re.sub(r'\s+', ' ', ctx).strip()
dom_tree = html.fromstring(ctx.encode('utf-8'))
match = re.search('<meta charset="([^"]*)"', ctx)
if match:
charset = match.group(1)
ctx = ctx.encode(charset)
print(charset)
else:
print("Charset not found")
return dom_tree
@staticmethod
def get_root(tree: html.HtmlElement) -> html.HtmlElement:
node = tree.xpath('//*')[0]
while True:
parent = node.getparent()
if parent is None:
break
node = parent
return node
def get_node_by_bid(self, tree: html.HtmlElement, bid: str) -> html.HtmlElement:
nodes = tree.xpath(f'//*[@{self.id_attr}="{bid}"]')
if len(nodes) == 0:
return None
return nodes[0]
def id_label_converter(self, label: str) -> str:
return self.bids2label.get(label, '')
def id_xpath_converter(self, label: str) -> str:
return self.bids2xpath.get(label, '')
def mark_id(self) -> None:
root = self.get_root(self.dom_tree)
_, i2xpath, used_labels = get_xpath_top_down(root, self.id_attr, self.label_attr)
self.used_labels = used_labels
self.bids2xpath = i2xpath
def parse(self, root: html.HtmlElement, keep: list[str], obs: list[str], parent_chain: bool=False, get_new_label: bool=False) -> dict[str]:
def get_text(str: str) -> str:
return '' if str is None else str.strip()[:500]
def check_attr(attr: str, node: html.HtmlElement) -> bool:
tag = node.tag
if (
( attr == 'role' and node.attrib.get(attr, '') in ['presentation', 'none', 'link'] )
or ( attr == 'type' and node.attrib.get(attr, '') == 'hidden' )
or ( attr == 'aria-hidden' and node.attrib.get(attr, '') == 'true')
# or ( attr == 'value' and tag in ['option'] )
):
return False
return True
def is_visible(node: html.HtmlElement, bid: str) -> bool:
if self.dataset == 'mind2web':
bound = node.attrib.get('bounding_box_rect', None)
self.rect[bid] = rect2tuple(bound)
if self.dataset == 'pipeline':
bound = node.attrib.get('data-bbox', None)
self.rect[bid] = rect2tuple(bound)
if node.attrib.get('aria-hidden', 'false') == 'true':
return False
if not self.use_position:
return True
rect = self.rect.get(bid, None)
if rect is None:
return False
if self.window_size is None:
return True
# get window size
wx, wy, ww, wh = self.window_size
wx, wy = 0, 0
x, y, w, h = rect
if x + w < wx or x > wx + ww or y + h < wy or y > wy + wh:
return False
return True
def _dfs(node: html.HtmlElement, keep: list[str]=[], obs: list[str]=[],
parent_chain: bool=False, get_new_label: bool=False, par_keep: bool=False) -> (str, dict[str]):
# basic information
bid = node.attrib.get(self.id_attr, '')
tag = node.tag
label = node.attrib.get(self.label_attr, '')
# element which is keeped equivalent to visible
visible = is_visible(node, bid)
in_keep_list = bid in keep
in_obs_list = (bid in obs or len(label) > 0) and visible
par_keep = par_keep and tag == "option"
keep_element = in_keep_list or in_obs_list or visible or par_keep
if label:
keep_element = True
# mark label
bids2label, labeled_elems = {}, []
have_label = False
if in_keep_list or in_obs_list:
if label is None or len(label) == 0 or get_new_label:
label = self.identifier.generate()
node.attrib[self.label_attr] = label
bids2label[bid] = label
bids2label[label] = bid
have_label = True
# get text or alt_text of current element
text = get_text(node.text)
classes = {}
# keep attributes if needed
keep_all_attrs = len(self.keep_attrs) == 0
keep_attrs = node.attrib.keys() if keep_all_attrs else self.keep_attrs
# traverse attributes
for attr in keep_attrs:
if attr not in node.attrib or not check_attr(attr, node):
continue
if attr in [self.id_attr, self.label_attr]:
continue
val = get_text(node.attrib[attr])
if len(val) > 0 or keep_all_attrs:
classes[attr] = val
have_text = len(text) > 0 or len(classes) - (1 if 'data-bbox' in classes else 0) > 0
par_keep = keep_element and tag == 'select'
parts = []
clickable_count = 0
children = node.getchildren()
for child in children:
cres, cmsg = _dfs(child, keep, obs, parent_chain, get_new_label, par_keep)
clickable_count += 1 if cmsg.get('have_clickable', False) else 0
bids2label.update(cmsg.get('bids2label', {}))
labeled_elems.extend(cmsg.get('label_element', []))
if len(cres) != 0:
parts.append(cres)
dom = self.prompt.subtree_constructor(parts)
# remove <text|> if all children are text
keep_as_all_text = (dom.count('<') == dom.count('<text|')) and dom.count('<') > 0
if keep_as_all_text:
matches = re.findall(r'<text\| ([^>]+) >', dom)
dom = self.prompt.subtree_constructor(matches)
keep_element = keep_element and (clickable_count > 1 or have_text or have_label or keep_as_all_text)
keep_as_parent = len(dom) > 0 and parent_chain
if in_keep_list or keep_element or keep_as_parent:
dom = self.prompt.prompt_constructor(tag, label, text, dom, classes)
if have_label:
labeled_elems.append(bid)
control_msg = {
'have_clickable': bool(clickable_count or have_text),
'bids2label': bids2label,
'label_element': labeled_elems,
}
return dom, control_msg
dom, cmsg = _dfs(root, keep, obs, parent_chain, get_new_label)
return dom, cmsg
def parse_tree(self) -> dict[str]:
# start from here
stt = time.time()
root = self.get_root(self.dom_tree)
dom, cmsg = self.parse(root, self.keep, self.obs, self.parent_chain, self.get_new_label)
self.bids2label = cmsg.get('bids2label', {})
self.keep = list(set(self.keep + cmsg.get('label_element', [])))
obj = {
'html': dom,
'parse_time': time.time() - stt
}
return obj
# From mind2web, https://github.com/OSU-NLP-Group/Mind2Web/blob/main/src/data_utils/dom_utils.py
def get_keep_elements(self, tree: html.HtmlElement, keep: list[str], max_depth: int, max_children: int,
max_sibling: int, dfs_count: int=1, keep_parent: bool=False) -> list[str]:
def get_anscendants(node: html.HtmlElement, max_depth: int, current_depth: int=0) -> list[str]:
if current_depth > max_depth:
return []
anscendants = []
parent = node.getparent()
if parent is not None:
anscendants.append(parent)
anscendants.extend(get_anscendants(parent, max_depth, current_depth + 1))
return anscendants
def get_descendants(node: html.HtmlElement, max_depth: int, current_depth: int=0) -> list[str]:
if current_depth > max_depth:
return []
descendants = []
for child in node:
descendants.append(child)
descendants.extend(get_descendants(child, max_depth, current_depth + 1))
return descendants
to_keep = set(copy.deepcopy(keep))
nodes_to_keep = set()
for _ in range(max(1, dfs_count)):
for bid in to_keep:
candidate_node = self.get_node_by_bid(tree, bid)
if candidate_node is None:
continue
nodes_to_keep.add(candidate_node.attrib[self.id_attr])
# get all ancestors or with max depth
nodes_to_keep.update([x.attrib.get(self.id_attr, '') for x in get_anscendants(candidate_node, max_depth)])
# get descendants with max depth
nodes_to_keep.update([x.attrib.get(self.id_attr, '') for x in get_descendants(candidate_node, max_depth)][:max_children])
# get siblings within range
parent = candidate_node.getparent()
if parent is None:
continue
siblings = [x for x in parent.getchildren() if x.tag != 'text']
if candidate_node not in siblings:
continue
idx_in_sibling = siblings.index(candidate_node)
nodes_to_keep.update([x.attrib.get(self.id_attr, '')
for x in siblings[max(0, idx_in_sibling - max_sibling) : idx_in_sibling + max_sibling + 1]])
max_children = int(max_children * 0.5)
max_depth = int(max_depth * 0.5)
max_sibling = int(max_sibling * 0.7)
to_keep = copy.deepcopy(nodes_to_keep)
if keep_parent:
for bid in keep:
candidate_node = self.get_node_by_bid(tree, bid)
if candidate_node is None:
continue
nodes_to_keep.update([x.attrib.get(self.id_attr, '') for x in candidate_node.xpath("ancestor::*")])
return list(nodes_to_keep)
def prune(self, tree: html.HtmlElement, nodes_to_keep: list[str]) -> html.HtmlElement:
# remove nodes not in nodes_to_keep
for node in tree.xpath('//*')[::-1]:
if node.tag != 'text':
is_keep = node.attrib.get(self.id_attr, '') in nodes_to_keep
is_candidate = node.attrib.get(self.id_attr, '') in self.keep
else:
is_keep = (node.getparent().attrib.get(self.id_attr, '') in nodes_to_keep)
is_candidate = (node.getparent().attrib.get(self.id_attr, '') in self.keep)
if not is_keep and node.getparent() is not None:
# insert all children into parent
for child in node.getchildren():
node.addprevious(child)
node.getparent().remove(node)
else:
# if not is_candidate or node.tag == 'text':
# node.attrib.pop(self.id_attr, None)
if (
len(node.attrib) == 0
and not any([x.tag == 'text' for x in node.getchildren()])
and node.getparent() is not None
and node.tag != "text"
and len(node.getchildren()) <= 1
):
# insert all children into parent
for child in node.getchildren():
node.addprevious(child)
node.getparent().remove(node)
return tree
def prune_tree(self, dfs_count: int=1, max_depth: int=3, max_children: int=30,
max_sibling: int=3, keep_parent: bool=False) -> None:
# clone the tree
new_tree = copy.deepcopy(self.dom_tree)
nodes_to_keep = self.get_keep_elements(new_tree, self.keep, max_depth, max_children, max_sibling, dfs_count, keep_parent)
new_tree = self.prune(new_tree, nodes_to_keep)
self.dom_tree = new_tree
def get_segment(self, bid: str) -> str:
# clone the tree
new_tree = copy.deepcopy(self.dom_tree)
nodes_to_keep = self.get_keep_elements(new_tree, [bid], 0, 2, 1)
new_tree = self.prune(new_tree, nodes_to_keep)
dom, _ = self.parse(new_tree, self.keep, [], False)
return dom
def get_rect_data(self, bids: list[str]) -> list[dict[str]]:
res = []
for bid in bids:
label = self.bids2label.get(bid, '')
rect = self.rect.get(bid, None)
res.append({
'bid': bid,
'label': label,
'rect': rect
})
return res

View File

@ -0,0 +1,64 @@
import secrets
class IdentifierTool:
def __init__(self, method: str='order', existing_labels: dict[str]={}) -> None:
self.methods = {
'order': self.get_identifier_in_order,
'random': self.get_random_identifier,
}
if method is None:
method = 'order'
self.func = self.methods.get(method, None)
self.name = method
if self.func is None:
raise ValueError(f'Invalid method for identifier: {method}')
self.reset(existing_labels)
def reset(self, exists: dict[str]={}) -> None:
self.identifier = -1
self.exists = {} if exists is None else exists
def get_identifier_in_order(self) -> str:
def id2str(id: int) -> str:
if id < 26:
return chr(id + 65)
id -= 26
c0 = id // 676
c1 = (id // 26) % 26
c2 = id % 26
label = f'{chr(c1 + 65)}{chr(c2 + 65)}'
return label if c0 == 0 else f'{chr(c0 + 64)}{label}'
self.identifier += 1
label = id2str(self.identifier)
while label in self.exists:
self.identifier += 1
label = id2str(self.identifier)
self.exists[label] = True
return label
def get_random_identifier(self) -> str:
secret_generator = secrets.SystemRandom()
def get_random_label(n: int=2) -> str:
tmp = ''
for _ in range(n):
tmp += chr(secret_generator.randint(65, 90))
return tmp
wc = 3 if len(self.exists) > 280 else 2
label = get_random_label(wc)
while label in self.exists:
label = get_random_label(wc)
self.exists[label] = True
return label
def generate(self):
return self.func()

View File

@ -0,0 +1,97 @@
from .configs import prompts
class HtmlPrompt:
def __init__(self, prompt: str='') -> None:
prompt = self.extract(prompt, 'xml')
if prompt not in prompts:
raise Exception('Unknown prompt: ' + prompt)
constructors = {
'refine': self.normal_prompt_constructor,
'xml': self.normal_prompt_constructor,
'new_data': self.new_data_prompt_constructor,
}
self.name = prompt
self.prompt = prompts[prompt]
self.constructor = constructors[prompt]
@staticmethod
def extract(data, default=''):
return data if data is not None else default
def subtree_constructor(self, subtree: list[str]=[]) -> str:
return self.prompt['subtree_splitter'].join(subtree)
def normal_prompt_constructor(self, tag: str='', label: str='', content: str='', subtree_str: str='', class_dict: dict[str]={}) -> str:
def add_prefix(data, prefix):
return prefix + data if len(data) > 0 else ''
tag = self.extract(tag)
label = self.extract(label)
content = self.extract(content)
subtree_str = self.extract(subtree_str, '')
class_dict = self.extract(class_dict, {})
label_str = ''
if len(label) > 0:
label_str = self.prompt['label'].format(label=label)
classes = []
values = set()
for key, val in class_dict.items():
if val in values:
continue
values.add(val)
classes.append(self.prompt['attr'].format(key=key, attr=val))
classes_str = self.prompt['attr_splitter'].join(classes)
content_splitter = ' ' if len(classes_str) == 0 else self.prompt['attr_splitter']
classes_str = add_prefix(classes_str, ' ')
content_str = add_prefix(content, content_splitter)
subtree_str = add_prefix(subtree_str, ' ')
return self.prompt['dom'].format(tag=tag, label=label_str, attr=classes_str, content=content_str, subtree=subtree_str)
def new_data_prompt_constructor(self, tag: str='', label: str='', content: str='', subtree_str: str='', class_dict: dict[str]={}) -> str:
def add_prefix(data, prefix):
return prefix + data if len(data) > 0 else ''
tag = self.extract(tag)
label = self.extract(label)
content = self.extract(content)
subtree_str = self.extract(subtree_str, '')
class_dict = self.extract(class_dict, {})
label_str = ''
if len(label) > 0:
label_str = self.prompt['label'].format(label=label)
classes = []
values = set()
message = []
for key, val in class_dict.items():
if val == '':
message.append(key)
continue
if val in values:
continue
values.add(val)
classes.append(self.prompt['attr'].format(key=key, attr=val))
if len(message) > 0:
message_str = ' '.join(message)
classes.append(self.prompt['attr'].format(key='message', attr=message_str))
classes_str = self.prompt['attr_splitter'].join(classes)
content_splitter = ' ' if len(classes_str) == 0 else self.prompt['attr_splitter']
classes_str = add_prefix(classes_str, ' ')
content_str = add_prefix(content, content_splitter)
subtree_str = add_prefix(subtree_str, ' ')
return self.prompt['dom'].format(tag=tag, label=label_str, attr=classes_str, content=content_str, subtree=subtree_str)
def prompt_constructor(self, tag: str='', label: str='', content: str='', subtree_str: str='', class_dict: dict[str]={}) -> str:
return self.constructor(tag, label, content, subtree_str, class_dict)

View File

@ -0,0 +1,43 @@
import os
from pathlib import Path
rootdir = Path(__file__).parent
with open(os.path.join(rootdir,'prepare.js'), 'r') as f:
prepare_script = f.read()
with open(os.path.join(rootdir, 'clickable_checker.js'), 'r') as f:
clickable_checker_script = f.read()
with open(os.path.join(rootdir, 'label.js'), 'r') as f:
label_script = f.read()
with open(os.path.join(rootdir, 'element_info.js'), 'r') as f:
element_info_script = f.read()
# draw label on page
with open(os.path.join(rootdir, 'label_marker.js'), 'r') as f:
label_marker_script = f.read()
# remove label draw on page
remove_label_mark_script = """
() => {
document.querySelectorAll(".our-dom-marker").forEach(item => {
document.body.removeChild(item);
});
}
"""
remove_id_script = """
() => {
Array.from(document.getElementsByClassName('possible-clickable-element')).forEach((element) => {
element.classList.remove('possible-clickable-element');
element.removeAttribute('data-value');
element.removeAttribute('data-text');
element.removeAttribute('data-label');
element.removeAttribute('data-bbox');
element.removeAttribute('data-status');
element.removeAttribute('data-backend-node-id');
element.removeAttribute('data-label-id');
});
}
"""

View File

@ -0,0 +1,148 @@
() => {
var items = Array.prototype.slice.call(
document.querySelectorAll('*')
).map(function(element) {
var vw = Math.max(document.documentElement.clientWidth || 0, window.innerWidth || 0);
var vh = Math.max(document.documentElement.clientHeight || 0, window.innerHeight || 0);
var rects = [...element.getClientRects()].filter(bb => {
var center_x = bb.left + bb.width / 2;
var center_y = bb.top + bb.height / 2;
var elAtCenter = document.elementFromPoint(center_x, center_y);
if (!elAtCenter) return false;
return elAtCenter === element || element.contains(elAtCenter)
}).map(bb => {
const rect = {
left: Math.max(0, bb.left),
top: Math.max(0, bb.top),
right: Math.min(vw, bb.right),
bottom: Math.min(vh, bb.bottom)
};
return {
...rect,
width: rect.right - rect.left,
height: rect.bottom - rect.top
}
});
// var rects = [];
var area = rects.reduce((acc, rect) => acc + rect.width * rect.height, 0);
const tagName = element.tagName.toLowerCase?.() || "";
let isClickable = ((element.onclick != null) || window.getComputedStyle(element).cursor == "pointer");
// Insert area elements that provide click functionality to an img.
if (tagName === "img") {
let mapName = element.getAttribute("usemap");
if (mapName) {
const imgClientRects = element.getClientRects();
mapName = mapName.replace(/^#/, "").replace('"', '\\"');
const map = document.querySelector(`map[name=\"${mapName}\"]`);
if (map && (imgClientRects.length > 0)) isClickable = true;
}
}
if (!isClickable) {
const role = element.getAttribute("role");
const clickableRoles = [
"button",
"tab",
"link",
"checkbox",
"menuitem",
"menuitemcheckbox",
"menuitemradio",
"radio",
];
if (role != null && clickableRoles.includes(role.toLowerCase())) {
isClickable = true;
} else {
const contentEditable = element.getAttribute("contentEditable");
if (
contentEditable != null &&
["", "contenteditable", "true"].includes(contentEditable.toLowerCase())
) {
isClickable = true;
}
}
}
// Check for jsaction event listeners on the element.
if (!isClickable && element.hasAttribute("jsaction")) {
const jsactionRules = element.getAttribute("jsaction").split(";");
for (let jsactionRule of jsactionRules) {
const ruleSplit = jsactionRule.trim().split(":");
if ((ruleSplit.length >= 1) && (ruleSplit.length <= 2)) {
const [eventType, namespace, actionName] = ruleSplit.length === 1
? ["click", ...ruleSplit[0].trim().split("."), "_"]
: [ruleSplit[0], ...ruleSplit[1].trim().split("."), "_"];
if (!isClickable) {
isClickable = (eventType === "click") && (namespace !== "none") && (actionName !== "_");
}
}
}
}
if (!isClickable) {
const clickableTags = [
"input",
"textarea",
"select",
"button",
"a",
"iframe",
"video",
"object",
"embed",
"details"
];
isClickable = clickableTags.includes(tagName);
}
if (!isClickable) {
if (tagName === "label")
isClickable = (element.control != null) && !element.control.disabled;
else if (tagName === "img")
isClickable = ["zoom-in", "zoom-out"].includes(element.style.cursor);
}
// An element with a class name containing the text "button" might be clickable. However, real
// clickables are often wrapped in elements with such class names. So, when we find clickables
// based only on their class name, we mark them as unreliable.
const className = element.getAttribute("class");
if (!isClickable && className && className.toLowerCase().includes("button")) {
isClickable = true;
}
// Elements with tabindex are sometimes useful, but usually not. We can treat them as second
// class citizens when it improves UX, so take special note of them.
const tabIndexValue = element.getAttribute("tabindex");
const tabIndex = tabIndexValue ? parseInt(tabIndexValue) : -1;
if (!isClickable && !(tabIndex < 0) && !isNaN(tabIndex)) {
isClickable = true;
}
const idValue = element.getAttribute("id");
const id = idValue ? idValue.toLowerCase() : "";
if (isClickable && area == 0) {
const textValue = element.textContent.trim().replace(/\s{2,}/g, ' ');
clickable_msg = `${tagName}[id=${id}] ${isClickable} (${area}) ${textValue}`
}
return {
element: element,
include: isClickable,
area,
rects,
text: element.textContent.trim().replace(/\s{2,}/g, ' ')
};
}).filter(item =>
item.include && (item.area >= 1)
);
items = items.filter(x => !items.some(y => x.element.contains(y.element) && !(x == y)))
items.forEach(item => {
item.element.classList.add('possible-clickable-element');
});
}

View File

@ -0,0 +1,34 @@
() => {
function getElementInfo(element) {
return {
"bid": element.getAttribute("data-backend-node-id") || "",
"label": element.getAttribute("data-label-id") || "",
"tag": element.tagName.toLowerCase?.() || "",
"area": JSON.parse("[" + (element.getAttribute("data-bbox") || "") + "]"),
"text": element.innerText?.trim().replace(/\s{2,}/g, " ") || "",
"id": element.getAttribute("id") || "",
"role": element.getAttribute("role") || "",
"aria-label": element.getAttribute("aria-label") || "",
"href": element.getAttribute("href") || "",
};
}
var all_items = Array.prototype.slice.call(
document.querySelectorAll("*")
).map((element) => {
return getElementInfo(element);
});
var clickable_items = Array.prototype.slice.call(
document.querySelectorAll("*")
).filter(
element => element.getAttribute("data-label-id")
).map((element) => {
return getElementInfo(element);
});
return {
all_elements: all_items,
clickable_elements: clickable_items
};
}

View File

@ -0,0 +1,74 @@
(packet) => {
function int2str(index) {
var str = "";
while (index >= 0) {
str = String.fromCharCode(65 + index % 26) + str;
index = Math.floor(index / 26) - 1;
}
return str;
};
selector = packet.selector
index = packet.startIndex
var items = Array.prototype.slice.call(
document.querySelectorAll(selector)
);
var vw = Math.max(document.documentElement.clientWidth || 0, window.innerWidth || 0);
var vh = Math.max(document.documentElement.clientHeight || 0, window.innerHeight || 0);
items = items.filter(
x => !items.some(y => x.contains(y) && !(x == y))
).map(element => {
var bb = element.getClientRects();
var rect = {
left: 0,
top: 0,
right: 0,
bottom: 0,
width: 0,
height: 0
};
var keep = false;
var text = "", id = -1;
if (bb.length > 0) {
bb = bb[0];
rect = {
left: Math.max(0, bb.left),
top: Math.max(0, bb.top),
right: Math.min(vw, bb.right),
bottom: Math.min(vh, bb.bottom)
};
rect = {
...rect,
width: rect.right - rect.left,
height: rect.bottom - rect.top
};
if (rect.width > 0 || rect.height > 0) {
keep = true;
if (index >= 0) {
// id = int2str(index++);
id = index++;
element.setAttribute("data-label-id", id);
}
var childNodes = element.childNodes;
for (var i = 0; i < childNodes.length; i++) {
if (childNodes[i].nodeType == Node.TEXT_NODE) {
text += childNodes[i].textContent;
}
}
}
}
return {
keep: true,
id,
rects: rect,
tag: element.tagName.toLowerCase?.() || "",
text,//: element.innerText?.trim().replace(/\s{2,}/g, " ") || ""
};
}).filter(x => x.keep);
return [items, index];
}

View File

@ -0,0 +1,65 @@
(items) => {
function getRandomColor() {
var letters = '0123456789ABCDEF';
var color = '#';
for (var i = 0; i < 6; i++) {
color += letters[Math.floor(Math.random() * 16)];
}
return color;
}
items.filter(
item => item.id != ""
).forEach((item) => {
const bbox = item.rects;
const id_string = `dom-marker-id-${index}`;
index = item.id;
outerElement = document.createElement("div");
outerElement.classList.add("our-dom-marker");
// var borderColor = getRandomColor();
var borderColor = "#FFFF00";
outerElement.style.outline = `2px dashed ${borderColor}`;
outerElement.style.position = "fixed";
outerElement.style.left = bbox.left - 2 + "px";
outerElement.style.top = bbox.top - 2 + "px";
outerElement.style.width = bbox.width + 4 + "px";
outerElement.style.height = bbox.height + 4 + "px";
outerElement.style.pointerEvents = "none";
outerElement.style.boxSizing = "border-box";
outerElement.style.zIndex = 2147483647;
innerElement = document.createElement("div");
innerElement.classList.add("our-dom-marker");
innerElement.style.outline = `2px dashed #222288`;
innerElement.style.position = "fixed";
innerElement.style.left = bbox.left + "px";
innerElement.style.top = bbox.top + "px";
innerElement.style.width = bbox.width + "px";
innerElement.style.height = bbox.height + "px";
innerElement.style.pointerEvents = "none";
innerElement.style.boxSizing = "border-box";
innerElement.style.zIndex = 2147483647;
// Add floating label at the corner
var label = document.createElement("span");
var topPosition = 25;
if (bbox.top < 25) topPosition = bbox.top;
label.textContent = index;
label.style.position = "absolute";
label.style.top = `-${topPosition}px`;
label.style.left = "0px";
label.style.background = borderColor;
label.style.color = "black";
label.style.padding = "2px 4px";
label.style.fontSize = "16px";
label.style.borderRadius = "2px";
label.style.fontWeight = "bold";
outerElement.appendChild(label);
document.body.appendChild(outerElement);
document.body.appendChild(innerElement);
})
return items;
}

View File

@ -0,0 +1,83 @@
() => {
// mark backend node id
var vw = Math.max(document.documentElement.clientWidth || 0, window.innerWidth || 0);
var vh = Math.max(document.documentElement.clientHeight || 0, window.innerHeight || 0);
var backendId = 0;
Array.prototype.slice.call(
document.querySelectorAll("*")
).forEach((element) => {
element.setAttribute("data-backend-node-id", backendId);
backendId++;
var tag = element.tagName.toLowerCase?.() || "";
var bb = element.getClientRects();
var rect = {
left: 0,
top: 0,
right: 0,
bottom: 0,
width: 0,
height: 0
};
if (bb.length > 0) {
bb = bb[0];
// rect = {
// left: Math.round(Math.max(0, bb.left) * 100) / 100,
// top: Math.round(Math.max(0, bb.top) * 100) / 100,
// right: Math.round(Math.min(vw, bb.right) * 100) / 100,
// bottom: Math.round(Math.min(vh, bb.bottom) * 100) / 100
// };
rect = {
left: (Math.round(bb.left) * 100) / 100,
top: (Math.round(bb.top) * 100) / 100,
right: (Math.round(bb.right) * 100) / 100,
bottom: (Math.round(bb.bottom) * 100) / 100
};
rect = {
...rect,
width: Math.round((rect.right - rect.left) * 100) / 100,
height: Math.round((rect.bottom - rect.top) * 100) / 100
};
element.setAttribute("data-bbox", `${rect.left},${rect.top},${rect.width},${rect.height}`);
}
if (element.hasChildNodes()) {
let children = Array.prototype.slice.call(element.childNodes);
var texts = children.filter(
(node) => node.nodeType == Node.TEXT_NODE
).map(
(node) => node.textContent.trim().replace(/\s{2,}/g, " ") || ""
).filter(
(text) => text.length > 0
)
element.setAttribute("data-text", texts.join(","));
}
// fix select issue
if (tag == "select") {
var value = element.value;
var text = element.options[element.selectedIndex]?.text || "";
element.setAttribute("data-value", value);
element.setAttribute("data-text", text);
element.options[element.selectedIndex]?.setAttribute("data-status", "selected");
}
if (tag == "input") {
var input_type = element.getAttribute("type") || "";
if (input_type == "checkbox") {
var status = element.checked? "checked" : "not-checked";
element.setAttribute("data-status", status);
}
}
});
// fix input and textarea issue
Array.prototype.slice.call(
document.querySelectorAll("input, textarea")
).forEach(element => {
element.setAttribute("data-value", element.value);
});
}

View File

@ -0,0 +1,101 @@
from lxml import html
def get_xpath_top_down(element: html.HtmlElement, id_column: str='temp_id', label_column: str='temp_clickable_label', path: str='', order: int=0,
in_svg: bool=False, temp_id: int=0) -> tuple[int, dict[str, str], dict[str]]:
used_labels, i2xpath = {}, {}
# path
tag = element.tag.lower()
in_svg = in_svg or (tag == 'svg')
if not in_svg and 'id' in element.attrib:
node_id = element.attrib['id']
path = f'//*[@id="{node_id}"]'
else:
suffix = f'[{order}]' if order > 0 else ''
prefix = f'*[name()="{tag}"]' if in_svg else tag
path = path + '/' + prefix + suffix
# add temp id
element.attrib[id_column] = str(temp_id)
ori_label = element.attrib.get(label_column, '')
if ori_label != '':
used_labels[ori_label] = True
bid = str(temp_id)
i2xpath[bid] = path
i2xpath[path] = bid
i2xpath[f'xpath/{path}'] = bid
i2xpath[f'xpath=/{path}'] = bid
temp_id += 1
# traverse node
children = element.getchildren()
tag_dict = {}
id_list = []
for child in children:
ctag = child.tag.lower()
if ctag not in tag_dict:
tag_dict[ctag] = 0
tag_dict[ctag] += 1
id_list.append(tag_dict[ctag])
for cid, child in zip(id_list, children):
ctag = child.tag.lower()
cod = cid if tag_dict[ctag] > 1 else 0
temp_id, i2x, ulabels = get_xpath_top_down(child, id_column, label_column, path, cod, in_svg, temp_id)
i2xpath.update(i2x)
used_labels.update(ulabels)
return temp_id, i2xpath, used_labels
def print_html_object(obj: str='') -> str:
tab_cnt = 0
result, content, sep = '', '', ''
last_is_left, last_is_right = False, False
for ch in obj:
if ch == '<':
result += '\n'
if len(content.strip()) > 0:
result += sep + content.strip() + '\n'
result += sep + '<'
tab_cnt += 1
sep = ' ' * tab_cnt
content = ''
last_is_right = False
last_is_left = True
elif ch == '>':
if last_is_left:
result += content
else:
if last_is_right:
result += '\n'
if len(content.strip()) > 0:
result += sep + content.strip() + '\n'
tab_cnt -= 1
sep = ' ' * tab_cnt
if not last_is_left:
result += sep
result += '>'
content = ''
last_is_right = True
last_is_left = False
else:
content += ch
return result
def rect2tuple(rect: str) -> tuple[int, int, int, int]:
if rect is None or type(rect) != type('str'):
return None
rect = rect.strip()
if rect.count(',') != 3:
return None
rect = rect.split(',')
rect = [float(r) for r in rect]
return tuple(rect)

View File

@ -36,7 +36,6 @@ def retry_with_exponential_backoff( # type: ignore
# Initialize variables # Initialize variables
num_retries = 0 num_retries = 0
delay = initial_delay delay = initial_delay
# Loop until a successful response or max_retries is hit or an exception is raised # Loop until a successful response or max_retries is hit or an exception is raised
while True: while True:
try: try:
@ -142,27 +141,32 @@ async def agenerate_from_openai_completion(
@retry_with_exponential_backoff @retry_with_exponential_backoff
def generate_from_openai_completion( def generate_from_openai_completion(
prompt: str, prompt: str,
engine: str, model: str,
temperature: float, temperature: float,
max_tokens: int, max_tokens: int,
top_p: float, top_p: float,
context_length: int,
stop_token: str | None = None, stop_token: str | None = None,
api_key: str | None = None,
base_url: str | None = None
) -> str: ) -> str:
if "OPENAI_API_KEY" not in os.environ: if "OPENAI_API_KEY" not in os.environ and api_key is None:
raise ValueError( raise ValueError(
"OPENAI_API_KEY environment variable must be set when using OpenAI API." "OPENAI_API_KEY environment variable must be set when using OpenAI API."
) )
if api_key is not None:
client = OpenAI(api_key=api_key, base_url=base_url)
response = client.completions.create( response = client.completions.create(
prompt=prompt, prompt=prompt,
engine=engine, model=model,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
top_p=top_p, top_p=top_p,
stop=[stop_token], stop=[stop_token],
) )
answer: str = response["choices"][0]["text"] try:
answer: str = response["choices"][0]["text"]
except:
answer: str = response.choices[0].text
return answer return answer

View File

@ -0,0 +1,13 @@
{
"intro": "",
"examples": [],
"template": "",
"meta_data": {
"observation": "webrl",
"action_type": "webrl_id",
"keywords": [],
"prompt_constructor": "WebRLPromptConstructor",
"answer_phrase": "",
"action_splitter": ""
}
}

File diff suppressed because it is too large Load Diff

View File

@ -499,3 +499,59 @@ class MultimodalCoTPromptConstructor(CoTPromptConstructor):
raise NotImplementedError( raise NotImplementedError(
f"Provider {self.lm_config.provider} not implemented" f"Provider {self.lm_config.provider} not implemented"
) )
class WebRLPromptConstructor(PromptConstructor):
"""The agent will direct predict the action"""
def __init__(
self,
instruction_path: str | Path,
lm_config: lm_config.LMConfig,
tokenizer: Tokenizer,
):
super().__init__(instruction_path, lm_config, tokenizer)
def construct(
self,
trajectory: Trajectory,
intent: str,
meta_data: dict[str, Any] = {},
) -> APIInput:
"""Construct prompt given the trajectory"""
state_info: StateInfo = trajectory[-1] # type: ignore[assignment]
obs = state_info["observation"][self.obs_modality]
max_obs_length = self.lm_config.gen_config["max_obs_length"]
if max_obs_length:
if self.lm_config.provider == "google":
print("NOTE: This is a Gemini model, so we use characters instead of tokens for max_obs_length.")
obs = obs[:max_obs_length]
else:
try:
obs = self.tokenizer.decode(self.tokenizer.encode(obs)[:max_obs_length]) # type: ignore[arg-type]
except:
print("NOTE: There is no available tokenizer, so we use characters instead of tokens for max_obs_length.")
obs = obs[:max_obs_length]
turn_num = len(meta_data["action_history"])
if turn_num == 1:
previous_action_str = []
else:
previous_action_str = meta_data["action_history"][1:]
index = turn_num - 1
history = ""
for i in range(index - 1, -1, -1):
if i == 0:
history = f"Round {i}\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n{intent}\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n{previous_action_str[i]}\n\n" + history
else:
history = f"Round {i}\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n** Simplified html **\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n{previous_action_str[i]}\n\n" + history
if len(history) + len(obs) > (16384 - 512):
obs = obs[:(16384 - 512)-len(history)]
current_turn = f"Round {index}\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n{obs}\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
prompt = f"Task Instruction: {intent}\n\n{history}{current_turn}"
return prompt
def extract_action(self, response: str) -> str:
return response

View File

@ -13,11 +13,13 @@ import tempfile
import time import time
from pathlib import Path from pathlib import Path
from typing import List from typing import List
import cv2
import shutil
import openai import openai
import requests import requests
import torch import torch
from PIL import Image from PIL import Image, ImageDraw, ImageFont
from agent import ( from agent import (
PromptAgent, PromptAgent,
@ -62,6 +64,34 @@ formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
console_handler.setFormatter(formatter) console_handler.setFormatter(formatter)
file_handler.setFormatter(formatter) file_handler.setFormatter(formatter)
def text_wrap(text, font, max_width):
lines = []
paragraphs = text.split('\n') # 按照 \n 分割文本为段落
for paragraph in paragraphs:
words = paragraph.split(' ')
line = ''
for word in words:
# 临时行
test_line = f"{line} {word}".strip()
# 获取临时行的宽度
test_line_bbox = font.getbbox(test_line)
test_line_width = test_line_bbox[2] - test_line_bbox[0]
if test_line_width <= max_width:
# 如果临时行的宽度不超过图片宽度,继续添加单词
line = test_line
else:
# 如果超过了最大宽度,保存当前行,开始新的一行
lines.append(line)
line = word
# 添加每段的最后一行
if line:
lines.append(line)
# 每个段落后添加一个空行,以保留段落的换行
lines.append('')
# 移除最后一个空行(不需要额外的空行)
if lines[-1] == '':
lines.pop()
return lines
def config() -> argparse.Namespace: def config() -> argparse.Namespace:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -88,6 +118,7 @@ def config() -> argparse.Namespace:
"html", "html",
"image", "image",
"image_som", "image_som",
"webrl",
], ],
default="accessibility_tree", default="accessibility_tree",
help="Observation type", help="Observation type",
@ -176,6 +207,9 @@ def config() -> argparse.Namespace:
# logging related # logging related
parser.add_argument("--result_dir", type=str, default="") parser.add_argument("--result_dir", type=str, default="")
# if use self-deployed model
parser.add_argument("--planner_ip", type=str, default=None)
args = parser.parse_args() args = parser.parse_args()
# check the whether the action space is compatible with the observation space # check the whether the action space is compatible with the observation space
@ -196,7 +230,7 @@ def config() -> argparse.Namespace:
def early_stop( def early_stop(
trajectory: Trajectory, max_steps: int, thresholds: dict[str, int] trajectory: Trajectory, max_steps: int, thresholds: dict[str, int], actions=None
) -> tuple[bool, str]: ) -> tuple[bool, str]:
"""Check whether need to stop early""" """Check whether need to stop early"""
@ -228,28 +262,38 @@ def early_stop(
if len(action_seq) == 0: if len(action_seq) == 0:
return False, "" return False, ""
last_action: Action = action_seq[-1] if actions is None:
last_action: Action = action_seq[-1]
if last_action["action_type"] != ActionTypes.TYPE:
if len(last_k_actions) >= k:
if all(
[
is_equivalent(action, last_action)
for action in last_k_actions
]
):
return True, f"Same action for {k} times"
else:
# check the action sequence
if (
sum([is_equivalent(action, last_action) for action in action_seq])
>= k
):
return True, f"Same typing action for {k} times"
return False, ""
if last_action["action_type"] != ActionTypes.TYPE: else:
last_k_actions = actions[-k:]
last_action = actions[-1]
if len(last_k_actions) >= k: if len(last_k_actions) >= k:
if all( if all(
[ [
is_equivalent(action, last_action) action == last_action
for action in last_k_actions for action in last_k_actions
] ]
): ):
return True, f"Same action for {k} times" return True, f"Same action for {k} times"
return False, ""
else:
# check the action sequence
if (
sum([is_equivalent(action, last_action) for action in action_seq])
>= k
):
return True, f"Same typing action for {k} times"
return False, ""
def update_action_history(path: str, task_id: int, actions: List[str], score: float=-0.1): def update_action_history(path: str, task_id: int, actions: List[str], score: float=-0.1):
obj = { obj = {
@ -387,16 +431,20 @@ def test(
obs, info = env.reset(options={"config_file": config_file}) obs, info = env.reset(options={"config_file": config_file})
state_info: StateInfo = {"observation": obs, "info": info} state_info: StateInfo = {"observation": obs, "info": info}
trajectory.append(state_info) trajectory.append(state_info)
meta_data = {"action_history": ["None"]} meta_data = {"action_history": ["None"]}
out_path = os.path.join(args.result_dir, "actions", f"{task_id}.json") out_path = os.path.join(args.result_dir, "actions", f"{task_id}.json")
actions = [] actions = []
os.makedirs(os.path.join(args.result_dir, 'screehshots'), exist_ok=True)
if os.path.exists(os.path.join(args.result_dir, 'screehshots', f"{task_id}")):
shutil.rmtree(os.path.join(args.result_dir, 'screehshots', f"{task_id}"))
os.makedirs(os.path.join(args.result_dir, 'screehshots', f"{task_id}"))
while True: while True:
update_action_history(out_path, task_id, actions=actions) update_action_history(out_path, task_id, actions=actions)
# If no actions variable is passed, the behavior of early_stop is the same as the original one.
early_stop_flag, stop_info = early_stop( early_stop_flag, stop_info = early_stop(
trajectory, max_steps, early_stop_thresholds trajectory, max_steps, early_stop_thresholds, actions
) )
if early_stop_flag: if early_stop_flag:
@ -407,7 +455,7 @@ def test(
trajectory, trajectory,
intent, intent,
images=images, images=images,
meta_data=meta_data, meta_data=meta_data
) )
except ValueError as e: except ValueError as e:
# get the error message # get the error message
@ -426,9 +474,46 @@ def test(
render_helper.render( render_helper.render(
action, state_info, meta_data, args.render_screenshot action, state_info, meta_data, args.render_screenshot
) )
current_screenshot = os.path.join(args.result_dir, 'screehshots', f"{task_id}", f"{len(actions)}.png")
_ = env.page.viewport_size
env.page.screenshot(path="/dev/null")
env.page.screenshot(path=current_screenshot)
element_id = action["element_id"]
if element_id != "":
element = env.page.query_selector(f"[data-label-id='{element_id}']")
if element:
bbox = element.bounding_box()
bbox = [int(bbox['x']), int(bbox['y']), int(bbox['width']),int(bbox['height'])]
image = cv2.imread(current_screenshot)
cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[0] + bbox[2], bbox[1] + bbox[3]), (0, 255, 0), 2)
cv2.circle(image, (int(bbox[0] + bbox[2] / 2), int(bbox[1] + bbox[3] / 2)), radius=0, color=(0, 255, 0), thickness=2)
cv2.imwrite(current_screenshot, image)
font_path = "/workspace/qzh/Arial.ttf" # 使用FreeSans字体
font_size = 30 # 你可以调整这个值来增大或缩小字体
image = Image.open(current_screenshot)
font = ImageFont.truetype(font_path, font_size)
draw = ImageDraw.Draw(image)
image_width = image.width
wrapped_text = text_wrap(action_str, font, image_width)
line_height = font.getbbox('hg')[3] - font.getbbox('hg')[1]
text_height = line_height * len(wrapped_text)
new_image_height = image.height + text_height + 20 # 20 is extra white space
new_image = Image.new('RGB', (image.width, new_image_height), color=(255, 255, 255)) # white background
draw_new = ImageDraw.Draw(new_image)
y_text = 10 # Initial position of text from top
for line in wrapped_text:
text_bbox = draw_new.textbbox((0, 0), line, font=font)
text_width = text_bbox[2] - text_bbox[0]
text_position = ((image_width - text_width) // 2, y_text) # Center text horizontally
draw_new.text(text_position, line, font=font, fill=(0, 0, 0)) # black text
y_text += line_height # move to next line
new_image.paste(image, (0, text_height + 20))
new_image.save(current_screenshot)
meta_data["action_history"].append(action_str) meta_data["action_history"].append(action_str)
actions.append(action_str) actions.append(action_str)
print(action_str) print('Action String: ', action_str)
if action["action_type"] == ActionTypes.STOP: if action["action_type"] == ActionTypes.STOP:
break break
@ -540,7 +625,7 @@ if __name__ == "__main__":
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = config() args = config()
args.sleep_after_execution = 2.5 args.sleep_after_execution = 3.0
prepare(args) prepare(args)
test_config_base_dir = args.test_config_base_dir test_config_base_dir = args.test_config_base_dir

View File

@ -0,0 +1,86 @@
import os, json, sys, copy
USE_TASKS = [i for i in range(165)]
def get_result(res_dict, src="all"):
if len(res_dict) == 0:
return ''
success_id = [k for k, v in res_dict.items() if v >= 1.0]
score = len(success_id)
finish_count = len(res_dict)
pacc, acc = score / finish_count * 100, score / TASKS * 100
print(sorted(success_id))
meta = """
--------
src file: {}
successed: {:3} / {:4} (812)
partial accuracy: {:7}
overall accuracy: {:7}
--------
""".format(src, int(score), finish_count, round(pacc, 2), round(acc, 2))
print(meta)
def export_result(res_dict, src=".", note=["1.0", "0.0"], show_all=False):
out_string = ""
for id in USE_TASKS:
# with open(f"Pipeline/config_files/{id}.json", "r") as f:
# jd = json.load(f)
# if "map" in jd["sites"]:
# continue
if id in res_dict:
if res_dict[id] >= 1.0:
out_string += note[0]
else:
out_string += note[1]
elif show_all:
out_string += note[1]
out_string += "\n"
with open(os.path.join(src, 'export.txt'), 'w') as f:
f.write(out_string)
TASKS = 165
files = sys.argv[1]
file_list = files.split(',')
all_result = {}
for src in file_list:
path = os.path.join(src, 'actions')
result = {}
finished = os.listdir(path)
for file in finished:
if not file.endswith('.json'):
continue
with open(os.path.join(path, file), 'r') as f:
data = json.load(f)
if not isinstance(data, dict):
continue
task_id = data.get('task_id', 1000)
# if task_id >= TASKS:
# continue
task_score = data.get('score', 0)
if task_score < 0:
continue
result[task_id] = task_score
if task_id not in all_result or task_score > all_result[task_id]:
all_result[task_id] = task_score
get_result(result, src)
export_result(result, src=src)
if len(file_list) > 1:
get_result(all_result)
export_result(all_result, show_all=True)

View File

@ -187,7 +187,7 @@
"string_match" "string_match"
], ],
"reference_answers": { "reference_answers": {
"must_include": [ "fuzzy_match": [
"0" "0"
] ]
}, },
@ -774,8 +774,8 @@
"string_match" "string_match"
], ],
"reference_answers": { "reference_answers": {
"fuzzy_match": [ "must_include": [
"914km" "914"
] ]
}, },
"reference_url": "", "reference_url": "",
@ -1139,7 +1139,7 @@
"string_match" "string_match"
], ],
"reference_answers": { "reference_answers": {
"must_include": [ "fuzzy_match": [
"0" "0"
] ]
}, },
@ -1679,7 +1679,7 @@
"string_match" "string_match"
], ],
"reference_answers": { "reference_answers": {
"exact_match": "N/A" "fuzzy_match": ["N/A"]
}, },
"reference_url": "", "reference_url": "",
"program_html": [], "program_html": [],
@ -2605,7 +2605,7 @@
"string_match" "string_match"
], ],
"reference_answers": { "reference_answers": {
"exact_match": "yjlou" "must_include": "yjlou"
}, },
"reference_url": "", "reference_url": "",
"program_html": [], "program_html": [],
@ -3712,7 +3712,7 @@
"string_match" "string_match"
], ],
"reference_answers": { "reference_answers": {
"exact_match": "N/A" "fuzzy_match": "N/A"
}, },
"reference_url": "", "reference_url": "",
"program_html": [], "program_html": [],

View File

@ -1,15 +1,18 @@
from typing import Any from typing import Any
import tiktoken import tiktoken
from transformers import LlamaTokenizer # type: ignore from transformers import LlamaTokenizer, AutoTokenizer # type: ignore
class Tokenizer(object): class Tokenizer(object):
def __init__(self, provider: str, model_name: str) -> None: def __init__(self, provider: str, model_name: str) -> None:
if provider == "openai": if provider == "openai":
self.tokenizer = tiktoken.encoding_for_model(model_name) try:
self.tokenizer = tiktoken.encoding_for_model(model_name)
except: # The provider is in openai format but the model is a finetuned model
self.tokenizer = None
elif provider == "huggingface": elif provider == "huggingface":
self.tokenizer = LlamaTokenizer.from_pretrained(model_name) self.tokenizer = LlamaTokenizer.from_pretrained(model_name, trust_remote_code=True)
# turn off adding special tokens automatically # turn off adding special tokens automatically
self.tokenizer.add_special_tokens = False # type: ignore[attr-defined] self.tokenizer.add_special_tokens = False # type: ignore[attr-defined]
self.tokenizer.add_bos_token = False # type: ignore[attr-defined] self.tokenizer.add_bos_token = False # type: ignore[attr-defined]

View File

@ -21,6 +21,8 @@ APIInput = str | list[Any] | dict[str, Any]
def call_llm( def call_llm(
lm_config: lm_config.LMConfig, lm_config: lm_config.LMConfig,
prompt: APIInput, prompt: APIInput,
api_key = None,
base_url = None
) -> str: ) -> str:
response: str response: str
if lm_config.provider == "openai": if lm_config.provider == "openai":
@ -39,11 +41,13 @@ def call_llm(
assert isinstance(prompt, str) assert isinstance(prompt, str)
response = generate_from_openai_completion( response = generate_from_openai_completion(
prompt=prompt, prompt=prompt,
engine=lm_config.model, model=lm_config.model,
temperature=lm_config.gen_config["temperature"], temperature=lm_config.gen_config["temperature"],
max_tokens=lm_config.gen_config["max_tokens"], max_tokens=lm_config.gen_config["max_tokens"],
top_p=lm_config.gen_config["top_p"], top_p=lm_config.gen_config["top_p"],
stop_token=lm_config.gen_config["stop_token"], stop_token=lm_config.gen_config["stop_token"],
api_key=api_key,
base_url=base_url
) )
else: else:
raise ValueError( raise ValueError(

View File

@ -0,0 +1,97 @@
#!/bin/bash
DATASET='webarena' # TODO: select from ['webarena', 'visualwebarena']
result_dir='' # TODO: set your result_dir
provider='openai' # TODO: select from ['openai', 'finetune', ...]
model='' # TODO: assign model name, which is used for action generation
planner_ip='' # TODO: ip address of the model you are deploying (only if you are deploying your own model using e.g. vllm)
instruction_path='agent/prompts/jsons/p_webrl.json' # e.g., agent/prompts/jsons/p_cot_id_actree_2s.json
test_config_base_dir='config_files/wa/test_webarena_lite' # e.g., config_files/wa/test_webarena_lite
temperature=0.0
SERVER='' # TODO: your server address
MAP_SERVER='' # TODO: the server address for MAP tasks
OPENAI_API_KEY='' # TODO: if you test OpenAI APIs
OPENAI_ORGANIZATION=''
CONDA_ENV_NAME='' # TODO: the name of your conda environment for testing WebArena
ENV_VARIABLES="export DATASET=${DATASET}; export SHOPPING='http://${SERVER}:7770';export SHOPPING_ADMIN='http://${SERVER}:7780/admin';export REDDIT='http://${SERVER}:9999';export GITLAB='http://${SERVER}:8023';export MAP='http://${MAP_SERVER}:3000';export WIKIPEDIA='http://${SERVER}:8888/wikipedia_en_all_maxi_2022-05/A/User:The_other_Kiwix_guy/Landing';export HOMEPAGE='http://${SERVER}:4399';export OPENAI_API_KEY=${OPENAI_API_KEY};export OPENAI_ORGANIZATION=${OPENAI_ORGANIZATION}"
# get the number of tmux panes
num_panes=$(tmux list-panes | wc -l)
# calculate how many panes need to be created
let "panes_to_create = 7 - num_panes"
# array of tmux commands to create each pane
tmux_commands=(
'tmux split-window -h'
'tmux split-window -v'
'tmux select-pane -t 0; tmux split-window -v'
'tmux split-window -v'
'tmux select-pane -t 3; tmux split-window -v'
'tmux select-pane -t 5; tmux split-window -v'
)
# create panes up to 7
for ((i=0; i<$panes_to_create; i++)); do
eval ${tmux_commands[$i]}
done
#!/bin/bash
# Function to run a job
run_job() {
tmux select-pane -t $1
COMMAND="python run.py \
--instruction_path ${instruction_path} \
--test_start_idx $2 \
--test_end_idx $3 \
--result_dir ${result_dir} \
--test_config_base_dir ${test_config_base_dir} \
--provider ${provider} \
--model ${model} \
--mode completion \
--planner_ip ${planner_ip} \
--stop_token \"<|eot_id|>\" \
--temperature ${temperature} \
--max_obs_length 0 \
--max_tokens 2048 \
--viewport_width 1280 \
--viewport_height 720 \
--parsing_failure_th 5 \
--repeating_action_failure_th 5 \
--action_set_tag webrl_id --observation_type webrl"
tmux send-keys "tmux set mouse on; conda activate ${CONDA_ENV_NAME}; ${ENV_VARIABLES}; until ${COMMAND}; do echo 'crashed' >&2; sleep 1; done" C-m
sleep 3
}
TOLERANCE=2
run_batch() {
args=("$@") # save all arguments in an array
num_jobs=${#args[@]} # get number of arguments
for ((i=1; i<$num_jobs; i++)); do
run_job $i ${args[i-1]} ${args[i]}
done
# Wait for all jobs to finish
while tmux list-panes -F "#{pane_pid} #{pane_current_command}" | grep -q python; do
sleep 100 # wait for 10 seconds before checking again
done
# Run checker
while ! python scripts/check_error_runs.py ${result_dir} --delete_errors --tolerance ${TOLERANCE}; do
echo "Check failed, rerunning jobs..."
for ((i=1; i<$num_jobs; i++)); do
run_job $i ${args[i-1]} ${args[i]}
done
# Wait for all jobs to finish
while tmux list-panes -F "#{pane_pid} #{pane_current_command}" | grep -q python; do
sleep 100 # wait for 10 seconds before checking again
done
done
}
run_batch 0 28 56 84 112 140 165

View File

@ -14,6 +14,14 @@ cp -f new/generate_test_data.py visualwebarena/scripts/generate_test_data.py
cp -f new/run.py visualwebarena/run.py cp -f new/run.py visualwebarena/run.py
cp -f new/agent.py visualwebarena/agent/agent.py cp -f new/agent.py visualwebarena/agent/agent.py
cp -f new/prompt_constructor.py visualwebarena/agent/prompts/prompt_constructor.py cp -f new/prompt_constructor.py visualwebarena/agent/prompts/prompt_constructor.py
cp -f new/p_webrl.json visualwebarena/agent/prompts/jsons/p_webrl.json
# browser_env
cp -f new/actions.py visualwebarena/browser_env/actions.py
cp -f new/envs.py visualwebarena/browser_env/envs.py
cp -f new/helper_functions_browser.py visualwebarena/browser_env/helper_functions.py
cp -f new/processors.py visualwebarena/browser_env/processors.py
cp -rf new/html_tools visualwebarena/browser_env/
# llms # llms
cp -f new/utils.py visualwebarena/llms/utils.py cp -f new/utils.py visualwebarena/llms/utils.py
@ -22,15 +30,19 @@ cp -f new/lm_config.py visualwebarena/llms/lm_config.py
cp -f new/tokenizers.py visualwebarena/llms/tokenizers.py cp -f new/tokenizers.py visualwebarena/llms/tokenizers.py
cp -f new/api_utils.py visualwebarena/llms/providers/api_utils.py cp -f new/api_utils.py visualwebarena/llms/providers/api_utils.py
cp -f new/openai_utils.py visualwebarena/llms/providers/openai_utils.py cp -f new/openai_utils.py visualwebarena/llms/providers/openai_utils.py
cp -f new/utils.py visualwebarena/llms/utils.py
# eval # eval
cp -f new/evaluators.py visualwebarena/evaluation_harness/evaluators.py cp -f new/evaluators.py visualwebarena/evaluation_harness/evaluators.py
cp -f new/helper_functions.py visualwebarena/evaluation_harness/helper_functions.py cp -f new/helper_functions_eval.py visualwebarena/evaluation_harness/helper_functions.py
# misc # misc
cp -f README.md visualwebarena/README.md cp -f README.md visualwebarena/README.md
cp -f new/wa_parallel_run.sh visualwebarena/wa_parallel_run.sh cp -f new/wa_parallel_run.sh visualwebarena/wa_parallel_run.sh
cp -f new/score.py visualwebarena/score.py
cp -f new/wa_parallel_run_webrl.sh visualwebarena/wa_parallel_run_webrl.sh
# 3. remove temporary files # 3. remove temporary files
mv visualwebarena/* . mv visualwebarena/* .
rm -rf new rm -rf new