add webrl mode
This commit is contained in:
parent
8d86a00e85
commit
521d7e999a
2006
VAB-WebArena-Lite/new/actions.py
Normal file
2006
VAB-WebArena-Lite/new/actions.py
Normal file
File diff suppressed because it is too large
Load Diff
|
@ -14,6 +14,7 @@ from browser_env.actions import (
|
|||
create_id_based_action,
|
||||
create_none_action,
|
||||
create_playwright_action,
|
||||
create_webrl_id_based_action
|
||||
)
|
||||
from browser_env.utils import Observation, StateInfo
|
||||
from llms import (
|
||||
|
@ -108,12 +109,14 @@ class PromptAgent(Agent):
|
|||
lm_config: lm_config.LMConfig,
|
||||
prompt_constructor: PromptConstructor,
|
||||
captioning_fn = None,
|
||||
planner_ip = None
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.lm_config = lm_config
|
||||
self.prompt_constructor = prompt_constructor
|
||||
self.action_set_tag = action_set_tag
|
||||
self.captioning_fn = captioning_fn
|
||||
self.planner_ip = planner_ip
|
||||
|
||||
# Check if the model is multimodal.
|
||||
if ("gemini" in lm_config.model or "gpt-4" in lm_config.model and "vision" in lm_config.model or lm_config.provider in ["api", "finetune"]) and type(prompt_constructor) == MultimodalCoTPromptConstructor:
|
||||
|
@ -165,6 +168,9 @@ class PromptAgent(Agent):
|
|||
lm_config = self.lm_config
|
||||
n = 0
|
||||
while True:
|
||||
if self.planner_ip is not None and self.planner_ip != "":
|
||||
response = call_llm(lm_config, prompt, 'EMPTY', self.planner_ip)
|
||||
else:
|
||||
response = call_llm(lm_config, prompt)
|
||||
force_prefix = self.prompt_constructor.instruction[
|
||||
"meta_data"
|
||||
|
@ -183,6 +189,8 @@ class PromptAgent(Agent):
|
|||
action = create_playwright_action(parsed_response)
|
||||
elif self.action_set_tag == "som":
|
||||
action = create_id_based_action(parsed_response)
|
||||
elif self.action_set_tag == 'webrl_id':
|
||||
action = create_webrl_id_based_action(parsed_response)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown action type {self.action_set_tag}"
|
||||
|
@ -218,7 +226,8 @@ def construct_agent(args: argparse.Namespace, captioning_fn=None) -> Agent:
|
|||
action_set_tag=args.action_set_tag,
|
||||
lm_config=llm_config,
|
||||
prompt_constructor=prompt_constructor,
|
||||
captioning_fn=captioning_fn
|
||||
captioning_fn=captioning_fn,
|
||||
planner_ip=args.planner_ip
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
|
|
319
VAB-WebArena-Lite/new/envs.py
Normal file
319
VAB-WebArena-Lite/new/envs.py
Normal file
|
@ -0,0 +1,319 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Union
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import requests
|
||||
from beartype import beartype
|
||||
from gymnasium import Env
|
||||
from gymnasium.spaces import Box, Text
|
||||
from playwright.sync_api import (
|
||||
CDPSession,
|
||||
Page,
|
||||
Playwright,
|
||||
ViewportSize,
|
||||
expect,
|
||||
sync_playwright,
|
||||
)
|
||||
|
||||
DATASET = os.environ["DATASET"]
|
||||
if DATASET == "visualwebarena":
|
||||
from browser_env.env_config import (
|
||||
CLASSIFIEDS,
|
||||
CLASSIFIEDS_RESET_TOKEN,
|
||||
)
|
||||
|
||||
from .actions import Action, execute_action, get_action_space, execute_action_webrl
|
||||
from .processors import ObservationHandler, ObservationMetadata
|
||||
from .utils import (
|
||||
AccessibilityTree,
|
||||
DetachedPage,
|
||||
Observation,
|
||||
png_bytes_to_numpy,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlaywrightScript:
|
||||
function: str # goto, get_by_role
|
||||
destination: str # https://www.google.com/, combobox
|
||||
name: str | None = None # Search, Avatar 2009
|
||||
operation: str | None = None # click, fill, press
|
||||
value: str | None = None # avatar movie, Enter
|
||||
|
||||
|
||||
def parse_action(action: str) -> PlaywrightScript:
|
||||
splitted = action.strip().split(" ")
|
||||
assert len(splitted) >= 2
|
||||
match splitted[:2]:
|
||||
case ["goto", url]:
|
||||
assert len(splitted) == 2
|
||||
return PlaywrightScript("goto", url)
|
||||
case ["get_by_role", destination]:
|
||||
assert len(splitted) >= 4
|
||||
match splitted[2:]:
|
||||
case [name, operation]:
|
||||
return PlaywrightScript(
|
||||
"get_by_role", destination, name, operation
|
||||
)
|
||||
case [name, operation, value]:
|
||||
return PlaywrightScript(
|
||||
"get_by_role", destination, name, operation, value
|
||||
)
|
||||
case _:
|
||||
raise ValueError("Invalid action")
|
||||
case _:
|
||||
raise ValueError(f"Invalid action {action}")
|
||||
|
||||
|
||||
class ScriptBrowserEnv(Env[dict[str, Observation], Action]):
|
||||
"""
|
||||
The goal of this environment is to produce a prototype of a browser environment.
|
||||
In the end, we want to support a fully configurable browser environment with wide
|
||||
range of action spaces and observation spaces, both structured and unstructured.
|
||||
But in this prototype, we just support action space specified by Playwright script,
|
||||
and observation space is the html content of the page.
|
||||
"""
|
||||
|
||||
@beartype
|
||||
def __init__(
|
||||
self,
|
||||
max_page_length: int = 8192,
|
||||
headless: bool = True,
|
||||
slow_mo: int = 0,
|
||||
observation_type: str = "html",
|
||||
current_viewport_only: bool = False,
|
||||
viewport_size: ViewportSize = {"width": 1280, "height": 720},
|
||||
save_trace_enabled: bool = False,
|
||||
sleep_after_execution: float = 0.0,
|
||||
captioning_fn=None,
|
||||
):
|
||||
# TODO: make Space[Action] = ActionSpace
|
||||
self.action_space = get_action_space() # type: ignore[assignment]
|
||||
self.headless = headless
|
||||
self.slow_mo = slow_mo
|
||||
self.current_viewport_only = current_viewport_only
|
||||
self.reset_finished = False
|
||||
self.viewport_size = viewport_size
|
||||
self.save_trace_enabled = save_trace_enabled
|
||||
self.sleep_after_execution = sleep_after_execution
|
||||
|
||||
match observation_type:
|
||||
case "html" | "accessibility_tree" | "accessibility_tree_with_captioner" | "webrl":
|
||||
self.text_observation_type = observation_type
|
||||
self.image_observation_type = ""
|
||||
self.main_observation_type = "text"
|
||||
case "image":
|
||||
self.image_observation_type = observation_type
|
||||
self.text_observation_type = "" # type: ignore[assignment]
|
||||
self.main_observation_type = "image"
|
||||
case "image_som":
|
||||
self.image_observation_type = observation_type
|
||||
self.text_observation_type = observation_type # type: ignore[assignment]
|
||||
self.main_observation_type = "image"
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Unsupported observation type: {observation_type}"
|
||||
)
|
||||
|
||||
self.observation_handler = ObservationHandler(
|
||||
self.main_observation_type,
|
||||
self.text_observation_type,
|
||||
self.image_observation_type,
|
||||
self.current_viewport_only,
|
||||
self.viewport_size,
|
||||
captioning_fn,
|
||||
)
|
||||
|
||||
self.observation_space = (
|
||||
self.observation_handler.get_observation_space()
|
||||
)
|
||||
|
||||
@beartype
|
||||
def setup(self, config_file: Path | None = None) -> None:
|
||||
self.context_manager = sync_playwright()
|
||||
self.playwright = self.context_manager.__enter__()
|
||||
self.browser = self.playwright.chromium.launch(
|
||||
headless=self.headless, slow_mo=self.slow_mo
|
||||
)
|
||||
|
||||
if config_file:
|
||||
with open(config_file, "r") as f:
|
||||
instance_config = json.load(f)
|
||||
else:
|
||||
instance_config = {}
|
||||
|
||||
# Reset site if needed. Currently only supported for Classifieds.
|
||||
# TODO(jykoh): Add reset functionality for Shopping/Reddit.
|
||||
if instance_config.get("require_reset", False):
|
||||
if "classifieds" in instance_config["sites"]:
|
||||
# Send POST request to __CLASSIFIEDS__/index.php?page=reset with token=CLASSIFIEDS_TOKEN
|
||||
response = requests.post(
|
||||
f"{CLASSIFIEDS}/index.php?page=reset",
|
||||
data={"token": CLASSIFIEDS_RESET_TOKEN},
|
||||
)
|
||||
|
||||
# Check if the request was successful
|
||||
if response.status_code == 200:
|
||||
print("Reset Classifieds site.")
|
||||
else:
|
||||
print(
|
||||
"Failed to reset Classifieds site:",
|
||||
response.status_code,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"WARNING: Reset is not supported for this site. Please manually reset the site."
|
||||
)
|
||||
|
||||
storage_state = instance_config.get("storage_state", None)
|
||||
start_url = instance_config.get("start_url", None)
|
||||
geolocation = instance_config.get("geolocation", None)
|
||||
|
||||
# Use custom viewport size if specified in the config, otherwise use the default.
|
||||
viewport_size = self.viewport_size.copy()
|
||||
viewport_size.update(instance_config.get("viewport_size", {}))
|
||||
self.observation_handler.viewport_size = viewport_size
|
||||
|
||||
self.context = self.browser.new_context(
|
||||
viewport=viewport_size,
|
||||
storage_state=storage_state,
|
||||
geolocation=geolocation,
|
||||
device_scale_factor=1,
|
||||
)
|
||||
if self.save_trace_enabled:
|
||||
self.context.tracing.start(screenshots=True, snapshots=True)
|
||||
|
||||
if start_url:
|
||||
start_urls = start_url.split(" |AND| ")
|
||||
for url in start_urls:
|
||||
page = self.context.new_page()
|
||||
if self.text_observation_type in [
|
||||
"accessibility_tree",
|
||||
"accessibility_tree_with_captioner",
|
||||
]:
|
||||
client = page.context.new_cdp_session(page)
|
||||
client.send("Accessibility.enable")
|
||||
client.detach()
|
||||
page.goto(url)
|
||||
# set the first page as the current page
|
||||
self.page = self.context.pages[0]
|
||||
self.page.bring_to_front()
|
||||
else:
|
||||
self.page = self.context.new_page()
|
||||
if self.text_observation_type in [
|
||||
"accessibility_tree",
|
||||
"accessibility_tree_with_captioner",
|
||||
]:
|
||||
client = self.page.context.new_cdp_session(self.page)
|
||||
client.send("Accessibility.enable")
|
||||
client.detach()
|
||||
|
||||
def _get_obs(self) -> dict[str, Observation]:
|
||||
obs = self.observation_handler.get_observation(self.page)
|
||||
return obs
|
||||
|
||||
def _get_obs_metadata(self) -> dict[str, ObservationMetadata]:
|
||||
metadata = self.observation_handler.get_observation_metadata()
|
||||
return metadata
|
||||
|
||||
@beartype
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: int | None = None,
|
||||
options: dict[str, str] | None = None,
|
||||
) -> tuple[dict[str, Observation], dict[str, Any]]:
|
||||
"""
|
||||
Reset the environment.
|
||||
:param options: options for the environment. The current supported options are:
|
||||
- "storage_state": the storage state of the browser. It is a file path to a json file.
|
||||
"""
|
||||
super().reset(seed=seed, options=options)
|
||||
if self.reset_finished:
|
||||
self.context_manager.__exit__()
|
||||
|
||||
if options is not None and "config_file" in options:
|
||||
config_file = Path(options["config_file"])
|
||||
if config_file.exists():
|
||||
self.setup(config_file=config_file)
|
||||
else:
|
||||
raise ValueError(f"Config file {config_file} does not exist.")
|
||||
else:
|
||||
self.setup()
|
||||
self.reset_finished = True
|
||||
timeout_in_ms = 120000
|
||||
self.page.set_default_timeout(timeout_in_ms)
|
||||
self.page.set_default_navigation_timeout(timeout_in_ms)
|
||||
self.page.wait_for_timeout(int(self.sleep_after_execution * 1000))
|
||||
|
||||
observation = self._get_obs()
|
||||
observation_metadata = self._get_obs_metadata()
|
||||
info = {
|
||||
"page": DetachedPage(self.page.url, ""),
|
||||
"fail_error": "",
|
||||
"observation_metadata": observation_metadata,
|
||||
}
|
||||
|
||||
return (observation, info)
|
||||
|
||||
def save_trace(self, trace_path: str | Path) -> None:
|
||||
if self.save_trace_enabled:
|
||||
self.context.tracing.stop(path=trace_path)
|
||||
|
||||
def close(self) -> None:
|
||||
if self.reset_finished:
|
||||
self.context_manager.__exit__()
|
||||
|
||||
def step(
|
||||
self, action: Action
|
||||
) -> tuple[dict[str, Observation], float, bool, bool, dict[str, Any]]:
|
||||
if not self.reset_finished:
|
||||
raise RuntimeError("Call reset first before calling step.")
|
||||
|
||||
success = False
|
||||
fail_error = ""
|
||||
try:
|
||||
if self.text_observation_type == 'webrl':
|
||||
self.page = execute_action_webrl(
|
||||
action,
|
||||
self.page,
|
||||
self.context,
|
||||
self.observation_handler.action_processor,
|
||||
self.sleep_after_execution,
|
||||
)
|
||||
else:
|
||||
self.page = execute_action(
|
||||
action,
|
||||
self.page,
|
||||
self.context,
|
||||
self.observation_handler.action_processor,
|
||||
self.sleep_after_execution,
|
||||
)
|
||||
success = True
|
||||
except Exception as e:
|
||||
fail_error = str(e)
|
||||
|
||||
observation = self._get_obs()
|
||||
observation_metadata = self._get_obs_metadata()
|
||||
|
||||
info = {
|
||||
"page": DetachedPage(self.page.url, self.page.content()),
|
||||
"fail_error": fail_error,
|
||||
"observation_metadata": observation_metadata,
|
||||
}
|
||||
msg = (
|
||||
observation,
|
||||
float(success), # reward
|
||||
False, # terminated
|
||||
False, # truncated
|
||||
info,
|
||||
)
|
||||
return msg
|
|
@ -172,6 +172,10 @@ class StringEvaluator(Evaluator):
|
|||
# prevent false positive (e.g, 0)
|
||||
if len(word_tokenize(clean_ref)) == 1:
|
||||
tok_pred = word_tokenize(clean_pred)
|
||||
for token in tok_pred:
|
||||
if '/' in token:
|
||||
sub_tokens = token.split('/')
|
||||
tok_pred.extend(sub_tokens)
|
||||
return float(clean_ref in tok_pred)
|
||||
else:
|
||||
return float(clean_ref in clean_pred)
|
||||
|
|
239
VAB-WebArena-Lite/new/helper_functions_browser.py
Normal file
239
VAB-WebArena-Lite/new/helper_functions_browser.py
Normal file
|
@ -0,0 +1,239 @@
|
|||
import base64
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from agent.prompts import *
|
||||
from browser_env import (
|
||||
Action,
|
||||
ActionTypes,
|
||||
ObservationMetadata,
|
||||
StateInfo,
|
||||
action2str,
|
||||
)
|
||||
|
||||
HTML_TEMPLATE = """
|
||||
<!DOCTYPE html>
|
||||
<head>
|
||||
<style>
|
||||
pre {{
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<html>
|
||||
<body>
|
||||
{body}
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
def get_render_action(
|
||||
action: Action,
|
||||
observation_metadata: dict[str, ObservationMetadata],
|
||||
action_set_tag: str,
|
||||
) -> str:
|
||||
"""Parse the predicted actions for rendering purpose. More comprehensive information"""
|
||||
match action_set_tag:
|
||||
case "id_accessibility_tree":
|
||||
text_meta_data = observation_metadata["text"]
|
||||
if action["element_id"] in text_meta_data["obs_nodes_info"]:
|
||||
node_content = text_meta_data["obs_nodes_info"][
|
||||
action["element_id"]
|
||||
]["text"]
|
||||
else:
|
||||
node_content = "No match found"
|
||||
|
||||
action_str = f"<div class='raw_parsed_prediction' style='background-color:grey'><pre>{action['raw_prediction']}</pre></div>"
|
||||
action_str += f"<div class='action_object' style='background-color:grey'><pre>{repr(action)}</pre></div>"
|
||||
action_str += f"<div class='parsed_action' style='background-color:yellow'><pre>{action2str(action, action_set_tag, node_content)}</pre></div>"
|
||||
|
||||
case "som":
|
||||
text_meta_data = observation_metadata["text"]
|
||||
if action["element_id"] in text_meta_data["obs_nodes_info"]:
|
||||
node_content = text_meta_data["obs_nodes_info"][
|
||||
action["element_id"]
|
||||
]
|
||||
else:
|
||||
node_content = "No match found"
|
||||
action_str = f"<div class='raw_parsed_prediction' style='background-color:grey'><pre>{action['raw_prediction']}</pre></div>"
|
||||
action_str += f"<div class='action_object' style='background-color:grey'><pre>{repr(action)}</pre></div>"
|
||||
action_str += f"<div class='parsed_action' style='background-color:yellow'><pre>{action2str(action, action_set_tag, node_content)}</pre></div>"
|
||||
|
||||
case "playwright":
|
||||
action_str = action["pw_code"]
|
||||
case "webrl_id":
|
||||
action_str = action["raw_prediction"]
|
||||
case _:
|
||||
raise ValueError(
|
||||
f"Unknown action type {action['action_type'], action_set_tag}"
|
||||
)
|
||||
return action_str
|
||||
|
||||
|
||||
def get_action_description(
|
||||
action: Action,
|
||||
observation_metadata: dict[str, ObservationMetadata],
|
||||
action_set_tag: str,
|
||||
prompt_constructor: PromptConstructor | None,
|
||||
) -> str:
|
||||
"""Generate the text version of the predicted actions to store in action history for prompt use.
|
||||
May contain hint information to recover from the failures"""
|
||||
|
||||
match action_set_tag:
|
||||
case "id_accessibility_tree":
|
||||
text_meta_data = observation_metadata["text"]
|
||||
if action["action_type"] in [
|
||||
ActionTypes.CLICK,
|
||||
ActionTypes.HOVER,
|
||||
ActionTypes.TYPE,
|
||||
]:
|
||||
action_name = str(action["action_type"]).split(".")[1].lower()
|
||||
if action["element_id"] in text_meta_data["obs_nodes_info"]:
|
||||
node_content = text_meta_data["obs_nodes_info"][
|
||||
action["element_id"]
|
||||
]["text"]
|
||||
node_content = " ".join(node_content.split()[1:])
|
||||
action_str = action2str(
|
||||
action, action_set_tag, node_content
|
||||
)
|
||||
else:
|
||||
action_str = f"Attempt to perfom \"{action_name}\" on element \"[{action['element_id']}]\" but no matching element found. Please check the observation more carefully."
|
||||
else:
|
||||
if (
|
||||
action["action_type"] == ActionTypes.NONE
|
||||
and prompt_constructor is not None
|
||||
):
|
||||
action_splitter = prompt_constructor.instruction[
|
||||
"meta_data"
|
||||
]["action_splitter"]
|
||||
action_str = f'The previous prediction you issued was "{action["raw_prediction"]}". However, the format was incorrect. Ensure that the action is wrapped inside a pair of {action_splitter} and enclose arguments within [] as follows: {action_splitter}action [arg] ...{action_splitter}.'
|
||||
else:
|
||||
action_str = action2str(action, action_set_tag, "")
|
||||
|
||||
case "som":
|
||||
text_meta_data = observation_metadata["image"]
|
||||
if action["action_type"] in [
|
||||
ActionTypes.CLICK,
|
||||
ActionTypes.HOVER,
|
||||
ActionTypes.TYPE,
|
||||
]:
|
||||
action_name = str(action["action_type"]).split(".")[1].lower()
|
||||
if action["element_id"] in text_meta_data["obs_nodes_info"]:
|
||||
action_str = action2str(action, action_set_tag, "")
|
||||
else:
|
||||
print(
|
||||
'action["element_id"], text_meta_data["obs_nodes_info"]',
|
||||
action["element_id"],
|
||||
text_meta_data["obs_nodes_info"],
|
||||
)
|
||||
action_str = f"Attempt to perfom \"{action_name}\" on element \"[{action['element_id']}]\" but no matching element found. Please check the observation more carefully."
|
||||
else:
|
||||
if (
|
||||
action["action_type"] == ActionTypes.NONE
|
||||
and prompt_constructor is not None
|
||||
):
|
||||
action_splitter = prompt_constructor.instruction[
|
||||
"meta_data"
|
||||
]["action_splitter"]
|
||||
action_str = f'The previous prediction you issued was "{action["raw_prediction"]}". However, the format was incorrect. Ensure that the action is wrapped inside a pair of {action_splitter} and enclose arguments within [] as follows: {action_splitter}action [arg] ...{action_splitter}.'
|
||||
else:
|
||||
action_str = action2str(action, action_set_tag, "")
|
||||
|
||||
case "playwright":
|
||||
action_str = action["pw_code"]
|
||||
|
||||
case "webrl_id":
|
||||
action_str = action["raw_prediction"]
|
||||
|
||||
case _:
|
||||
raise ValueError(f"Unknown action type {action['action_type']}")
|
||||
|
||||
return action_str
|
||||
|
||||
|
||||
class RenderHelper(object):
|
||||
"""Helper class to render text and image observations and meta data in the trajectory"""
|
||||
|
||||
def __init__(
|
||||
self, config_file: str, result_dir: str, action_set_tag: str
|
||||
) -> None:
|
||||
with open(config_file, "r") as f:
|
||||
_config = json.load(f)
|
||||
_config_str = ""
|
||||
for k, v in _config.items():
|
||||
_config_str += f"{k}: {v}\n"
|
||||
_config_str = f"<pre>{_config_str}</pre>\n"
|
||||
task_id = _config["task_id"]
|
||||
|
||||
self.action_set_tag = action_set_tag
|
||||
|
||||
self.render_file = open(
|
||||
Path(result_dir) / f"render_{task_id}.html", "a+", encoding="utf-8"
|
||||
)
|
||||
self.render_file.truncate(0)
|
||||
# write init template
|
||||
self.render_file.write(HTML_TEMPLATE.format(body=f"{_config_str}"))
|
||||
self.render_file.read()
|
||||
self.render_file.flush()
|
||||
|
||||
def render(
|
||||
self,
|
||||
action: Action,
|
||||
state_info: StateInfo,
|
||||
meta_data: dict[str, Any],
|
||||
render_screenshot: bool = False,
|
||||
) -> None:
|
||||
"""Render the trajectory"""
|
||||
# text observation
|
||||
observation = state_info["observation"]
|
||||
text_obs = observation["text"]
|
||||
info = state_info["info"]
|
||||
new_content = f"<h2>New Page</h2>\n"
|
||||
new_content += f"<h3 class='url'><a href={state_info['info']['page'].url}>URL: {state_info['info']['page'].url}</a></h3>\n"
|
||||
new_content += f"<div class='state_obv'><pre>{text_obs}</pre><div>\n"
|
||||
|
||||
if render_screenshot:
|
||||
# image observation
|
||||
img_obs = observation["image"]
|
||||
image = Image.fromarray(img_obs)
|
||||
byte_io = io.BytesIO()
|
||||
image.save(byte_io, format="PNG")
|
||||
byte_io.seek(0)
|
||||
image_bytes = base64.b64encode(byte_io.read())
|
||||
image_str = image_bytes.decode("utf-8")
|
||||
new_content += f"<img src='data:image/png;base64,{image_str}' style='width:50vw; height:auto;'/>\n"
|
||||
|
||||
# meta data
|
||||
new_content += f"<div class='prev_action' style='background-color:pink'>{meta_data['action_history'][-1]}</div>\n"
|
||||
|
||||
# action
|
||||
action_str = get_render_action(
|
||||
action,
|
||||
info["observation_metadata"],
|
||||
action_set_tag=self.action_set_tag,
|
||||
)
|
||||
# with yellow background
|
||||
action_str = f"<div class='predict_action'>{action_str}</div>"
|
||||
new_content += f"{action_str}\n"
|
||||
|
||||
# add new content
|
||||
self.render_file.seek(0)
|
||||
html = self.render_file.read()
|
||||
html_body = re.findall(r"<body>(.*?)</body>", html, re.DOTALL)[0]
|
||||
html_body += new_content
|
||||
|
||||
html = HTML_TEMPLATE.format(body=html_body)
|
||||
self.render_file.seek(0)
|
||||
self.render_file.truncate()
|
||||
self.render_file.write(html)
|
||||
self.render_file.flush()
|
||||
|
||||
def close(self) -> None:
|
||||
self.render_file.close()
|
7
VAB-WebArena-Lite/new/html_tools/__init__.py
Executable file
7
VAB-WebArena-Lite/new/html_tools/__init__.py
Executable file
|
@ -0,0 +1,7 @@
|
|||
from .identifier import IdentifierTool
|
||||
from .prompt import HtmlPrompt
|
||||
from .html_parser import HtmlParser
|
||||
|
||||
from .utils import print_html_object
|
||||
from .configs import basic_attrs, mind2web_keep_attrs
|
||||
from .fetch import get_parsed_html
|
3
VAB-WebArena-Lite/new/html_tools/configs/__init__.py
Executable file
3
VAB-WebArena-Lite/new/html_tools/configs/__init__.py
Executable file
|
@ -0,0 +1,3 @@
|
|||
from .html_prompt import prompts
|
||||
from .config import basic_attrs, mind2web_keep_attrs, miniwob_attrs
|
||||
from .config import config_meta
|
56
VAB-WebArena-Lite/new/html_tools/configs/config.py
Executable file
56
VAB-WebArena-Lite/new/html_tools/configs/config.py
Executable file
|
@ -0,0 +1,56 @@
|
|||
basic_attrs = [
|
||||
'title',
|
||||
'value',
|
||||
'type',
|
||||
'placeholder',
|
||||
'selected',
|
||||
'data-value',
|
||||
'data-text',
|
||||
'data-testid',
|
||||
'data-label',
|
||||
'data-bbox',
|
||||
'data-status'
|
||||
]
|
||||
|
||||
mind2web_keep_attrs = [
|
||||
'alt',
|
||||
'aria_description',
|
||||
'aria_label',
|
||||
'aria_role',
|
||||
'input_checked',
|
||||
'input_value',
|
||||
'label',
|
||||
'name',
|
||||
'option_selected',
|
||||
'placeholder',
|
||||
'role',
|
||||
'text_value',
|
||||
'title',
|
||||
'type',
|
||||
'value',
|
||||
]
|
||||
|
||||
miniwob_attrs = [
|
||||
'id',
|
||||
'type',
|
||||
'value',
|
||||
]
|
||||
|
||||
config_meta = """
|
||||
======= Configs =======
|
||||
Columns:
|
||||
- id: {id_attr}
|
||||
- label: {label_attr}
|
||||
Position: {use_position}
|
||||
- window: {window_size}
|
||||
- rect_dict: {rect}
|
||||
Keep:
|
||||
- parents: {parent_chain}
|
||||
- attrs: {keep_attrs}
|
||||
- elems: {keep_elem}
|
||||
- obs_elem: {obs_elem}
|
||||
Generator:
|
||||
- prompt: {prompt_name}
|
||||
- label: {identifier_name}
|
||||
========================
|
||||
"""
|
22
VAB-WebArena-Lite/new/html_tools/configs/html_prompt.py
Executable file
22
VAB-WebArena-Lite/new/html_tools/configs/html_prompt.py
Executable file
|
@ -0,0 +1,22 @@
|
|||
refine_prompt = {
|
||||
'dom': '<{tag}{label}|{attr}{content}{subtree} >',
|
||||
'label': '[{label}]',
|
||||
'attr': '{attr}',
|
||||
'attr_splitter': '; ',
|
||||
'subtree_splitter': ' ',
|
||||
}
|
||||
|
||||
xml_prompt = {
|
||||
'dom': '<{tag}{label}{attr}>{content}{subtree} </{tag}>',
|
||||
'label': ' id="{label}"',
|
||||
'attr': '{key}="{attr}"',
|
||||
'attr_splitter': ' ',
|
||||
'subtree_splitter': ' ',
|
||||
}
|
||||
|
||||
prompts = {
|
||||
'refine': refine_prompt,
|
||||
'xml': xml_prompt,
|
||||
'new_data': refine_prompt,
|
||||
}
|
||||
|
108
VAB-WebArena-Lite/new/html_tools/fetch.py
Executable file
108
VAB-WebArena-Lite/new/html_tools/fetch.py
Executable file
|
@ -0,0 +1,108 @@
|
|||
import os
|
||||
import json
|
||||
import base64
|
||||
from .html_parser import HtmlParser
|
||||
from .configs import basic_attrs
|
||||
from .scripts import *
|
||||
|
||||
def get_window(page):
|
||||
x = page.evaluate("window.scrollX")
|
||||
y = page.evaluate("window.scrollY")
|
||||
w = page.evaluate("window.innerWidth")
|
||||
h = page.evaluate("window.innerHeight")
|
||||
return (x, y, w, h)
|
||||
|
||||
def modify_page(page):
|
||||
page.wait_for_timeout(500)
|
||||
|
||||
try:
|
||||
page.evaluate(remove_id_script)
|
||||
except:
|
||||
pass
|
||||
|
||||
packet = {
|
||||
"raw_html": page.evaluate("document.documentElement.outerHTML"),
|
||||
"window": get_window(page)
|
||||
}
|
||||
|
||||
page.evaluate(prepare_script)
|
||||
page.wait_for_timeout(100)
|
||||
|
||||
img_bytes = page.screenshot(path="debug_info/screenshot_raw.png")
|
||||
raw_image = base64.b64encode(img_bytes).decode()
|
||||
|
||||
page.evaluate(clickable_checker_script)
|
||||
page.wait_for_timeout(50)
|
||||
|
||||
# get all clickable elements
|
||||
start_id = 0
|
||||
items, start_id = page.evaluate(label_script, {
|
||||
"selector": ".possible-clickable-element",
|
||||
"startIndex": start_id
|
||||
})
|
||||
page.wait_for_timeout(50)
|
||||
|
||||
# mark our own labels and get the images
|
||||
items = page.evaluate(label_marker_script, items)
|
||||
page.wait_for_timeout(100)
|
||||
img_bytes = page.screenshot(path="debug_info/marked.png")
|
||||
marked_image = base64.b64encode(img_bytes).decode()
|
||||
|
||||
# remove markers on the page
|
||||
page.evaluate(remove_label_mark_script)
|
||||
|
||||
packet.update({
|
||||
"raw_image": raw_image,
|
||||
"marked_image": marked_image,
|
||||
"modified_html": page.evaluate("document.documentElement.outerHTML")
|
||||
})
|
||||
|
||||
# element_info, include "all_elements" and "clickable_elements"
|
||||
element_info = page.evaluate(element_info_script)
|
||||
page.wait_for_timeout(100)
|
||||
packet.update(element_info)
|
||||
return packet
|
||||
|
||||
def save_debug_info(packet):
|
||||
with open("debug_info/raw.html", "w") as f:
|
||||
f.write(packet["modified_html"])
|
||||
with open("debug_info/parsed.html", "w") as f:
|
||||
f.write(packet["html"])
|
||||
with open("debug_info/all_element.json", "w") as f:
|
||||
f.write(json.dumps(packet["all_elements"]))
|
||||
|
||||
def get_parsed_html(page):
|
||||
if not os.path.exists("debug_info"):
|
||||
os.makedirs("debug_info")
|
||||
|
||||
print("parsing html...")
|
||||
|
||||
packet = modify_page(page)
|
||||
raw_html = packet["modified_html"]
|
||||
|
||||
args = {
|
||||
"use_position": True,
|
||||
"rect_dict": {},
|
||||
"window_size": packet["window"],
|
||||
"id-attr": "data-backend-node-id",
|
||||
"label_attr": "data-label-id",
|
||||
"label_generator": "order",
|
||||
"regenerate_label": False,
|
||||
"attr_list": basic_attrs,
|
||||
"prompt": "xml",
|
||||
"dataset": "pipeline"
|
||||
}
|
||||
|
||||
hp = HtmlParser(raw_html, args)
|
||||
res = hp.parse_tree()
|
||||
page_html = res.get("html", "")
|
||||
|
||||
packet["html"] = page_html
|
||||
|
||||
# for debug
|
||||
save_debug_info(packet)
|
||||
|
||||
print("parsing finished.")
|
||||
|
||||
return packet
|
||||
|
447
VAB-WebArena-Lite/new/html_tools/html_parser.py
Executable file
447
VAB-WebArena-Lite/new/html_tools/html_parser.py
Executable file
|
@ -0,0 +1,447 @@
|
|||
from lxml import html
|
||||
import time, copy, random
|
||||
import json, re, os
|
||||
|
||||
from .identifier import IdentifierTool
|
||||
from .prompt import HtmlPrompt
|
||||
from .configs import config_meta
|
||||
from .utils import get_xpath_top_down, rect2tuple
|
||||
|
||||
class HtmlParser():
|
||||
def __init__(self, ctx: str, args: dict[str]={}) -> None:
|
||||
stt = time.time()
|
||||
self.dom_tree = self.ctx2tree(ctx)
|
||||
# tool related
|
||||
self.bids2label = {}
|
||||
self.bids2xpath = {}
|
||||
self.used_labels = {}
|
||||
|
||||
# parse args
|
||||
self.parse_args(args)
|
||||
self.init_time = time.time() - stt
|
||||
|
||||
def parse_args(self, args: dict[str]={}) -> None:
|
||||
def attr_check(attr, type_model='str'):
|
||||
if attr is None:
|
||||
return False
|
||||
attr_type = type(attr)
|
||||
if attr_type != type(type_model):
|
||||
return False
|
||||
if attr_type == type('str') and len(attr) == 0:
|
||||
return False
|
||||
return True
|
||||
|
||||
args = {} if args is None else args
|
||||
|
||||
# [Position] use_pos: False -> use full page, otherwise use window_size
|
||||
dataset = args.get('dataset', '')
|
||||
use_position = args.get('use_position', False)
|
||||
window_size = args.get('window_size', None)
|
||||
rect = args.get('rect_dict', None)
|
||||
if use_position:
|
||||
if not attr_check(window_size, ()):
|
||||
raise ValueError('window_size must be set when use_position is True')
|
||||
if not attr_check(rect, {}):
|
||||
raise ValueError('rect_dict must be set when use_position is True')
|
||||
|
||||
if not attr_check(rect, {}):
|
||||
rect = {}
|
||||
|
||||
# [Label] for vimium is temp_clickable_label, otherwise keep all of it
|
||||
label_attr = args.get('label_attr', '')
|
||||
get_new_label = args.get('regenerate_label', False)
|
||||
label_method = args.get('label_generator', None)
|
||||
regen_label = not attr_check(label_method)
|
||||
|
||||
# [id] for mind2web is backend_node_id, for normal website use our method
|
||||
id_attr = args.get('id_attr', '')
|
||||
regen_id = not attr_check(id_attr)
|
||||
|
||||
if regen_id:
|
||||
id_attr = 'temp_id'
|
||||
|
||||
# [attributes]
|
||||
keep_attrs = args.get('attr_list', [])
|
||||
if not attr_check(keep_attrs, []):
|
||||
keep_attrs = []
|
||||
|
||||
# [Tags] for clickable elem, keep: must keep, obs: keep if follow specific rule
|
||||
parent_chain = args.get('parent_chain', False)
|
||||
keep_elem = args.get('keep_elem', [])
|
||||
obs_elem = args.get('obs_elem', [])
|
||||
|
||||
# sanity check
|
||||
self.set_args(use_position, window_size, rect, label_attr, id_attr, keep_attrs, keep_elem, obs_elem, parent_chain, get_new_label, dataset)
|
||||
|
||||
# [Prompt]
|
||||
prompt = args.get('prompt', None)
|
||||
self.prompt = HtmlPrompt(prompt)
|
||||
|
||||
# traverse and get special data
|
||||
if regen_id or regen_label:
|
||||
self.mark_id()
|
||||
|
||||
if get_new_label:
|
||||
self.used_labels = {}
|
||||
|
||||
self.identifier = IdentifierTool(label_method, self.used_labels)
|
||||
|
||||
def set_args(self, use_position: bool=False, window_size: tuple=(), rect_dict: dict[str]={}, label_attr: str='',
|
||||
id_attr: str='', keep_attrs: list[str]=[], keep_elem: list[str]=[], obs_elem: list[str]=[],
|
||||
parent_chain: bool=False, get_new_label: bool=False, dataset: str='') -> None:
|
||||
|
||||
self.use_position = use_position
|
||||
self.window_size = window_size
|
||||
self.rect = rect_dict
|
||||
self.label_attr = label_attr
|
||||
self.id_attr = id_attr
|
||||
self.keep_attrs = keep_attrs
|
||||
self.keep = keep_elem
|
||||
self.obs = obs_elem
|
||||
self.parent_chain = parent_chain
|
||||
self.get_new_label = get_new_label
|
||||
self.dataset = dataset
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
'id_attr': self.id_attr,
|
||||
'keep_attrs': self.keep_attrs[:5],
|
||||
'label_attr': self.label_attr,
|
||||
'use_position': self.use_position,
|
||||
'window_size': self.window_size,
|
||||
'rect': dict(list(self.rect.items())[:3]),
|
||||
'keep_elem': self.keep[:5],
|
||||
'obs_elem': self.obs[:5],
|
||||
'parent_chain': self.parent_chain,
|
||||
'prompt_name': self.prompt.name,
|
||||
'identifier_name': self.identifier.name
|
||||
}
|
||||
|
||||
return config, config_meta.format(**config)
|
||||
|
||||
def update_rect_dict(self, rect_dict: dict[str]={}) -> None:
|
||||
self.rect = rect_dict
|
||||
|
||||
@staticmethod
|
||||
def ctx2tree(ctx: str) -> html.HtmlElement:
|
||||
# remove useless tags, eg. style and script
|
||||
ctx = re.sub('<!--[\W\w]*?-->', '', ctx)
|
||||
ctx = re.sub('<style[\W\w]*?>[\W\w]*?</style>', '', ctx)
|
||||
ctx = re.sub('<script[\W\w]*?>[\W\w]*?</script>', '', ctx)
|
||||
ctx = '' if ctx is None else re.sub(r'\s+', ' ', ctx).strip()
|
||||
dom_tree = html.fromstring(ctx.encode('utf-8'))
|
||||
match = re.search('<meta charset="([^"]*)"', ctx)
|
||||
if match:
|
||||
charset = match.group(1)
|
||||
ctx = ctx.encode(charset)
|
||||
print(charset)
|
||||
else:
|
||||
print("Charset not found")
|
||||
return dom_tree
|
||||
|
||||
@staticmethod
|
||||
def get_root(tree: html.HtmlElement) -> html.HtmlElement:
|
||||
node = tree.xpath('//*')[0]
|
||||
while True:
|
||||
parent = node.getparent()
|
||||
if parent is None:
|
||||
break
|
||||
node = parent
|
||||
return node
|
||||
|
||||
def get_node_by_bid(self, tree: html.HtmlElement, bid: str) -> html.HtmlElement:
|
||||
nodes = tree.xpath(f'//*[@{self.id_attr}="{bid}"]')
|
||||
if len(nodes) == 0:
|
||||
return None
|
||||
return nodes[0]
|
||||
|
||||
def id_label_converter(self, label: str) -> str:
|
||||
return self.bids2label.get(label, '')
|
||||
|
||||
def id_xpath_converter(self, label: str) -> str:
|
||||
return self.bids2xpath.get(label, '')
|
||||
|
||||
def mark_id(self) -> None:
|
||||
root = self.get_root(self.dom_tree)
|
||||
_, i2xpath, used_labels = get_xpath_top_down(root, self.id_attr, self.label_attr)
|
||||
self.used_labels = used_labels
|
||||
self.bids2xpath = i2xpath
|
||||
|
||||
def parse(self, root: html.HtmlElement, keep: list[str], obs: list[str], parent_chain: bool=False, get_new_label: bool=False) -> dict[str]:
|
||||
def get_text(str: str) -> str:
|
||||
return '' if str is None else str.strip()[:500]
|
||||
|
||||
def check_attr(attr: str, node: html.HtmlElement) -> bool:
|
||||
tag = node.tag
|
||||
if (
|
||||
( attr == 'role' and node.attrib.get(attr, '') in ['presentation', 'none', 'link'] )
|
||||
or ( attr == 'type' and node.attrib.get(attr, '') == 'hidden' )
|
||||
or ( attr == 'aria-hidden' and node.attrib.get(attr, '') == 'true')
|
||||
# or ( attr == 'value' and tag in ['option'] )
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
def is_visible(node: html.HtmlElement, bid: str) -> bool:
|
||||
if self.dataset == 'mind2web':
|
||||
bound = node.attrib.get('bounding_box_rect', None)
|
||||
self.rect[bid] = rect2tuple(bound)
|
||||
|
||||
if self.dataset == 'pipeline':
|
||||
bound = node.attrib.get('data-bbox', None)
|
||||
self.rect[bid] = rect2tuple(bound)
|
||||
|
||||
if node.attrib.get('aria-hidden', 'false') == 'true':
|
||||
return False
|
||||
|
||||
if not self.use_position:
|
||||
return True
|
||||
|
||||
rect = self.rect.get(bid, None)
|
||||
if rect is None:
|
||||
return False
|
||||
|
||||
if self.window_size is None:
|
||||
return True
|
||||
|
||||
# get window size
|
||||
wx, wy, ww, wh = self.window_size
|
||||
wx, wy = 0, 0
|
||||
x, y, w, h = rect
|
||||
if x + w < wx or x > wx + ww or y + h < wy or y > wy + wh:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _dfs(node: html.HtmlElement, keep: list[str]=[], obs: list[str]=[],
|
||||
parent_chain: bool=False, get_new_label: bool=False, par_keep: bool=False) -> (str, dict[str]):
|
||||
# basic information
|
||||
bid = node.attrib.get(self.id_attr, '')
|
||||
tag = node.tag
|
||||
label = node.attrib.get(self.label_attr, '')
|
||||
|
||||
# element which is keeped equivalent to visible
|
||||
visible = is_visible(node, bid)
|
||||
in_keep_list = bid in keep
|
||||
in_obs_list = (bid in obs or len(label) > 0) and visible
|
||||
par_keep = par_keep and tag == "option"
|
||||
keep_element = in_keep_list or in_obs_list or visible or par_keep
|
||||
|
||||
if label:
|
||||
keep_element = True
|
||||
|
||||
# mark label
|
||||
bids2label, labeled_elems = {}, []
|
||||
have_label = False
|
||||
if in_keep_list or in_obs_list:
|
||||
if label is None or len(label) == 0 or get_new_label:
|
||||
label = self.identifier.generate()
|
||||
node.attrib[self.label_attr] = label
|
||||
bids2label[bid] = label
|
||||
bids2label[label] = bid
|
||||
have_label = True
|
||||
|
||||
# get text or alt_text of current element
|
||||
text = get_text(node.text)
|
||||
|
||||
classes = {}
|
||||
# keep attributes if needed
|
||||
keep_all_attrs = len(self.keep_attrs) == 0
|
||||
keep_attrs = node.attrib.keys() if keep_all_attrs else self.keep_attrs
|
||||
|
||||
# traverse attributes
|
||||
for attr in keep_attrs:
|
||||
if attr not in node.attrib or not check_attr(attr, node):
|
||||
continue
|
||||
if attr in [self.id_attr, self.label_attr]:
|
||||
continue
|
||||
val = get_text(node.attrib[attr])
|
||||
if len(val) > 0 or keep_all_attrs:
|
||||
classes[attr] = val
|
||||
|
||||
have_text = len(text) > 0 or len(classes) - (1 if 'data-bbox' in classes else 0) > 0
|
||||
par_keep = keep_element and tag == 'select'
|
||||
|
||||
parts = []
|
||||
clickable_count = 0
|
||||
children = node.getchildren()
|
||||
for child in children:
|
||||
cres, cmsg = _dfs(child, keep, obs, parent_chain, get_new_label, par_keep)
|
||||
clickable_count += 1 if cmsg.get('have_clickable', False) else 0
|
||||
bids2label.update(cmsg.get('bids2label', {}))
|
||||
labeled_elems.extend(cmsg.get('label_element', []))
|
||||
if len(cres) != 0:
|
||||
parts.append(cres)
|
||||
|
||||
dom = self.prompt.subtree_constructor(parts)
|
||||
|
||||
# remove <text|> if all children are text
|
||||
keep_as_all_text = (dom.count('<') == dom.count('<text|')) and dom.count('<') > 0
|
||||
if keep_as_all_text:
|
||||
matches = re.findall(r'<text\| ([^>]+) >', dom)
|
||||
dom = self.prompt.subtree_constructor(matches)
|
||||
|
||||
keep_element = keep_element and (clickable_count > 1 or have_text or have_label or keep_as_all_text)
|
||||
keep_as_parent = len(dom) > 0 and parent_chain
|
||||
if in_keep_list or keep_element or keep_as_parent:
|
||||
dom = self.prompt.prompt_constructor(tag, label, text, dom, classes)
|
||||
|
||||
if have_label:
|
||||
labeled_elems.append(bid)
|
||||
|
||||
control_msg = {
|
||||
'have_clickable': bool(clickable_count or have_text),
|
||||
'bids2label': bids2label,
|
||||
'label_element': labeled_elems,
|
||||
}
|
||||
|
||||
return dom, control_msg
|
||||
|
||||
dom, cmsg = _dfs(root, keep, obs, parent_chain, get_new_label)
|
||||
return dom, cmsg
|
||||
|
||||
def parse_tree(self) -> dict[str]:
|
||||
# start from here
|
||||
stt = time.time()
|
||||
root = self.get_root(self.dom_tree)
|
||||
dom, cmsg = self.parse(root, self.keep, self.obs, self.parent_chain, self.get_new_label)
|
||||
self.bids2label = cmsg.get('bids2label', {})
|
||||
self.keep = list(set(self.keep + cmsg.get('label_element', [])))
|
||||
|
||||
obj = {
|
||||
'html': dom,
|
||||
'parse_time': time.time() - stt
|
||||
}
|
||||
|
||||
return obj
|
||||
|
||||
# From mind2web, https://github.com/OSU-NLP-Group/Mind2Web/blob/main/src/data_utils/dom_utils.py
|
||||
def get_keep_elements(self, tree: html.HtmlElement, keep: list[str], max_depth: int, max_children: int,
|
||||
max_sibling: int, dfs_count: int=1, keep_parent: bool=False) -> list[str]:
|
||||
def get_anscendants(node: html.HtmlElement, max_depth: int, current_depth: int=0) -> list[str]:
|
||||
if current_depth > max_depth:
|
||||
return []
|
||||
|
||||
anscendants = []
|
||||
parent = node.getparent()
|
||||
if parent is not None:
|
||||
anscendants.append(parent)
|
||||
anscendants.extend(get_anscendants(parent, max_depth, current_depth + 1))
|
||||
|
||||
return anscendants
|
||||
|
||||
def get_descendants(node: html.HtmlElement, max_depth: int, current_depth: int=0) -> list[str]:
|
||||
if current_depth > max_depth:
|
||||
return []
|
||||
|
||||
descendants = []
|
||||
for child in node:
|
||||
descendants.append(child)
|
||||
descendants.extend(get_descendants(child, max_depth, current_depth + 1))
|
||||
|
||||
return descendants
|
||||
|
||||
to_keep = set(copy.deepcopy(keep))
|
||||
nodes_to_keep = set()
|
||||
|
||||
for _ in range(max(1, dfs_count)):
|
||||
for bid in to_keep:
|
||||
candidate_node = self.get_node_by_bid(tree, bid)
|
||||
if candidate_node is None:
|
||||
continue
|
||||
|
||||
nodes_to_keep.add(candidate_node.attrib[self.id_attr])
|
||||
# get all ancestors or with max depth
|
||||
nodes_to_keep.update([x.attrib.get(self.id_attr, '') for x in get_anscendants(candidate_node, max_depth)])
|
||||
|
||||
# get descendants with max depth
|
||||
nodes_to_keep.update([x.attrib.get(self.id_attr, '') for x in get_descendants(candidate_node, max_depth)][:max_children])
|
||||
# get siblings within range
|
||||
parent = candidate_node.getparent()
|
||||
if parent is None:
|
||||
continue
|
||||
|
||||
siblings = [x for x in parent.getchildren() if x.tag != 'text']
|
||||
if candidate_node not in siblings:
|
||||
continue
|
||||
|
||||
idx_in_sibling = siblings.index(candidate_node)
|
||||
nodes_to_keep.update([x.attrib.get(self.id_attr, '')
|
||||
for x in siblings[max(0, idx_in_sibling - max_sibling) : idx_in_sibling + max_sibling + 1]])
|
||||
|
||||
max_children = int(max_children * 0.5)
|
||||
max_depth = int(max_depth * 0.5)
|
||||
max_sibling = int(max_sibling * 0.7)
|
||||
|
||||
to_keep = copy.deepcopy(nodes_to_keep)
|
||||
|
||||
if keep_parent:
|
||||
for bid in keep:
|
||||
candidate_node = self.get_node_by_bid(tree, bid)
|
||||
if candidate_node is None:
|
||||
continue
|
||||
nodes_to_keep.update([x.attrib.get(self.id_attr, '') for x in candidate_node.xpath("ancestor::*")])
|
||||
|
||||
return list(nodes_to_keep)
|
||||
|
||||
def prune(self, tree: html.HtmlElement, nodes_to_keep: list[str]) -> html.HtmlElement:
|
||||
# remove nodes not in nodes_to_keep
|
||||
for node in tree.xpath('//*')[::-1]:
|
||||
if node.tag != 'text':
|
||||
is_keep = node.attrib.get(self.id_attr, '') in nodes_to_keep
|
||||
is_candidate = node.attrib.get(self.id_attr, '') in self.keep
|
||||
else:
|
||||
is_keep = (node.getparent().attrib.get(self.id_attr, '') in nodes_to_keep)
|
||||
is_candidate = (node.getparent().attrib.get(self.id_attr, '') in self.keep)
|
||||
|
||||
if not is_keep and node.getparent() is not None:
|
||||
# insert all children into parent
|
||||
for child in node.getchildren():
|
||||
node.addprevious(child)
|
||||
node.getparent().remove(node)
|
||||
else:
|
||||
# if not is_candidate or node.tag == 'text':
|
||||
# node.attrib.pop(self.id_attr, None)
|
||||
if (
|
||||
len(node.attrib) == 0
|
||||
and not any([x.tag == 'text' for x in node.getchildren()])
|
||||
and node.getparent() is not None
|
||||
and node.tag != "text"
|
||||
and len(node.getchildren()) <= 1
|
||||
):
|
||||
# insert all children into parent
|
||||
for child in node.getchildren():
|
||||
node.addprevious(child)
|
||||
node.getparent().remove(node)
|
||||
|
||||
return tree
|
||||
|
||||
def prune_tree(self, dfs_count: int=1, max_depth: int=3, max_children: int=30,
|
||||
max_sibling: int=3, keep_parent: bool=False) -> None:
|
||||
# clone the tree
|
||||
new_tree = copy.deepcopy(self.dom_tree)
|
||||
nodes_to_keep = self.get_keep_elements(new_tree, self.keep, max_depth, max_children, max_sibling, dfs_count, keep_parent)
|
||||
new_tree = self.prune(new_tree, nodes_to_keep)
|
||||
|
||||
self.dom_tree = new_tree
|
||||
|
||||
def get_segment(self, bid: str) -> str:
|
||||
# clone the tree
|
||||
new_tree = copy.deepcopy(self.dom_tree)
|
||||
nodes_to_keep = self.get_keep_elements(new_tree, [bid], 0, 2, 1)
|
||||
new_tree = self.prune(new_tree, nodes_to_keep)
|
||||
dom, _ = self.parse(new_tree, self.keep, [], False)
|
||||
return dom
|
||||
|
||||
def get_rect_data(self, bids: list[str]) -> list[dict[str]]:
|
||||
res = []
|
||||
for bid in bids:
|
||||
label = self.bids2label.get(bid, '')
|
||||
rect = self.rect.get(bid, None)
|
||||
res.append({
|
||||
'bid': bid,
|
||||
'label': label,
|
||||
'rect': rect
|
||||
})
|
||||
return res
|
||||
|
64
VAB-WebArena-Lite/new/html_tools/identifier.py
Executable file
64
VAB-WebArena-Lite/new/html_tools/identifier.py
Executable file
|
@ -0,0 +1,64 @@
|
|||
import secrets
|
||||
|
||||
class IdentifierTool:
|
||||
def __init__(self, method: str='order', existing_labels: dict[str]={}) -> None:
|
||||
self.methods = {
|
||||
'order': self.get_identifier_in_order,
|
||||
'random': self.get_random_identifier,
|
||||
}
|
||||
|
||||
if method is None:
|
||||
method = 'order'
|
||||
|
||||
self.func = self.methods.get(method, None)
|
||||
self.name = method
|
||||
if self.func is None:
|
||||
raise ValueError(f'Invalid method for identifier: {method}')
|
||||
|
||||
self.reset(existing_labels)
|
||||
|
||||
def reset(self, exists: dict[str]={}) -> None:
|
||||
self.identifier = -1
|
||||
self.exists = {} if exists is None else exists
|
||||
|
||||
def get_identifier_in_order(self) -> str:
|
||||
def id2str(id: int) -> str:
|
||||
if id < 26:
|
||||
return chr(id + 65)
|
||||
id -= 26
|
||||
c0 = id // 676
|
||||
c1 = (id // 26) % 26
|
||||
c2 = id % 26
|
||||
label = f'{chr(c1 + 65)}{chr(c2 + 65)}'
|
||||
return label if c0 == 0 else f'{chr(c0 + 64)}{label}'
|
||||
|
||||
self.identifier += 1
|
||||
label = id2str(self.identifier)
|
||||
|
||||
while label in self.exists:
|
||||
self.identifier += 1
|
||||
label = id2str(self.identifier)
|
||||
|
||||
self.exists[label] = True
|
||||
return label
|
||||
|
||||
def get_random_identifier(self) -> str:
|
||||
secret_generator = secrets.SystemRandom()
|
||||
|
||||
def get_random_label(n: int=2) -> str:
|
||||
tmp = ''
|
||||
for _ in range(n):
|
||||
tmp += chr(secret_generator.randint(65, 90))
|
||||
return tmp
|
||||
|
||||
wc = 3 if len(self.exists) > 280 else 2
|
||||
|
||||
label = get_random_label(wc)
|
||||
while label in self.exists:
|
||||
label = get_random_label(wc)
|
||||
|
||||
self.exists[label] = True
|
||||
return label
|
||||
|
||||
def generate(self):
|
||||
return self.func()
|
97
VAB-WebArena-Lite/new/html_tools/prompt.py
Executable file
97
VAB-WebArena-Lite/new/html_tools/prompt.py
Executable file
|
@ -0,0 +1,97 @@
|
|||
from .configs import prompts
|
||||
|
||||
class HtmlPrompt:
|
||||
def __init__(self, prompt: str='') -> None:
|
||||
prompt = self.extract(prompt, 'xml')
|
||||
if prompt not in prompts:
|
||||
raise Exception('Unknown prompt: ' + prompt)
|
||||
|
||||
constructors = {
|
||||
'refine': self.normal_prompt_constructor,
|
||||
'xml': self.normal_prompt_constructor,
|
||||
'new_data': self.new_data_prompt_constructor,
|
||||
}
|
||||
|
||||
self.name = prompt
|
||||
self.prompt = prompts[prompt]
|
||||
self.constructor = constructors[prompt]
|
||||
|
||||
@staticmethod
|
||||
def extract(data, default=''):
|
||||
return data if data is not None else default
|
||||
|
||||
def subtree_constructor(self, subtree: list[str]=[]) -> str:
|
||||
return self.prompt['subtree_splitter'].join(subtree)
|
||||
|
||||
def normal_prompt_constructor(self, tag: str='', label: str='', content: str='', subtree_str: str='', class_dict: dict[str]={}) -> str:
|
||||
def add_prefix(data, prefix):
|
||||
return prefix + data if len(data) > 0 else ''
|
||||
|
||||
tag = self.extract(tag)
|
||||
label = self.extract(label)
|
||||
content = self.extract(content)
|
||||
subtree_str = self.extract(subtree_str, '')
|
||||
class_dict = self.extract(class_dict, {})
|
||||
|
||||
label_str = ''
|
||||
if len(label) > 0:
|
||||
label_str = self.prompt['label'].format(label=label)
|
||||
|
||||
classes = []
|
||||
values = set()
|
||||
for key, val in class_dict.items():
|
||||
if val in values:
|
||||
continue
|
||||
values.add(val)
|
||||
classes.append(self.prompt['attr'].format(key=key, attr=val))
|
||||
classes_str = self.prompt['attr_splitter'].join(classes)
|
||||
|
||||
content_splitter = ' ' if len(classes_str) == 0 else self.prompt['attr_splitter']
|
||||
classes_str = add_prefix(classes_str, ' ')
|
||||
content_str = add_prefix(content, content_splitter)
|
||||
subtree_str = add_prefix(subtree_str, ' ')
|
||||
|
||||
return self.prompt['dom'].format(tag=tag, label=label_str, attr=classes_str, content=content_str, subtree=subtree_str)
|
||||
|
||||
def new_data_prompt_constructor(self, tag: str='', label: str='', content: str='', subtree_str: str='', class_dict: dict[str]={}) -> str:
|
||||
def add_prefix(data, prefix):
|
||||
return prefix + data if len(data) > 0 else ''
|
||||
|
||||
tag = self.extract(tag)
|
||||
label = self.extract(label)
|
||||
content = self.extract(content)
|
||||
subtree_str = self.extract(subtree_str, '')
|
||||
class_dict = self.extract(class_dict, {})
|
||||
|
||||
label_str = ''
|
||||
if len(label) > 0:
|
||||
label_str = self.prompt['label'].format(label=label)
|
||||
|
||||
classes = []
|
||||
values = set()
|
||||
|
||||
message = []
|
||||
for key, val in class_dict.items():
|
||||
if val == '':
|
||||
message.append(key)
|
||||
continue
|
||||
if val in values:
|
||||
continue
|
||||
values.add(val)
|
||||
classes.append(self.prompt['attr'].format(key=key, attr=val))
|
||||
|
||||
if len(message) > 0:
|
||||
message_str = ' '.join(message)
|
||||
classes.append(self.prompt['attr'].format(key='message', attr=message_str))
|
||||
|
||||
classes_str = self.prompt['attr_splitter'].join(classes)
|
||||
|
||||
content_splitter = ' ' if len(classes_str) == 0 else self.prompt['attr_splitter']
|
||||
classes_str = add_prefix(classes_str, ' ')
|
||||
content_str = add_prefix(content, content_splitter)
|
||||
subtree_str = add_prefix(subtree_str, ' ')
|
||||
|
||||
return self.prompt['dom'].format(tag=tag, label=label_str, attr=classes_str, content=content_str, subtree=subtree_str)
|
||||
|
||||
def prompt_constructor(self, tag: str='', label: str='', content: str='', subtree_str: str='', class_dict: dict[str]={}) -> str:
|
||||
return self.constructor(tag, label, content, subtree_str, class_dict)
|
43
VAB-WebArena-Lite/new/html_tools/scripts/__init__.py
Executable file
43
VAB-WebArena-Lite/new/html_tools/scripts/__init__.py
Executable file
|
@ -0,0 +1,43 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
rootdir = Path(__file__).parent
|
||||
|
||||
with open(os.path.join(rootdir,'prepare.js'), 'r') as f:
|
||||
prepare_script = f.read()
|
||||
|
||||
with open(os.path.join(rootdir, 'clickable_checker.js'), 'r') as f:
|
||||
clickable_checker_script = f.read()
|
||||
|
||||
with open(os.path.join(rootdir, 'label.js'), 'r') as f:
|
||||
label_script = f.read()
|
||||
|
||||
with open(os.path.join(rootdir, 'element_info.js'), 'r') as f:
|
||||
element_info_script = f.read()
|
||||
|
||||
# draw label on page
|
||||
with open(os.path.join(rootdir, 'label_marker.js'), 'r') as f:
|
||||
label_marker_script = f.read()
|
||||
|
||||
# remove label draw on page
|
||||
remove_label_mark_script = """
|
||||
() => {
|
||||
document.querySelectorAll(".our-dom-marker").forEach(item => {
|
||||
document.body.removeChild(item);
|
||||
});
|
||||
}
|
||||
"""
|
||||
|
||||
remove_id_script = """
|
||||
() => {
|
||||
Array.from(document.getElementsByClassName('possible-clickable-element')).forEach((element) => {
|
||||
element.classList.remove('possible-clickable-element');
|
||||
element.removeAttribute('data-value');
|
||||
element.removeAttribute('data-text');
|
||||
element.removeAttribute('data-label');
|
||||
element.removeAttribute('data-bbox');
|
||||
element.removeAttribute('data-status');
|
||||
element.removeAttribute('data-backend-node-id');
|
||||
element.removeAttribute('data-label-id');
|
||||
});
|
||||
}
|
||||
"""
|
148
VAB-WebArena-Lite/new/html_tools/scripts/clickable_checker.js
Executable file
148
VAB-WebArena-Lite/new/html_tools/scripts/clickable_checker.js
Executable file
|
@ -0,0 +1,148 @@
|
|||
() => {
|
||||
var items = Array.prototype.slice.call(
|
||||
document.querySelectorAll('*')
|
||||
).map(function(element) {
|
||||
var vw = Math.max(document.documentElement.clientWidth || 0, window.innerWidth || 0);
|
||||
var vh = Math.max(document.documentElement.clientHeight || 0, window.innerHeight || 0);
|
||||
|
||||
var rects = [...element.getClientRects()].filter(bb => {
|
||||
var center_x = bb.left + bb.width / 2;
|
||||
var center_y = bb.top + bb.height / 2;
|
||||
var elAtCenter = document.elementFromPoint(center_x, center_y);
|
||||
|
||||
if (!elAtCenter) return false;
|
||||
return elAtCenter === element || element.contains(elAtCenter)
|
||||
}).map(bb => {
|
||||
const rect = {
|
||||
left: Math.max(0, bb.left),
|
||||
top: Math.max(0, bb.top),
|
||||
right: Math.min(vw, bb.right),
|
||||
bottom: Math.min(vh, bb.bottom)
|
||||
};
|
||||
return {
|
||||
...rect,
|
||||
width: rect.right - rect.left,
|
||||
height: rect.bottom - rect.top
|
||||
}
|
||||
});
|
||||
// var rects = [];
|
||||
var area = rects.reduce((acc, rect) => acc + rect.width * rect.height, 0);
|
||||
|
||||
const tagName = element.tagName.toLowerCase?.() || "";
|
||||
let isClickable = ((element.onclick != null) || window.getComputedStyle(element).cursor == "pointer");
|
||||
|
||||
// Insert area elements that provide click functionality to an img.
|
||||
if (tagName === "img") {
|
||||
let mapName = element.getAttribute("usemap");
|
||||
if (mapName) {
|
||||
const imgClientRects = element.getClientRects();
|
||||
mapName = mapName.replace(/^#/, "").replace('"', '\\"');
|
||||
const map = document.querySelector(`map[name=\"${mapName}\"]`);
|
||||
if (map && (imgClientRects.length > 0)) isClickable = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!isClickable) {
|
||||
const role = element.getAttribute("role");
|
||||
const clickableRoles = [
|
||||
"button",
|
||||
"tab",
|
||||
"link",
|
||||
"checkbox",
|
||||
"menuitem",
|
||||
"menuitemcheckbox",
|
||||
"menuitemradio",
|
||||
"radio",
|
||||
];
|
||||
if (role != null && clickableRoles.includes(role.toLowerCase())) {
|
||||
isClickable = true;
|
||||
} else {
|
||||
const contentEditable = element.getAttribute("contentEditable");
|
||||
if (
|
||||
contentEditable != null &&
|
||||
["", "contenteditable", "true"].includes(contentEditable.toLowerCase())
|
||||
) {
|
||||
isClickable = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for jsaction event listeners on the element.
|
||||
if (!isClickable && element.hasAttribute("jsaction")) {
|
||||
const jsactionRules = element.getAttribute("jsaction").split(";");
|
||||
for (let jsactionRule of jsactionRules) {
|
||||
const ruleSplit = jsactionRule.trim().split(":");
|
||||
if ((ruleSplit.length >= 1) && (ruleSplit.length <= 2)) {
|
||||
const [eventType, namespace, actionName] = ruleSplit.length === 1
|
||||
? ["click", ...ruleSplit[0].trim().split("."), "_"]
|
||||
: [ruleSplit[0], ...ruleSplit[1].trim().split("."), "_"];
|
||||
if (!isClickable) {
|
||||
isClickable = (eventType === "click") && (namespace !== "none") && (actionName !== "_");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!isClickable) {
|
||||
const clickableTags = [
|
||||
"input",
|
||||
"textarea",
|
||||
"select",
|
||||
"button",
|
||||
"a",
|
||||
"iframe",
|
||||
"video",
|
||||
"object",
|
||||
"embed",
|
||||
"details"
|
||||
];
|
||||
isClickable = clickableTags.includes(tagName);
|
||||
}
|
||||
|
||||
if (!isClickable) {
|
||||
if (tagName === "label")
|
||||
isClickable = (element.control != null) && !element.control.disabled;
|
||||
else if (tagName === "img")
|
||||
isClickable = ["zoom-in", "zoom-out"].includes(element.style.cursor);
|
||||
}
|
||||
|
||||
// An element with a class name containing the text "button" might be clickable. However, real
|
||||
// clickables are often wrapped in elements with such class names. So, when we find clickables
|
||||
// based only on their class name, we mark them as unreliable.
|
||||
const className = element.getAttribute("class");
|
||||
if (!isClickable && className && className.toLowerCase().includes("button")) {
|
||||
isClickable = true;
|
||||
}
|
||||
|
||||
// Elements with tabindex are sometimes useful, but usually not. We can treat them as second
|
||||
// class citizens when it improves UX, so take special note of them.
|
||||
const tabIndexValue = element.getAttribute("tabindex");
|
||||
const tabIndex = tabIndexValue ? parseInt(tabIndexValue) : -1;
|
||||
if (!isClickable && !(tabIndex < 0) && !isNaN(tabIndex)) {
|
||||
isClickable = true;
|
||||
}
|
||||
|
||||
const idValue = element.getAttribute("id");
|
||||
const id = idValue ? idValue.toLowerCase() : "";
|
||||
if (isClickable && area == 0) {
|
||||
const textValue = element.textContent.trim().replace(/\s{2,}/g, ' ');
|
||||
clickable_msg = `${tagName}[id=${id}] ${isClickable} (${area}) ${textValue}`
|
||||
}
|
||||
|
||||
return {
|
||||
element: element,
|
||||
include: isClickable,
|
||||
area,
|
||||
rects,
|
||||
text: element.textContent.trim().replace(/\s{2,}/g, ' ')
|
||||
};
|
||||
}).filter(item =>
|
||||
item.include && (item.area >= 1)
|
||||
);
|
||||
|
||||
items = items.filter(x => !items.some(y => x.element.contains(y.element) && !(x == y)))
|
||||
|
||||
items.forEach(item => {
|
||||
item.element.classList.add('possible-clickable-element');
|
||||
});
|
||||
}
|
34
VAB-WebArena-Lite/new/html_tools/scripts/element_info.js
Executable file
34
VAB-WebArena-Lite/new/html_tools/scripts/element_info.js
Executable file
|
@ -0,0 +1,34 @@
|
|||
() => {
|
||||
function getElementInfo(element) {
|
||||
return {
|
||||
"bid": element.getAttribute("data-backend-node-id") || "",
|
||||
"label": element.getAttribute("data-label-id") || "",
|
||||
"tag": element.tagName.toLowerCase?.() || "",
|
||||
"area": JSON.parse("[" + (element.getAttribute("data-bbox") || "") + "]"),
|
||||
"text": element.innerText?.trim().replace(/\s{2,}/g, " ") || "",
|
||||
"id": element.getAttribute("id") || "",
|
||||
"role": element.getAttribute("role") || "",
|
||||
"aria-label": element.getAttribute("aria-label") || "",
|
||||
"href": element.getAttribute("href") || "",
|
||||
};
|
||||
}
|
||||
|
||||
var all_items = Array.prototype.slice.call(
|
||||
document.querySelectorAll("*")
|
||||
).map((element) => {
|
||||
return getElementInfo(element);
|
||||
});
|
||||
|
||||
var clickable_items = Array.prototype.slice.call(
|
||||
document.querySelectorAll("*")
|
||||
).filter(
|
||||
element => element.getAttribute("data-label-id")
|
||||
).map((element) => {
|
||||
return getElementInfo(element);
|
||||
});
|
||||
|
||||
return {
|
||||
all_elements: all_items,
|
||||
clickable_elements: clickable_items
|
||||
};
|
||||
}
|
74
VAB-WebArena-Lite/new/html_tools/scripts/label.js
Executable file
74
VAB-WebArena-Lite/new/html_tools/scripts/label.js
Executable file
|
@ -0,0 +1,74 @@
|
|||
(packet) => {
|
||||
function int2str(index) {
|
||||
var str = "";
|
||||
while (index >= 0) {
|
||||
str = String.fromCharCode(65 + index % 26) + str;
|
||||
index = Math.floor(index / 26) - 1;
|
||||
}
|
||||
return str;
|
||||
};
|
||||
|
||||
selector = packet.selector
|
||||
index = packet.startIndex
|
||||
var items = Array.prototype.slice.call(
|
||||
document.querySelectorAll(selector)
|
||||
);
|
||||
|
||||
var vw = Math.max(document.documentElement.clientWidth || 0, window.innerWidth || 0);
|
||||
var vh = Math.max(document.documentElement.clientHeight || 0, window.innerHeight || 0);
|
||||
|
||||
items = items.filter(
|
||||
x => !items.some(y => x.contains(y) && !(x == y))
|
||||
).map(element => {
|
||||
var bb = element.getClientRects();
|
||||
var rect = {
|
||||
left: 0,
|
||||
top: 0,
|
||||
right: 0,
|
||||
bottom: 0,
|
||||
width: 0,
|
||||
height: 0
|
||||
};
|
||||
var keep = false;
|
||||
var text = "", id = -1;
|
||||
if (bb.length > 0) {
|
||||
bb = bb[0];
|
||||
rect = {
|
||||
left: Math.max(0, bb.left),
|
||||
top: Math.max(0, bb.top),
|
||||
right: Math.min(vw, bb.right),
|
||||
bottom: Math.min(vh, bb.bottom)
|
||||
};
|
||||
rect = {
|
||||
...rect,
|
||||
width: rect.right - rect.left,
|
||||
height: rect.bottom - rect.top
|
||||
};
|
||||
if (rect.width > 0 || rect.height > 0) {
|
||||
keep = true;
|
||||
if (index >= 0) {
|
||||
// id = int2str(index++);
|
||||
id = index++;
|
||||
element.setAttribute("data-label-id", id);
|
||||
}
|
||||
var childNodes = element.childNodes;
|
||||
|
||||
for (var i = 0; i < childNodes.length; i++) {
|
||||
if (childNodes[i].nodeType == Node.TEXT_NODE) {
|
||||
text += childNodes[i].textContent;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
keep: true,
|
||||
id,
|
||||
rects: rect,
|
||||
tag: element.tagName.toLowerCase?.() || "",
|
||||
text,//: element.innerText?.trim().replace(/\s{2,}/g, " ") || ""
|
||||
};
|
||||
}).filter(x => x.keep);
|
||||
|
||||
return [items, index];
|
||||
}
|
65
VAB-WebArena-Lite/new/html_tools/scripts/label_marker.js
Executable file
65
VAB-WebArena-Lite/new/html_tools/scripts/label_marker.js
Executable file
|
@ -0,0 +1,65 @@
|
|||
(items) => {
|
||||
function getRandomColor() {
|
||||
var letters = '0123456789ABCDEF';
|
||||
var color = '#';
|
||||
for (var i = 0; i < 6; i++) {
|
||||
color += letters[Math.floor(Math.random() * 16)];
|
||||
}
|
||||
return color;
|
||||
}
|
||||
|
||||
items.filter(
|
||||
item => item.id != ""
|
||||
).forEach((item) => {
|
||||
const bbox = item.rects;
|
||||
const id_string = `dom-marker-id-${index}`;
|
||||
|
||||
index = item.id;
|
||||
|
||||
outerElement = document.createElement("div");
|
||||
outerElement.classList.add("our-dom-marker");
|
||||
// var borderColor = getRandomColor();
|
||||
var borderColor = "#FFFF00";
|
||||
outerElement.style.outline = `2px dashed ${borderColor}`;
|
||||
outerElement.style.position = "fixed";
|
||||
outerElement.style.left = bbox.left - 2 + "px";
|
||||
outerElement.style.top = bbox.top - 2 + "px";
|
||||
outerElement.style.width = bbox.width + 4 + "px";
|
||||
outerElement.style.height = bbox.height + 4 + "px";
|
||||
outerElement.style.pointerEvents = "none";
|
||||
outerElement.style.boxSizing = "border-box";
|
||||
outerElement.style.zIndex = 2147483647;
|
||||
|
||||
innerElement = document.createElement("div");
|
||||
innerElement.classList.add("our-dom-marker");
|
||||
innerElement.style.outline = `2px dashed #222288`;
|
||||
innerElement.style.position = "fixed";
|
||||
innerElement.style.left = bbox.left + "px";
|
||||
innerElement.style.top = bbox.top + "px";
|
||||
innerElement.style.width = bbox.width + "px";
|
||||
innerElement.style.height = bbox.height + "px";
|
||||
innerElement.style.pointerEvents = "none";
|
||||
innerElement.style.boxSizing = "border-box";
|
||||
innerElement.style.zIndex = 2147483647;
|
||||
|
||||
// Add floating label at the corner
|
||||
var label = document.createElement("span");
|
||||
var topPosition = 25;
|
||||
if (bbox.top < 25) topPosition = bbox.top;
|
||||
label.textContent = index;
|
||||
label.style.position = "absolute";
|
||||
label.style.top = `-${topPosition}px`;
|
||||
label.style.left = "0px";
|
||||
label.style.background = borderColor;
|
||||
label.style.color = "black";
|
||||
label.style.padding = "2px 4px";
|
||||
label.style.fontSize = "16px";
|
||||
label.style.borderRadius = "2px";
|
||||
label.style.fontWeight = "bold";
|
||||
outerElement.appendChild(label);
|
||||
|
||||
document.body.appendChild(outerElement);
|
||||
document.body.appendChild(innerElement);
|
||||
})
|
||||
return items;
|
||||
}
|
83
VAB-WebArena-Lite/new/html_tools/scripts/prepare.js
Executable file
83
VAB-WebArena-Lite/new/html_tools/scripts/prepare.js
Executable file
|
@ -0,0 +1,83 @@
|
|||
() => {
|
||||
// mark backend node id
|
||||
var vw = Math.max(document.documentElement.clientWidth || 0, window.innerWidth || 0);
|
||||
var vh = Math.max(document.documentElement.clientHeight || 0, window.innerHeight || 0);
|
||||
|
||||
var backendId = 0;
|
||||
Array.prototype.slice.call(
|
||||
document.querySelectorAll("*")
|
||||
).forEach((element) => {
|
||||
element.setAttribute("data-backend-node-id", backendId);
|
||||
backendId++;
|
||||
|
||||
var tag = element.tagName.toLowerCase?.() || "";
|
||||
var bb = element.getClientRects();
|
||||
var rect = {
|
||||
left: 0,
|
||||
top: 0,
|
||||
right: 0,
|
||||
bottom: 0,
|
||||
width: 0,
|
||||
height: 0
|
||||
};
|
||||
|
||||
if (bb.length > 0) {
|
||||
bb = bb[0];
|
||||
// rect = {
|
||||
// left: Math.round(Math.max(0, bb.left) * 100) / 100,
|
||||
// top: Math.round(Math.max(0, bb.top) * 100) / 100,
|
||||
// right: Math.round(Math.min(vw, bb.right) * 100) / 100,
|
||||
// bottom: Math.round(Math.min(vh, bb.bottom) * 100) / 100
|
||||
// };
|
||||
rect = {
|
||||
left: (Math.round(bb.left) * 100) / 100,
|
||||
top: (Math.round(bb.top) * 100) / 100,
|
||||
right: (Math.round(bb.right) * 100) / 100,
|
||||
bottom: (Math.round(bb.bottom) * 100) / 100
|
||||
};
|
||||
rect = {
|
||||
...rect,
|
||||
width: Math.round((rect.right - rect.left) * 100) / 100,
|
||||
height: Math.round((rect.bottom - rect.top) * 100) / 100
|
||||
};
|
||||
|
||||
element.setAttribute("data-bbox", `${rect.left},${rect.top},${rect.width},${rect.height}`);
|
||||
}
|
||||
|
||||
if (element.hasChildNodes()) {
|
||||
let children = Array.prototype.slice.call(element.childNodes);
|
||||
var texts = children.filter(
|
||||
(node) => node.nodeType == Node.TEXT_NODE
|
||||
).map(
|
||||
(node) => node.textContent.trim().replace(/\s{2,}/g, " ") || ""
|
||||
).filter(
|
||||
(text) => text.length > 0
|
||||
)
|
||||
element.setAttribute("data-text", texts.join(","));
|
||||
}
|
||||
|
||||
// fix select issue
|
||||
if (tag == "select") {
|
||||
var value = element.value;
|
||||
var text = element.options[element.selectedIndex]?.text || "";
|
||||
element.setAttribute("data-value", value);
|
||||
element.setAttribute("data-text", text);
|
||||
element.options[element.selectedIndex]?.setAttribute("data-status", "selected");
|
||||
}
|
||||
|
||||
if (tag == "input") {
|
||||
var input_type = element.getAttribute("type") || "";
|
||||
if (input_type == "checkbox") {
|
||||
var status = element.checked? "checked" : "not-checked";
|
||||
element.setAttribute("data-status", status);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// fix input and textarea issue
|
||||
Array.prototype.slice.call(
|
||||
document.querySelectorAll("input, textarea")
|
||||
).forEach(element => {
|
||||
element.setAttribute("data-value", element.value);
|
||||
});
|
||||
}
|
101
VAB-WebArena-Lite/new/html_tools/utils.py
Executable file
101
VAB-WebArena-Lite/new/html_tools/utils.py
Executable file
|
@ -0,0 +1,101 @@
|
|||
from lxml import html
|
||||
def get_xpath_top_down(element: html.HtmlElement, id_column: str='temp_id', label_column: str='temp_clickable_label', path: str='', order: int=0,
|
||||
in_svg: bool=False, temp_id: int=0) -> tuple[int, dict[str, str], dict[str]]:
|
||||
used_labels, i2xpath = {}, {}
|
||||
# path
|
||||
tag = element.tag.lower()
|
||||
in_svg = in_svg or (tag == 'svg')
|
||||
|
||||
if not in_svg and 'id' in element.attrib:
|
||||
node_id = element.attrib['id']
|
||||
path = f'//*[@id="{node_id}"]'
|
||||
else:
|
||||
suffix = f'[{order}]' if order > 0 else ''
|
||||
prefix = f'*[name()="{tag}"]' if in_svg else tag
|
||||
path = path + '/' + prefix + suffix
|
||||
|
||||
# add temp id
|
||||
element.attrib[id_column] = str(temp_id)
|
||||
ori_label = element.attrib.get(label_column, '')
|
||||
if ori_label != '':
|
||||
used_labels[ori_label] = True
|
||||
|
||||
bid = str(temp_id)
|
||||
i2xpath[bid] = path
|
||||
i2xpath[path] = bid
|
||||
i2xpath[f'xpath/{path}'] = bid
|
||||
i2xpath[f'xpath=/{path}'] = bid
|
||||
|
||||
temp_id += 1
|
||||
|
||||
# traverse node
|
||||
children = element.getchildren()
|
||||
tag_dict = {}
|
||||
id_list = []
|
||||
for child in children:
|
||||
ctag = child.tag.lower()
|
||||
if ctag not in tag_dict:
|
||||
tag_dict[ctag] = 0
|
||||
tag_dict[ctag] += 1
|
||||
id_list.append(tag_dict[ctag])
|
||||
|
||||
for cid, child in zip(id_list, children):
|
||||
ctag = child.tag.lower()
|
||||
cod = cid if tag_dict[ctag] > 1 else 0
|
||||
temp_id, i2x, ulabels = get_xpath_top_down(child, id_column, label_column, path, cod, in_svg, temp_id)
|
||||
i2xpath.update(i2x)
|
||||
used_labels.update(ulabels)
|
||||
|
||||
return temp_id, i2xpath, used_labels
|
||||
|
||||
def print_html_object(obj: str='') -> str:
|
||||
tab_cnt = 0
|
||||
result, content, sep = '', '', ''
|
||||
last_is_left, last_is_right = False, False
|
||||
for ch in obj:
|
||||
if ch == '<':
|
||||
result += '\n'
|
||||
if len(content.strip()) > 0:
|
||||
result += sep + content.strip() + '\n'
|
||||
result += sep + '<'
|
||||
|
||||
tab_cnt += 1
|
||||
sep = ' ' * tab_cnt
|
||||
|
||||
content = ''
|
||||
last_is_right = False
|
||||
last_is_left = True
|
||||
elif ch == '>':
|
||||
if last_is_left:
|
||||
result += content
|
||||
else:
|
||||
if last_is_right:
|
||||
result += '\n'
|
||||
if len(content.strip()) > 0:
|
||||
result += sep + content.strip() + '\n'
|
||||
|
||||
tab_cnt -= 1
|
||||
sep = ' ' * tab_cnt
|
||||
|
||||
if not last_is_left:
|
||||
result += sep
|
||||
|
||||
result += '>'
|
||||
content = ''
|
||||
|
||||
last_is_right = True
|
||||
last_is_left = False
|
||||
else:
|
||||
content += ch
|
||||
|
||||
return result
|
||||
|
||||
def rect2tuple(rect: str) -> tuple[int, int, int, int]:
|
||||
if rect is None or type(rect) != type('str'):
|
||||
return None
|
||||
rect = rect.strip()
|
||||
if rect.count(',') != 3:
|
||||
return None
|
||||
rect = rect.split(',')
|
||||
rect = [float(r) for r in rect]
|
||||
return tuple(rect)
|
|
@ -36,7 +36,6 @@ def retry_with_exponential_backoff( # type: ignore
|
|||
# Initialize variables
|
||||
num_retries = 0
|
||||
delay = initial_delay
|
||||
|
||||
# Loop until a successful response or max_retries is hit or an exception is raised
|
||||
while True:
|
||||
try:
|
||||
|
@ -142,27 +141,32 @@ async def agenerate_from_openai_completion(
|
|||
@retry_with_exponential_backoff
|
||||
def generate_from_openai_completion(
|
||||
prompt: str,
|
||||
engine: str,
|
||||
model: str,
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
top_p: float,
|
||||
context_length: int,
|
||||
stop_token: str | None = None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None
|
||||
) -> str:
|
||||
if "OPENAI_API_KEY" not in os.environ:
|
||||
if "OPENAI_API_KEY" not in os.environ and api_key is None:
|
||||
raise ValueError(
|
||||
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
|
||||
)
|
||||
|
||||
if api_key is not None:
|
||||
client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
response = client.completions.create(
|
||||
prompt=prompt,
|
||||
engine=engine,
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
stop=[stop_token],
|
||||
)
|
||||
try:
|
||||
answer: str = response["choices"][0]["text"]
|
||||
except:
|
||||
answer: str = response.choices[0].text
|
||||
return answer
|
||||
|
||||
|
||||
|
|
13
VAB-WebArena-Lite/new/p_webrl.json
Normal file
13
VAB-WebArena-Lite/new/p_webrl.json
Normal file
|
@ -0,0 +1,13 @@
|
|||
{
|
||||
"intro": "",
|
||||
"examples": [],
|
||||
"template": "",
|
||||
"meta_data": {
|
||||
"observation": "webrl",
|
||||
"action_type": "webrl_id",
|
||||
"keywords": [],
|
||||
"prompt_constructor": "WebRLPromptConstructor",
|
||||
"answer_phrase": "",
|
||||
"action_splitter": ""
|
||||
}
|
||||
}
|
1351
VAB-WebArena-Lite/new/processors.py
Normal file
1351
VAB-WebArena-Lite/new/processors.py
Normal file
File diff suppressed because it is too large
Load Diff
|
@ -499,3 +499,59 @@ class MultimodalCoTPromptConstructor(CoTPromptConstructor):
|
|||
raise NotImplementedError(
|
||||
f"Provider {self.lm_config.provider} not implemented"
|
||||
)
|
||||
|
||||
class WebRLPromptConstructor(PromptConstructor):
|
||||
"""The agent will direct predict the action"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
instruction_path: str | Path,
|
||||
lm_config: lm_config.LMConfig,
|
||||
tokenizer: Tokenizer,
|
||||
):
|
||||
super().__init__(instruction_path, lm_config, tokenizer)
|
||||
|
||||
def construct(
|
||||
self,
|
||||
trajectory: Trajectory,
|
||||
intent: str,
|
||||
meta_data: dict[str, Any] = {},
|
||||
) -> APIInput:
|
||||
"""Construct prompt given the trajectory"""
|
||||
state_info: StateInfo = trajectory[-1] # type: ignore[assignment]
|
||||
|
||||
obs = state_info["observation"][self.obs_modality]
|
||||
max_obs_length = self.lm_config.gen_config["max_obs_length"]
|
||||
if max_obs_length:
|
||||
if self.lm_config.provider == "google":
|
||||
print("NOTE: This is a Gemini model, so we use characters instead of tokens for max_obs_length.")
|
||||
obs = obs[:max_obs_length]
|
||||
else:
|
||||
try:
|
||||
obs = self.tokenizer.decode(self.tokenizer.encode(obs)[:max_obs_length]) # type: ignore[arg-type]
|
||||
except:
|
||||
print("NOTE: There is no available tokenizer, so we use characters instead of tokens for max_obs_length.")
|
||||
obs = obs[:max_obs_length]
|
||||
|
||||
turn_num = len(meta_data["action_history"])
|
||||
if turn_num == 1:
|
||||
previous_action_str = []
|
||||
else:
|
||||
previous_action_str = meta_data["action_history"][1:]
|
||||
|
||||
index = turn_num - 1
|
||||
history = ""
|
||||
for i in range(index - 1, -1, -1):
|
||||
if i == 0:
|
||||
history = f"Round {i}\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n{intent}\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n{previous_action_str[i]}\n\n" + history
|
||||
else:
|
||||
history = f"Round {i}\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n** Simplified html **\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n{previous_action_str[i]}\n\n" + history
|
||||
if len(history) + len(obs) > (16384 - 512):
|
||||
obs = obs[:(16384 - 512)-len(history)]
|
||||
current_turn = f"Round {index}\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n{obs}\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
|
||||
prompt = f"Task Instruction: {intent}\n\n{history}{current_turn}"
|
||||
|
||||
return prompt
|
||||
|
||||
def extract_action(self, response: str) -> str:
|
||||
return response
|
|
@ -13,11 +13,13 @@ import tempfile
|
|||
import time
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
import cv2
|
||||
import shutil
|
||||
|
||||
import openai
|
||||
import requests
|
||||
import torch
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
from agent import (
|
||||
PromptAgent,
|
||||
|
@ -62,6 +64,34 @@ formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
|||
console_handler.setFormatter(formatter)
|
||||
file_handler.setFormatter(formatter)
|
||||
|
||||
def text_wrap(text, font, max_width):
|
||||
lines = []
|
||||
paragraphs = text.split('\n') # 按照 \n 分割文本为段落
|
||||
for paragraph in paragraphs:
|
||||
words = paragraph.split(' ')
|
||||
line = ''
|
||||
for word in words:
|
||||
# 临时行
|
||||
test_line = f"{line} {word}".strip()
|
||||
# 获取临时行的宽度
|
||||
test_line_bbox = font.getbbox(test_line)
|
||||
test_line_width = test_line_bbox[2] - test_line_bbox[0]
|
||||
if test_line_width <= max_width:
|
||||
# 如果临时行的宽度不超过图片宽度,继续添加单词
|
||||
line = test_line
|
||||
else:
|
||||
# 如果超过了最大宽度,保存当前行,开始新的一行
|
||||
lines.append(line)
|
||||
line = word
|
||||
# 添加每段的最后一行
|
||||
if line:
|
||||
lines.append(line)
|
||||
# 每个段落后添加一个空行,以保留段落的换行
|
||||
lines.append('')
|
||||
# 移除最后一个空行(不需要额外的空行)
|
||||
if lines[-1] == '':
|
||||
lines.pop()
|
||||
return lines
|
||||
|
||||
def config() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
|
@ -88,6 +118,7 @@ def config() -> argparse.Namespace:
|
|||
"html",
|
||||
"image",
|
||||
"image_som",
|
||||
"webrl",
|
||||
],
|
||||
default="accessibility_tree",
|
||||
help="Observation type",
|
||||
|
@ -176,6 +207,9 @@ def config() -> argparse.Namespace:
|
|||
|
||||
# logging related
|
||||
parser.add_argument("--result_dir", type=str, default="")
|
||||
|
||||
# if use self-deployed model
|
||||
parser.add_argument("--planner_ip", type=str, default=None)
|
||||
args = parser.parse_args()
|
||||
|
||||
# check the whether the action space is compatible with the observation space
|
||||
|
@ -196,7 +230,7 @@ def config() -> argparse.Namespace:
|
|||
|
||||
|
||||
def early_stop(
|
||||
trajectory: Trajectory, max_steps: int, thresholds: dict[str, int]
|
||||
trajectory: Trajectory, max_steps: int, thresholds: dict[str, int], actions=None
|
||||
) -> tuple[bool, str]:
|
||||
"""Check whether need to stop early"""
|
||||
|
||||
|
@ -228,8 +262,8 @@ def early_stop(
|
|||
if len(action_seq) == 0:
|
||||
return False, ""
|
||||
|
||||
if actions is None:
|
||||
last_action: Action = action_seq[-1]
|
||||
|
||||
if last_action["action_type"] != ActionTypes.TYPE:
|
||||
if len(last_k_actions) >= k:
|
||||
if all(
|
||||
|
@ -239,7 +273,6 @@ def early_stop(
|
|||
]
|
||||
):
|
||||
return True, f"Same action for {k} times"
|
||||
|
||||
else:
|
||||
# check the action sequence
|
||||
if (
|
||||
|
@ -247,9 +280,20 @@ def early_stop(
|
|||
>= k
|
||||
):
|
||||
return True, f"Same typing action for {k} times"
|
||||
|
||||
return False, ""
|
||||
|
||||
else:
|
||||
last_k_actions = actions[-k:]
|
||||
last_action = actions[-1]
|
||||
if len(last_k_actions) >= k:
|
||||
if all(
|
||||
[
|
||||
action == last_action
|
||||
for action in last_k_actions
|
||||
]
|
||||
):
|
||||
return True, f"Same action for {k} times"
|
||||
return False, ""
|
||||
|
||||
def update_action_history(path: str, task_id: int, actions: List[str], score: float=-0.1):
|
||||
obj = {
|
||||
|
@ -387,16 +431,20 @@ def test(
|
|||
obs, info = env.reset(options={"config_file": config_file})
|
||||
state_info: StateInfo = {"observation": obs, "info": info}
|
||||
trajectory.append(state_info)
|
||||
|
||||
meta_data = {"action_history": ["None"]}
|
||||
out_path = os.path.join(args.result_dir, "actions", f"{task_id}.json")
|
||||
actions = []
|
||||
|
||||
os.makedirs(os.path.join(args.result_dir, 'screehshots'), exist_ok=True)
|
||||
if os.path.exists(os.path.join(args.result_dir, 'screehshots', f"{task_id}")):
|
||||
shutil.rmtree(os.path.join(args.result_dir, 'screehshots', f"{task_id}"))
|
||||
os.makedirs(os.path.join(args.result_dir, 'screehshots', f"{task_id}"))
|
||||
|
||||
while True:
|
||||
update_action_history(out_path, task_id, actions=actions)
|
||||
|
||||
# If no actions variable is passed, the behavior of early_stop is the same as the original one.
|
||||
early_stop_flag, stop_info = early_stop(
|
||||
trajectory, max_steps, early_stop_thresholds
|
||||
trajectory, max_steps, early_stop_thresholds, actions
|
||||
)
|
||||
|
||||
if early_stop_flag:
|
||||
|
@ -407,7 +455,7 @@ def test(
|
|||
trajectory,
|
||||
intent,
|
||||
images=images,
|
||||
meta_data=meta_data,
|
||||
meta_data=meta_data
|
||||
)
|
||||
except ValueError as e:
|
||||
# get the error message
|
||||
|
@ -426,9 +474,46 @@ def test(
|
|||
render_helper.render(
|
||||
action, state_info, meta_data, args.render_screenshot
|
||||
)
|
||||
|
||||
current_screenshot = os.path.join(args.result_dir, 'screehshots', f"{task_id}", f"{len(actions)}.png")
|
||||
_ = env.page.viewport_size
|
||||
env.page.screenshot(path="/dev/null")
|
||||
env.page.screenshot(path=current_screenshot)
|
||||
element_id = action["element_id"]
|
||||
if element_id != "":
|
||||
element = env.page.query_selector(f"[data-label-id='{element_id}']")
|
||||
if element:
|
||||
bbox = element.bounding_box()
|
||||
bbox = [int(bbox['x']), int(bbox['y']), int(bbox['width']),int(bbox['height'])]
|
||||
image = cv2.imread(current_screenshot)
|
||||
cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[0] + bbox[2], bbox[1] + bbox[3]), (0, 255, 0), 2)
|
||||
cv2.circle(image, (int(bbox[0] + bbox[2] / 2), int(bbox[1] + bbox[3] / 2)), radius=0, color=(0, 255, 0), thickness=2)
|
||||
cv2.imwrite(current_screenshot, image)
|
||||
font_path = "/workspace/qzh/Arial.ttf" # 使用FreeSans字体
|
||||
font_size = 30 # 你可以调整这个值来增大或缩小字体
|
||||
image = Image.open(current_screenshot)
|
||||
font = ImageFont.truetype(font_path, font_size)
|
||||
draw = ImageDraw.Draw(image)
|
||||
image_width = image.width
|
||||
wrapped_text = text_wrap(action_str, font, image_width)
|
||||
line_height = font.getbbox('hg')[3] - font.getbbox('hg')[1]
|
||||
text_height = line_height * len(wrapped_text)
|
||||
new_image_height = image.height + text_height + 20 # 20 is extra white space
|
||||
new_image = Image.new('RGB', (image.width, new_image_height), color=(255, 255, 255)) # white background
|
||||
draw_new = ImageDraw.Draw(new_image)
|
||||
y_text = 10 # Initial position of text from top
|
||||
for line in wrapped_text:
|
||||
text_bbox = draw_new.textbbox((0, 0), line, font=font)
|
||||
text_width = text_bbox[2] - text_bbox[0]
|
||||
text_position = ((image_width - text_width) // 2, y_text) # Center text horizontally
|
||||
draw_new.text(text_position, line, font=font, fill=(0, 0, 0)) # black text
|
||||
y_text += line_height # move to next line
|
||||
new_image.paste(image, (0, text_height + 20))
|
||||
new_image.save(current_screenshot)
|
||||
|
||||
meta_data["action_history"].append(action_str)
|
||||
actions.append(action_str)
|
||||
print(action_str)
|
||||
print('Action String: ', action_str)
|
||||
|
||||
if action["action_type"] == ActionTypes.STOP:
|
||||
break
|
||||
|
@ -540,7 +625,7 @@ if __name__ == "__main__":
|
|||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
args = config()
|
||||
args.sleep_after_execution = 2.5
|
||||
args.sleep_after_execution = 3.0
|
||||
prepare(args)
|
||||
|
||||
test_config_base_dir = args.test_config_base_dir
|
||||
|
|
86
VAB-WebArena-Lite/new/score.py
Normal file
86
VAB-WebArena-Lite/new/score.py
Normal file
|
@ -0,0 +1,86 @@
|
|||
import os, json, sys, copy
|
||||
|
||||
USE_TASKS = [i for i in range(165)]
|
||||
|
||||
def get_result(res_dict, src="all"):
|
||||
if len(res_dict) == 0:
|
||||
return ''
|
||||
|
||||
success_id = [k for k, v in res_dict.items() if v >= 1.0]
|
||||
score = len(success_id)
|
||||
finish_count = len(res_dict)
|
||||
pacc, acc = score / finish_count * 100, score / TASKS * 100
|
||||
|
||||
print(sorted(success_id))
|
||||
|
||||
meta = """
|
||||
--------
|
||||
src file: {}
|
||||
successed: {:3} / {:4} (812)
|
||||
partial accuracy: {:7}
|
||||
overall accuracy: {:7}
|
||||
--------
|
||||
""".format(src, int(score), finish_count, round(pacc, 2), round(acc, 2))
|
||||
|
||||
print(meta)
|
||||
|
||||
def export_result(res_dict, src=".", note=["1.0", "0.0"], show_all=False):
|
||||
out_string = ""
|
||||
for id in USE_TASKS:
|
||||
# with open(f"Pipeline/config_files/{id}.json", "r") as f:
|
||||
# jd = json.load(f)
|
||||
|
||||
# if "map" in jd["sites"]:
|
||||
# continue
|
||||
if id in res_dict:
|
||||
if res_dict[id] >= 1.0:
|
||||
out_string += note[0]
|
||||
else:
|
||||
out_string += note[1]
|
||||
elif show_all:
|
||||
out_string += note[1]
|
||||
out_string += "\n"
|
||||
|
||||
with open(os.path.join(src, 'export.txt'), 'w') as f:
|
||||
f.write(out_string)
|
||||
|
||||
TASKS = 165
|
||||
|
||||
files = sys.argv[1]
|
||||
file_list = files.split(',')
|
||||
|
||||
all_result = {}
|
||||
|
||||
for src in file_list:
|
||||
path = os.path.join(src, 'actions')
|
||||
|
||||
result = {}
|
||||
finished = os.listdir(path)
|
||||
|
||||
for file in finished:
|
||||
if not file.endswith('.json'):
|
||||
continue
|
||||
with open(os.path.join(path, file), 'r') as f:
|
||||
data = json.load(f)
|
||||
|
||||
if not isinstance(data, dict):
|
||||
continue
|
||||
|
||||
task_id = data.get('task_id', 1000)
|
||||
# if task_id >= TASKS:
|
||||
# continue
|
||||
|
||||
task_score = data.get('score', 0)
|
||||
if task_score < 0:
|
||||
continue
|
||||
|
||||
result[task_id] = task_score
|
||||
if task_id not in all_result or task_score > all_result[task_id]:
|
||||
all_result[task_id] = task_score
|
||||
|
||||
get_result(result, src)
|
||||
export_result(result, src=src)
|
||||
|
||||
if len(file_list) > 1:
|
||||
get_result(all_result)
|
||||
export_result(all_result, show_all=True)
|
|
@ -187,7 +187,7 @@
|
|||
"string_match"
|
||||
],
|
||||
"reference_answers": {
|
||||
"must_include": [
|
||||
"fuzzy_match": [
|
||||
"0"
|
||||
]
|
||||
},
|
||||
|
@ -774,8 +774,8 @@
|
|||
"string_match"
|
||||
],
|
||||
"reference_answers": {
|
||||
"fuzzy_match": [
|
||||
"914km"
|
||||
"must_include": [
|
||||
"914"
|
||||
]
|
||||
},
|
||||
"reference_url": "",
|
||||
|
@ -1139,7 +1139,7 @@
|
|||
"string_match"
|
||||
],
|
||||
"reference_answers": {
|
||||
"must_include": [
|
||||
"fuzzy_match": [
|
||||
"0"
|
||||
]
|
||||
},
|
||||
|
@ -1679,7 +1679,7 @@
|
|||
"string_match"
|
||||
],
|
||||
"reference_answers": {
|
||||
"exact_match": "N/A"
|
||||
"fuzzy_match": ["N/A"]
|
||||
},
|
||||
"reference_url": "",
|
||||
"program_html": [],
|
||||
|
@ -2605,7 +2605,7 @@
|
|||
"string_match"
|
||||
],
|
||||
"reference_answers": {
|
||||
"exact_match": "yjlou"
|
||||
"must_include": "yjlou"
|
||||
},
|
||||
"reference_url": "",
|
||||
"program_html": [],
|
||||
|
@ -3712,7 +3712,7 @@
|
|||
"string_match"
|
||||
],
|
||||
"reference_answers": {
|
||||
"exact_match": "N/A"
|
||||
"fuzzy_match": "N/A"
|
||||
},
|
||||
"reference_url": "",
|
||||
"program_html": [],
|
||||
|
|
|
@ -1,15 +1,18 @@
|
|||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
from transformers import LlamaTokenizer # type: ignore
|
||||
from transformers import LlamaTokenizer, AutoTokenizer # type: ignore
|
||||
|
||||
|
||||
class Tokenizer(object):
|
||||
def __init__(self, provider: str, model_name: str) -> None:
|
||||
if provider == "openai":
|
||||
try:
|
||||
self.tokenizer = tiktoken.encoding_for_model(model_name)
|
||||
except: # The provider is in openai format but the model is a finetuned model
|
||||
self.tokenizer = None
|
||||
elif provider == "huggingface":
|
||||
self.tokenizer = LlamaTokenizer.from_pretrained(model_name)
|
||||
self.tokenizer = LlamaTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
# turn off adding special tokens automatically
|
||||
self.tokenizer.add_special_tokens = False # type: ignore[attr-defined]
|
||||
self.tokenizer.add_bos_token = False # type: ignore[attr-defined]
|
||||
|
|
|
@ -21,6 +21,8 @@ APIInput = str | list[Any] | dict[str, Any]
|
|||
def call_llm(
|
||||
lm_config: lm_config.LMConfig,
|
||||
prompt: APIInput,
|
||||
api_key = None,
|
||||
base_url = None
|
||||
) -> str:
|
||||
response: str
|
||||
if lm_config.provider == "openai":
|
||||
|
@ -39,11 +41,13 @@ def call_llm(
|
|||
assert isinstance(prompt, str)
|
||||
response = generate_from_openai_completion(
|
||||
prompt=prompt,
|
||||
engine=lm_config.model,
|
||||
model=lm_config.model,
|
||||
temperature=lm_config.gen_config["temperature"],
|
||||
max_tokens=lm_config.gen_config["max_tokens"],
|
||||
top_p=lm_config.gen_config["top_p"],
|
||||
stop_token=lm_config.gen_config["stop_token"],
|
||||
api_key=api_key,
|
||||
base_url=base_url
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
|
97
VAB-WebArena-Lite/new/wa_parallel_run_webrl.sh
Normal file
97
VAB-WebArena-Lite/new/wa_parallel_run_webrl.sh
Normal file
|
@ -0,0 +1,97 @@
|
|||
#!/bin/bash
|
||||
DATASET='webarena' # TODO: select from ['webarena', 'visualwebarena']
|
||||
result_dir='' # TODO: set your result_dir
|
||||
provider='openai' # TODO: select from ['openai', 'finetune', ...]
|
||||
model='' # TODO: assign model name, which is used for action generation
|
||||
planner_ip='' # TODO: ip address of the model you are deploying (only if you are deploying your own model using e.g. vllm)
|
||||
instruction_path='agent/prompts/jsons/p_webrl.json' # e.g., agent/prompts/jsons/p_cot_id_actree_2s.json
|
||||
test_config_base_dir='config_files/wa/test_webarena_lite' # e.g., config_files/wa/test_webarena_lite
|
||||
temperature=0.0
|
||||
|
||||
SERVER='' # TODO: your server address
|
||||
MAP_SERVER='' # TODO: the server address for MAP tasks
|
||||
OPENAI_API_KEY='' # TODO: if you test OpenAI APIs
|
||||
OPENAI_ORGANIZATION=''
|
||||
CONDA_ENV_NAME='' # TODO: the name of your conda environment for testing WebArena
|
||||
|
||||
ENV_VARIABLES="export DATASET=${DATASET}; export SHOPPING='http://${SERVER}:7770';export SHOPPING_ADMIN='http://${SERVER}:7780/admin';export REDDIT='http://${SERVER}:9999';export GITLAB='http://${SERVER}:8023';export MAP='http://${MAP_SERVER}:3000';export WIKIPEDIA='http://${SERVER}:8888/wikipedia_en_all_maxi_2022-05/A/User:The_other_Kiwix_guy/Landing';export HOMEPAGE='http://${SERVER}:4399';export OPENAI_API_KEY=${OPENAI_API_KEY};export OPENAI_ORGANIZATION=${OPENAI_ORGANIZATION}"
|
||||
|
||||
# get the number of tmux panes
|
||||
num_panes=$(tmux list-panes | wc -l)
|
||||
|
||||
# calculate how many panes need to be created
|
||||
let "panes_to_create = 7 - num_panes"
|
||||
|
||||
# array of tmux commands to create each pane
|
||||
tmux_commands=(
|
||||
'tmux split-window -h'
|
||||
'tmux split-window -v'
|
||||
'tmux select-pane -t 0; tmux split-window -v'
|
||||
'tmux split-window -v'
|
||||
'tmux select-pane -t 3; tmux split-window -v'
|
||||
'tmux select-pane -t 5; tmux split-window -v'
|
||||
)
|
||||
|
||||
# create panes up to 7
|
||||
for ((i=0; i<$panes_to_create; i++)); do
|
||||
eval ${tmux_commands[$i]}
|
||||
done
|
||||
|
||||
#!/bin/bash
|
||||
|
||||
# Function to run a job
|
||||
run_job() {
|
||||
tmux select-pane -t $1
|
||||
COMMAND="python run.py \
|
||||
--instruction_path ${instruction_path} \
|
||||
--test_start_idx $2 \
|
||||
--test_end_idx $3 \
|
||||
--result_dir ${result_dir} \
|
||||
--test_config_base_dir ${test_config_base_dir} \
|
||||
--provider ${provider} \
|
||||
--model ${model} \
|
||||
--mode completion \
|
||||
--planner_ip ${planner_ip} \
|
||||
--stop_token \"<|eot_id|>\" \
|
||||
--temperature ${temperature} \
|
||||
--max_obs_length 0 \
|
||||
--max_tokens 2048 \
|
||||
--viewport_width 1280 \
|
||||
--viewport_height 720 \
|
||||
--parsing_failure_th 5 \
|
||||
--repeating_action_failure_th 5 \
|
||||
--action_set_tag webrl_id --observation_type webrl"
|
||||
tmux send-keys "tmux set mouse on; conda activate ${CONDA_ENV_NAME}; ${ENV_VARIABLES}; until ${COMMAND}; do echo 'crashed' >&2; sleep 1; done" C-m
|
||||
sleep 3
|
||||
}
|
||||
|
||||
TOLERANCE=2
|
||||
run_batch() {
|
||||
args=("$@") # save all arguments in an array
|
||||
num_jobs=${#args[@]} # get number of arguments
|
||||
|
||||
for ((i=1; i<$num_jobs; i++)); do
|
||||
run_job $i ${args[i-1]} ${args[i]}
|
||||
done
|
||||
|
||||
# Wait for all jobs to finish
|
||||
while tmux list-panes -F "#{pane_pid} #{pane_current_command}" | grep -q python; do
|
||||
sleep 100 # wait for 10 seconds before checking again
|
||||
done
|
||||
|
||||
# Run checker
|
||||
while ! python scripts/check_error_runs.py ${result_dir} --delete_errors --tolerance ${TOLERANCE}; do
|
||||
echo "Check failed, rerunning jobs..."
|
||||
for ((i=1; i<$num_jobs; i++)); do
|
||||
run_job $i ${args[i-1]} ${args[i]}
|
||||
done
|
||||
|
||||
# Wait for all jobs to finish
|
||||
while tmux list-panes -F "#{pane_pid} #{pane_current_command}" | grep -q python; do
|
||||
sleep 100 # wait for 10 seconds before checking again
|
||||
done
|
||||
done
|
||||
|
||||
}
|
||||
run_batch 0 28 56 84 112 140 165
|
||||
|
|
@ -14,6 +14,14 @@ cp -f new/generate_test_data.py visualwebarena/scripts/generate_test_data.py
|
|||
cp -f new/run.py visualwebarena/run.py
|
||||
cp -f new/agent.py visualwebarena/agent/agent.py
|
||||
cp -f new/prompt_constructor.py visualwebarena/agent/prompts/prompt_constructor.py
|
||||
cp -f new/p_webrl.json visualwebarena/agent/prompts/jsons/p_webrl.json
|
||||
|
||||
# browser_env
|
||||
cp -f new/actions.py visualwebarena/browser_env/actions.py
|
||||
cp -f new/envs.py visualwebarena/browser_env/envs.py
|
||||
cp -f new/helper_functions_browser.py visualwebarena/browser_env/helper_functions.py
|
||||
cp -f new/processors.py visualwebarena/browser_env/processors.py
|
||||
cp -rf new/html_tools visualwebarena/browser_env/
|
||||
|
||||
# llms
|
||||
cp -f new/utils.py visualwebarena/llms/utils.py
|
||||
|
@ -22,15 +30,19 @@ cp -f new/lm_config.py visualwebarena/llms/lm_config.py
|
|||
cp -f new/tokenizers.py visualwebarena/llms/tokenizers.py
|
||||
cp -f new/api_utils.py visualwebarena/llms/providers/api_utils.py
|
||||
cp -f new/openai_utils.py visualwebarena/llms/providers/openai_utils.py
|
||||
cp -f new/utils.py visualwebarena/llms/utils.py
|
||||
|
||||
# eval
|
||||
cp -f new/evaluators.py visualwebarena/evaluation_harness/evaluators.py
|
||||
cp -f new/helper_functions.py visualwebarena/evaluation_harness/helper_functions.py
|
||||
cp -f new/helper_functions_eval.py visualwebarena/evaluation_harness/helper_functions.py
|
||||
|
||||
# misc
|
||||
cp -f README.md visualwebarena/README.md
|
||||
cp -f new/wa_parallel_run.sh visualwebarena/wa_parallel_run.sh
|
||||
|
||||
cp -f new/score.py visualwebarena/score.py
|
||||
cp -f new/wa_parallel_run_webrl.sh visualwebarena/wa_parallel_run_webrl.sh
|
||||
|
||||
# 3. remove temporary files
|
||||
mv visualwebarena/* .
|
||||
rm -rf new
|
||||
|
|
Loading…
Reference in New Issue
Block a user