From 521d7e999af0b88034411bfd534a8a82db67cd2f Mon Sep 17 00:00:00 2001
From: QZH-777 <1961710177@qq.com>
Date: Thu, 14 Nov 2024 15:51:41 +0800
Subject: [PATCH] add webrl mode
---
VAB-WebArena-Lite/new/actions.py | 2006 +++++++++++++++++
VAB-WebArena-Lite/new/agent.py | 15 +-
VAB-WebArena-Lite/new/envs.py | 319 +++
VAB-WebArena-Lite/new/evaluators.py | 4 +
.../new/helper_functions_browser.py | 239 ++
..._functions.py => helper_functions_eval.py} | 0
VAB-WebArena-Lite/new/html_tools/__init__.py | 7 +
.../new/html_tools/configs/__init__.py | 3 +
.../new/html_tools/configs/config.py | 56 +
.../new/html_tools/configs/html_prompt.py | 22 +
VAB-WebArena-Lite/new/html_tools/fetch.py | 108 +
.../new/html_tools/html_parser.py | 447 ++++
.../new/html_tools/identifier.py | 64 +
VAB-WebArena-Lite/new/html_tools/prompt.py | 97 +
.../new/html_tools/scripts/__init__.py | 43 +
.../html_tools/scripts/clickable_checker.js | 148 ++
.../new/html_tools/scripts/element_info.js | 34 +
.../new/html_tools/scripts/label.js | 74 +
.../new/html_tools/scripts/label_marker.js | 65 +
.../new/html_tools/scripts/prepare.js | 83 +
VAB-WebArena-Lite/new/html_tools/utils.py | 101 +
VAB-WebArena-Lite/new/openai_utils.py | 18 +-
VAB-WebArena-Lite/new/p_webrl.json | 13 +
VAB-WebArena-Lite/new/processors.py | 1351 +++++++++++
VAB-WebArena-Lite/new/prompt_constructor.py | 56 +
VAB-WebArena-Lite/new/run.py | 137 +-
VAB-WebArena-Lite/new/score.py | 86 +
.../new/test_webarena_lite.raw.json | 14 +-
VAB-WebArena-Lite/new/tokenizers.py | 9 +-
VAB-WebArena-Lite/new/utils.py | 6 +-
.../new/wa_parallel_run_webrl.sh | 97 +
VAB-WebArena-Lite/replace.sh | 14 +-
32 files changed, 5688 insertions(+), 48 deletions(-)
create mode 100644 VAB-WebArena-Lite/new/actions.py
create mode 100644 VAB-WebArena-Lite/new/envs.py
create mode 100644 VAB-WebArena-Lite/new/helper_functions_browser.py
rename VAB-WebArena-Lite/new/{helper_functions.py => helper_functions_eval.py} (100%)
create mode 100755 VAB-WebArena-Lite/new/html_tools/__init__.py
create mode 100755 VAB-WebArena-Lite/new/html_tools/configs/__init__.py
create mode 100755 VAB-WebArena-Lite/new/html_tools/configs/config.py
create mode 100755 VAB-WebArena-Lite/new/html_tools/configs/html_prompt.py
create mode 100755 VAB-WebArena-Lite/new/html_tools/fetch.py
create mode 100755 VAB-WebArena-Lite/new/html_tools/html_parser.py
create mode 100755 VAB-WebArena-Lite/new/html_tools/identifier.py
create mode 100755 VAB-WebArena-Lite/new/html_tools/prompt.py
create mode 100755 VAB-WebArena-Lite/new/html_tools/scripts/__init__.py
create mode 100755 VAB-WebArena-Lite/new/html_tools/scripts/clickable_checker.js
create mode 100755 VAB-WebArena-Lite/new/html_tools/scripts/element_info.js
create mode 100755 VAB-WebArena-Lite/new/html_tools/scripts/label.js
create mode 100755 VAB-WebArena-Lite/new/html_tools/scripts/label_marker.js
create mode 100755 VAB-WebArena-Lite/new/html_tools/scripts/prepare.js
create mode 100755 VAB-WebArena-Lite/new/html_tools/utils.py
create mode 100644 VAB-WebArena-Lite/new/p_webrl.json
create mode 100644 VAB-WebArena-Lite/new/processors.py
create mode 100644 VAB-WebArena-Lite/new/score.py
create mode 100644 VAB-WebArena-Lite/new/wa_parallel_run_webrl.sh
diff --git a/VAB-WebArena-Lite/new/actions.py b/VAB-WebArena-Lite/new/actions.py
new file mode 100644
index 0000000..eedb327
--- /dev/null
+++ b/VAB-WebArena-Lite/new/actions.py
@@ -0,0 +1,2006 @@
+"""
+Browser Env action space.
+Inspited by Farama-Foundation/miniwob-plusplus
+"""
+import ast
+import random
+import re
+import string
+from enum import IntEnum
+from itertools import chain
+from typing import Any, TypedDict, Union, cast
+import time
+
+import numpy as np
+import numpy.typing as npt
+from beartype import beartype
+from beartype.door import is_bearable
+from gymnasium import spaces
+from playwright._impl._api_structures import ViewportSize
+from playwright.async_api import BrowserContext as ABrowserContext
+from playwright.async_api import Locator as ALocator
+from playwright.async_api import Page as APage
+from playwright.sync_api import BrowserContext, Locator, Page
+
+from browser_env.constants import (
+ ASCII_CHARSET,
+ FREQ_UNICODE_CHARSET,
+ MAX_ANSWER_LENGTH,
+ MAX_ELEMENT_ID,
+ MAX_ELEMENT_INDEX_IN_VIEWPORT,
+ MAX_PAGE_NUMBER,
+ MAX_VANILLA_STR_LENGTH,
+ PLAYWRIGHT_ACTIONS,
+ PLAYWRIGHT_LOCATORS,
+ ROLES,
+ SPECIAL_KEY_MAPPINGS,
+ SPECIAL_KEYS,
+ SPECIAL_LOCATORS,
+ TEXT_MAX_LENGTH,
+ TYPING_MAX_LENGTH,
+ URL_MAX_LENGTH,
+ RolesType,
+)
+from browser_env.processors import ObservationProcessor
+
+
+class ParsedPlaywrightCode(TypedDict):
+ function_name: str
+ arguments: list[str]
+ keywords: dict[str, Any]
+
+
+from browser_env.processors import (
+ ObservationProcessor,
+ TextObervationProcessor,
+)
+
+
+@beartype
+def is_in_viewport(
+ element: Locator, viewport: ViewportSize, threshold: float = 0.3
+) -> bool:
+ """Given a playwright locator, check if it is in the viewport"""
+ box = element.bounding_box()
+ assert box is not None
+ boxx0 = box["x"]
+ boxx1 = box["x"] + box["width"]
+ boxy0 = box["y"]
+ boxy1 = box["y"] + box["height"]
+ viewportx0, viewporty0 = 0, 0
+ viewportx1, viewporty1 = viewport["width"], viewport["height"]
+ inter = max(0, min(boxx1, viewportx1) - max(boxx0, viewportx0)) * max(
+ 0, min(boxy1, viewporty1) - max(boxy0, viewporty0)
+ )
+ ratio = inter / (box["width"] * box["height"])
+ return ratio > threshold
+
+
+@beartype
+async def async_is_in_viewport(
+ element: ALocator, viewport: ViewportSize, threshold: float = 0.3
+) -> bool:
+ box = await element.bounding_box()
+ assert box is not None
+ boxx0 = box["x"]
+ boxx1 = box["x"] + box["width"]
+ boxy0 = box["y"]
+ boxy1 = box["y"] + box["height"]
+ viewportx0, viewporty0 = 0, 0
+ viewportx1, viewporty1 = viewport["width"], viewport["height"]
+ inter = max(0, min(boxx1, viewportx1) - max(boxx0, viewportx0)) * max(
+ 0, min(boxy1, viewporty1) - max(boxy0, viewporty0)
+ )
+ ratio = inter / (box["width"] * box["height"])
+ return ratio > threshold
+
+
+class Action(TypedDict):
+ action_type: int
+ coords: npt.NDArray[np.float32]
+ element_role: int
+ element_name: str
+ text: list[int]
+ page_number: int
+ url: str
+ nth: int
+ element_id: str
+ direction: str
+ key_comb: str
+ pw_code: str
+ answer: str
+ raw_prediction: str # raw prediction from the model
+
+
+@beartype
+def action2str(
+ action: Action, action_set_tag: str, semantic_element: str = ""
+) -> str:
+ """Return the string representation of an action
+
+ sementic_element: the semantic information of the element
+ such as a line in an accessibility tree
+ """
+ if action_set_tag in [
+ "id_accessibility_tree",
+ "id_accessibility_tree_with_captioner",
+ ]:
+ element_id = action["element_id"]
+ match action["action_type"]:
+ case ActionTypes.CLICK:
+ # [ID=X] xxxxx
+ action_str = f"click [{element_id}] where [{element_id}] is {semantic_element}"
+ case ActionTypes.TYPE:
+ text = "".join([_id2key[i] for i in action["text"]])
+ action_str = f"type [{element_id}] [{text}] where [{element_id}] is {semantic_element}"
+ case ActionTypes.HOVER:
+ action_str = f"hover [{element_id}] where [{element_id}] is {semantic_element}"
+ case ActionTypes.SCROLL:
+ action_str = f"scroll [{action['direction']}]"
+ case ActionTypes.KEY_PRESS:
+ action_str = f"press [{action['key_comb']}]"
+ case ActionTypes.GOTO_URL:
+ action_str = f"goto [{action['url']}]"
+ case ActionTypes.NEW_TAB:
+ action_str = "new_tab"
+ case ActionTypes.PAGE_CLOSE:
+ action_str = "close_tab"
+ case ActionTypes.GO_BACK:
+ action_str = "go_back"
+ case ActionTypes.GO_FORWARD:
+ action_str = "go_forward"
+ case ActionTypes.PAGE_FOCUS:
+ action_str = f"page_focus [{action['page_number']}]"
+ case ActionTypes.CLEAR:
+ action_str = f"clear [{element_id}] where [{element_id}] is {semantic_element}"
+ case ActionTypes.STOP:
+ action_str = f"stop [{action['answer']}]"
+ case ActionTypes.NONE:
+ action_str = "none"
+ case _:
+ raise ValueError(
+ f"Unknown action type {action['action_type']}"
+ )
+ elif action_set_tag == "som":
+ element_id = action["element_id"]
+ match action["action_type"]:
+ case ActionTypes.CLICK:
+ # [ID=X] xxxxx
+ action_str = f"click [{element_id}] where [{element_id}]"
+ case ActionTypes.CLEAR:
+ action_str = f"clear [{element_id}] where [{element_id}] is {semantic_element}"
+ case ActionTypes.TYPE:
+ text = "".join([_id2key[i] for i in action["text"]])
+ action_str = (
+ f"type [{element_id}] [{text}] where [{element_id}]"
+ )
+ case ActionTypes.HOVER:
+ action_str = f"hover [{element_id}] where [{element_id}]"
+ case ActionTypes.SCROLL:
+ action_str = f"scroll [{action['direction']}]"
+ case ActionTypes.KEY_PRESS:
+ action_str = f"press [{action['key_comb']}]"
+ case ActionTypes.GOTO_URL:
+ action_str = f"goto [{action['url']}]"
+ case ActionTypes.NEW_TAB:
+ action_str = "new_tab"
+ case ActionTypes.PAGE_CLOSE:
+ action_str = "close_tab"
+ case ActionTypes.GO_BACK:
+ action_str = "go_back"
+ case ActionTypes.GO_FORWARD:
+ action_str = "go_forward"
+ case ActionTypes.PAGE_FOCUS:
+ action_str = f"page_focus [{action['page_number']}]"
+ case ActionTypes.STOP:
+ action_str = f"stop [{action['answer']}]"
+ case ActionTypes.NONE:
+ action_str = "none"
+ case _:
+ raise ValueError(
+ f"Unknown action type {action['action_type']}"
+ )
+ else:
+ raise NotImplementedError(f"Unknown action set tag {action_set_tag}")
+
+ return action_str
+
+
+def action2create_function(action: Action) -> str:
+ match (action["action_type"]):
+ case ActionTypes.NONE:
+ return "create_none_action()"
+ # mouse wheel and keyboard action
+ case ActionTypes.SCROLL:
+ direction = "up" if "up" in action["direction"] else "down"
+ return f"create_scroll_action({repr(direction)})"
+ case ActionTypes.KEY_PRESS:
+ return f"create_key_press_action({repr(action['key_comb'])})"
+ # inter-page actions
+ case ActionTypes.PAGE_FOCUS:
+ return f"create_page_focus_action({action['page_number']})"
+ case ActionTypes.NEW_TAB:
+ return "create_new_tab_action()"
+ case ActionTypes.GO_BACK:
+ return "create_go_back_action()"
+ case ActionTypes.GO_FORWARD:
+ return "create_go_forward_action()"
+ case ActionTypes.GOTO_URL:
+ return f"create_goto_url_action({repr(action['url'])})"
+ case ActionTypes.PAGE_CLOSE:
+ return "create_page_close_action()"
+
+ # low-level keyboard and mouse actions
+ case ActionTypes.MOUSE_CLICK:
+ return f"create_mouse_click_action({action['coords'][0]}, {action['coords'][1]})"
+ case ActionTypes.MOUSE_HOVER:
+ return f"create_mouse_hover_action({action['coords'][0]}, {action['coords'][1]})"
+ case ActionTypes.KEYBOARD_TYPE:
+ return f"create_keyboard_type_action({list(map(lambda x: _id2key[x], action['text']))})"
+
+ # mid-level keyboard and mouse actions
+ case ActionTypes.CLICK:
+ args = []
+ args.append(f"element_id={repr(action['element_id'])}")
+ args.append(
+ f"element_role={repr(_id2role[action['element_role']])}"
+ )
+ args.append(f"element_name={repr(action['element_name'])}")
+ args.append(f"pw_code={repr(action['pw_code'])}")
+ args_str = ", ".join(args)
+ return f"create_click_action({args_str})"
+ case ActionTypes.CLEAR:
+ args = []
+ args.append(f"element_id={repr(action['element_id'])}")
+ args.append(
+ f"element_role={repr(_id2role[action['element_role']])}"
+ )
+ args.append(f"element_name={repr(action['element_name'])}")
+ args.append(f"pw_code={repr(action['pw_code'])}")
+ args_str = ", ".join(args)
+ return f"create_clear_action({args_str})"
+ case ActionTypes.HOVER:
+ args = []
+ args.append(f"element_id={repr(action['element_id'])}")
+ args.append(
+ f"element_role={repr(_id2role[action['element_role']])}"
+ )
+ args.append(f"element_name={repr(action['element_name'])}")
+ args.append(f"pw_code={repr(action['pw_code'])}")
+ args_str = ", ".join(args)
+ return f"create_hover_action({args_str})"
+ case ActionTypes.TYPE:
+ args = []
+ text = "".join(map(lambda x: _id2key[x], action["text"]))
+ args.append(f"text={repr(text)}")
+ args.append(f"element_id={repr(action['element_id'])}")
+ args.append(
+ f"element_role={repr(_id2role[action['element_role']])}"
+ )
+ args.append(f"element_name={repr(action['element_name'])}")
+ args.append(f"pw_code={repr(action['pw_code'])}")
+ args_str = ", ".join(args)
+ return f"create_type_action({args_str})"
+
+ # high-level actions, only support locators from playwright
+ case ActionTypes.CHECK:
+ return f"create_check_action(pw_code={repr(action['pw_code'])})"
+ case ActionTypes.SELECT_OPTION:
+ return f"create_select_option_action(pw_code={repr(action['pw_code'])})"
+ case ActionTypes.STOP:
+ return f'create_stop_action({repr(action["answer"])})'
+
+ raise ValueError(f"Invalid action type: {action['action_type']}")
+
+
+class ActionTypes(IntEnum):
+ """Valid action types for browser env."""
+
+ NONE = 0
+ # mouse wheel and keyboard, universal across all action spaces
+ SCROLL = 1
+ KEY_PRESS = 2
+
+ # low level mouse and keyboard actions
+ MOUSE_CLICK = 3
+ KEYBOARD_TYPE = 4
+ MOUSE_HOVER = 5
+
+ # mid level mouse and keyboard actions
+ CLICK = 6
+ TYPE = 7
+ HOVER = 8
+
+ # page level actions, universal across all action spaces
+ PAGE_FOCUS = 9
+ NEW_TAB = 10
+ GO_BACK = 11
+ GO_FORWARD = 12
+ GOTO_URL = 13
+ PAGE_CLOSE = 14
+
+ # high-leval actions that playwright support
+ CHECK = 15
+ SELECT_OPTION = 16
+
+ STOP = 17
+ CLEAR = 18
+
+ # webrl actions
+ SEARCH = 19
+ SELECT_DROPDOWN_OPTION = 20
+
+ def __str__(self) -> str:
+ return f"ACTION_TYPES.{self.name}"
+
+
+@beartype
+def is_equivalent(a: Action, b: Action) -> bool:
+ """Return True if two actions are equal."""
+ if a["action_type"] != b["action_type"]:
+ return False
+ match (a["action_type"]):
+ case ActionTypes.NONE:
+ return True
+ case ActionTypes.SCROLL:
+ da = "up" if "up" in a["direction"] else "down"
+ db = "up" if "up" in b["direction"] else "down"
+ return da == db
+ case ActionTypes.KEY_PRESS:
+ return a["key_comb"] == b["key_comb"]
+ case ActionTypes.MOUSE_CLICK | ActionTypes.MOUSE_HOVER:
+ return np.allclose(a["coords"], b["coords"])
+ case ActionTypes.KEYBOARD_TYPE:
+ return a["text"] == b["text"]
+ case ActionTypes.CLICK | ActionTypes.HOVER | ActionTypes.TYPE: # TODO: can be further optimized
+ if a["element_id"] and b["element_id"]:
+ return a["element_id"] == b["element_id"]
+ elif a["element_role"] and b["element_role"]:
+ return (
+ a["element_role"] == b["element_role"]
+ and a["element_name"] == b["element_name"]
+ )
+ elif a["pw_code"] and b["pw_code"]:
+ return a["pw_code"] == b["pw_code"]
+ else:
+ return False
+ case ActionTypes.PAGE_FOCUS:
+ return a["page_number"] == b["page_number"]
+ case ActionTypes.NEW_TAB:
+ return True
+ case ActionTypes.GO_BACK:
+ return True
+ case ActionTypes.GO_FORWARD:
+ return True
+ case ActionTypes.GOTO_URL:
+ return a["url"] == b["url"]
+ case ActionTypes.PAGE_CLOSE:
+ return True
+ case ActionTypes.CHECK | ActionTypes.SELECT_OPTION:
+ return a["pw_code"] == b["pw_code"]
+ case ActionTypes.STOP:
+ return a["answer"] == b["answer"]
+ case _:
+ raise ValueError(f"Unknown action type: {a['action_type']}")
+
+
+_key2id: dict[str, int] = {
+ key: i
+ for i, key in enumerate(
+ chain(SPECIAL_KEYS, ASCII_CHARSET, FREQ_UNICODE_CHARSET, ["\n"])
+ )
+}
+_id2key: list[str] = sorted(_key2id, key=_key2id.get) # type: ignore[arg-type]
+_role2id: dict[RolesType, int] = {
+ cast(RolesType, role): i
+ for i, role in enumerate(chain(ROLES, SPECIAL_LOCATORS))
+}
+_id2role: list[RolesType] = sorted(_role2id, key=_role2id.get) # type: ignore[arg-type]
+
+
+@beartype
+def _keys2ids(keys: list[int | str] | str) -> list[int]:
+ return list(
+ map(
+ lambda key: _key2id.get(str(key), _key2id.get(key, " "))
+ if is_bearable(key, str)
+ else int(key),
+ keys,
+ )
+ )
+
+
+def get_action_space() -> spaces.Dict:
+ """Return the space of serialized actions."""
+ space = spaces.Dict(
+ {
+ "action_type": spaces.Discrete(len(ActionTypes)),
+ # coords (left, top) is used for COORD_CLICK
+ "coords": spaces.Box(
+ np.array([0.0, 0.0], dtype=np.float32),
+ np.array([1.0, 1.0], dtype=np.float32),
+ ),
+ # element role is used for FOCUS_AND_CLICK and FOCUS_AND_TYPE
+ "element_role": spaces.Discrete(
+ len(ROLES) + len(SPECIAL_LOCATORS)
+ ),
+ # element name is used with element role
+ "element_name": spaces.Text(TEXT_MAX_LENGTH),
+ "element_id": spaces.Text(TEXT_MAX_LENGTH),
+ # text is only used for TYPE and FOCUS_AND_TYPE
+ "text": spaces.MultiDiscrete(
+ [
+ len(ASCII_CHARSET)
+ + len(SPECIAL_KEYS)
+ + len(FREQ_UNICODE_CHARSET)
+ ]
+ * TYPING_MAX_LENGTH
+ ),
+ "page_number": spaces.Discrete(MAX_PAGE_NUMBER),
+ "url": spaces.Text(URL_MAX_LENGTH),
+ "nth": spaces.Discrete(MAX_ELEMENT_INDEX_IN_VIEWPORT),
+ "key_comb": spaces.Text(MAX_VANILLA_STR_LENGTH),
+ "direction": spaces.Text(MAX_VANILLA_STR_LENGTH),
+ "pw_code": spaces.Text(MAX_VANILLA_STR_LENGTH),
+ "answer": spaces.Text(MAX_ANSWER_LENGTH),
+ }
+ )
+ return space
+
+
+def create_random_action() -> Action:
+ """Return a random action."""
+ return {
+ "action_type": np.random.randint(len(ActionTypes)),
+ "coords": np.random.rand(2).astype(np.float32),
+ "element_role": np.random.randint(len(ROLES) + len(SPECIAL_LOCATORS)),
+ "element_name": "".join(
+ random.choices(ASCII_CHARSET, k=np.random.randint(TEXT_MAX_LENGTH))
+ ),
+ "text": list(
+ random.choices(
+ list(range(len(ASCII_CHARSET))),
+ k=np.random.randint(TYPING_MAX_LENGTH),
+ )
+ ),
+ "page_number": np.random.randint(MAX_PAGE_NUMBER),
+ "url": "".join(
+ random.choices(ASCII_CHARSET, k=np.random.randint(URL_MAX_LENGTH))
+ ),
+ "nth": np.random.randint(MAX_ELEMENT_INDEX_IN_VIEWPORT),
+ "element_id": str(np.random.randint(MAX_ELEMENT_ID)),
+ "key_comb": "+".join(
+ random.choices(SPECIAL_KEYS, k=np.random.randint(3))
+ ),
+ "direction": random.choice(["up", "down"]),
+ "pw_code": "".join(
+ random.choices(
+ string.ascii_uppercase + string.digits,
+ k=np.random.randint(MAX_VANILLA_STR_LENGTH),
+ )
+ ),
+ "answer": str(np.random.randint(MAX_ANSWER_LENGTH)),
+ "raw_prediction": str(np.random.randint(MAX_ANSWER_LENGTH)),
+ }
+
+
+@beartype
+def create_none_action() -> Action:
+ """Return a valid action object that does nothing."""
+ return {
+ "action_type": ActionTypes.NONE,
+ "coords": np.zeros(2, dtype=np.float32),
+ "element_role": 0,
+ "element_name": "",
+ "text": [],
+ "page_number": 0,
+ "url": "",
+ "nth": 0,
+ "pw_code": "", # str that requires further processing
+ "element_id": "",
+ "key_comb": "",
+ "direction": "",
+ "answer": "",
+ "raw_prediction": "",
+ }
+
+
+@beartype
+def create_stop_action(answer: str) -> Action:
+ action = create_none_action()
+ action.update({"action_type": ActionTypes.STOP, "answer": answer})
+ return action
+
+
+@beartype
+def create_scroll_action(direction: str) -> Action:
+ """Return the playwright action"""
+ assert direction in ["up", "down"]
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.SCROLL,
+ "direction": direction,
+ }
+ )
+ return action
+
+
+@beartype
+def create_mouse_hover_action(
+ left: float | None = None, top: float | None = None
+) -> Action:
+ """Return a valid action object with type COORD_CLICK."""
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.MOUSE_HOVER,
+ "coords": np.array([left, top], dtype=np.float32),
+ }
+ )
+ return action
+
+
+@beartype
+def create_key_press_action(key_comb: str) -> Action:
+ """Return the key press action"""
+
+ def map_keys(key_comb: str) -> str:
+ keys = key_comb.split("+")
+ mapped_keys = []
+ for key in keys:
+ mapped_key = SPECIAL_KEY_MAPPINGS.get(key.lower(), key)
+ mapped_keys.append(mapped_key)
+ return "+".join(mapped_keys)
+
+ action = create_none_action()
+ mapped_key_comb = map_keys(key_comb)
+ action.update(
+ {
+ "action_type": ActionTypes.KEY_PRESS,
+ "key_comb": mapped_key_comb,
+ }
+ )
+ return action
+
+
+@beartype
+def create_page_focus_action(page_number: int) -> Action:
+ """Return a valid action object with type PAGE_FOCUS."""
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.PAGE_FOCUS,
+ "page_number": page_number,
+ }
+ )
+ return action
+
+
+@beartype
+def create_new_tab_action() -> Action:
+ """Return a valid action object with type NEW_TAB."""
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.NEW_TAB,
+ }
+ )
+ return action
+
+
+@beartype
+def create_go_back_action() -> Action:
+ """Return a valid action object with type GO_BACK."""
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.GO_BACK,
+ }
+ )
+ return action
+
+
+@beartype
+def create_go_forward_action() -> Action:
+ """Return a valid action object with type GO_FORWARD."""
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.GO_FORWARD,
+ }
+ )
+ return action
+
+
+@beartype
+def create_goto_url_action(url: str) -> Action:
+ """Return a valid action object with type GOTO_URL."""
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.GOTO_URL,
+ "url": url,
+ }
+ )
+ return action
+
+
+@beartype
+def create_page_close_action() -> Action:
+ """Return a valid action object with type PAGE_CLOSE."""
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.PAGE_CLOSE,
+ }
+ )
+ return action
+
+
+@beartype
+def create_mouse_click_action(
+ left: float | None = None, top: float | None = None
+) -> Action:
+ """Return a valid action object with type COORD_CLICK."""
+ action = create_none_action()
+ if left and top:
+ action.update(
+ {
+ "action_type": ActionTypes.MOUSE_CLICK,
+ "coords": np.array([left, top], dtype=np.float32),
+ }
+ )
+ elif (not left) and (not top):
+ action.update(
+ {
+ "action_type": ActionTypes.CLICK,
+ }
+ )
+ else:
+ raise ValueError("left and top must be both None or both not None")
+ return action
+
+
+@beartype
+def create_clear_action(
+ element_id: str = "",
+ element_role: RolesType = "link",
+ element_name: str = "",
+ pw_code: str = "",
+ nth: int = 0,
+) -> Action:
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.CLEAR,
+ "element_id": element_id,
+ "element_role": _role2id[element_role],
+ "element_name": element_name,
+ "nth": nth,
+ "pw_code": pw_code,
+ }
+ )
+ return action
+
+
+@beartype
+def create_keyboard_type_action(keys: list[int | str] | str) -> Action:
+ """Return a valid action object with type TYPE."""
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.KEYBOARD_TYPE,
+ "text": _keys2ids(keys),
+ }
+ )
+ return action
+
+
+@beartype
+def create_click_action(
+ element_id: str = "",
+ element_role: RolesType = "link",
+ element_name: str = "",
+ pw_code: str = "",
+ nth: int = 0,
+) -> Action:
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.CLICK,
+ "element_id": element_id,
+ "element_role": _role2id[element_role],
+ "element_name": element_name,
+ "nth": nth,
+ "pw_code": pw_code,
+ }
+ )
+ return action
+
+
+@beartype
+def create_hover_action(
+ element_id: str = "",
+ element_role: RolesType = "link",
+ element_name: str = "",
+ pw_code: str = "",
+ nth: int = 0,
+) -> Action:
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.HOVER,
+ "element_id": element_id,
+ "element_role": _role2id[element_role],
+ "element_name": element_name,
+ "nth": nth,
+ "pw_code": pw_code,
+ }
+ )
+ return action
+
+
+@beartype
+def create_type_action(
+ text: str,
+ element_id: str = "",
+ element_role: RolesType = "link",
+ element_name: str = "",
+ pw_code: str = "",
+ nth: int = 0,
+) -> Action:
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.TYPE,
+ "element_id": element_id,
+ "element_role": _role2id[element_role],
+ "element_name": element_name,
+ "nth": nth,
+ "text": _keys2ids(text),
+ "pw_code": pw_code,
+ }
+ )
+ return action
+
+
+@beartype
+def create_type_action_webrl(
+ text: str,
+ element_id: str = "",
+ element_role: RolesType = "link",
+ element_name: str = "",
+ pw_code: str = "",
+ nth: int = 0,
+) -> Action:
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.TYPE,
+ "element_id": element_id,
+ "element_role": _role2id[element_role],
+ "element_name": element_name,
+ "nth": nth,
+ "text": text,
+ "pw_code": pw_code,
+ }
+ )
+ return action
+
+
+@beartype
+def create_search_action(
+ text: str,
+ element_id: str = "",
+ element_role: RolesType = "link",
+ element_name: str = "",
+ pw_code: str = "",
+ nth: int = 0,
+) -> Action:
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.SEARCH,
+ "element_id": element_id,
+ "element_role": _role2id[element_role],
+ "element_name": element_name,
+ "nth": nth,
+ "text": text,
+ "pw_code": pw_code,
+ }
+ )
+ return action
+
+
+@beartype
+def create_select_dropdown_option_action(
+ argument: str,
+ element_id: str = "",
+ element_role: RolesType = "link",
+ element_name: str = "",
+ pw_code: str = "",
+ nth: int = 0,
+) -> Action:
+ """Return a valid action object with type SELECT_DROPDOWN_OPTION."""
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.SELECT_DROPDOWN_OPTION,
+ "element_id": element_id,
+ "element_role": _role2id[element_role],
+ "element_name": element_name,
+ "nth": nth,
+ "argument": argument,
+ "pw_code": pw_code,
+ }
+ )
+ return action
+
+
+@beartype
+def create_check_action(pw_code: str) -> Action:
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.CHECK,
+ "pw_code": pw_code,
+ }
+ )
+ return action
+
+
+@beartype
+def create_select_option_action(
+ pw_code: str,
+) -> Action:
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.SELECT_OPTION,
+ "pw_code": pw_code,
+ }
+ )
+ return action
+
+
+@beartype
+def create_focus_action(
+ element_role: RolesType, element_name: str = "", nth: int = 0
+) -> Action:
+ """Return a valid action object with type CLICK.
+
+ Keep compatible with the old version."""
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.CLICK,
+ "element_role": _role2id[element_role],
+ "element_name": element_name,
+ "nth": nth,
+ }
+ )
+ return action
+
+
+@beartype
+def create_focus_and_click_action(
+ element_role: RolesType, element_name: str = "", nth: int = 0
+) -> Action:
+ """Return a valid action object with type CLICK.
+
+ Keep compatible with the old version."""
+
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.CLICK,
+ "element_role": _role2id[element_role],
+ "element_name": element_name,
+ "nth": nth,
+ }
+ )
+ return action
+
+
+@beartype
+def create_focus_and_type_action(
+ keys: list[int | str] | str,
+ element_role: RolesType,
+ element_name: str = "",
+ nth: int = 0,
+) -> Action:
+ """Return a valid action object with type TYPE.
+
+ Keep compatible with the old version."""
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.TYPE,
+ "element_role": _role2id[element_role],
+ "element_name": element_name,
+ "text": _keys2ids(keys),
+ "nth": nth,
+ }
+ )
+ return action
+
+
+@beartype
+def execute_scroll(direction: str, page: Page) -> None:
+ # perform the action
+ # code from natbot
+ if direction == "up":
+ page.evaluate(
+ "(document.scrollingElement || document.body).scrollTop = (document.scrollingElement || document.body).scrollTop - window.innerHeight;"
+ )
+ elif direction == "down":
+ page.evaluate(
+ "(document.scrollingElement || document.body).scrollTop = (document.scrollingElement || document.body).scrollTop + window.innerHeight;"
+ )
+
+@beartype
+def execute_scroll_webrl(direction: str, page: Page) -> None:
+ # perform the action which move 2/3 of the height of the page at a time
+ if direction == "up":
+ page.mouse.wheel(0, -page.viewport_size['height'] * 2.0 / 3)
+ elif direction == "down":
+ page.mouse.wheel(0, page.viewport_size['height'] * 2.0 / 3)
+
+@beartype
+async def aexecute_scroll(direction: str, page: APage) -> None:
+ # perform the action
+ # code from natbot
+ if direction == "up":
+ await page.evaluate(
+ "(document.scrollingElement || document.body).scrollTop = (document.scrollingElement || document.body).scrollTop - window.innerHeight;"
+ )
+ elif direction == "down":
+ await page.evaluate(
+ "(document.scrollingElement || document.body).scrollTop = (document.scrollingElement || document.body).scrollTop + window.innerHeight;"
+ )
+
+
+@beartype
+def execute_key_press(key: str, page: Page) -> None:
+ """Press a key."""
+ if "Meta" in key and "Mac" not in page.evaluate("navigator.platform"):
+ key = key.replace("Meta", "Control")
+ page.keyboard.press(key)
+
+
+@beartype
+async def aexecute_key_press(key: str, page: APage) -> None:
+ """Press a key."""
+ if "Meta" in key and "Mac" not in await page.evaluate(
+ "navigator.platform"
+ ):
+ key = key.replace("Meta", "Control")
+ await page.keyboard.press(key)
+
+
+@beartype
+def execute_mouse_hover(left: float, top: float, page: Page) -> None:
+ """Click at coordinates (left, top)."""
+ viewport_size = page.viewport_size
+ assert viewport_size
+ page.mouse.move(
+ left * viewport_size["width"], top * viewport_size["height"]
+ )
+
+
+@beartype
+async def aexecute_mouse_hover(left: float, top: float, page: APage) -> None:
+ """Click at coordinates (left, top)."""
+ viewport_size = page.viewport_size
+ assert viewport_size
+ await page.mouse.move(
+ left * viewport_size["width"], top * viewport_size["height"]
+ )
+
+
+def execute_mouse_click(left: float, top: float, page: Page) -> None:
+ """Click at coordinates (left, top)."""
+ viewport_size = page.viewport_size
+ assert viewport_size
+ page.mouse.click(
+ left * viewport_size["width"], top * viewport_size["height"]
+ )
+
+
+@beartype
+async def aexecute_mouse_click(left: float, top: float, page: APage) -> None:
+ """Click at coordinates (left, top)."""
+ viewport_size = page.viewport_size
+ assert viewport_size
+ await page.mouse.click(
+ left * viewport_size["width"], top * viewport_size["height"]
+ )
+
+
+@beartype
+def execute_keyboard_type(text: str, page: Page) -> None:
+ """Fill the focused element with text."""
+ page.keyboard.type(text)
+
+
+@beartype
+async def aexecute_keyboard_type(text: str, page: APage) -> None:
+ """Fill the focused element with text."""
+ await page.keyboard.type(text)
+
+
+@beartype
+def execute_click_current(page: Page) -> None:
+ """Click at the current mouse position."""
+ locators = page.locator("*:focus")
+ if not locators.count():
+ for frame in page.frames[1:]:
+ locators = frame.locator("*:focus")
+ if locators.count():
+ break
+ locators.click()
+
+
+@beartype
+async def aexecute_click_current(page: APage) -> None:
+ """Click at the current mouse position."""
+ locators = page.locator("*:focus")
+ locator_count = await locators.count()
+ if not locator_count:
+ for frame in page.frames[1:]:
+ locators = frame.locator("*:focus")
+ locator_count = await locators.count()
+ if locator_count:
+ break
+ await locators.click()
+ await page.wait_for_load_state("load")
+
+
+@beartype
+def execute_type(keys: list[int], page: Page) -> None:
+ """Send keystrokes to the focused element."""
+ text = "".join([_id2key[key] for key in keys])
+ page.keyboard.type(text)
+
+
+@beartype
+async def aexecute_type(keys: list[int], page: APage) -> None:
+ """Send keystrokes to the focused element."""
+ text = "".join([_id2key[key] for key in keys])
+ await page.keyboard.type(text)
+
+
+@beartype
+def execute_focus(
+ element_role: int, element_name: str, nth: int, page: Page
+) -> None:
+ """Click the specified DOM element."""
+ element_role_str = _id2role[element_role]
+ if page.viewport_size is None:
+ raise ValueError("Viewport size is not set for the current page")
+ element_location_list: list[tuple[Locator, float, float]] = []
+ for frame in page.frames:
+ match element_role_str:
+ case "alt_text":
+ locators = frame.get_by_alt_text(element_name)
+ case "label":
+ locators = frame.get_by_label(element_name)
+ case "placeholder":
+ locators = frame.get_by_placeholder(element_name)
+ case _:
+ locators = frame.get_by_role(
+ role=element_role_str, name=element_name
+ )
+ for locator_idx in range(locators.count()):
+ locator = locators.nth(locator_idx)
+ if is_in_viewport(locator, page.viewport_size):
+ bounding_box = locator.bounding_box()
+ assert bounding_box
+ element_location_list.append(
+ (locator, bounding_box["x"], bounding_box["y"])
+ )
+ if len(element_location_list) <= nth:
+ raise ValueError(
+ f"There are only {len(element_location_list)} elements found in viewport, but {nth + 1} is requested"
+ )
+ element_location_list.sort(key=lambda x: (x[2], x[1])) # row major order
+ element_location_list[nth][0].focus()
+
+
+@beartype
+async def aexecute_focus(
+ element_role: int, element_name: str, nth: int, page: APage
+) -> None:
+ """Click the specified DOM element."""
+ element_role_str = _id2role[element_role]
+ if page.viewport_size is None:
+ raise ValueError("Viewport size is not set for the current page")
+ element_location_list: list[tuple[ALocator, float, float]] = []
+ for frame in page.frames:
+ match element_role_str:
+ case "alt_text":
+ locators = frame.get_by_alt_text(element_name)
+ case "label":
+ locators = frame.get_by_label(element_name)
+ case "placeholder":
+ locators = frame.get_by_placeholder(element_name)
+ case _:
+ locators = frame.get_by_role(
+ role=element_role_str, name=element_name
+ )
+ for locator_idx in range(await locators.count()):
+ locator = locators.nth(locator_idx)
+ if await async_is_in_viewport(locator, page.viewport_size):
+ bounding_box = await locator.bounding_box()
+ assert bounding_box
+ element_location_list.append(
+ (locator, bounding_box["x"], bounding_box["y"])
+ )
+ if len(element_location_list) <= nth:
+ raise ValueError(
+ f"There are only {len(element_location_list)} elements found in viewport, but {nth + 1} is requested"
+ )
+ element_location_list.sort(key=lambda x: (x[2], x[1])) # row major order
+ await element_location_list[nth][0].focus()
+
+
+@beartype
+def locate(locator_calls: list[ParsedPlaywrightCode], page: Page) -> Locator:
+ locator = page
+ for call in locator_calls:
+ function_name = call["function_name"]
+ arguments = call["arguments"]
+ keywords = call["keywords"]
+ locator = getattr(locator, function_name)(*arguments, **keywords)
+ return locator # type: ignore[return-value]
+
+
+@beartype
+async def alocate(
+ locator_calls: list[ParsedPlaywrightCode], page: APage
+) -> ALocator:
+ locator = page
+ for call in locator_calls:
+ function_name = call["function_name"]
+ arguments = call["arguments"]
+ keywords = call["keywords"]
+ locator = await getattr(locator, function_name)(*arguments, **keywords)
+ return locator # type: ignore[return-value]
+
+
+@beartype
+def execute_playwright_click(
+ locator_code: list[ParsedPlaywrightCode],
+ page: Page,
+ pw_action_args: list[str] = [],
+ pw_action_kwargs: dict[str, Any] = {},
+) -> None:
+ locator = locate(locator_code, page)
+
+ # perform the action
+ locator.click(*pw_action_args, **pw_action_kwargs)
+
+
+@beartype
+async def aexecute_playwright_click(
+ locator_code: list[ParsedPlaywrightCode],
+ page: APage,
+ pw_action_args: list[str] = [],
+ pw_action_kwargs: dict[str, Any] = {},
+) -> None:
+ locator = await alocate(locator_code, page)
+
+ # perform the action
+ await locator.click(*pw_action_args, **pw_action_kwargs)
+
+
+@beartype
+def execute_playwright_hover(
+ locator_code: list[ParsedPlaywrightCode], page: Page
+) -> None:
+ locator = locate(locator_code, page)
+
+ # perform the action
+ locator.hover()
+
+
+@beartype
+async def aexecute_playwright_hover(
+ locator_code: list[ParsedPlaywrightCode], page: APage
+) -> None:
+ locator = await alocate(locator_code, page)
+
+ # perform the action
+ await locator.hover()
+
+
+@beartype
+def execute_playwright_type(
+ text: str,
+ locator_code: list[ParsedPlaywrightCode],
+ page: Page,
+ pw_action_args: list[str] = [],
+ pw_action_kwargs: dict[str, Any] = {},
+) -> None:
+ locator = locate(locator_code, page)
+ # perform the action
+ pw_action_args = [text] + pw_action_args # text is the first argument
+ locator.type(*pw_action_args, **pw_action_kwargs)
+
+
+@beartype
+async def aexecute_playwright_type(
+ text: str,
+ locator_code: list[ParsedPlaywrightCode],
+ page: APage,
+ pw_action_args: list[str] = [],
+ pw_action_kwargs: dict[str, Any] = {},
+) -> None:
+ locator = await alocate(locator_code, page)
+ # perform the action
+ pw_action_args = [text] + pw_action_args # text is the first argument
+ await locator.type(*pw_action_args, **pw_action_kwargs)
+
+
+@beartype
+def execute_playwright_select_option(
+ locator_code: list[ParsedPlaywrightCode],
+ page: Page,
+ pw_action_args: list[str] = [],
+ pw_action_kwargs: dict[str, Any] = {},
+) -> None:
+ locator = locate(locator_code, page)
+ # perform the action
+ locator.select_option(*pw_action_args, **pw_action_kwargs)
+
+
+@beartype
+async def aexecute_playwright_select_option(
+ locator_code: list[ParsedPlaywrightCode],
+ page: APage,
+ pw_action_args: list[str] = [],
+ pw_action_kwargs: dict[str, Any] = {},
+) -> None:
+ locator = await alocate(locator_code, page)
+ # perform the action
+ await locator.select_option(*pw_action_args, **pw_action_kwargs)
+
+
+@beartype
+def execute_playwright_check(
+ locator_code: list[ParsedPlaywrightCode], page: Page
+) -> None:
+ locator = locate(locator_code, page)
+ # perform the action
+ locator.check()
+
+
+@beartype
+async def aexecute_playwright_check(
+ locator_code: list[ParsedPlaywrightCode], page: APage
+) -> None:
+ locator = await alocate(locator_code, page)
+ # perform the action
+ await locator.check()
+
+
+@beartype
+def execute_action_webrl(
+ action: Action,
+ page: Page,
+ browser_ctx: BrowserContext,
+ obseration_processor: ObservationProcessor,
+ sleep_after_execution: float = 0.0,
+) -> Page:
+ """Execute the action on the ChromeDriver."""
+ action_type = action["action_type"]
+ num_tabs_before = len(browser_ctx.pages)
+ match action_type:
+ case ActionTypes.NONE:
+ pass
+ case ActionTypes.SCROLL:
+ direction = "up" if "up" in action["direction"] else "down"
+ execute_scroll_webrl(direction, page)
+ case ActionTypes.KEY_PRESS:
+ keys = action["key_comb"]
+ execute_key_press(keys, page)
+ case ActionTypes.MOUSE_CLICK:
+ execute_mouse_click(action["coords"][0], action["coords"][1], page)
+ case ActionTypes.CLICK:
+ # check each kind of locator in order
+ # TODO[shuyanzh]: order is temp now
+ element_id = action["element_id"]
+ element_center = obseration_processor.get_element_center(element_id, page) # type: ignore[attr-defined]
+ execute_mouse_click(element_center[0], element_center[1], page)
+ case ActionTypes.HOVER:
+ element_id = action["element_id"]
+ element_center = obseration_processor.get_element_center(element_id) # type: ignore[attr-defined]
+ execute_mouse_hover(element_center[0], element_center[1], page)
+ case ActionTypes.TYPE:
+ element_id = action["element_id"]
+ element_center = obseration_processor.get_element_center(element_id) # type: ignore[attr-defined]
+ execute_mouse_click(element_center[0], element_center[1], page)
+ execute_key_press("Meta+A", page)
+ execute_key_press('Backspace', page)
+ # execute_mouse_click(element_center[0], element_center[1], page)
+ text = _keys2ids(action["text"])
+ execute_type(text, page)
+ case ActionTypes.SEARCH:
+ element_id = action["element_id"]
+ element_center = obseration_processor.get_element_center(element_id) # type: ignore[attr-defined]
+ execute_mouse_click(element_center[0], element_center[1], page)
+ execute_key_press("Meta+A", page)
+ execute_key_press('Backspace', page)
+ # execute_mouse_click(element_center[0], element_center[1], page)
+ text = _keys2ids(action["text"])
+ execute_type(text, page)
+ time.sleep(2)
+ execute_key_press("Enter", page)
+ case ActionTypes.GO_BACK:
+ page.go_back()
+ case ActionTypes.GO_FORWARD:
+ page.go_forward()
+ case ActionTypes.GOTO_URL:
+ page.goto(action["url"])
+ case ActionTypes.SELECT_DROPDOWN_OPTION:
+ # Click
+ element_id = action["element_id"]
+ argument = action["argument"]
+ element_center = obseration_processor.get_element_center(element_id, page) # type: ignore[attr-defined]
+ execute_mouse_click(element_center[0], element_center[1], page)
+ # get element
+ device_pixel_ratio = page.evaluate("window.devicePixelRatio")
+ center_x, center_y = element_center[0] * page.viewport_size["width"], element_center[1] * page.viewport_size["height"]
+ last_turn_element = page.evaluate_handle(f"""() => document.elementFromPoint({center_x / device_pixel_ratio}, {center_y / device_pixel_ratio})""")
+ # get select element options
+ select_element_options = [{"value": option.get_attribute('value'), "text": option.text_content().strip(' \n')} for option in
+ last_turn_element.query_selector_all("option")]
+ selector_option_dict = dict((o["text"].lower(), o["value"]) for o in select_element_options)
+ value = None
+ for key in selector_option_dict.keys():
+ if argument.lower() in key.lower():
+ value = selector_option_dict[key]
+ break
+ if value is not None:
+ last_turn_element.select_option(value=value)
+ case _:
+ raise ValueError(f"Unknown action type: {action_type}")
+
+ page.wait_for_timeout(int(sleep_after_execution * 1000))
+ num_tabs_now = len(browser_ctx.pages)
+ # if a new tab is opened by clicking, switch to the new tab
+ if num_tabs_now > num_tabs_before:
+ page = browser_ctx.pages[-1]
+ page.bring_to_front()
+
+ return page
+
+@beartype
+def execute_action(
+ action: Action,
+ page: Page,
+ browser_ctx: BrowserContext,
+ obseration_processor: ObservationProcessor,
+ sleep_after_execution: float = 0.0,
+) -> Page:
+ """Execute the action on the ChromeDriver."""
+ action_type = action["action_type"]
+ num_tabs_before = len(browser_ctx.pages)
+ match action_type:
+ case ActionTypes.NONE:
+ pass
+
+ case ActionTypes.SCROLL:
+ direction = "up" if "up" in action["direction"] else "down"
+ execute_scroll(direction, page)
+ case ActionTypes.KEY_PRESS:
+ keys = action["key_comb"]
+ execute_key_press(keys, page)
+
+ case ActionTypes.MOUSE_CLICK:
+ execute_mouse_click(action["coords"][0], action["coords"][1], page)
+ case ActionTypes.CLEAR:
+ element_id = action["element_id"]
+ element_center = obseration_processor.get_element_center(element_id) # type: ignore[attr-defined]
+ execute_mouse_click(element_center[0], element_center[1], page)
+ execute_key_press("Meta+A", page)
+ execute_key_press('Backspace', page)
+ case ActionTypes.MOUSE_HOVER:
+ execute_mouse_hover(action["coords"][0], action["coords"][1], page)
+ case ActionTypes.KEYBOARD_TYPE:
+ execute_type(action["text"], page)
+ case ActionTypes.CLICK:
+ # check each kind of locator in order
+ # TODO[shuyanzh]: order is temp now
+ if action["element_id"]:
+ element_id = action["element_id"]
+ element_center = obseration_processor.get_element_center(element_id) # type: ignore[attr-defined]
+ execute_mouse_click(element_center[0], element_center[1], page)
+ elif action["element_role"] and action["element_name"]:
+ element_role = int(action["element_role"])
+ element_name = action["element_name"]
+ nth = action["nth"]
+ execute_focus(element_role, element_name, nth, page)
+ execute_click_current(page)
+ elif action["pw_code"]:
+ parsed_code = parse_playwright_code(action["pw_code"])
+ locator_code = parsed_code[:-1]
+ # [shuyanzh], don't support action args and kwargs now
+ execute_playwright_click(locator_code=locator_code, page=page)
+ else:
+ raise ValueError("No proper locator found for click action")
+ case ActionTypes.HOVER:
+ if action["element_id"]:
+ element_id = action["element_id"]
+ element_center = obseration_processor.get_element_center(element_id) # type: ignore[attr-defined]
+ execute_mouse_hover(element_center[0], element_center[1], page)
+ elif action["element_role"] and action["element_name"]:
+ element_role = int(action["element_role"])
+ element_name = action["element_name"]
+ nth = action["nth"]
+ execute_focus(element_role, element_name, nth, page)
+ elif action["pw_code"]:
+ parsed_code = parse_playwright_code(action["pw_code"])
+ locator_code = parsed_code[:-1]
+ # [shuyanzh], don't support action args and kwargs now
+ execute_playwright_hover(locator_code=locator_code, page=page)
+ else:
+ raise NotImplementedError(
+ "No proper locator found for hover action"
+ )
+ case ActionTypes.TYPE:
+ if action["element_id"]:
+ element_id = action["element_id"]
+ element_center = obseration_processor.get_element_center(element_id) # type: ignore[attr-defined]
+ execute_mouse_click(element_center[0], element_center[1], page)
+ execute_type(action["text"], page)
+ elif action["element_role"] and action["element_name"]:
+ element_role = int(action["element_role"])
+ element_name = action["element_name"]
+ nth = action["nth"]
+ execute_focus(element_role, element_name, nth, page)
+ execute_type(action["text"], page)
+ elif action["pw_code"]:
+ parsed_code = parse_playwright_code(action["pw_code"])
+ locator_code = parsed_code[:-1]
+ text = parsed_code[-1]["arguments"][0]
+ # [shuyanzh], don't support action args and kwargs now
+ execute_playwright_type(
+ text=text, locator_code=locator_code, page=page
+ )
+ else:
+ raise NotImplementedError(
+ "No proper locator found for type action"
+ )
+
+ case ActionTypes.PAGE_FOCUS:
+ page = browser_ctx.pages[action["page_number"]]
+ page.bring_to_front()
+ case ActionTypes.NEW_TAB:
+ page = browser_ctx.new_page()
+ case ActionTypes.GO_BACK:
+ page.go_back()
+ case ActionTypes.GO_FORWARD:
+ page.go_forward()
+ case ActionTypes.GOTO_URL:
+ page.goto(action["url"])
+ case ActionTypes.PAGE_CLOSE:
+ page.close()
+ if len(browser_ctx.pages) > 0:
+ page = browser_ctx.pages[-1]
+ else:
+ page = browser_ctx.new_page()
+
+ case ActionTypes.SELECT_OPTION:
+ if action["pw_code"]:
+ parsed_code = parse_playwright_code(action["pw_code"])
+ locator_code = parsed_code[:-1]
+ execute_playwright_select_option(locator_code, page)
+ else:
+ raise NotImplementedError(
+ "No proper locator found for select option action"
+ )
+ case ActionTypes.CHECK:
+ if action["pw_code"]:
+ parsed_code = parse_playwright_code(action["pw_code"])
+ locator_code = parsed_code[:-1]
+ execute_playwright_check(locator_code, page)
+ else:
+ raise NotImplementedError(
+ "No proper locator found for select option action"
+ )
+
+ case _:
+ raise ValueError(f"Unknown action type: {action_type}")
+
+ page.wait_for_timeout(int(sleep_after_execution * 1000))
+ num_tabs_now = len(browser_ctx.pages)
+ # if a new tab is opened by clicking, switch to the new tab
+ if num_tabs_now > num_tabs_before:
+ page = browser_ctx.pages[-1]
+ page.bring_to_front()
+
+ return page
+
+
+
+@beartype
+async def aexecute_action(
+ action: Action, page: APage, browser_ctx: ABrowserContext
+) -> APage:
+ """Execute the async action on the ChromeDriver."""
+ action_type = action["action_type"]
+ match action_type:
+ case ActionTypes.NONE:
+ pass
+ case ActionTypes.SCROLL:
+ direction = "up" if "up" in action["direction"] else "down"
+ await aexecute_scroll(direction, page)
+ case ActionTypes.KEY_PRESS:
+ keys = action["key_comb"]
+ await aexecute_key_press(keys, page)
+
+ case ActionTypes.MOUSE_CLICK:
+ await aexecute_mouse_click(
+ action["coords"][0], action["coords"][1], page
+ )
+ case ActionTypes.CLEAR:
+ element_id = action["element_id"]
+ element_center = obseration_processor.get_element_center(element_id) # type: ignore[attr-defined]
+ await execute_mouse_click(element_center[0], element_center[1], page)
+ await execute_key_press("Meta+A", page)
+ await execute_key_press('Backspace', page)
+ case ActionTypes.MOUSE_HOVER:
+ await aexecute_mouse_hover(
+ action["coords"][0], action["coords"][1], page
+ )
+ case ActionTypes.KEYBOARD_TYPE:
+ await aexecute_type(action["text"], page)
+
+ case ActionTypes.CLICK:
+ # check each kind of locator in order
+ # TODO[shuyanzh]: order is temp now
+ if action["element_id"]:
+ raise NotImplementedError
+ elif action["element_role"] and action["element_name"]:
+ element_role = int(action["element_role"])
+ element_name = action["element_name"]
+ nth = action["nth"]
+ await aexecute_focus(element_role, element_name, nth, page)
+ await aexecute_click_current(page)
+ elif action["pw_code"]:
+ parsed_code = parse_playwright_code(action["pw_code"])
+ locator_code = parsed_code[:-1]
+ # [shuyanzh], don't support action args and kwargs now
+ await aexecute_playwright_click(
+ locator_code=locator_code, page=page
+ )
+ else:
+ raise ValueError("No proper locator found for click action")
+ case ActionTypes.HOVER:
+ if action["element_id"]:
+ raise NotImplementedError
+ elif action["element_role"] and action["element_name"]:
+ element_role = int(action["element_role"])
+ element_name = action["element_name"]
+ nth = action["nth"]
+ await aexecute_focus(element_role, element_name, nth, page)
+ elif action["pw_code"]:
+ parsed_code = parse_playwright_code(action["pw_code"])
+ locator_code = parsed_code[:-1]
+ # [shuyanzh], don't support action args and kwargs now
+ await aexecute_playwright_hover(
+ locator_code=locator_code, page=page
+ )
+ else:
+ raise NotImplementedError(
+ "No proper locator found for hover action"
+ )
+ case ActionTypes.TYPE:
+ if action["element_id"]:
+ raise NotImplementedError
+ elif action["element_role"] and action["element_name"]:
+ element_role = int(action["element_role"])
+ element_name = action["element_name"]
+ nth = action["nth"]
+ await aexecute_focus(element_role, element_name, nth, page)
+ await aexecute_type(action["text"], page)
+ elif action["pw_code"]:
+ parsed_code = parse_playwright_code(action["pw_code"])
+ locator_code = parsed_code[:-1]
+ text = parsed_code[-1]["arguments"][0]
+ # [shuyanzh], don't support action args and kwargs now
+ await aexecute_playwright_type(
+ text=text, locator_code=locator_code, page=page
+ )
+ else:
+ raise NotImplementedError(
+ "No proper locator found for type action"
+ )
+
+ case ActionTypes.PAGE_FOCUS:
+ page = browser_ctx.pages[action["page_number"]]
+ await page.bring_to_front()
+ case ActionTypes.NEW_TAB:
+ page = await browser_ctx.new_page()
+ case ActionTypes.GO_BACK:
+ await page.go_back()
+ case ActionTypes.GO_FORWARD:
+ await page.go_forward()
+ case ActionTypes.GOTO_URL:
+ await page.goto(action["url"])
+ case ActionTypes.PAGE_CLOSE:
+ await page.close()
+ if len(browser_ctx.pages) > 0:
+ page = browser_ctx.pages[-1]
+ else:
+ page = await browser_ctx.new_page()
+
+ case ActionTypes.SELECT_OPTION:
+ if action["pw_code"]:
+ parsed_code = parse_playwright_code(action["pw_code"])
+ locator_code = parsed_code[:-1]
+ await aexecute_playwright_select_option(locator_code, page)
+ else:
+ raise NotImplementedError(
+ "No proper locator found for select option action"
+ )
+ case ActionTypes.CHECK:
+ if action["pw_code"]:
+ parsed_code = parse_playwright_code(action["pw_code"])
+ locator_code = parsed_code[:-1]
+ await aexecute_playwright_check(locator_code, page)
+ else:
+ raise NotImplementedError(
+ "No proper locator found for select option action"
+ )
+
+ case _:
+ raise ValueError(f"Unknown action type: {action_type}")
+
+ return page
+
+
+@beartype
+def parse_playwright_code(code: str) -> list[ParsedPlaywrightCode]:
+ # extract function calls
+ if not code.startswith("page."):
+ raise ValueError(
+ f'Playwright action must start with "page.", but got {code}'
+ )
+
+ regex = r"\.(?![^\(\)]*\))"
+ chain = re.split(regex, code)[1:]
+
+ parsed_chain = []
+
+ for item in chain:
+ tree = ast.parse(item)
+ funcs = []
+ for node in ast.walk(tree):
+ if isinstance(node, ast.Call):
+ function_name = node.func.id # type: ignore[attr-defined]
+ arguments = [
+ ast.literal_eval(arg) if isinstance(arg, ast.Str) else arg
+ for arg in node.args
+ ]
+ keywords = {
+ str(kw.arg): ast.literal_eval(kw.value)
+ for kw in node.keywords
+ }
+ funcs.append(
+ ParsedPlaywrightCode(
+ {
+ "function_name": function_name,
+ "arguments": arguments,
+ "keywords": keywords,
+ }
+ )
+ )
+
+ if len(funcs) != 1:
+ raise ValueError(f"Fail to parse {item} in {code}")
+
+ if (
+ funcs[0]["function_name"]
+ not in PLAYWRIGHT_LOCATORS + PLAYWRIGHT_ACTIONS
+ ):
+ raise ValueError(
+ f"Invalid playwright code {item}, ",
+ f"the function needs to be one of {PLAYWRIGHT_LOCATORS + PLAYWRIGHT_ACTIONS}",
+ )
+
+ parsed_chain.append(funcs[0])
+
+ last_action = parsed_chain[-1]
+ if last_action["function_name"] not in PLAYWRIGHT_ACTIONS:
+ raise ValueError(
+ f"Invalid playwright action {last_action},",
+ f"the action needs to be one of {PLAYWRIGHT_ACTIONS}",
+ )
+
+ return parsed_chain
+
+
+@beartype
+class ActionParsingError(Exception):
+ def __init__(self, message: str) -> None:
+ self.message = message
+ super().__init__(self.message)
+
+
+@beartype
+def create_playwright_action(playwright_code: str) -> Action:
+ """Main function to return individual playwright action"""
+ # get the last action
+ regex = r"\.(?![^\(\)]*\))"
+ action = re.split(regex, playwright_code)[-1].split("(")[0]
+ match action:
+ case "press":
+ p = r'press\((?:"|\')(.+?)(?:"|\')\)'
+ match = re.search(p, playwright_code)
+ if not match:
+ raise ActionParsingError(
+ f"Invalid press action, required to be page.press(KEY_COMB_STR)"
+ )
+ key_comb = match.group(1)
+ return create_key_press_action(key_comb=key_comb)
+ case "scroll":
+ direction = "up" if "up" in playwright_code else "down"
+ return create_scroll_action(direction=direction)
+ case "click":
+ return create_click_action(pw_code=playwright_code)
+ case "clear":
+ return create_clear_action(pw_code=playwright_code)
+ case "hover":
+ return create_hover_action(pw_code=playwright_code)
+ case "type" | "fill":
+ p = r'type|fill\((?:"|\')(.+?)(?:"|\')\)'
+ match = re.search(p, playwright_code)
+ if not match:
+ raise ActionParsingError(
+ f"Invalid type/fill action, required to be page.type(TEXT)"
+ )
+ text = match.group(1)
+ return create_type_action(text=text, pw_code=playwright_code)
+ case "select_option":
+ return create_select_option_action(pw_code=playwright_code)
+ case "check":
+ return create_check_action(pw_code=playwright_code)
+ case "goto":
+ p = r'goto\((?:"|\')(.+?)(?:"|\')\)'
+ match = re.search(p, playwright_code)
+ if not match:
+ raise ActionParsingError(
+ f"Invalid goto action, required to be page.goto(URL_STR)"
+ )
+ url = match.group(1)
+ return create_goto_url_action(url)
+ case "page_focus":
+ # get the page number
+ p = r"page_focus\((\d+)\)"
+ match = re.search(p, playwright_code)
+ if not match:
+ raise ActionParsingError("page focus requires a page number")
+ page_num = int(match.group(1))
+ return create_page_focus_action(page_num)
+ case "new_tab":
+ return create_new_tab_action()
+ case "go_back":
+ return create_go_back_action()
+ case "go_forward":
+ return create_go_forward_action()
+ case "page_close":
+ return create_page_close_action()
+ case "stop": # page.stop(answer)
+ p = r'stop\(?"(.+)?"\)'
+ match = re.search(p, playwright_code)
+ if not match:
+ answer = ""
+ else:
+ answer = match.group(1)
+ return create_stop_action(answer)
+
+ raise ActionParsingError(f"Unknown playwright action {action}")
+
+
+@beartype
+def create_id_based_action(action_str: str) -> Action:
+ """Main function to return individual id based action"""
+ action_str = action_str.strip()
+ if "[" in action_str:
+ action = action_str.split("[")[0].strip()
+ else:
+ actions = action_str.split()
+ if actions:
+ action = actions[0].strip()
+ else:
+ raise ActionParsingError(f"No action specified: {action_str}")
+ match action:
+ case "click":
+ match = re.search(r"click ?\[(\d+)\]", action_str)
+ if not match:
+ raise ActionParsingError(f"Invalid click action {action_str}")
+ element_id = match.group(1)
+ return create_click_action(element_id=element_id)
+ case "clear":
+ match = re.search(r"clear ?\[(\d+)\]", action_str)
+ if not match:
+ raise ActionParsingError(f"Invalid clear action {action_str}")
+ element_id = match.group(1)
+ return create_clear_action(element_id=element_id)
+ case "hover":
+ match = re.search(r"hover ?\[(\d+)\]", action_str)
+ if not match:
+ raise ActionParsingError(f"Invalid hover action {action_str}")
+ element_id = match.group(1)
+ return create_hover_action(element_id=element_id)
+ case "type":
+ # add default enter flag
+ if not (action_str.endswith("[0]") or action_str.endswith("[1]")):
+ action_str += " [1]"
+
+ match = re.search(
+ r"type ?\[(\d+)\] ?\[(.+)\] ?\[(\d+)\]", action_str
+ )
+ if not match:
+ raise ActionParsingError(f"Invalid type action {action_str}")
+ element_id, text, enter_flag = (
+ match.group(1),
+ match.group(2),
+ match.group(3),
+ )
+ if enter_flag == "1":
+ text += "\n"
+ return create_type_action(text=text, element_id=element_id)
+ case "press":
+ match = re.search(r"press ?\[(.+)\]", action_str)
+ if not match:
+ raise ActionParsingError(f"Invalid press action {action_str}")
+ key_comb = match.group(1)
+ return create_key_press_action(key_comb=key_comb)
+ case "scroll":
+ # up or down
+ match = re.search(r"scroll ?\[?(up|down)\]?", action_str)
+ if not match:
+ raise ActionParsingError(f"Invalid scroll action {action_str}")
+ direction = match.group(1)
+ return create_scroll_action(direction=direction)
+ case "goto":
+ match = re.search(r"goto ?\[(.+)\]", action_str)
+ if not match:
+ raise ActionParsingError(f"Invalid goto action {action_str}")
+ url = match.group(1)
+ return create_goto_url_action(url=url)
+ case "new_tab":
+ return create_new_tab_action()
+ case "go_back":
+ return create_go_back_action()
+ case "go_forward":
+ return create_go_forward_action()
+ case "tab_focus":
+ match = re.search(r"tab_focus ?\[(\d+)\]", action_str)
+ if not match:
+ raise ActionParsingError(
+ f"Invalid tab_focus action {action_str}"
+ )
+ page_number = int(match.group(1))
+ return create_page_focus_action(page_number)
+ case "close_tab":
+ return create_page_close_action()
+ case "stop": # stop answer
+ match = re.search(r"stop ?\[(.+)\]", action_str)
+ if not match: # some tasks don't require an answer
+ answer = ""
+ else:
+ answer = match.group(1)
+ return create_stop_action(answer)
+
+ raise ActionParsingError(f"Invalid action {action_str}")
+
+
+@beartype
+def create_webrl_id_based_action(action_str: str) -> Action:
+ """Main function to return individual webrl id based action"""
+ import ast
+ def remove_comments(code):
+ # 按行分割代码
+ for key in ['exit(','do(','go_backward(']:
+ if key in code:
+ return key + code.split(key)[-1]
+ lines = code.split('\n')
+ for i, line in enumerate(lines):
+ if line.strip().startswith('#'):
+ # 跳过注释行
+ continue
+ else:
+ # 返回非注释行及其后面的部分
+ return '\n'.join(lines[i:])
+ return ''
+
+ def parse_function_call(expression):
+ expression = remove_comments(expression)
+ # 将字符串解析为 AST
+ expression = expression.strip()
+ tree = ast.parse(expression, mode='eval')
+ # 提取函数名称
+ func_call = tree.body
+ if not isinstance(func_call, ast.Call):
+ return {
+ "operation": expression,
+ }
+ func_name = func_call.func.id
+ result = {
+ "operation": func_name,
+ }
+ # 提取参数
+ args = func_call.args
+ kwargs = func_call.keywords
+ for kw in kwargs:
+ if func_name == "do" and kw.arg == "action":
+ result["action"] = ast.literal_eval(kw.value)
+ # elif func_name == "do" and kw.arg == "argument":
+ # result["argument"] = ast.literal_eval(kw.value)
+ else:
+ if "kwargs" not in result:
+ result["kwargs"] = {}
+ if kw.arg == "element":
+ try:
+ # 解析元素的内部函数
+ inner_func = kw.value
+ if isinstance(inner_func, ast.Call) and inner_func.func.id == 'find_element_by_instruction':
+ for inner_kw in inner_func.keywords:
+ if inner_kw.arg == "instruction":
+ result["kwargs"]["instruction"] = ast.literal_eval(inner_kw.value)
+ else:
+ result["kwargs"][kw.arg] = ast.literal_eval(inner_func)
+ except Exception:
+ result["kwargs"][kw.arg] = ast.literal_eval(kw.value)
+ else:
+ result["kwargs"][kw.arg] = ast.literal_eval(kw.value)
+ return result
+
+ action_str = action_str.strip()
+ try:
+ action = parse_function_call(action_str)
+ except Exception as e:
+ raise ActionParsingError(f"No action specified: {action_str}")
+ operation = action["operation"]
+ match operation:
+ case "do":
+ action_type = action["action"].lower()
+ match action_type:
+ case "press enter":
+ return create_key_press_action(key_comb='enter')
+ case "scroll up":
+ return create_scroll_action(direction='up')
+ case "scroll down":
+ return create_scroll_action(direction='down')
+ case "click":
+ element_id = action["kwargs"]["element"]
+ return create_click_action(element_id=element_id)
+ case "type":
+ element_id = action["kwargs"]["element"]
+ text = action["kwargs"]["argument"]
+ return create_type_action_webrl(text=text, element_id=element_id)
+ case "hover":
+ element_id = action["kwargs"]["element"]
+ return create_hover_action(element_id=element_id)
+ case "select dropdown option":
+ element_id = action["kwargs"]["element"]
+ argument = action["kwargs"]["argument"]
+ return create_select_dropdown_option_action(argument=argument, element_id=element_id)
+ case "go forward":
+ return create_go_forward_action()
+ case "go backward":
+ return create_go_back_action()
+ case "search":
+ element_id = action["kwargs"]["element"]
+ text = action["kwargs"]["argument"]
+ return create_search_action(text=text, element_id=element_id)
+ case "exit": # stop answer
+ answer = action['kwargs']['message']
+ return create_stop_action(answer)
+
+ raise ActionParsingError(f"Invalid action {action_str}")
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/agent.py b/VAB-WebArena-Lite/new/agent.py
index 461ba7e..96caf17 100644
--- a/VAB-WebArena-Lite/new/agent.py
+++ b/VAB-WebArena-Lite/new/agent.py
@@ -14,6 +14,7 @@ from browser_env.actions import (
create_id_based_action,
create_none_action,
create_playwright_action,
+ create_webrl_id_based_action
)
from browser_env.utils import Observation, StateInfo
from llms import (
@@ -108,13 +109,15 @@ class PromptAgent(Agent):
lm_config: lm_config.LMConfig,
prompt_constructor: PromptConstructor,
captioning_fn = None,
+ planner_ip = None
) -> None:
super().__init__()
self.lm_config = lm_config
self.prompt_constructor = prompt_constructor
self.action_set_tag = action_set_tag
self.captioning_fn = captioning_fn
-
+ self.planner_ip = planner_ip
+
# Check if the model is multimodal.
if ("gemini" in lm_config.model or "gpt-4" in lm_config.model and "vision" in lm_config.model or lm_config.provider in ["api", "finetune"]) and type(prompt_constructor) == MultimodalCoTPromptConstructor:
self.multimodal_inputs = True
@@ -165,7 +168,10 @@ class PromptAgent(Agent):
lm_config = self.lm_config
n = 0
while True:
- response = call_llm(lm_config, prompt)
+ if self.planner_ip is not None and self.planner_ip != "":
+ response = call_llm(lm_config, prompt, 'EMPTY', self.planner_ip)
+ else:
+ response = call_llm(lm_config, prompt)
force_prefix = self.prompt_constructor.instruction[
"meta_data"
].get("force_prefix", "")
@@ -183,6 +189,8 @@ class PromptAgent(Agent):
action = create_playwright_action(parsed_response)
elif self.action_set_tag == "som":
action = create_id_based_action(parsed_response)
+ elif self.action_set_tag == 'webrl_id':
+ action = create_webrl_id_based_action(parsed_response)
else:
raise ValueError(
f"Unknown action type {self.action_set_tag}"
@@ -218,7 +226,8 @@ def construct_agent(args: argparse.Namespace, captioning_fn=None) -> Agent:
action_set_tag=args.action_set_tag,
lm_config=llm_config,
prompt_constructor=prompt_constructor,
- captioning_fn=captioning_fn
+ captioning_fn=captioning_fn,
+ planner_ip=args.planner_ip
)
else:
raise NotImplementedError(
diff --git a/VAB-WebArena-Lite/new/envs.py b/VAB-WebArena-Lite/new/envs.py
new file mode 100644
index 0000000..2dfe7dc
--- /dev/null
+++ b/VAB-WebArena-Lite/new/envs.py
@@ -0,0 +1,319 @@
+import json
+import os
+import re
+import subprocess
+import time
+from collections import defaultdict
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any, Union
+
+import numpy as np
+import numpy.typing as npt
+import requests
+from beartype import beartype
+from gymnasium import Env
+from gymnasium.spaces import Box, Text
+from playwright.sync_api import (
+ CDPSession,
+ Page,
+ Playwright,
+ ViewportSize,
+ expect,
+ sync_playwright,
+)
+
+DATASET = os.environ["DATASET"]
+if DATASET == "visualwebarena":
+ from browser_env.env_config import (
+ CLASSIFIEDS,
+ CLASSIFIEDS_RESET_TOKEN,
+ )
+
+from .actions import Action, execute_action, get_action_space, execute_action_webrl
+from .processors import ObservationHandler, ObservationMetadata
+from .utils import (
+ AccessibilityTree,
+ DetachedPage,
+ Observation,
+ png_bytes_to_numpy,
+)
+
+
+@dataclass
+class PlaywrightScript:
+ function: str # goto, get_by_role
+ destination: str # https://www.google.com/, combobox
+ name: str | None = None # Search, Avatar 2009
+ operation: str | None = None # click, fill, press
+ value: str | None = None # avatar movie, Enter
+
+
+def parse_action(action: str) -> PlaywrightScript:
+ splitted = action.strip().split(" ")
+ assert len(splitted) >= 2
+ match splitted[:2]:
+ case ["goto", url]:
+ assert len(splitted) == 2
+ return PlaywrightScript("goto", url)
+ case ["get_by_role", destination]:
+ assert len(splitted) >= 4
+ match splitted[2:]:
+ case [name, operation]:
+ return PlaywrightScript(
+ "get_by_role", destination, name, operation
+ )
+ case [name, operation, value]:
+ return PlaywrightScript(
+ "get_by_role", destination, name, operation, value
+ )
+ case _:
+ raise ValueError("Invalid action")
+ case _:
+ raise ValueError(f"Invalid action {action}")
+
+
+class ScriptBrowserEnv(Env[dict[str, Observation], Action]):
+ """
+ The goal of this environment is to produce a prototype of a browser environment.
+ In the end, we want to support a fully configurable browser environment with wide
+ range of action spaces and observation spaces, both structured and unstructured.
+ But in this prototype, we just support action space specified by Playwright script,
+ and observation space is the html content of the page.
+ """
+
+ @beartype
+ def __init__(
+ self,
+ max_page_length: int = 8192,
+ headless: bool = True,
+ slow_mo: int = 0,
+ observation_type: str = "html",
+ current_viewport_only: bool = False,
+ viewport_size: ViewportSize = {"width": 1280, "height": 720},
+ save_trace_enabled: bool = False,
+ sleep_after_execution: float = 0.0,
+ captioning_fn=None,
+ ):
+ # TODO: make Space[Action] = ActionSpace
+ self.action_space = get_action_space() # type: ignore[assignment]
+ self.headless = headless
+ self.slow_mo = slow_mo
+ self.current_viewport_only = current_viewport_only
+ self.reset_finished = False
+ self.viewport_size = viewport_size
+ self.save_trace_enabled = save_trace_enabled
+ self.sleep_after_execution = sleep_after_execution
+
+ match observation_type:
+ case "html" | "accessibility_tree" | "accessibility_tree_with_captioner" | "webrl":
+ self.text_observation_type = observation_type
+ self.image_observation_type = ""
+ self.main_observation_type = "text"
+ case "image":
+ self.image_observation_type = observation_type
+ self.text_observation_type = "" # type: ignore[assignment]
+ self.main_observation_type = "image"
+ case "image_som":
+ self.image_observation_type = observation_type
+ self.text_observation_type = observation_type # type: ignore[assignment]
+ self.main_observation_type = "image"
+ case _:
+ raise ValueError(
+ f"Unsupported observation type: {observation_type}"
+ )
+
+ self.observation_handler = ObservationHandler(
+ self.main_observation_type,
+ self.text_observation_type,
+ self.image_observation_type,
+ self.current_viewport_only,
+ self.viewport_size,
+ captioning_fn,
+ )
+
+ self.observation_space = (
+ self.observation_handler.get_observation_space()
+ )
+
+ @beartype
+ def setup(self, config_file: Path | None = None) -> None:
+ self.context_manager = sync_playwright()
+ self.playwright = self.context_manager.__enter__()
+ self.browser = self.playwright.chromium.launch(
+ headless=self.headless, slow_mo=self.slow_mo
+ )
+
+ if config_file:
+ with open(config_file, "r") as f:
+ instance_config = json.load(f)
+ else:
+ instance_config = {}
+
+ # Reset site if needed. Currently only supported for Classifieds.
+ # TODO(jykoh): Add reset functionality for Shopping/Reddit.
+ if instance_config.get("require_reset", False):
+ if "classifieds" in instance_config["sites"]:
+ # Send POST request to __CLASSIFIEDS__/index.php?page=reset with token=CLASSIFIEDS_TOKEN
+ response = requests.post(
+ f"{CLASSIFIEDS}/index.php?page=reset",
+ data={"token": CLASSIFIEDS_RESET_TOKEN},
+ )
+
+ # Check if the request was successful
+ if response.status_code == 200:
+ print("Reset Classifieds site.")
+ else:
+ print(
+ "Failed to reset Classifieds site:",
+ response.status_code,
+ )
+ else:
+ print(
+ "WARNING: Reset is not supported for this site. Please manually reset the site."
+ )
+
+ storage_state = instance_config.get("storage_state", None)
+ start_url = instance_config.get("start_url", None)
+ geolocation = instance_config.get("geolocation", None)
+
+ # Use custom viewport size if specified in the config, otherwise use the default.
+ viewport_size = self.viewport_size.copy()
+ viewport_size.update(instance_config.get("viewport_size", {}))
+ self.observation_handler.viewport_size = viewport_size
+
+ self.context = self.browser.new_context(
+ viewport=viewport_size,
+ storage_state=storage_state,
+ geolocation=geolocation,
+ device_scale_factor=1,
+ )
+ if self.save_trace_enabled:
+ self.context.tracing.start(screenshots=True, snapshots=True)
+
+ if start_url:
+ start_urls = start_url.split(" |AND| ")
+ for url in start_urls:
+ page = self.context.new_page()
+ if self.text_observation_type in [
+ "accessibility_tree",
+ "accessibility_tree_with_captioner",
+ ]:
+ client = page.context.new_cdp_session(page)
+ client.send("Accessibility.enable")
+ client.detach()
+ page.goto(url)
+ # set the first page as the current page
+ self.page = self.context.pages[0]
+ self.page.bring_to_front()
+ else:
+ self.page = self.context.new_page()
+ if self.text_observation_type in [
+ "accessibility_tree",
+ "accessibility_tree_with_captioner",
+ ]:
+ client = self.page.context.new_cdp_session(self.page)
+ client.send("Accessibility.enable")
+ client.detach()
+
+ def _get_obs(self) -> dict[str, Observation]:
+ obs = self.observation_handler.get_observation(self.page)
+ return obs
+
+ def _get_obs_metadata(self) -> dict[str, ObservationMetadata]:
+ metadata = self.observation_handler.get_observation_metadata()
+ return metadata
+
+ @beartype
+ def reset(
+ self,
+ *,
+ seed: int | None = None,
+ options: dict[str, str] | None = None,
+ ) -> tuple[dict[str, Observation], dict[str, Any]]:
+ """
+ Reset the environment.
+ :param options: options for the environment. The current supported options are:
+ - "storage_state": the storage state of the browser. It is a file path to a json file.
+ """
+ super().reset(seed=seed, options=options)
+ if self.reset_finished:
+ self.context_manager.__exit__()
+
+ if options is not None and "config_file" in options:
+ config_file = Path(options["config_file"])
+ if config_file.exists():
+ self.setup(config_file=config_file)
+ else:
+ raise ValueError(f"Config file {config_file} does not exist.")
+ else:
+ self.setup()
+ self.reset_finished = True
+ timeout_in_ms = 120000
+ self.page.set_default_timeout(timeout_in_ms)
+ self.page.set_default_navigation_timeout(timeout_in_ms)
+ self.page.wait_for_timeout(int(self.sleep_after_execution * 1000))
+
+ observation = self._get_obs()
+ observation_metadata = self._get_obs_metadata()
+ info = {
+ "page": DetachedPage(self.page.url, ""),
+ "fail_error": "",
+ "observation_metadata": observation_metadata,
+ }
+
+ return (observation, info)
+
+ def save_trace(self, trace_path: str | Path) -> None:
+ if self.save_trace_enabled:
+ self.context.tracing.stop(path=trace_path)
+
+ def close(self) -> None:
+ if self.reset_finished:
+ self.context_manager.__exit__()
+
+ def step(
+ self, action: Action
+ ) -> tuple[dict[str, Observation], float, bool, bool, dict[str, Any]]:
+ if not self.reset_finished:
+ raise RuntimeError("Call reset first before calling step.")
+
+ success = False
+ fail_error = ""
+ try:
+ if self.text_observation_type == 'webrl':
+ self.page = execute_action_webrl(
+ action,
+ self.page,
+ self.context,
+ self.observation_handler.action_processor,
+ self.sleep_after_execution,
+ )
+ else:
+ self.page = execute_action(
+ action,
+ self.page,
+ self.context,
+ self.observation_handler.action_processor,
+ self.sleep_after_execution,
+ )
+ success = True
+ except Exception as e:
+ fail_error = str(e)
+
+ observation = self._get_obs()
+ observation_metadata = self._get_obs_metadata()
+
+ info = {
+ "page": DetachedPage(self.page.url, self.page.content()),
+ "fail_error": fail_error,
+ "observation_metadata": observation_metadata,
+ }
+ msg = (
+ observation,
+ float(success), # reward
+ False, # terminated
+ False, # truncated
+ info,
+ )
+ return msg
diff --git a/VAB-WebArena-Lite/new/evaluators.py b/VAB-WebArena-Lite/new/evaluators.py
index 85b5363..aac511b 100644
--- a/VAB-WebArena-Lite/new/evaluators.py
+++ b/VAB-WebArena-Lite/new/evaluators.py
@@ -172,6 +172,10 @@ class StringEvaluator(Evaluator):
# prevent false positive (e.g, 0)
if len(word_tokenize(clean_ref)) == 1:
tok_pred = word_tokenize(clean_pred)
+ for token in tok_pred:
+ if '/' in token:
+ sub_tokens = token.split('/')
+ tok_pred.extend(sub_tokens)
return float(clean_ref in tok_pred)
else:
return float(clean_ref in clean_pred)
diff --git a/VAB-WebArena-Lite/new/helper_functions_browser.py b/VAB-WebArena-Lite/new/helper_functions_browser.py
new file mode 100644
index 0000000..cb902ba
--- /dev/null
+++ b/VAB-WebArena-Lite/new/helper_functions_browser.py
@@ -0,0 +1,239 @@
+import base64
+import io
+import json
+import re
+from pathlib import Path
+from typing import Any
+
+from PIL import Image
+
+from agent.prompts import *
+from browser_env import (
+ Action,
+ ActionTypes,
+ ObservationMetadata,
+ StateInfo,
+ action2str,
+)
+
+HTML_TEMPLATE = """
+
+
+
+
+
+
+ {body}
+
+
+"""
+
+
+def get_render_action(
+ action: Action,
+ observation_metadata: dict[str, ObservationMetadata],
+ action_set_tag: str,
+) -> str:
+ """Parse the predicted actions for rendering purpose. More comprehensive information"""
+ match action_set_tag:
+ case "id_accessibility_tree":
+ text_meta_data = observation_metadata["text"]
+ if action["element_id"] in text_meta_data["obs_nodes_info"]:
+ node_content = text_meta_data["obs_nodes_info"][
+ action["element_id"]
+ ]["text"]
+ else:
+ node_content = "No match found"
+
+ action_str = f"{action['raw_prediction']}
"
+ action_str += f""
+ action_str += f"{action2str(action, action_set_tag, node_content)}
"
+
+ case "som":
+ text_meta_data = observation_metadata["text"]
+ if action["element_id"] in text_meta_data["obs_nodes_info"]:
+ node_content = text_meta_data["obs_nodes_info"][
+ action["element_id"]
+ ]
+ else:
+ node_content = "No match found"
+ action_str = f"{action['raw_prediction']}
"
+ action_str += f""
+ action_str += f"{action2str(action, action_set_tag, node_content)}
"
+
+ case "playwright":
+ action_str = action["pw_code"]
+ case "webrl_id":
+ action_str = action["raw_prediction"]
+ case _:
+ raise ValueError(
+ f"Unknown action type {action['action_type'], action_set_tag}"
+ )
+ return action_str
+
+
+def get_action_description(
+ action: Action,
+ observation_metadata: dict[str, ObservationMetadata],
+ action_set_tag: str,
+ prompt_constructor: PromptConstructor | None,
+) -> str:
+ """Generate the text version of the predicted actions to store in action history for prompt use.
+ May contain hint information to recover from the failures"""
+
+ match action_set_tag:
+ case "id_accessibility_tree":
+ text_meta_data = observation_metadata["text"]
+ if action["action_type"] in [
+ ActionTypes.CLICK,
+ ActionTypes.HOVER,
+ ActionTypes.TYPE,
+ ]:
+ action_name = str(action["action_type"]).split(".")[1].lower()
+ if action["element_id"] in text_meta_data["obs_nodes_info"]:
+ node_content = text_meta_data["obs_nodes_info"][
+ action["element_id"]
+ ]["text"]
+ node_content = " ".join(node_content.split()[1:])
+ action_str = action2str(
+ action, action_set_tag, node_content
+ )
+ else:
+ action_str = f"Attempt to perfom \"{action_name}\" on element \"[{action['element_id']}]\" but no matching element found. Please check the observation more carefully."
+ else:
+ if (
+ action["action_type"] == ActionTypes.NONE
+ and prompt_constructor is not None
+ ):
+ action_splitter = prompt_constructor.instruction[
+ "meta_data"
+ ]["action_splitter"]
+ action_str = f'The previous prediction you issued was "{action["raw_prediction"]}". However, the format was incorrect. Ensure that the action is wrapped inside a pair of {action_splitter} and enclose arguments within [] as follows: {action_splitter}action [arg] ...{action_splitter}.'
+ else:
+ action_str = action2str(action, action_set_tag, "")
+
+ case "som":
+ text_meta_data = observation_metadata["image"]
+ if action["action_type"] in [
+ ActionTypes.CLICK,
+ ActionTypes.HOVER,
+ ActionTypes.TYPE,
+ ]:
+ action_name = str(action["action_type"]).split(".")[1].lower()
+ if action["element_id"] in text_meta_data["obs_nodes_info"]:
+ action_str = action2str(action, action_set_tag, "")
+ else:
+ print(
+ 'action["element_id"], text_meta_data["obs_nodes_info"]',
+ action["element_id"],
+ text_meta_data["obs_nodes_info"],
+ )
+ action_str = f"Attempt to perfom \"{action_name}\" on element \"[{action['element_id']}]\" but no matching element found. Please check the observation more carefully."
+ else:
+ if (
+ action["action_type"] == ActionTypes.NONE
+ and prompt_constructor is not None
+ ):
+ action_splitter = prompt_constructor.instruction[
+ "meta_data"
+ ]["action_splitter"]
+ action_str = f'The previous prediction you issued was "{action["raw_prediction"]}". However, the format was incorrect. Ensure that the action is wrapped inside a pair of {action_splitter} and enclose arguments within [] as follows: {action_splitter}action [arg] ...{action_splitter}.'
+ else:
+ action_str = action2str(action, action_set_tag, "")
+
+ case "playwright":
+ action_str = action["pw_code"]
+
+ case "webrl_id":
+ action_str = action["raw_prediction"]
+
+ case _:
+ raise ValueError(f"Unknown action type {action['action_type']}")
+
+ return action_str
+
+
+class RenderHelper(object):
+ """Helper class to render text and image observations and meta data in the trajectory"""
+
+ def __init__(
+ self, config_file: str, result_dir: str, action_set_tag: str
+ ) -> None:
+ with open(config_file, "r") as f:
+ _config = json.load(f)
+ _config_str = ""
+ for k, v in _config.items():
+ _config_str += f"{k}: {v}\n"
+ _config_str = f"{_config_str}
\n"
+ task_id = _config["task_id"]
+
+ self.action_set_tag = action_set_tag
+
+ self.render_file = open(
+ Path(result_dir) / f"render_{task_id}.html", "a+", encoding="utf-8"
+ )
+ self.render_file.truncate(0)
+ # write init template
+ self.render_file.write(HTML_TEMPLATE.format(body=f"{_config_str}"))
+ self.render_file.read()
+ self.render_file.flush()
+
+ def render(
+ self,
+ action: Action,
+ state_info: StateInfo,
+ meta_data: dict[str, Any],
+ render_screenshot: bool = False,
+ ) -> None:
+ """Render the trajectory"""
+ # text observation
+ observation = state_info["observation"]
+ text_obs = observation["text"]
+ info = state_info["info"]
+ new_content = f"New Page
\n"
+ new_content += f"\n"
+ new_content += f"{text_obs}
\n"
+
+ if render_screenshot:
+ # image observation
+ img_obs = observation["image"]
+ image = Image.fromarray(img_obs)
+ byte_io = io.BytesIO()
+ image.save(byte_io, format="PNG")
+ byte_io.seek(0)
+ image_bytes = base64.b64encode(byte_io.read())
+ image_str = image_bytes.decode("utf-8")
+ new_content += f"

\n"
+
+ # meta data
+ new_content += f"
{meta_data['action_history'][-1]}
\n"
+
+ # action
+ action_str = get_render_action(
+ action,
+ info["observation_metadata"],
+ action_set_tag=self.action_set_tag,
+ )
+ # with yellow background
+ action_str = f"
{action_str}
"
+ new_content += f"{action_str}\n"
+
+ # add new content
+ self.render_file.seek(0)
+ html = self.render_file.read()
+ html_body = re.findall(r"(.*?)", html, re.DOTALL)[0]
+ html_body += new_content
+
+ html = HTML_TEMPLATE.format(body=html_body)
+ self.render_file.seek(0)
+ self.render_file.truncate()
+ self.render_file.write(html)
+ self.render_file.flush()
+
+ def close(self) -> None:
+ self.render_file.close()
diff --git a/VAB-WebArena-Lite/new/helper_functions.py b/VAB-WebArena-Lite/new/helper_functions_eval.py
similarity index 100%
rename from VAB-WebArena-Lite/new/helper_functions.py
rename to VAB-WebArena-Lite/new/helper_functions_eval.py
diff --git a/VAB-WebArena-Lite/new/html_tools/__init__.py b/VAB-WebArena-Lite/new/html_tools/__init__.py
new file mode 100755
index 0000000..6aac581
--- /dev/null
+++ b/VAB-WebArena-Lite/new/html_tools/__init__.py
@@ -0,0 +1,7 @@
+from .identifier import IdentifierTool
+from .prompt import HtmlPrompt
+from .html_parser import HtmlParser
+
+from .utils import print_html_object
+from .configs import basic_attrs, mind2web_keep_attrs
+from .fetch import get_parsed_html
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/html_tools/configs/__init__.py b/VAB-WebArena-Lite/new/html_tools/configs/__init__.py
new file mode 100755
index 0000000..b900a67
--- /dev/null
+++ b/VAB-WebArena-Lite/new/html_tools/configs/__init__.py
@@ -0,0 +1,3 @@
+from .html_prompt import prompts
+from .config import basic_attrs, mind2web_keep_attrs, miniwob_attrs
+from .config import config_meta
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/html_tools/configs/config.py b/VAB-WebArena-Lite/new/html_tools/configs/config.py
new file mode 100755
index 0000000..d7da96c
--- /dev/null
+++ b/VAB-WebArena-Lite/new/html_tools/configs/config.py
@@ -0,0 +1,56 @@
+basic_attrs = [
+ 'title',
+ 'value',
+ 'type',
+ 'placeholder',
+ 'selected',
+ 'data-value',
+ 'data-text',
+ 'data-testid',
+ 'data-label',
+ 'data-bbox',
+ 'data-status'
+]
+
+mind2web_keep_attrs = [
+ 'alt',
+ 'aria_description',
+ 'aria_label',
+ 'aria_role',
+ 'input_checked',
+ 'input_value',
+ 'label',
+ 'name',
+ 'option_selected',
+ 'placeholder',
+ 'role',
+ 'text_value',
+ 'title',
+ 'type',
+ 'value',
+]
+
+miniwob_attrs = [
+ 'id',
+ 'type',
+ 'value',
+]
+
+config_meta = """
+======= Configs =======
+Columns:
+ - id: {id_attr}
+ - label: {label_attr}
+Position: {use_position}
+ - window: {window_size}
+ - rect_dict: {rect}
+Keep:
+ - parents: {parent_chain}
+ - attrs: {keep_attrs}
+ - elems: {keep_elem}
+ - obs_elem: {obs_elem}
+Generator:
+ - prompt: {prompt_name}
+ - label: {identifier_name}
+========================
+"""
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/html_tools/configs/html_prompt.py b/VAB-WebArena-Lite/new/html_tools/configs/html_prompt.py
new file mode 100755
index 0000000..904d021
--- /dev/null
+++ b/VAB-WebArena-Lite/new/html_tools/configs/html_prompt.py
@@ -0,0 +1,22 @@
+refine_prompt = {
+ 'dom': '<{tag}{label}|{attr}{content}{subtree} >',
+ 'label': '[{label}]',
+ 'attr': '{attr}',
+ 'attr_splitter': '; ',
+ 'subtree_splitter': ' ',
+}
+
+xml_prompt = {
+ 'dom': '<{tag}{label}{attr}>{content}{subtree} {tag}>',
+ 'label': ' id="{label}"',
+ 'attr': '{key}="{attr}"',
+ 'attr_splitter': ' ',
+ 'subtree_splitter': ' ',
+}
+
+prompts = {
+ 'refine': refine_prompt,
+ 'xml': xml_prompt,
+ 'new_data': refine_prompt,
+}
+
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/html_tools/fetch.py b/VAB-WebArena-Lite/new/html_tools/fetch.py
new file mode 100755
index 0000000..25d5637
--- /dev/null
+++ b/VAB-WebArena-Lite/new/html_tools/fetch.py
@@ -0,0 +1,108 @@
+import os
+import json
+import base64
+from .html_parser import HtmlParser
+from .configs import basic_attrs
+from .scripts import *
+
+def get_window(page):
+ x = page.evaluate("window.scrollX")
+ y = page.evaluate("window.scrollY")
+ w = page.evaluate("window.innerWidth")
+ h = page.evaluate("window.innerHeight")
+ return (x, y, w, h)
+
+def modify_page(page):
+ page.wait_for_timeout(500)
+
+ try:
+ page.evaluate(remove_id_script)
+ except:
+ pass
+
+ packet = {
+ "raw_html": page.evaluate("document.documentElement.outerHTML"),
+ "window": get_window(page)
+ }
+
+ page.evaluate(prepare_script)
+ page.wait_for_timeout(100)
+
+ img_bytes = page.screenshot(path="debug_info/screenshot_raw.png")
+ raw_image = base64.b64encode(img_bytes).decode()
+
+ page.evaluate(clickable_checker_script)
+ page.wait_for_timeout(50)
+
+ # get all clickable elements
+ start_id = 0
+ items, start_id = page.evaluate(label_script, {
+ "selector": ".possible-clickable-element",
+ "startIndex": start_id
+ })
+ page.wait_for_timeout(50)
+
+ # mark our own labels and get the images
+ items = page.evaluate(label_marker_script, items)
+ page.wait_for_timeout(100)
+ img_bytes = page.screenshot(path="debug_info/marked.png")
+ marked_image = base64.b64encode(img_bytes).decode()
+
+ # remove markers on the page
+ page.evaluate(remove_label_mark_script)
+
+ packet.update({
+ "raw_image": raw_image,
+ "marked_image": marked_image,
+ "modified_html": page.evaluate("document.documentElement.outerHTML")
+ })
+
+ # element_info, include "all_elements" and "clickable_elements"
+ element_info = page.evaluate(element_info_script)
+ page.wait_for_timeout(100)
+ packet.update(element_info)
+ return packet
+
+def save_debug_info(packet):
+ with open("debug_info/raw.html", "w") as f:
+ f.write(packet["modified_html"])
+ with open("debug_info/parsed.html", "w") as f:
+ f.write(packet["html"])
+ with open("debug_info/all_element.json", "w") as f:
+ f.write(json.dumps(packet["all_elements"]))
+
+def get_parsed_html(page):
+ if not os.path.exists("debug_info"):
+ os.makedirs("debug_info")
+
+ print("parsing html...")
+
+ packet = modify_page(page)
+ raw_html = packet["modified_html"]
+
+ args = {
+ "use_position": True,
+ "rect_dict": {},
+ "window_size": packet["window"],
+ "id-attr": "data-backend-node-id",
+ "label_attr": "data-label-id",
+ "label_generator": "order",
+ "regenerate_label": False,
+ "attr_list": basic_attrs,
+ "prompt": "xml",
+ "dataset": "pipeline"
+ }
+
+ hp = HtmlParser(raw_html, args)
+ res = hp.parse_tree()
+ page_html = res.get("html", "")
+
+ packet["html"] = page_html
+
+ # for debug
+ save_debug_info(packet)
+
+ print("parsing finished.")
+
+ return packet
+
diff --git a/VAB-WebArena-Lite/new/html_tools/html_parser.py b/VAB-WebArena-Lite/new/html_tools/html_parser.py
new file mode 100755
index 0000000..080342f
--- /dev/null
+++ b/VAB-WebArena-Lite/new/html_tools/html_parser.py
@@ -0,0 +1,447 @@
+from lxml import html
+import time, copy, random
+import json, re, os
+
+from .identifier import IdentifierTool
+from .prompt import HtmlPrompt
+from .configs import config_meta
+from .utils import get_xpath_top_down, rect2tuple
+
+class HtmlParser():
+ def __init__(self, ctx: str, args: dict[str]={}) -> None:
+ stt = time.time()
+ self.dom_tree = self.ctx2tree(ctx)
+ # tool related
+ self.bids2label = {}
+ self.bids2xpath = {}
+ self.used_labels = {}
+
+ # parse args
+ self.parse_args(args)
+ self.init_time = time.time() - stt
+
+ def parse_args(self, args: dict[str]={}) -> None:
+ def attr_check(attr, type_model='str'):
+ if attr is None:
+ return False
+ attr_type = type(attr)
+ if attr_type != type(type_model):
+ return False
+ if attr_type == type('str') and len(attr) == 0:
+ return False
+ return True
+
+ args = {} if args is None else args
+
+ # [Position] use_pos: False -> use full page, otherwise use window_size
+ dataset = args.get('dataset', '')
+ use_position = args.get('use_position', False)
+ window_size = args.get('window_size', None)
+ rect = args.get('rect_dict', None)
+ if use_position:
+ if not attr_check(window_size, ()):
+ raise ValueError('window_size must be set when use_position is True')
+ if not attr_check(rect, {}):
+ raise ValueError('rect_dict must be set when use_position is True')
+
+ if not attr_check(rect, {}):
+ rect = {}
+
+ # [Label] for vimium is temp_clickable_label, otherwise keep all of it
+ label_attr = args.get('label_attr', '')
+ get_new_label = args.get('regenerate_label', False)
+ label_method = args.get('label_generator', None)
+ regen_label = not attr_check(label_method)
+
+ # [id] for mind2web is backend_node_id, for normal website use our method
+ id_attr = args.get('id_attr', '')
+ regen_id = not attr_check(id_attr)
+
+ if regen_id:
+ id_attr = 'temp_id'
+
+ # [attributes]
+ keep_attrs = args.get('attr_list', [])
+ if not attr_check(keep_attrs, []):
+ keep_attrs = []
+
+ # [Tags] for clickable elem, keep: must keep, obs: keep if follow specific rule
+ parent_chain = args.get('parent_chain', False)
+ keep_elem = args.get('keep_elem', [])
+ obs_elem = args.get('obs_elem', [])
+
+ # sanity check
+ self.set_args(use_position, window_size, rect, label_attr, id_attr, keep_attrs, keep_elem, obs_elem, parent_chain, get_new_label, dataset)
+
+ # [Prompt]
+ prompt = args.get('prompt', None)
+ self.prompt = HtmlPrompt(prompt)
+
+ # traverse and get special data
+ if regen_id or regen_label:
+ self.mark_id()
+
+ if get_new_label:
+ self.used_labels = {}
+
+ self.identifier = IdentifierTool(label_method, self.used_labels)
+
+ def set_args(self, use_position: bool=False, window_size: tuple=(), rect_dict: dict[str]={}, label_attr: str='',
+ id_attr: str='', keep_attrs: list[str]=[], keep_elem: list[str]=[], obs_elem: list[str]=[],
+ parent_chain: bool=False, get_new_label: bool=False, dataset: str='') -> None:
+
+ self.use_position = use_position
+ self.window_size = window_size
+ self.rect = rect_dict
+ self.label_attr = label_attr
+ self.id_attr = id_attr
+ self.keep_attrs = keep_attrs
+ self.keep = keep_elem
+ self.obs = obs_elem
+ self.parent_chain = parent_chain
+ self.get_new_label = get_new_label
+ self.dataset = dataset
+
+ def get_config(self):
+ config = {
+ 'id_attr': self.id_attr,
+ 'keep_attrs': self.keep_attrs[:5],
+ 'label_attr': self.label_attr,
+ 'use_position': self.use_position,
+ 'window_size': self.window_size,
+ 'rect': dict(list(self.rect.items())[:3]),
+ 'keep_elem': self.keep[:5],
+ 'obs_elem': self.obs[:5],
+ 'parent_chain': self.parent_chain,
+ 'prompt_name': self.prompt.name,
+ 'identifier_name': self.identifier.name
+ }
+
+ return config, config_meta.format(**config)
+
+ def update_rect_dict(self, rect_dict: dict[str]={}) -> None:
+ self.rect = rect_dict
+
+ @staticmethod
+ def ctx2tree(ctx: str) -> html.HtmlElement:
+ # remove useless tags, eg. style and script
+ ctx = re.sub('', '', ctx)
+ ctx = re.sub('', '', ctx)
+ ctx = re.sub('', '', ctx)
+ ctx = '' if ctx is None else re.sub(r'\s+', ' ', ctx).strip()
+ dom_tree = html.fromstring(ctx.encode('utf-8'))
+ match = re.search('
html.HtmlElement:
+ node = tree.xpath('//*')[0]
+ while True:
+ parent = node.getparent()
+ if parent is None:
+ break
+ node = parent
+ return node
+
+ def get_node_by_bid(self, tree: html.HtmlElement, bid: str) -> html.HtmlElement:
+ nodes = tree.xpath(f'//*[@{self.id_attr}="{bid}"]')
+ if len(nodes) == 0:
+ return None
+ return nodes[0]
+
+ def id_label_converter(self, label: str) -> str:
+ return self.bids2label.get(label, '')
+
+ def id_xpath_converter(self, label: str) -> str:
+ return self.bids2xpath.get(label, '')
+
+ def mark_id(self) -> None:
+ root = self.get_root(self.dom_tree)
+ _, i2xpath, used_labels = get_xpath_top_down(root, self.id_attr, self.label_attr)
+ self.used_labels = used_labels
+ self.bids2xpath = i2xpath
+
+ def parse(self, root: html.HtmlElement, keep: list[str], obs: list[str], parent_chain: bool=False, get_new_label: bool=False) -> dict[str]:
+ def get_text(str: str) -> str:
+ return '' if str is None else str.strip()[:500]
+
+ def check_attr(attr: str, node: html.HtmlElement) -> bool:
+ tag = node.tag
+ if (
+ ( attr == 'role' and node.attrib.get(attr, '') in ['presentation', 'none', 'link'] )
+ or ( attr == 'type' and node.attrib.get(attr, '') == 'hidden' )
+ or ( attr == 'aria-hidden' and node.attrib.get(attr, '') == 'true')
+ # or ( attr == 'value' and tag in ['option'] )
+ ):
+ return False
+ return True
+
+ def is_visible(node: html.HtmlElement, bid: str) -> bool:
+ if self.dataset == 'mind2web':
+ bound = node.attrib.get('bounding_box_rect', None)
+ self.rect[bid] = rect2tuple(bound)
+
+ if self.dataset == 'pipeline':
+ bound = node.attrib.get('data-bbox', None)
+ self.rect[bid] = rect2tuple(bound)
+
+ if node.attrib.get('aria-hidden', 'false') == 'true':
+ return False
+
+ if not self.use_position:
+ return True
+
+ rect = self.rect.get(bid, None)
+ if rect is None:
+ return False
+
+ if self.window_size is None:
+ return True
+
+ # get window size
+ wx, wy, ww, wh = self.window_size
+ wx, wy = 0, 0
+ x, y, w, h = rect
+ if x + w < wx or x > wx + ww or y + h < wy or y > wy + wh:
+ return False
+
+ return True
+
+ def _dfs(node: html.HtmlElement, keep: list[str]=[], obs: list[str]=[],
+ parent_chain: bool=False, get_new_label: bool=False, par_keep: bool=False) -> (str, dict[str]):
+ # basic information
+ bid = node.attrib.get(self.id_attr, '')
+ tag = node.tag
+ label = node.attrib.get(self.label_attr, '')
+
+ # element which is keeped equivalent to visible
+ visible = is_visible(node, bid)
+ in_keep_list = bid in keep
+ in_obs_list = (bid in obs or len(label) > 0) and visible
+ par_keep = par_keep and tag == "option"
+ keep_element = in_keep_list or in_obs_list or visible or par_keep
+
+ if label:
+ keep_element = True
+
+ # mark label
+ bids2label, labeled_elems = {}, []
+ have_label = False
+ if in_keep_list or in_obs_list:
+ if label is None or len(label) == 0 or get_new_label:
+ label = self.identifier.generate()
+ node.attrib[self.label_attr] = label
+ bids2label[bid] = label
+ bids2label[label] = bid
+ have_label = True
+
+ # get text or alt_text of current element
+ text = get_text(node.text)
+
+ classes = {}
+ # keep attributes if needed
+ keep_all_attrs = len(self.keep_attrs) == 0
+ keep_attrs = node.attrib.keys() if keep_all_attrs else self.keep_attrs
+
+ # traverse attributes
+ for attr in keep_attrs:
+ if attr not in node.attrib or not check_attr(attr, node):
+ continue
+ if attr in [self.id_attr, self.label_attr]:
+ continue
+ val = get_text(node.attrib[attr])
+ if len(val) > 0 or keep_all_attrs:
+ classes[attr] = val
+
+ have_text = len(text) > 0 or len(classes) - (1 if 'data-bbox' in classes else 0) > 0
+ par_keep = keep_element and tag == 'select'
+
+ parts = []
+ clickable_count = 0
+ children = node.getchildren()
+ for child in children:
+ cres, cmsg = _dfs(child, keep, obs, parent_chain, get_new_label, par_keep)
+ clickable_count += 1 if cmsg.get('have_clickable', False) else 0
+ bids2label.update(cmsg.get('bids2label', {}))
+ labeled_elems.extend(cmsg.get('label_element', []))
+ if len(cres) != 0:
+ parts.append(cres)
+
+ dom = self.prompt.subtree_constructor(parts)
+
+ # remove
if all children are text
+ keep_as_all_text = (dom.count('<') == dom.count(' 0
+ if keep_as_all_text:
+ matches = re.findall(r']+) >', dom)
+ dom = self.prompt.subtree_constructor(matches)
+
+ keep_element = keep_element and (clickable_count > 1 or have_text or have_label or keep_as_all_text)
+ keep_as_parent = len(dom) > 0 and parent_chain
+ if in_keep_list or keep_element or keep_as_parent:
+ dom = self.prompt.prompt_constructor(tag, label, text, dom, classes)
+
+ if have_label:
+ labeled_elems.append(bid)
+
+ control_msg = {
+ 'have_clickable': bool(clickable_count or have_text),
+ 'bids2label': bids2label,
+ 'label_element': labeled_elems,
+ }
+
+ return dom, control_msg
+
+ dom, cmsg = _dfs(root, keep, obs, parent_chain, get_new_label)
+ return dom, cmsg
+
+ def parse_tree(self) -> dict[str]:
+ # start from here
+ stt = time.time()
+ root = self.get_root(self.dom_tree)
+ dom, cmsg = self.parse(root, self.keep, self.obs, self.parent_chain, self.get_new_label)
+ self.bids2label = cmsg.get('bids2label', {})
+ self.keep = list(set(self.keep + cmsg.get('label_element', [])))
+
+ obj = {
+ 'html': dom,
+ 'parse_time': time.time() - stt
+ }
+
+ return obj
+
+ # From mind2web, https://github.com/OSU-NLP-Group/Mind2Web/blob/main/src/data_utils/dom_utils.py
+ def get_keep_elements(self, tree: html.HtmlElement, keep: list[str], max_depth: int, max_children: int,
+ max_sibling: int, dfs_count: int=1, keep_parent: bool=False) -> list[str]:
+ def get_anscendants(node: html.HtmlElement, max_depth: int, current_depth: int=0) -> list[str]:
+ if current_depth > max_depth:
+ return []
+
+ anscendants = []
+ parent = node.getparent()
+ if parent is not None:
+ anscendants.append(parent)
+ anscendants.extend(get_anscendants(parent, max_depth, current_depth + 1))
+
+ return anscendants
+
+ def get_descendants(node: html.HtmlElement, max_depth: int, current_depth: int=0) -> list[str]:
+ if current_depth > max_depth:
+ return []
+
+ descendants = []
+ for child in node:
+ descendants.append(child)
+ descendants.extend(get_descendants(child, max_depth, current_depth + 1))
+
+ return descendants
+
+ to_keep = set(copy.deepcopy(keep))
+ nodes_to_keep = set()
+
+ for _ in range(max(1, dfs_count)):
+ for bid in to_keep:
+ candidate_node = self.get_node_by_bid(tree, bid)
+ if candidate_node is None:
+ continue
+
+ nodes_to_keep.add(candidate_node.attrib[self.id_attr])
+ # get all ancestors or with max depth
+ nodes_to_keep.update([x.attrib.get(self.id_attr, '') for x in get_anscendants(candidate_node, max_depth)])
+
+ # get descendants with max depth
+ nodes_to_keep.update([x.attrib.get(self.id_attr, '') for x in get_descendants(candidate_node, max_depth)][:max_children])
+ # get siblings within range
+ parent = candidate_node.getparent()
+ if parent is None:
+ continue
+
+ siblings = [x for x in parent.getchildren() if x.tag != 'text']
+ if candidate_node not in siblings:
+ continue
+
+ idx_in_sibling = siblings.index(candidate_node)
+ nodes_to_keep.update([x.attrib.get(self.id_attr, '')
+ for x in siblings[max(0, idx_in_sibling - max_sibling) : idx_in_sibling + max_sibling + 1]])
+
+ max_children = int(max_children * 0.5)
+ max_depth = int(max_depth * 0.5)
+ max_sibling = int(max_sibling * 0.7)
+
+ to_keep = copy.deepcopy(nodes_to_keep)
+
+ if keep_parent:
+ for bid in keep:
+ candidate_node = self.get_node_by_bid(tree, bid)
+ if candidate_node is None:
+ continue
+ nodes_to_keep.update([x.attrib.get(self.id_attr, '') for x in candidate_node.xpath("ancestor::*")])
+
+ return list(nodes_to_keep)
+
+ def prune(self, tree: html.HtmlElement, nodes_to_keep: list[str]) -> html.HtmlElement:
+ # remove nodes not in nodes_to_keep
+ for node in tree.xpath('//*')[::-1]:
+ if node.tag != 'text':
+ is_keep = node.attrib.get(self.id_attr, '') in nodes_to_keep
+ is_candidate = node.attrib.get(self.id_attr, '') in self.keep
+ else:
+ is_keep = (node.getparent().attrib.get(self.id_attr, '') in nodes_to_keep)
+ is_candidate = (node.getparent().attrib.get(self.id_attr, '') in self.keep)
+
+ if not is_keep and node.getparent() is not None:
+ # insert all children into parent
+ for child in node.getchildren():
+ node.addprevious(child)
+ node.getparent().remove(node)
+ else:
+ # if not is_candidate or node.tag == 'text':
+ # node.attrib.pop(self.id_attr, None)
+ if (
+ len(node.attrib) == 0
+ and not any([x.tag == 'text' for x in node.getchildren()])
+ and node.getparent() is not None
+ and node.tag != "text"
+ and len(node.getchildren()) <= 1
+ ):
+ # insert all children into parent
+ for child in node.getchildren():
+ node.addprevious(child)
+ node.getparent().remove(node)
+
+ return tree
+
+ def prune_tree(self, dfs_count: int=1, max_depth: int=3, max_children: int=30,
+ max_sibling: int=3, keep_parent: bool=False) -> None:
+ # clone the tree
+ new_tree = copy.deepcopy(self.dom_tree)
+ nodes_to_keep = self.get_keep_elements(new_tree, self.keep, max_depth, max_children, max_sibling, dfs_count, keep_parent)
+ new_tree = self.prune(new_tree, nodes_to_keep)
+
+ self.dom_tree = new_tree
+
+ def get_segment(self, bid: str) -> str:
+ # clone the tree
+ new_tree = copy.deepcopy(self.dom_tree)
+ nodes_to_keep = self.get_keep_elements(new_tree, [bid], 0, 2, 1)
+ new_tree = self.prune(new_tree, nodes_to_keep)
+ dom, _ = self.parse(new_tree, self.keep, [], False)
+ return dom
+
+ def get_rect_data(self, bids: list[str]) -> list[dict[str]]:
+ res = []
+ for bid in bids:
+ label = self.bids2label.get(bid, '')
+ rect = self.rect.get(bid, None)
+ res.append({
+ 'bid': bid,
+ 'label': label,
+ 'rect': rect
+ })
+ return res
+
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/html_tools/identifier.py b/VAB-WebArena-Lite/new/html_tools/identifier.py
new file mode 100755
index 0000000..793ebed
--- /dev/null
+++ b/VAB-WebArena-Lite/new/html_tools/identifier.py
@@ -0,0 +1,64 @@
+import secrets
+
+class IdentifierTool:
+ def __init__(self, method: str='order', existing_labels: dict[str]={}) -> None:
+ self.methods = {
+ 'order': self.get_identifier_in_order,
+ 'random': self.get_random_identifier,
+ }
+
+ if method is None:
+ method = 'order'
+
+ self.func = self.methods.get(method, None)
+ self.name = method
+ if self.func is None:
+ raise ValueError(f'Invalid method for identifier: {method}')
+
+ self.reset(existing_labels)
+
+ def reset(self, exists: dict[str]={}) -> None:
+ self.identifier = -1
+ self.exists = {} if exists is None else exists
+
+ def get_identifier_in_order(self) -> str:
+ def id2str(id: int) -> str:
+ if id < 26:
+ return chr(id + 65)
+ id -= 26
+ c0 = id // 676
+ c1 = (id // 26) % 26
+ c2 = id % 26
+ label = f'{chr(c1 + 65)}{chr(c2 + 65)}'
+ return label if c0 == 0 else f'{chr(c0 + 64)}{label}'
+
+ self.identifier += 1
+ label = id2str(self.identifier)
+
+ while label in self.exists:
+ self.identifier += 1
+ label = id2str(self.identifier)
+
+ self.exists[label] = True
+ return label
+
+ def get_random_identifier(self) -> str:
+ secret_generator = secrets.SystemRandom()
+
+ def get_random_label(n: int=2) -> str:
+ tmp = ''
+ for _ in range(n):
+ tmp += chr(secret_generator.randint(65, 90))
+ return tmp
+
+ wc = 3 if len(self.exists) > 280 else 2
+
+ label = get_random_label(wc)
+ while label in self.exists:
+ label = get_random_label(wc)
+
+ self.exists[label] = True
+ return label
+
+ def generate(self):
+ return self.func()
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/html_tools/prompt.py b/VAB-WebArena-Lite/new/html_tools/prompt.py
new file mode 100755
index 0000000..38d6b94
--- /dev/null
+++ b/VAB-WebArena-Lite/new/html_tools/prompt.py
@@ -0,0 +1,97 @@
+from .configs import prompts
+
+class HtmlPrompt:
+ def __init__(self, prompt: str='') -> None:
+ prompt = self.extract(prompt, 'xml')
+ if prompt not in prompts:
+ raise Exception('Unknown prompt: ' + prompt)
+
+ constructors = {
+ 'refine': self.normal_prompt_constructor,
+ 'xml': self.normal_prompt_constructor,
+ 'new_data': self.new_data_prompt_constructor,
+ }
+
+ self.name = prompt
+ self.prompt = prompts[prompt]
+ self.constructor = constructors[prompt]
+
+ @staticmethod
+ def extract(data, default=''):
+ return data if data is not None else default
+
+ def subtree_constructor(self, subtree: list[str]=[]) -> str:
+ return self.prompt['subtree_splitter'].join(subtree)
+
+ def normal_prompt_constructor(self, tag: str='', label: str='', content: str='', subtree_str: str='', class_dict: dict[str]={}) -> str:
+ def add_prefix(data, prefix):
+ return prefix + data if len(data) > 0 else ''
+
+ tag = self.extract(tag)
+ label = self.extract(label)
+ content = self.extract(content)
+ subtree_str = self.extract(subtree_str, '')
+ class_dict = self.extract(class_dict, {})
+
+ label_str = ''
+ if len(label) > 0:
+ label_str = self.prompt['label'].format(label=label)
+
+ classes = []
+ values = set()
+ for key, val in class_dict.items():
+ if val in values:
+ continue
+ values.add(val)
+ classes.append(self.prompt['attr'].format(key=key, attr=val))
+ classes_str = self.prompt['attr_splitter'].join(classes)
+
+ content_splitter = ' ' if len(classes_str) == 0 else self.prompt['attr_splitter']
+ classes_str = add_prefix(classes_str, ' ')
+ content_str = add_prefix(content, content_splitter)
+ subtree_str = add_prefix(subtree_str, ' ')
+
+ return self.prompt['dom'].format(tag=tag, label=label_str, attr=classes_str, content=content_str, subtree=subtree_str)
+
+ def new_data_prompt_constructor(self, tag: str='', label: str='', content: str='', subtree_str: str='', class_dict: dict[str]={}) -> str:
+ def add_prefix(data, prefix):
+ return prefix + data if len(data) > 0 else ''
+
+ tag = self.extract(tag)
+ label = self.extract(label)
+ content = self.extract(content)
+ subtree_str = self.extract(subtree_str, '')
+ class_dict = self.extract(class_dict, {})
+
+ label_str = ''
+ if len(label) > 0:
+ label_str = self.prompt['label'].format(label=label)
+
+ classes = []
+ values = set()
+
+ message = []
+ for key, val in class_dict.items():
+ if val == '':
+ message.append(key)
+ continue
+ if val in values:
+ continue
+ values.add(val)
+ classes.append(self.prompt['attr'].format(key=key, attr=val))
+
+ if len(message) > 0:
+ message_str = ' '.join(message)
+ classes.append(self.prompt['attr'].format(key='message', attr=message_str))
+
+ classes_str = self.prompt['attr_splitter'].join(classes)
+
+ content_splitter = ' ' if len(classes_str) == 0 else self.prompt['attr_splitter']
+ classes_str = add_prefix(classes_str, ' ')
+ content_str = add_prefix(content, content_splitter)
+ subtree_str = add_prefix(subtree_str, ' ')
+
+ return self.prompt['dom'].format(tag=tag, label=label_str, attr=classes_str, content=content_str, subtree=subtree_str)
+
+ def prompt_constructor(self, tag: str='', label: str='', content: str='', subtree_str: str='', class_dict: dict[str]={}) -> str:
+ return self.constructor(tag, label, content, subtree_str, class_dict)
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/html_tools/scripts/__init__.py b/VAB-WebArena-Lite/new/html_tools/scripts/__init__.py
new file mode 100755
index 0000000..0d36c9a
--- /dev/null
+++ b/VAB-WebArena-Lite/new/html_tools/scripts/__init__.py
@@ -0,0 +1,43 @@
+import os
+from pathlib import Path
+rootdir = Path(__file__).parent
+
+with open(os.path.join(rootdir,'prepare.js'), 'r') as f:
+ prepare_script = f.read()
+
+with open(os.path.join(rootdir, 'clickable_checker.js'), 'r') as f:
+ clickable_checker_script = f.read()
+
+with open(os.path.join(rootdir, 'label.js'), 'r') as f:
+ label_script = f.read()
+
+with open(os.path.join(rootdir, 'element_info.js'), 'r') as f:
+ element_info_script = f.read()
+
+# draw label on page
+with open(os.path.join(rootdir, 'label_marker.js'), 'r') as f:
+ label_marker_script = f.read()
+
+# remove label draw on page
+remove_label_mark_script = """
+ () => {
+ document.querySelectorAll(".our-dom-marker").forEach(item => {
+ document.body.removeChild(item);
+ });
+ }
+"""
+
+remove_id_script = """
+ () => {
+ Array.from(document.getElementsByClassName('possible-clickable-element')).forEach((element) => {
+ element.classList.remove('possible-clickable-element');
+ element.removeAttribute('data-value');
+ element.removeAttribute('data-text');
+ element.removeAttribute('data-label');
+ element.removeAttribute('data-bbox');
+ element.removeAttribute('data-status');
+ element.removeAttribute('data-backend-node-id');
+ element.removeAttribute('data-label-id');
+ });
+ }
+"""
diff --git a/VAB-WebArena-Lite/new/html_tools/scripts/clickable_checker.js b/VAB-WebArena-Lite/new/html_tools/scripts/clickable_checker.js
new file mode 100755
index 0000000..42267f0
--- /dev/null
+++ b/VAB-WebArena-Lite/new/html_tools/scripts/clickable_checker.js
@@ -0,0 +1,148 @@
+() => {
+ var items = Array.prototype.slice.call(
+ document.querySelectorAll('*')
+ ).map(function(element) {
+ var vw = Math.max(document.documentElement.clientWidth || 0, window.innerWidth || 0);
+ var vh = Math.max(document.documentElement.clientHeight || 0, window.innerHeight || 0);
+
+ var rects = [...element.getClientRects()].filter(bb => {
+ var center_x = bb.left + bb.width / 2;
+ var center_y = bb.top + bb.height / 2;
+ var elAtCenter = document.elementFromPoint(center_x, center_y);
+
+ if (!elAtCenter) return false;
+ return elAtCenter === element || element.contains(elAtCenter)
+ }).map(bb => {
+ const rect = {
+ left: Math.max(0, bb.left),
+ top: Math.max(0, bb.top),
+ right: Math.min(vw, bb.right),
+ bottom: Math.min(vh, bb.bottom)
+ };
+ return {
+ ...rect,
+ width: rect.right - rect.left,
+ height: rect.bottom - rect.top
+ }
+ });
+ // var rects = [];
+ var area = rects.reduce((acc, rect) => acc + rect.width * rect.height, 0);
+
+ const tagName = element.tagName.toLowerCase?.() || "";
+ let isClickable = ((element.onclick != null) || window.getComputedStyle(element).cursor == "pointer");
+
+ // Insert area elements that provide click functionality to an img.
+ if (tagName === "img") {
+ let mapName = element.getAttribute("usemap");
+ if (mapName) {
+ const imgClientRects = element.getClientRects();
+ mapName = mapName.replace(/^#/, "").replace('"', '\\"');
+ const map = document.querySelector(`map[name=\"${mapName}\"]`);
+ if (map && (imgClientRects.length > 0)) isClickable = true;
+ }
+ }
+
+ if (!isClickable) {
+ const role = element.getAttribute("role");
+ const clickableRoles = [
+ "button",
+ "tab",
+ "link",
+ "checkbox",
+ "menuitem",
+ "menuitemcheckbox",
+ "menuitemradio",
+ "radio",
+ ];
+ if (role != null && clickableRoles.includes(role.toLowerCase())) {
+ isClickable = true;
+ } else {
+ const contentEditable = element.getAttribute("contentEditable");
+ if (
+ contentEditable != null &&
+ ["", "contenteditable", "true"].includes(contentEditable.toLowerCase())
+ ) {
+ isClickable = true;
+ }
+ }
+ }
+
+ // Check for jsaction event listeners on the element.
+ if (!isClickable && element.hasAttribute("jsaction")) {
+ const jsactionRules = element.getAttribute("jsaction").split(";");
+ for (let jsactionRule of jsactionRules) {
+ const ruleSplit = jsactionRule.trim().split(":");
+ if ((ruleSplit.length >= 1) && (ruleSplit.length <= 2)) {
+ const [eventType, namespace, actionName] = ruleSplit.length === 1
+ ? ["click", ...ruleSplit[0].trim().split("."), "_"]
+ : [ruleSplit[0], ...ruleSplit[1].trim().split("."), "_"];
+ if (!isClickable) {
+ isClickable = (eventType === "click") && (namespace !== "none") && (actionName !== "_");
+ }
+ }
+ }
+ }
+
+ if (!isClickable) {
+ const clickableTags = [
+ "input",
+ "textarea",
+ "select",
+ "button",
+ "a",
+ "iframe",
+ "video",
+ "object",
+ "embed",
+ "details"
+ ];
+ isClickable = clickableTags.includes(tagName);
+ }
+
+ if (!isClickable) {
+ if (tagName === "label")
+ isClickable = (element.control != null) && !element.control.disabled;
+ else if (tagName === "img")
+ isClickable = ["zoom-in", "zoom-out"].includes(element.style.cursor);
+ }
+
+ // An element with a class name containing the text "button" might be clickable. However, real
+ // clickables are often wrapped in elements with such class names. So, when we find clickables
+ // based only on their class name, we mark them as unreliable.
+ const className = element.getAttribute("class");
+ if (!isClickable && className && className.toLowerCase().includes("button")) {
+ isClickable = true;
+ }
+
+ // Elements with tabindex are sometimes useful, but usually not. We can treat them as second
+ // class citizens when it improves UX, so take special note of them.
+ const tabIndexValue = element.getAttribute("tabindex");
+ const tabIndex = tabIndexValue ? parseInt(tabIndexValue) : -1;
+ if (!isClickable && !(tabIndex < 0) && !isNaN(tabIndex)) {
+ isClickable = true;
+ }
+
+ const idValue = element.getAttribute("id");
+ const id = idValue ? idValue.toLowerCase() : "";
+ if (isClickable && area == 0) {
+ const textValue = element.textContent.trim().replace(/\s{2,}/g, ' ');
+ clickable_msg = `${tagName}[id=${id}] ${isClickable} (${area}) ${textValue}`
+ }
+
+ return {
+ element: element,
+ include: isClickable,
+ area,
+ rects,
+ text: element.textContent.trim().replace(/\s{2,}/g, ' ')
+ };
+ }).filter(item =>
+ item.include && (item.area >= 1)
+ );
+
+ items = items.filter(x => !items.some(y => x.element.contains(y.element) && !(x == y)))
+
+ items.forEach(item => {
+ item.element.classList.add('possible-clickable-element');
+ });
+}
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/html_tools/scripts/element_info.js b/VAB-WebArena-Lite/new/html_tools/scripts/element_info.js
new file mode 100755
index 0000000..c396220
--- /dev/null
+++ b/VAB-WebArena-Lite/new/html_tools/scripts/element_info.js
@@ -0,0 +1,34 @@
+() => {
+ function getElementInfo(element) {
+ return {
+ "bid": element.getAttribute("data-backend-node-id") || "",
+ "label": element.getAttribute("data-label-id") || "",
+ "tag": element.tagName.toLowerCase?.() || "",
+ "area": JSON.parse("[" + (element.getAttribute("data-bbox") || "") + "]"),
+ "text": element.innerText?.trim().replace(/\s{2,}/g, " ") || "",
+ "id": element.getAttribute("id") || "",
+ "role": element.getAttribute("role") || "",
+ "aria-label": element.getAttribute("aria-label") || "",
+ "href": element.getAttribute("href") || "",
+ };
+ }
+
+ var all_items = Array.prototype.slice.call(
+ document.querySelectorAll("*")
+ ).map((element) => {
+ return getElementInfo(element);
+ });
+
+ var clickable_items = Array.prototype.slice.call(
+ document.querySelectorAll("*")
+ ).filter(
+ element => element.getAttribute("data-label-id")
+ ).map((element) => {
+ return getElementInfo(element);
+ });
+
+ return {
+ all_elements: all_items,
+ clickable_elements: clickable_items
+ };
+}
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/html_tools/scripts/label.js b/VAB-WebArena-Lite/new/html_tools/scripts/label.js
new file mode 100755
index 0000000..285ca89
--- /dev/null
+++ b/VAB-WebArena-Lite/new/html_tools/scripts/label.js
@@ -0,0 +1,74 @@
+(packet) => {
+ function int2str(index) {
+ var str = "";
+ while (index >= 0) {
+ str = String.fromCharCode(65 + index % 26) + str;
+ index = Math.floor(index / 26) - 1;
+ }
+ return str;
+ };
+
+ selector = packet.selector
+ index = packet.startIndex
+ var items = Array.prototype.slice.call(
+ document.querySelectorAll(selector)
+ );
+
+ var vw = Math.max(document.documentElement.clientWidth || 0, window.innerWidth || 0);
+ var vh = Math.max(document.documentElement.clientHeight || 0, window.innerHeight || 0);
+
+ items = items.filter(
+ x => !items.some(y => x.contains(y) && !(x == y))
+ ).map(element => {
+ var bb = element.getClientRects();
+ var rect = {
+ left: 0,
+ top: 0,
+ right: 0,
+ bottom: 0,
+ width: 0,
+ height: 0
+ };
+ var keep = false;
+ var text = "", id = -1;
+ if (bb.length > 0) {
+ bb = bb[0];
+ rect = {
+ left: Math.max(0, bb.left),
+ top: Math.max(0, bb.top),
+ right: Math.min(vw, bb.right),
+ bottom: Math.min(vh, bb.bottom)
+ };
+ rect = {
+ ...rect,
+ width: rect.right - rect.left,
+ height: rect.bottom - rect.top
+ };
+ if (rect.width > 0 || rect.height > 0) {
+ keep = true;
+ if (index >= 0) {
+ // id = int2str(index++);
+ id = index++;
+ element.setAttribute("data-label-id", id);
+ }
+ var childNodes = element.childNodes;
+
+ for (var i = 0; i < childNodes.length; i++) {
+ if (childNodes[i].nodeType == Node.TEXT_NODE) {
+ text += childNodes[i].textContent;
+ }
+ }
+ }
+ }
+
+ return {
+ keep: true,
+ id,
+ rects: rect,
+ tag: element.tagName.toLowerCase?.() || "",
+ text,//: element.innerText?.trim().replace(/\s{2,}/g, " ") || ""
+ };
+ }).filter(x => x.keep);
+
+ return [items, index];
+}
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/html_tools/scripts/label_marker.js b/VAB-WebArena-Lite/new/html_tools/scripts/label_marker.js
new file mode 100755
index 0000000..6c55af5
--- /dev/null
+++ b/VAB-WebArena-Lite/new/html_tools/scripts/label_marker.js
@@ -0,0 +1,65 @@
+(items) => {
+ function getRandomColor() {
+ var letters = '0123456789ABCDEF';
+ var color = '#';
+ for (var i = 0; i < 6; i++) {
+ color += letters[Math.floor(Math.random() * 16)];
+ }
+ return color;
+ }
+
+ items.filter(
+ item => item.id != ""
+ ).forEach((item) => {
+ const bbox = item.rects;
+ const id_string = `dom-marker-id-${index}`;
+
+ index = item.id;
+
+ outerElement = document.createElement("div");
+ outerElement.classList.add("our-dom-marker");
+ // var borderColor = getRandomColor();
+ var borderColor = "#FFFF00";
+ outerElement.style.outline = `2px dashed ${borderColor}`;
+ outerElement.style.position = "fixed";
+ outerElement.style.left = bbox.left - 2 + "px";
+ outerElement.style.top = bbox.top - 2 + "px";
+ outerElement.style.width = bbox.width + 4 + "px";
+ outerElement.style.height = bbox.height + 4 + "px";
+ outerElement.style.pointerEvents = "none";
+ outerElement.style.boxSizing = "border-box";
+ outerElement.style.zIndex = 2147483647;
+
+ innerElement = document.createElement("div");
+ innerElement.classList.add("our-dom-marker");
+ innerElement.style.outline = `2px dashed #222288`;
+ innerElement.style.position = "fixed";
+ innerElement.style.left = bbox.left + "px";
+ innerElement.style.top = bbox.top + "px";
+ innerElement.style.width = bbox.width + "px";
+ innerElement.style.height = bbox.height + "px";
+ innerElement.style.pointerEvents = "none";
+ innerElement.style.boxSizing = "border-box";
+ innerElement.style.zIndex = 2147483647;
+
+ // Add floating label at the corner
+ var label = document.createElement("span");
+ var topPosition = 25;
+ if (bbox.top < 25) topPosition = bbox.top;
+ label.textContent = index;
+ label.style.position = "absolute";
+ label.style.top = `-${topPosition}px`;
+ label.style.left = "0px";
+ label.style.background = borderColor;
+ label.style.color = "black";
+ label.style.padding = "2px 4px";
+ label.style.fontSize = "16px";
+ label.style.borderRadius = "2px";
+ label.style.fontWeight = "bold";
+ outerElement.appendChild(label);
+
+ document.body.appendChild(outerElement);
+ document.body.appendChild(innerElement);
+ })
+ return items;
+}
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/html_tools/scripts/prepare.js b/VAB-WebArena-Lite/new/html_tools/scripts/prepare.js
new file mode 100755
index 0000000..9e76a62
--- /dev/null
+++ b/VAB-WebArena-Lite/new/html_tools/scripts/prepare.js
@@ -0,0 +1,83 @@
+() => {
+ // mark backend node id
+ var vw = Math.max(document.documentElement.clientWidth || 0, window.innerWidth || 0);
+ var vh = Math.max(document.documentElement.clientHeight || 0, window.innerHeight || 0);
+
+ var backendId = 0;
+ Array.prototype.slice.call(
+ document.querySelectorAll("*")
+ ).forEach((element) => {
+ element.setAttribute("data-backend-node-id", backendId);
+ backendId++;
+
+ var tag = element.tagName.toLowerCase?.() || "";
+ var bb = element.getClientRects();
+ var rect = {
+ left: 0,
+ top: 0,
+ right: 0,
+ bottom: 0,
+ width: 0,
+ height: 0
+ };
+
+ if (bb.length > 0) {
+ bb = bb[0];
+ // rect = {
+ // left: Math.round(Math.max(0, bb.left) * 100) / 100,
+ // top: Math.round(Math.max(0, bb.top) * 100) / 100,
+ // right: Math.round(Math.min(vw, bb.right) * 100) / 100,
+ // bottom: Math.round(Math.min(vh, bb.bottom) * 100) / 100
+ // };
+ rect = {
+ left: (Math.round(bb.left) * 100) / 100,
+ top: (Math.round(bb.top) * 100) / 100,
+ right: (Math.round(bb.right) * 100) / 100,
+ bottom: (Math.round(bb.bottom) * 100) / 100
+ };
+ rect = {
+ ...rect,
+ width: Math.round((rect.right - rect.left) * 100) / 100,
+ height: Math.round((rect.bottom - rect.top) * 100) / 100
+ };
+
+ element.setAttribute("data-bbox", `${rect.left},${rect.top},${rect.width},${rect.height}`);
+ }
+
+ if (element.hasChildNodes()) {
+ let children = Array.prototype.slice.call(element.childNodes);
+ var texts = children.filter(
+ (node) => node.nodeType == Node.TEXT_NODE
+ ).map(
+ (node) => node.textContent.trim().replace(/\s{2,}/g, " ") || ""
+ ).filter(
+ (text) => text.length > 0
+ )
+ element.setAttribute("data-text", texts.join(","));
+ }
+
+ // fix select issue
+ if (tag == "select") {
+ var value = element.value;
+ var text = element.options[element.selectedIndex]?.text || "";
+ element.setAttribute("data-value", value);
+ element.setAttribute("data-text", text);
+ element.options[element.selectedIndex]?.setAttribute("data-status", "selected");
+ }
+
+ if (tag == "input") {
+ var input_type = element.getAttribute("type") || "";
+ if (input_type == "checkbox") {
+ var status = element.checked? "checked" : "not-checked";
+ element.setAttribute("data-status", status);
+ }
+ }
+ });
+
+ // fix input and textarea issue
+ Array.prototype.slice.call(
+ document.querySelectorAll("input, textarea")
+ ).forEach(element => {
+ element.setAttribute("data-value", element.value);
+ });
+}
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/html_tools/utils.py b/VAB-WebArena-Lite/new/html_tools/utils.py
new file mode 100755
index 0000000..9e0be65
--- /dev/null
+++ b/VAB-WebArena-Lite/new/html_tools/utils.py
@@ -0,0 +1,101 @@
+from lxml import html
+def get_xpath_top_down(element: html.HtmlElement, id_column: str='temp_id', label_column: str='temp_clickable_label', path: str='', order: int=0,
+ in_svg: bool=False, temp_id: int=0) -> tuple[int, dict[str, str], dict[str]]:
+ used_labels, i2xpath = {}, {}
+ # path
+ tag = element.tag.lower()
+ in_svg = in_svg or (tag == 'svg')
+
+ if not in_svg and 'id' in element.attrib:
+ node_id = element.attrib['id']
+ path = f'//*[@id="{node_id}"]'
+ else:
+ suffix = f'[{order}]' if order > 0 else ''
+ prefix = f'*[name()="{tag}"]' if in_svg else tag
+ path = path + '/' + prefix + suffix
+
+ # add temp id
+ element.attrib[id_column] = str(temp_id)
+ ori_label = element.attrib.get(label_column, '')
+ if ori_label != '':
+ used_labels[ori_label] = True
+
+ bid = str(temp_id)
+ i2xpath[bid] = path
+ i2xpath[path] = bid
+ i2xpath[f'xpath/{path}'] = bid
+ i2xpath[f'xpath=/{path}'] = bid
+
+ temp_id += 1
+
+ # traverse node
+ children = element.getchildren()
+ tag_dict = {}
+ id_list = []
+ for child in children:
+ ctag = child.tag.lower()
+ if ctag not in tag_dict:
+ tag_dict[ctag] = 0
+ tag_dict[ctag] += 1
+ id_list.append(tag_dict[ctag])
+
+ for cid, child in zip(id_list, children):
+ ctag = child.tag.lower()
+ cod = cid if tag_dict[ctag] > 1 else 0
+ temp_id, i2x, ulabels = get_xpath_top_down(child, id_column, label_column, path, cod, in_svg, temp_id)
+ i2xpath.update(i2x)
+ used_labels.update(ulabels)
+
+ return temp_id, i2xpath, used_labels
+
+def print_html_object(obj: str='') -> str:
+ tab_cnt = 0
+ result, content, sep = '', '', ''
+ last_is_left, last_is_right = False, False
+ for ch in obj:
+ if ch == '<':
+ result += '\n'
+ if len(content.strip()) > 0:
+ result += sep + content.strip() + '\n'
+ result += sep + '<'
+
+ tab_cnt += 1
+ sep = ' ' * tab_cnt
+
+ content = ''
+ last_is_right = False
+ last_is_left = True
+ elif ch == '>':
+ if last_is_left:
+ result += content
+ else:
+ if last_is_right:
+ result += '\n'
+ if len(content.strip()) > 0:
+ result += sep + content.strip() + '\n'
+
+ tab_cnt -= 1
+ sep = ' ' * tab_cnt
+
+ if not last_is_left:
+ result += sep
+
+ result += '>'
+ content = ''
+
+ last_is_right = True
+ last_is_left = False
+ else:
+ content += ch
+
+ return result
+
+def rect2tuple(rect: str) -> tuple[int, int, int, int]:
+ if rect is None or type(rect) != type('str'):
+ return None
+ rect = rect.strip()
+ if rect.count(',') != 3:
+ return None
+ rect = rect.split(',')
+ rect = [float(r) for r in rect]
+ return tuple(rect)
diff --git a/VAB-WebArena-Lite/new/openai_utils.py b/VAB-WebArena-Lite/new/openai_utils.py
index 8374c7a..c9754f2 100644
--- a/VAB-WebArena-Lite/new/openai_utils.py
+++ b/VAB-WebArena-Lite/new/openai_utils.py
@@ -36,7 +36,6 @@ def retry_with_exponential_backoff( # type: ignore
# Initialize variables
num_retries = 0
delay = initial_delay
-
# Loop until a successful response or max_retries is hit or an exception is raised
while True:
try:
@@ -142,27 +141,32 @@ async def agenerate_from_openai_completion(
@retry_with_exponential_backoff
def generate_from_openai_completion(
prompt: str,
- engine: str,
+ model: str,
temperature: float,
max_tokens: int,
top_p: float,
- context_length: int,
stop_token: str | None = None,
+ api_key: str | None = None,
+ base_url: str | None = None
) -> str:
- if "OPENAI_API_KEY" not in os.environ:
+ if "OPENAI_API_KEY" not in os.environ and api_key is None:
raise ValueError(
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
)
-
+ if api_key is not None:
+ client = OpenAI(api_key=api_key, base_url=base_url)
response = client.completions.create(
prompt=prompt,
- engine=engine,
+ model=model,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
stop=[stop_token],
)
- answer: str = response["choices"][0]["text"]
+ try:
+ answer: str = response["choices"][0]["text"]
+ except:
+ answer: str = response.choices[0].text
return answer
diff --git a/VAB-WebArena-Lite/new/p_webrl.json b/VAB-WebArena-Lite/new/p_webrl.json
new file mode 100644
index 0000000..c172455
--- /dev/null
+++ b/VAB-WebArena-Lite/new/p_webrl.json
@@ -0,0 +1,13 @@
+{
+ "intro": "",
+ "examples": [],
+ "template": "",
+ "meta_data": {
+ "observation": "webrl",
+ "action_type": "webrl_id",
+ "keywords": [],
+ "prompt_constructor": "WebRLPromptConstructor",
+ "answer_phrase": "",
+ "action_splitter": ""
+ }
+ }
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/processors.py b/VAB-WebArena-Lite/new/processors.py
new file mode 100644
index 0000000..d3f422d
--- /dev/null
+++ b/VAB-WebArena-Lite/new/processors.py
@@ -0,0 +1,1351 @@
+import json
+import pkgutil
+import re
+from collections import defaultdict
+from dataclasses import dataclass
+from io import BytesIO, StringIO
+from typing import Any, Optional, TypedDict, Union
+from urllib.parse import urljoin, urlparse
+
+import matplotlib.pyplot as plt
+import numpy as np
+import numpy.typing as npt
+import pandas as pd
+import playwright
+import requests
+from gymnasium import spaces
+from PIL import Image, ImageDraw, ImageFont
+from playwright.sync_api import CDPSession, Page, ViewportSize
+from .html_tools.fetch import get_parsed_html
+
+from browser_env.constants import (
+ ASCII_CHARSET,
+ FREQ_UNICODE_CHARSET,
+ IGNORED_ACTREE_PROPERTIES,
+ INJECTED_ATTR_NAME,
+ UTTERANCE_MAX_LENGTH,
+ BID_ATTR,
+ DATA_REGEXP,
+ IN_VIEWPORT_RATIO_THRESHOLD,
+)
+
+from .utils import (
+ AccessibilityTree,
+ AccessibilityTreeNode,
+ BrowserConfig,
+ BrowserInfo,
+ DOMNode,
+ DOMTree,
+ Observation,
+ png_bytes_to_numpy,
+)
+
+
+def remove_unicode(input_string):
+ # Define a regex pattern to match Unicode characters
+ unicode_pattern = re.compile(r"[^\x00-\x7F]+")
+
+ # Use the pattern to replace Unicode characters with an empty string
+ cleaned_string = unicode_pattern.sub("", input_string)
+
+ return cleaned_string
+
+
+class ObservationProcessor:
+ def process(self, page: Page) -> Observation:
+ raise NotImplementedError
+
+
+class ObservationMetadata(TypedDict):
+ obs_nodes_info: dict[str, Any]
+
+
+def create_empty_metadata() -> ObservationMetadata:
+ return {
+ "obs_nodes_info": {},
+ }
+
+
+def extract_data_items_from_aria(string: str) -> tuple[list[str], str]:
+ """
+ Utility function to extract temporary data stored in the "aria-roledescription" attribute of a node
+ """
+
+ match = DATA_REGEXP.fullmatch(string)
+ if not match:
+ return [], string
+
+ groups = match.groups()
+ data_items = groups[:-1]
+ original_aria = groups[-1]
+ return data_items, original_aria
+
+
+class TextObervationProcessor(ObservationProcessor):
+ def __init__(
+ self,
+ observation_type: str,
+ current_viewport_only: bool,
+ viewport_size: ViewportSize,
+ captioning_fn=None,
+ ):
+ self.observation_type = observation_type
+ self.current_viewport_only = current_viewport_only
+ self.viewport_size = viewport_size
+ self.observation_tag = "text"
+ self.meta_data = (
+ create_empty_metadata()
+ ) # use the store meta data of this observation type
+
+ if self.observation_type in [
+ "accessibility_tree_with_captioner",
+ "image_som",
+ ]:
+ self.captioning_fn = captioning_fn
+ # Cache captions.
+ self.url2caption = {}
+
+ def fetch_browser_info(
+ self,
+ page: Page,
+ ) -> BrowserInfo:
+ # extract domtree
+ client = page.context.new_cdp_session(page)
+ tree = client.send(
+ "DOMSnapshot.captureSnapshot",
+ {
+ "computedStyles": [],
+ "includeDOMRects": True,
+ "includePaintOrder": True,
+ },
+ )
+ client.detach()
+
+ # calibrate the bounds, in some cases, the bounds are scaled somehow
+ bounds = tree["documents"][0]["layout"]["bounds"]
+ b = bounds[0]
+ n = b[2] / self.viewport_size["width"]
+ bounds = [[x / n for x in bound] for bound in bounds]
+ tree["documents"][0]["layout"]["bounds"] = bounds
+ # add union bound placeholder
+ tree["documents"][0]["layout"]["unionBounds"] = [None for _ in bounds]
+
+ # extract browser info
+ win_upper_bound = page.evaluate("window.pageYOffset")
+ win_left_bound = page.evaluate("window.pageXOffset")
+ win_width = page.evaluate("window.screen.width")
+ win_height = page.evaluate("window.screen.height")
+ win_right_bound = win_left_bound + win_width
+ win_lower_bound = win_upper_bound + win_height
+ device_pixel_ratio = page.evaluate("window.devicePixelRatio")
+ assert device_pixel_ratio == 1.0, "devicePixelRatio is not 1.0"
+
+ config: BrowserConfig = {
+ "win_upper_bound": win_upper_bound,
+ "win_left_bound": win_left_bound,
+ "win_width": win_width,
+ "win_height": win_height,
+ "win_right_bound": win_right_bound,
+ "win_lower_bound": win_lower_bound,
+ "device_pixel_ratio": device_pixel_ratio,
+ }
+
+ # assert len(tree['documents']) == 1, "More than one document in the DOM tree"
+ info: BrowserInfo = {"DOMTree": tree, "config": config}
+
+ return info
+
+ @staticmethod
+ def get_bounding_client_rect(
+ client: CDPSession, backend_node_id: str
+ ) -> dict[str, Any]:
+ try:
+ remote_object = client.send(
+ "DOM.resolveNode", {"backendNodeId": int(backend_node_id)}
+ )
+ remote_object_id = remote_object["object"]["objectId"]
+ response = client.send(
+ "Runtime.callFunctionOn",
+ {
+ "objectId": remote_object_id,
+ "functionDeclaration": """
+ function() {
+ if (this.nodeType == 3) {
+ var range = document.createRange();
+ range.selectNode(this);
+ var rect = range.getBoundingClientRect().toJSON();
+ range.detach();
+ return rect;
+ } else {
+ return this.getBoundingClientRect().toJSON();
+ }
+ }
+ """,
+ "returnByValue": True,
+ },
+ )
+ return response
+ except Exception as e:
+ return {"result": {"subtype": "error"}}
+
+ @staticmethod
+ def get_element_in_viewport_ratio(
+ elem_left_bound: float,
+ elem_top_bound: float,
+ width: float,
+ height: float,
+ config: BrowserConfig,
+ ) -> float:
+ elem_right_bound = elem_left_bound + width
+ elem_lower_bound = elem_top_bound + height
+
+ win_left_bound = 0
+ win_right_bound = config["win_width"]
+ win_top_bound = 0
+ win_lower_bound = config["win_height"]
+
+ # Compute the overlap in x and y axes
+ overlap_width = max(
+ 0,
+ min(elem_right_bound, win_right_bound)
+ - max(elem_left_bound, win_left_bound),
+ )
+ overlap_height = max(
+ 0,
+ min(elem_lower_bound, win_lower_bound) - max(elem_top_bound, win_top_bound),
+ )
+
+ # Compute the overlap area
+ ratio = overlap_width * overlap_height / width * height
+ return ratio
+
+ def fetch_page_html(
+ self,
+ info: BrowserInfo,
+ page: Page,
+ current_viewport_only: bool,
+ ) -> DOMTree:
+ # adopted from [natbot](https://github.com/nat/natbot)
+ tree = info["DOMTree"]
+ strings = tree["strings"]
+ document = tree["documents"][0]
+ nodes = document["nodes"]
+
+ # make a dom tree that is easier to navigate
+ dom_tree: DOMTree = []
+ graph = defaultdict(list)
+ client = page.context.new_cdp_session(page)
+ for node_idx in range(len(nodes["nodeName"])):
+ cur_node: DOMNode = {
+ "nodeId": "",
+ "nodeType": "",
+ "nodeName": "",
+ "nodeValue": "",
+ "attributes": "",
+ "backendNodeId": "",
+ "parentId": "",
+ "childIds": [],
+ "cursor": 0,
+ "union_bound": None,
+ }
+
+ node_type_idx = nodes["nodeType"][node_idx]
+ node_type = "generic"
+ if node_type_idx >= 0 and node_type_idx < len(strings):
+ node_type = strings[node_type_idx]
+
+ node_name = strings[nodes["nodeName"][node_idx]]
+
+ node_value_idx = nodes["nodeValue"][node_idx]
+ node_value = ""
+ if node_value_idx >= 0 and node_value_idx < len(strings):
+ node_value = " ".join(strings[node_value_idx].split())
+
+ node_attributes = [strings[i] for i in nodes["attributes"][node_idx]]
+ node_attributes_str = ""
+ for i in range(0, len(node_attributes), 2):
+ a = node_attributes[i]
+ b = node_attributes[i + 1]
+ b = " ".join(b.split())
+ node_attributes_str += f'{a}="{b}" '
+ node_attributes_str = node_attributes_str.strip()
+
+ cur_node["nodeId"] = str(node_idx)
+ cur_node["nodeType"] = node_type
+ cur_node["nodeName"] = node_name
+ cur_node["nodeValue"] = node_value
+ cur_node["attributes"] = node_attributes_str
+ cur_node["backendNodeId"] = str(nodes["backendNodeId"][node_idx])
+ cur_node["parentId"] = str(nodes["parentIndex"][node_idx])
+
+ if cur_node["parentId"] != "-1":
+ graph[cur_node["parentId"]].append(str(cur_node["nodeId"]))
+
+ # get the bound
+ if cur_node["parentId"] == "-1":
+ cur_node["union_bound"] = [0.0, 0.0, 10.0, 10.0]
+ else:
+ response = self.get_bounding_client_rect(
+ client, cur_node["backendNodeId"]
+ )
+ if response.get("result", {}).get("subtype", "") == "error":
+ cur_node["union_bound"] = None
+ else:
+ x = response["result"]["value"]["x"]
+ y = response["result"]["value"]["y"]
+ width = response["result"]["value"]["width"]
+ height = response["result"]["value"]["height"]
+ cur_node["union_bound"] = [x, y, width, height]
+
+ dom_tree.append(cur_node)
+
+ client.detach()
+ # add parent children index to the node
+ for parent_id, child_ids in graph.items():
+ dom_tree[int(parent_id)]["childIds"] = child_ids
+
+ # remove the nodes that are not in the current viewport
+ if current_viewport_only:
+
+ def remove_node_in_graph(node: DOMNode) -> None:
+ # update the node information in the accessibility tree
+ node_id = node["nodeId"]
+ parent_id = node["parentId"]
+ child_ids = node["childIds"]
+
+ # update the children of the parent node
+ assert dom_tree[int(parent_id)]["parentId"] != "[REMOVED]"
+ # remove the nodeid from parent
+ index = dom_tree[int(parent_id)]["childIds"].index(node_id)
+ dom_tree[int(parent_id)]["childIds"].pop(index)
+
+ # Insert children_nodeids in the same location
+ for child_id in child_ids:
+ dom_tree[int(parent_id)]["childIds"].insert(index, child_id)
+ index += 1
+
+ # update children node's parent
+ for child_id in child_ids:
+ dom_tree[int(child_id)]["parentId"] = parent_id
+ # mark as removed
+ dom_tree[int(node_id)]["parentId"] = "[REMOVED]"
+
+ config = info["config"]
+ for cursor, node in enumerate(dom_tree):
+ if not node["union_bound"]:
+ remove_node_in_graph(node)
+ continue
+
+ [x, y, width, height] = node["union_bound"]
+
+ # invisible node
+ if width == 0.0 or height == 0.0:
+ remove_node_in_graph(node)
+ continue
+
+ in_viewport_ratio = self.get_element_in_viewport_ratio(
+ elem_left_bound=float(x),
+ elem_top_bound=float(y),
+ width=float(width),
+ height=float(height),
+ config=config,
+ )
+
+ if in_viewport_ratio < IN_VIEWPORT_RATIO_THRESHOLD:
+ remove_node_in_graph(node)
+
+ dom_tree = [
+ node for node in dom_tree if node.get("parentId", "-1") != "[REMOVED]"
+ ]
+
+ return dom_tree
+
+ @staticmethod
+ def parse_html(dom_tree: DOMTree) -> tuple[str, dict[str, Any]]:
+ """Parse the html tree into a string text"""
+
+ obs_nodes_info = {}
+ nodeid_to_cursor = {node["nodeId"]: idx for idx, node in enumerate(dom_tree)}
+
+ def dfs(node_cursor: int, depth: int) -> str:
+ tree_str = ""
+ node = dom_tree[node_cursor]
+ indent = "\t" * depth
+ valid_node = True
+ try:
+ node_str = f"[{node_cursor}] <{node['nodeName']}"
+ if node["attributes"]:
+ node_str += f" {node['attributes']}"
+ node_str += f"> {node['nodeValue']}"
+ valid_node = bool(node["attributes"] or node["nodeValue"])
+
+ if valid_node:
+ obs_nodes_info[str(node_cursor)] = {
+ "backend_id": node["backendNodeId"],
+ "union_bound": node["union_bound"],
+ "text": node_str,
+ }
+ tree_str += f"{indent}{node_str}\n"
+
+ except Exception as e:
+ valid_node = False
+
+ for child_ids in node["childIds"]:
+ child_cursor = nodeid_to_cursor[child_ids]
+ child_depth = depth + 1 if valid_node else depth
+ child_str = dfs(child_cursor, child_depth)
+ tree_str += child_str
+
+ return tree_str
+
+ html = dfs(0, 0)
+ return html, obs_nodes_info
+
+ def fetch_page_accessibility_tree(
+ self,
+ page: Page,
+ info: BrowserInfo,
+ current_viewport_only: bool,
+ ) -> AccessibilityTree:
+ client = page.context.new_cdp_session(page)
+ accessibility_tree: AccessibilityTree = client.send(
+ "Accessibility.getFullAXTree", {}
+ )["nodes"]
+
+ # a few nodes are repeated in the accessibility tree
+ seen_ids = set()
+ _accessibility_tree = []
+ for node in accessibility_tree:
+ if node["nodeId"] not in seen_ids:
+ _accessibility_tree.append(node)
+ seen_ids.add(node["nodeId"])
+ accessibility_tree = _accessibility_tree
+
+ nodeid_to_cursor = {}
+ for cursor, node in enumerate(accessibility_tree):
+ nodeid_to_cursor[node["nodeId"]] = cursor
+ # usually because the node is not visible etc
+ if "backendDOMNodeId" not in node:
+ node["union_bound"] = None
+ continue
+ backend_node_id = str(node["backendDOMNodeId"])
+ if node["role"]["value"] == "RootWebArea":
+ # always inside the viewport
+ node["union_bound"] = [0.0, 0.0, 10.0, 10.0]
+ else:
+ response = self.get_bounding_client_rect(
+ client,
+ backend_node_id
+ )
+ if response.get("result", {}).get("subtype", "") == "error":
+ node["union_bound"] = None
+ else:
+ x = response["result"]["value"]["x"]
+ y = response["result"]["value"]["y"]
+ width = response["result"]["value"]["width"]
+ height = response["result"]["value"]["height"]
+ node["union_bound"] = [x, y, width, height]
+
+ client.detach()
+ # filter nodes that are not in the current viewport
+ if current_viewport_only:
+
+ def remove_node_in_graph(node: AccessibilityTreeNode) -> None:
+ # update the node information in the accessibility tree
+ nodeid = node["nodeId"]
+ node_cursor = nodeid_to_cursor[nodeid]
+ parent_nodeid = node["parentId"]
+ children_nodeids = node["childIds"]
+ parent_cursor = nodeid_to_cursor[parent_nodeid]
+ # update the children of the parent node
+ assert (
+ accessibility_tree[parent_cursor].get("parentId", "Root")
+ is not None
+ )
+ # remove the nodeid from parent's childIds
+ index = accessibility_tree[parent_cursor]["childIds"].index(nodeid)
+ accessibility_tree[parent_cursor]["childIds"].pop(index)
+ # Insert children_nodeids in the same location
+ for child_nodeid in children_nodeids:
+ accessibility_tree[parent_cursor]["childIds"].insert(
+ index, child_nodeid
+ )
+ index += 1
+ # update children node's parent
+ for child_nodeid in children_nodeids:
+ child_cursor = nodeid_to_cursor[child_nodeid]
+ accessibility_tree[child_cursor]["parentId"] = parent_nodeid
+ # mark as removed
+ accessibility_tree[node_cursor]["parentId"] = "[REMOVED]"
+
+ config = info["config"]
+ for node in accessibility_tree:
+ if not node["union_bound"]:
+ remove_node_in_graph(node)
+ continue
+
+ [x, y, width, height] = node["union_bound"]
+
+ # invisible node
+ if width == 0 or height == 0:
+ remove_node_in_graph(node)
+ continue
+
+ in_viewport_ratio = self.get_element_in_viewport_ratio(
+ elem_left_bound=float(x),
+ elem_top_bound=float(y),
+ width=float(width),
+ height=float(height),
+ config=config,
+ )
+
+ if in_viewport_ratio < IN_VIEWPORT_RATIO_THRESHOLD:
+ remove_node_in_graph(node)
+
+ accessibility_tree = [
+ node
+ for node in accessibility_tree
+ if node.get("parentId", "Root") != "[REMOVED]"
+ ]
+
+ return accessibility_tree
+
+ @staticmethod
+ def parse_accessibility_tree(
+ accessibility_tree: AccessibilityTree,
+ ) -> tuple[str, dict[str, Any]]:
+ """Parse the accessibility tree into a string text"""
+ node_id_to_idx = {}
+ for idx, node in enumerate(accessibility_tree):
+ node_id_to_idx[node["nodeId"]] = idx
+
+ obs_nodes_info = {}
+
+ def dfs(idx: int, obs_node_id: str, depth: int) -> str:
+ tree_str = ""
+ node = accessibility_tree[idx]
+ indent = "\t" * depth
+ valid_node = True
+ try:
+ role = node["role"]["value"]
+ name = node["name"]["value"]
+ node_str = f"[{obs_node_id}] {role} {repr(name)}"
+ properties = []
+ for property in node.get("properties", []):
+ try:
+ if property["name"] in IGNORED_ACTREE_PROPERTIES:
+ continue
+ properties.append(
+ f'{property["name"]}: {property["value"]["value"]}'
+ )
+ except KeyError:
+ pass
+
+ if properties:
+ node_str += " " + " ".join(properties)
+
+ # check valid
+ if not node_str.strip():
+ valid_node = False
+
+ # empty generic node
+ if not name.strip():
+ if not properties:
+ if role in [
+ "generic",
+ "img",
+ "list",
+ "strong",
+ "paragraph",
+ "banner",
+ "navigation",
+ "Section",
+ "LabelText",
+ "Legend",
+ "listitem",
+ ]:
+ valid_node = False
+ elif role in ["listitem"]:
+ valid_node = False
+
+ if valid_node:
+ tree_str += f"{indent}{node_str}"
+ obs_nodes_info[obs_node_id] = {
+ "backend_id": node["backendDOMNodeId"],
+ "union_bound": node["union_bound"],
+ "text": node_str,
+ }
+
+ except Exception as e:
+ valid_node = False
+
+ for _, child_node_id in enumerate(node["childIds"]):
+ if child_node_id not in node_id_to_idx:
+ continue
+ # mark this to save some tokens
+ child_depth = depth + 1 if valid_node else depth
+ child_str = dfs(
+ node_id_to_idx[child_node_id], child_node_id, child_depth
+ )
+ if child_str.strip():
+ if tree_str.strip():
+ tree_str += "\n"
+ tree_str += child_str
+
+ return tree_str
+
+ tree_str = dfs(0, accessibility_tree[0]["nodeId"], 0)
+ return tree_str, obs_nodes_info
+
+ @staticmethod
+ def clean_accesibility_tree(tree_str: str) -> str:
+ """further clean accesibility tree"""
+ clean_lines: list[str] = []
+ for line in tree_str.split("\n"):
+ # remove statictext if the content already appears in the previous line
+ if "statictext" in line.lower():
+ prev_lines = clean_lines[-3:]
+ pattern = r"\[\d+\] StaticText (.+)"
+
+ match = re.search(pattern, line, re.DOTALL)
+ if match:
+ static_text = match.group(1)[1:-1] # remove the quotes
+ if static_text and all(
+ static_text not in prev_line for prev_line in prev_lines
+ ):
+ clean_lines.append(line)
+ else:
+ clean_lines.append(line)
+
+ return "\n".join(clean_lines)
+
+ def fetch_image_related(self, page: Page, browser_info: BrowserInfo) -> str:
+ # Check if the current page is an image url
+ if page.url.endswith((".jpg", ".jpeg", ".png")):
+ print("NOTE: We are on an image page!!!")
+ # Load image from current url and run captioning on it.
+ if page.url not in self.url2caption and self.captioning_fn is not None:
+ try:
+ image = Image.open(requests.get(page.url, stream=True).raw)
+ caption = self.captioning_fn([image])[0].strip()
+ self.url2caption[page.url] = remove_unicode(caption)
+ except Exception as e:
+ print("L579 WARNING: ", e)
+ content = self.url2caption.get(page.url, "Image")
+
+ else:
+ if self.captioning_fn is not None:
+ images = page.query_selector_all("img")
+ image_urls = []
+ for image in images:
+ try:
+ image_url = image.get_attribute("src")
+ if not image_url.startswith(("http://", "https://", "www.")):
+ image_url = urljoin(page.url, image_url)
+ if image_url not in self.url2caption:
+ image_urls.append(image_url)
+ except Exception as e:
+ print("L604 WARNING: ", e)
+
+ # Run image captioning on image_url pixels. This is for models which use captioning as a baseline.
+ if len(image_urls) > 0:
+ image_pixels = []
+ valid_urls = []
+ for url in image_urls:
+ if "data:image/svg" in url:
+ continue
+ else:
+ try:
+ image = Image.open(requests.get(url, stream=True).raw)
+ image_pixels.append(image)
+ valid_urls.append(url)
+ except Exception as e:
+ print("L616 WARNING: ", e)
+
+ # Caption images.
+ if image_pixels:
+ # Run in batches of 4.
+ bs = 4
+ captions = []
+ for i in range(0, len(image_pixels), bs):
+ try:
+ captions.extend(
+ self.captioning_fn(image_pixels[i : i + bs])
+ )
+ except Exception as e:
+ print("L628 WARNING: ", e)
+ captions.extend([""] * len(image_pixels[i : i + bs]))
+ assert len(valid_urls) == len(
+ captions
+ ), f"len(images)={len(valid_urls)}, len(captions)={len(captions)}"
+ for image_url, caption in zip(valid_urls, captions):
+ self.url2caption[image_url] = remove_unicode(
+ caption.strip()
+ )
+
+ image_idx = 0
+ for image in images:
+ try:
+ original_alt = image.get_attribute("alt") or ""
+ image_url = image.get_attribute("src")
+ if not image_url.startswith(("http://", "https://", "www.")):
+ image_url = urljoin(page.url, image_url)
+
+ updated_alt = original_alt
+
+ if image_url in self.url2caption:
+ if self.url2caption[image_url] not in updated_alt:
+ updated_alt = f"{updated_alt}, description: {self.url2caption[image_url]}"
+ elif "data:image/svg" not in image_url:
+ print(f"WARNING: {image_url} not in self.url2caption")
+
+ if "url:" not in updated_alt:
+ updated_alt = f"{updated_alt}, url: {image_url}"
+
+ safe_updated_alt = json.dumps(updated_alt)
+ image.evaluate(f"node => node.alt = {safe_updated_alt}")
+ except Exception as e:
+ print("L653 WARNING:", e)
+
+ if self.observation_type == "accessibility_tree_with_captioner":
+ frame_ax_trees = self.fetch_page_accessibility_tree(
+ page,
+ browser_info,
+ current_viewport_only=self.current_viewport_only
+ )
+ content, obs_nodes_info = self.parse_accessibility_tree(frame_ax_trees)
+ content = self.clean_accesibility_tree(content)
+ self.obs_nodes_info = obs_nodes_info
+ self.meta_data["obs_nodes_info"] = obs_nodes_info
+ else:
+ content = "" # Not used for SoM
+
+ return content
+
+ def process(self, page: Page) -> str:
+ # get the tab info
+ open_tabs = page.context.pages
+ try:
+ tab_titles = [tab.title() for tab in open_tabs]
+ current_tab_idx = open_tabs.index(page)
+ for idx in range(len(open_tabs)):
+ if idx == current_tab_idx:
+ tab_titles[idx] = f"Tab {idx} (current): {open_tabs[idx].title()}"
+ else:
+ tab_titles[idx] = f"Tab {idx}: {open_tabs[idx].title()}"
+ tab_title_str = " | ".join(tab_titles)
+ except Exception:
+ tab_title_str = " | ".join([f"Tab {idx}" for idx in range(len(open_tabs))])
+
+ try:
+ browser_info = self.fetch_browser_info(page)
+ except Exception:
+ page.wait_for_load_state("load", timeout=500)
+ browser_info = self.fetch_browser_info(page)
+
+ if self.observation_type == "html":
+ dom_tree = self.fetch_page_html(
+ browser_info,
+ page,
+ self.current_viewport_only,
+ )
+ content, obs_nodes_info = self.parse_html(dom_tree)
+ self.obs_nodes_info = obs_nodes_info
+ self.meta_data["obs_nodes_info"] = obs_nodes_info
+
+ elif self.observation_type == "accessibility_tree":
+ accessibility_tree = self.fetch_page_accessibility_tree(
+ page,
+ browser_info,
+ self.current_viewport_only,
+ )
+ content, obs_nodes_info = self.parse_accessibility_tree(accessibility_tree)
+ content = self.clean_accesibility_tree(content)
+ self.obs_nodes_info = obs_nodes_info
+ self.meta_data["obs_nodes_info"] = obs_nodes_info
+
+ elif self.observation_type in [
+ "accessibility_tree_with_captioner",
+ "image_som",
+ ]:
+ content = self.fetch_image_related(
+ page,
+ browser_info,
+ )
+
+ elif self.observation_type == "":
+ content = ""
+
+ else:
+ raise ValueError(f"Invalid observation type: {self.observation_type}")
+
+ self.browser_config = browser_info["config"]
+ content = f"{tab_title_str}\n\n{content}"
+
+ return content
+
+ def get_element_center(self, element_id: str) -> tuple[float, float]:
+ node_info = self.obs_nodes_info[element_id]
+ node_bound = node_info["union_bound"]
+ x, y, width, height = node_bound
+ center_x = x + width / 2
+ center_y = y + height / 2
+ return (
+ center_x / self.viewport_size["width"],
+ center_y / self.viewport_size["height"],
+ )
+
+
+class ImageObservationProcessor(ObservationProcessor):
+ def __init__(
+ self,
+ observation_type: str,
+ viewport_size: Optional[ViewportSize] = None,
+ ):
+ self.observation_type = observation_type
+ self.observation_tag = "image"
+ self.viewport_size = viewport_size
+ self.meta_data = create_empty_metadata()
+
+ def get_page_bboxes(self, page: Page) -> list[list[float]]:
+ """JavaScript code to return bounding boxes and other metadata from HTML elements."""
+ js_script = """
+ (() => {
+ const interactableSelectors = [
+ 'a[href]:not(:has(img))', 'a[href] img', 'button', 'input:not([type="hidden"])', 'textarea', 'select',
+ '[tabindex]:not([tabindex="-1"])', '[contenteditable="true"]', '[role="button"]', '[role="link"]',
+ '[role="checkbox"]', '[role="menuitem"]', '[role="tab"]', '[draggable="true"]',
+ '.btn', 'a[href="/notifications"]', 'a[href="/submit"]', '.fa.fa-star.is-rating-item', 'input[type="checkbox"]'
+
+ ];
+
+ const textSelectors = ['p', 'span', 'div:not(:has(*))', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'li', 'article'];
+ const modifiedTextSelectors = textSelectors.map(selector =>
+ `:not(${interactableSelectors.join(', ')}):not(style) > ${selector}`
+ );
+
+ const combinedSelectors = [...interactableSelectors, ...modifiedTextSelectors];
+ const elements = document.querySelectorAll(combinedSelectors.join(', '));
+
+ const pixelRatio = window.devicePixelRatio;
+ let csvContent = "ID,Element,Top,Right,Bottom,Left,Width,Height,Alt,Class,Id,TextContent,Interactable\\n";
+ let counter = 1;
+
+ elements.forEach(element => {
+ const rect = element.getBoundingClientRect();
+ if (rect.width === 0 || rect.height === 0) return;
+ let altText = element.getAttribute('alt') || '';
+ altText = altText.replace(/"/g, ''); // Escape double quotes in alt text
+ const classList = element.className || '';
+ const id = element.id || '';
+ let textContent = element.textContent || '';
+ textContent = textContent.replace(/"/g, ''); // Escape double quotes in textContent
+
+ // Determine if the element is interactable
+ const isInteractable = interactableSelectors.some(selector => element.matches(selector));
+
+ const dataString = [
+ counter, element.tagName, (rect.top + window.scrollY) * pixelRatio,
+ (rect.right + window.scrollX) * pixelRatio, (rect.bottom + window.scrollY) * pixelRatio,
+ (rect.left + window.scrollX) * pixelRatio, rect.width * pixelRatio, rect.height * pixelRatio,
+ altText, classList, id, textContent, isInteractable
+ ].map(value => `"${value}"`).join(",");
+
+ csvContent += dataString + "\\n";
+ counter++;
+ });
+
+ return csvContent;
+ })();
+ """
+ # Save the bbox as a CSV
+ csv_content = page.evaluate(js_script)
+ return csv_content
+
+ def draw_bounding_boxes(
+ self,
+ data_string,
+ screenshot_img,
+ viewport_size=None,
+ add_ids=True,
+ bbox_color=None,
+ min_width=8,
+ min_height=8,
+ bbox_padding=0,
+ bbox_border=2,
+ plot_ids=None,
+ ):
+ """
+ min_width and min_height: Minimum dimensions of the bounding box to be plotted.
+ """
+ # Read CSV data
+ df = pd.read_csv(StringIO(data_string), delimiter=",", quotechar='"')
+ df["Area"] = df["Width"] * df["Height"]
+ # Remove bounding boxes that are clipped.
+ b_x, b_y = (
+ self.browser_config["win_left_bound"],
+ self.browser_config["win_upper_bound"],
+ )
+ if viewport_size is not None:
+ df = df[
+ (df["Bottom"] - b_y >= 0)
+ & (df["Top"] - b_y <= viewport_size["height"])
+ & (df["Right"] - b_x >= 0)
+ & (df["Left"] - b_x <= viewport_size["width"])
+ ]
+ viewport_area = viewport_size["width"] * viewport_size["height"]
+ # Filter out bounding boxes that too large (more than 80% of the viewport)
+ df = df[df["Area"] <= 0.8 * viewport_area]
+
+ # Open the screenshot image
+ img = screenshot_img.copy()
+ draw = ImageDraw.Draw(img)
+
+ # Load a TTF font with a larger size
+ font_path = "media/SourceCodePro-SemiBold.ttf"
+ font_size, padding = 16, 2
+ font = ImageFont.truetype(font_path, font_size)
+
+ # Create a color cycle using one of the categorical color palettes in matplotlib
+ color_cycle = plt.rcParams["axes.prop_cycle"].by_key()["color"]
+ bbox_id2visid = {}
+ bbox_id2desc = {}
+ index = 0
+ id2center = {}
+ existing_text_rectangles = []
+ text_to_draw = []
+ # Provide [id] textContent inputs to the model as text.
+ text_content_elements = []
+ text_content_text = set() # Store text of interactable elements
+
+ # Iterate through each row in the CSV and draw bounding boxes
+ for _, row in df.iterrows():
+ if not row["Interactable"]:
+ content = ""
+ # Add image alt-text to the text representation.
+ if row["Element"] == "IMG" and pd.notna(row["Alt"]):
+ content += row["Alt"]
+ # Add HTML textContent (if any) to the text representation.
+ if pd.notna(row["TextContent"]):
+ content += (
+ row["TextContent"].strip().replace("\n", "").replace("\t", "")
+ )[
+ :200
+ ] # Limit to 200 characters to avoid having too much text
+
+ # Check if the text is a CSS selector
+ if content and not (content.startswith(".") and "{" in content):
+ # Add elements which are not interactable as StaticText
+ if content not in text_content_text:
+ text_content_elements.append(f"[] [StaticText] [{content}]")
+ text_content_text.add(content)
+ continue
+
+ if (plot_ids is not None) and (row["ID"] not in plot_ids):
+ continue
+
+ unique_id = str(index + 1)
+ bbox_id2visid[row["ID"]] = (
+ unique_id # map the bounding box ID to the unique character ID
+ )
+ top, right, bottom, left, width, height = (
+ row["Top"],
+ row["Right"],
+ row["Bottom"],
+ row["Left"],
+ row["Width"],
+ row["Height"],
+ )
+ left, right, top, bottom = left - b_x, right - b_x, top - b_y, bottom - b_y
+ id2center[unique_id] = (
+ (left + right) / 2,
+ (bottom + top) / 2,
+ width,
+ height,
+ )
+
+ if width >= min_width and height >= min_height:
+ # Get the next color in the cycle
+ color = bbox_color or color_cycle[index % len(color_cycle)]
+ draw.rectangle(
+ [
+ left - bbox_padding,
+ top - bbox_padding,
+ right + bbox_padding,
+ bottom + bbox_padding,
+ ],
+ outline=color,
+ width=bbox_border,
+ )
+ bbox_id2desc[row["ID"]] = color
+
+ # Draw the text on top of the rectangle
+ if add_ids:
+ # Calculate list of possible text positions
+ text_positions = [
+ (left - font_size, top - font_size), # Top-left corner
+ (
+ left,
+ top - font_size,
+ ), # A little to the right of the top-left corner
+ (right, top - font_size), # Top-right corner
+ (
+ right - font_size - 2 * padding,
+ top - font_size,
+ ), # A little to the left of the top-right corner
+ (left - font_size, bottom), # Bottom-left corner
+ (
+ left,
+ bottom,
+ ), # A little to the right of the bottom-left corner
+ (
+ right - font_size - 2 * padding,
+ bottom,
+ ), # A little to the left of the bottom-right corner
+ (
+ left,
+ bottom,
+ ), # A little to the right of the bottom-left corner
+ (
+ right - font_size - 2 * padding,
+ bottom,
+ ), # A little to the left of the bottom-right corner
+ ]
+ text_width = draw.textlength(unique_id, font=font)
+ text_height = font_size # Assume the text is one line
+
+ if viewport_size is not None:
+ for text_position in text_positions:
+ new_text_rectangle = [
+ text_position[0] - padding,
+ text_position[1] - padding,
+ text_position[0] + text_width + padding,
+ text_position[1] + text_height + padding,
+ ]
+
+ # Check if the new text rectangle is within the viewport
+ if (
+ new_text_rectangle[0] >= 0
+ and new_text_rectangle[1] >= 0
+ and new_text_rectangle[2] <= viewport_size["width"]
+ and new_text_rectangle[3] <= viewport_size["height"]
+ ):
+ # If the rectangle is within the viewport, check for overlaps
+ overlaps = False
+ for existing_rectangle in existing_text_rectangles:
+ if self.rectangles_overlap(
+ new_text_rectangle,
+ existing_rectangle,
+ padding * 2,
+ ):
+ overlaps = True
+ break
+
+ if not overlaps:
+ break
+ else:
+ # If the rectangle is outside the viewport, try the next position
+ continue
+ else:
+ # If none of the corners work, move the text rectangle by a fixed amount
+ text_position = (
+ text_positions[0][0] + padding,
+ text_positions[0][1],
+ )
+ new_text_rectangle = [
+ text_position[0] - padding,
+ text_position[1] - padding,
+ text_position[0] + text_width + padding,
+ text_position[1] + text_height + padding,
+ ]
+
+ existing_text_rectangles.append(new_text_rectangle)
+ text_to_draw.append(
+ (new_text_rectangle, text_position, unique_id, color)
+ )
+
+ content = ""
+ if row["Element"] == "IMG" and pd.notna(row["Alt"]):
+ content += row["Alt"]
+ if pd.notna(row["TextContent"]):
+ content += (
+ row["TextContent"]
+ .strip()
+ .replace("\n", "")
+ .replace("\t", "")
+ )[
+ :200
+ ] # Limit to 200 characters
+ text_content_elements.append(
+ f"[{unique_id}] [{row['Element']}] [{content}]"
+ )
+ if content in text_content_text:
+ # Remove text_content_elements with content
+ text_content_elements = [
+ element
+ for element in text_content_elements
+ if element.strip() != content
+ ]
+ text_content_text.add(content)
+
+ index += 1
+
+ for text_rectangle, text_position, unique_id, color in text_to_draw:
+ # Draw a background rectangle for the text
+ draw.rectangle(text_rectangle, fill=color)
+ draw.text(text_position, unique_id, font=font, fill="white")
+
+ content_str = "\n".join(text_content_elements)
+ return img, id2center, content_str
+
+ def rectangles_overlap(self, rect1, rect2, padding):
+ """
+ Check if two rectangles overlap.
+ Each rectangle is represented as a list [x1, y1, x2, y2].
+ """
+ return not (
+ rect1[2] < rect2[0] + padding
+ or rect1[0] > rect2[2] - padding
+ or rect1[1] > rect2[3] - padding
+ or rect1[3] < rect2[1] + padding
+ )
+
+ def process(self, page: Page) -> npt.NDArray[np.uint8]:
+ try:
+ browser_info = self.fetch_browser_info(page)
+ except Exception:
+ page.wait_for_load_state("load", timeout=500)
+ browser_info = self.fetch_browser_info(page)
+
+ self.browser_config = browser_info["config"]
+
+ if self.observation_type == "image_som":
+ # Produce the SoM image, with bounding boxes
+ try:
+ screenshot_bytes = page.screenshot()
+ som_bboxes = self.get_page_bboxes(page)
+ screenshot_img = Image.open(BytesIO(screenshot_bytes))
+ bbox_img, id2center, content_str = self.draw_bounding_boxes(
+ som_bboxes,
+ screenshot_img,
+ viewport_size=self.viewport_size,
+ )
+ self.som_id_info = id2center
+ self.meta_data["obs_nodes_info"] = id2center
+ screenshot_som = np.array(bbox_img)
+ return screenshot_som, content_str
+ except:
+ page.wait_for_event("load")
+ screenshot_bytes = page.screenshot()
+ som_bboxes = self.get_page_bboxes(page)
+ screenshot_img = Image.open(BytesIO(screenshot_bytes))
+ bbox_img, id2center, content_str = self.draw_bounding_boxes(
+ som_bboxes,
+ screenshot_img,
+ viewport_size=self.viewport_size,
+ )
+ self.som_id_info = id2center
+ self.meta_data["obs_nodes_info"] = id2center
+ screenshot_som = np.array(bbox_img)
+ return screenshot_som, content_str
+ else:
+ try:
+ screenshot = png_bytes_to_numpy(page.screenshot())
+ except:
+ page.wait_for_event("load")
+ screenshot = png_bytes_to_numpy(page.screenshot())
+ return screenshot, ""
+
+ def fetch_browser_info(self, page: Page) -> BrowserInfo:
+ client = page.context.new_cdp_session(page)
+ # extract domtree
+ tree = client.send(
+ "DOMSnapshot.captureSnapshot",
+ {
+ "computedStyles": [],
+ "includeDOMRects": True,
+ "includePaintOrder": True,
+ },
+ )
+ client.detach()
+ # calibrate the bounds, in some cases, the bounds are scaled somehow
+ bounds = tree["documents"][0]["layout"]["bounds"]
+ b = bounds[0]
+ n = b[2] / self.viewport_size["width"]
+ bounds = [[x / n for x in bound] for bound in bounds]
+ tree["documents"][0]["layout"]["bounds"] = bounds
+ # add union bound placeholder
+ tree["documents"][0]["layout"]["unionBounds"] = [None for _ in bounds]
+
+ # extract browser info
+ win_upper_bound = page.evaluate("window.pageYOffset")
+ win_left_bound = page.evaluate("window.pageXOffset")
+ win_width = page.evaluate("window.screen.width")
+ win_height = page.evaluate("window.screen.height")
+ win_right_bound = win_left_bound + win_width
+ win_lower_bound = win_upper_bound + win_height
+ device_pixel_ratio = page.evaluate("window.devicePixelRatio")
+ assert device_pixel_ratio == 1.0, "devicePixelRatio is not 1.0"
+
+ config: BrowserConfig = {
+ "win_upper_bound": win_upper_bound,
+ "win_left_bound": win_left_bound,
+ "win_width": win_width,
+ "win_height": win_height,
+ "win_right_bound": win_right_bound,
+ "win_lower_bound": win_lower_bound,
+ "device_pixel_ratio": device_pixel_ratio,
+ }
+
+ # assert len(tree['documents']) == 1, "More than one document in the DOM tree"
+ info: BrowserInfo = {"DOMTree": tree, "config": config}
+
+ return info
+
+ def get_element_center(self, element_id: str) -> tuple[float, float]:
+ if not self.observation_type == "image_som":
+ raise ValueError(
+ "get_element_center() is only supported for 'image_som' observation type."
+ )
+
+ browser_config = self.browser_config
+ center_x, center_y, width, height = self.som_id_info[element_id]
+ return (
+ center_x / self.viewport_size["width"],
+ center_y / self.viewport_size["height"],
+ )
+
+
+class TextObervationProcessorWebRL(TextObervationProcessor):
+ def __init__(
+ self,
+ observation_type: str,
+ current_viewport_only: bool,
+ viewport_size: ViewportSize,
+ captioning_fn=None,
+ ):
+ super().__init__(
+ observation_type,
+ current_viewport_only,
+ viewport_size,
+ captioning_fn,
+ )
+
+ def process(self, page: Page) -> str:
+ # get the tab info
+ page_info = get_parsed_html(page)
+ html = page_info["html"]
+ from bs4 import BeautifulSoup
+ soup = BeautifulSoup(html, 'html.parser')
+ obs_nodes_info = {}
+ for tag in soup.find_all(True):
+ if tag.has_attr('id') and tag.has_attr('data-bbox'):
+ backend_id = tag['id']
+ union_bound = tag['data-bbox']
+ union_bound = [float(num) for num in union_bound.split(',')]
+ obs_nodes_info[str(backend_id)] = {
+ "backend_id": backend_id,
+ "union_bound": union_bound,
+ "text": str(tag)
+ }
+ self.obs_nodes_info = obs_nodes_info
+ self.meta_data["obs_nodes_info"] = obs_nodes_info
+ return html
+
+ def get_element_center(self, element_id: str, page: Page=None) -> tuple[float, float]:
+
+ if page is not None:
+ element = page.query_selector(f"[data-label-id='{element_id}']")
+ bbox = element.bounding_box()
+ relative_bbox = (bbox['x'], bbox['y'], bbox['x'] + bbox['width'], bbox['y'] + bbox['height'])
+ center_x = (relative_bbox[0] + relative_bbox[2]) / 2
+ center_y = (relative_bbox[1] + relative_bbox[3]) / 2
+ else:
+ node_info = self.obs_nodes_info[element_id]
+ node_bound = node_info["union_bound"]
+ x, y, width, height = node_bound
+ center_x = x + width / 2
+ center_y = y + height / 2
+
+ return (
+ center_x / self.viewport_size["width"],
+ center_y / self.viewport_size["height"],
+ )
+
+class ObservationHandler:
+ """Main entry point to access all observation processor"""
+
+ def __init__(
+ self,
+ main_observation_type: str,
+ text_observation_type: str,
+ image_observation_type: str,
+ current_viewport_only: bool,
+ viewport_size: ViewportSize,
+ captioning_fn=None,
+ ) -> None:
+ self.main_observation_type = main_observation_type
+ if text_observation_type == "webrl":
+ self.text_processor = TextObervationProcessorWebRL(
+ text_observation_type,
+ current_viewport_only,
+ viewport_size,
+ captioning_fn,
+ )
+ else:
+ self.text_processor = TextObervationProcessor(
+ text_observation_type,
+ current_viewport_only,
+ viewport_size,
+ captioning_fn,
+ )
+ self.image_processor = ImageObservationProcessor(
+ image_observation_type, viewport_size
+ )
+ self.viewport_size = viewport_size
+
+ def get_observation_space(self) -> spaces.Dict:
+ text_space = spaces.Text(
+ min_length=0,
+ max_length=UTTERANCE_MAX_LENGTH,
+ charset=ASCII_CHARSET + FREQ_UNICODE_CHARSET,
+ )
+
+ image_space = spaces.Box(
+ # Each position stores the RGB values. Note the swapped axes (height first).
+ np.zeros(
+ (self.viewport_size["height"], self.viewport_size["width"], 3),
+ dtype=np.uint8,
+ ),
+ np.ones(
+ (self.viewport_size["height"], self.viewport_size["width"], 3),
+ dtype=np.uint8,
+ )
+ * 255.0,
+ dtype=np.uint8,
+ )
+
+ return spaces.Dict({"text": text_space, "image": image_space})
+
+ def get_observation(self, page: Page) -> dict[str, Observation]:
+ text_obs = self.text_processor.process(page)
+ image_obs, content_str = self.image_processor.process(page)
+ if content_str != "":
+ text_obs = content_str
+ return {"text": text_obs, "image": image_obs}
+
+ def get_observation_metadata(self) -> dict[str, ObservationMetadata]:
+ return {
+ "text": self.text_processor.meta_data,
+ "image": self.image_processor.meta_data,
+ }
+
+ @property
+ def action_processor(self) -> ObservationProcessor:
+ """Return the main processor that is associated with the action space"""
+ if self.main_observation_type == "text":
+ return self.text_processor
+ elif self.main_observation_type == "image":
+ return self.image_processor
+ else:
+ raise ValueError("Invalid main observation type")
diff --git a/VAB-WebArena-Lite/new/prompt_constructor.py b/VAB-WebArena-Lite/new/prompt_constructor.py
index 9cd5a2a..fc59b06 100644
--- a/VAB-WebArena-Lite/new/prompt_constructor.py
+++ b/VAB-WebArena-Lite/new/prompt_constructor.py
@@ -499,3 +499,59 @@ class MultimodalCoTPromptConstructor(CoTPromptConstructor):
raise NotImplementedError(
f"Provider {self.lm_config.provider} not implemented"
)
+
+class WebRLPromptConstructor(PromptConstructor):
+ """The agent will direct predict the action"""
+
+ def __init__(
+ self,
+ instruction_path: str | Path,
+ lm_config: lm_config.LMConfig,
+ tokenizer: Tokenizer,
+ ):
+ super().__init__(instruction_path, lm_config, tokenizer)
+
+ def construct(
+ self,
+ trajectory: Trajectory,
+ intent: str,
+ meta_data: dict[str, Any] = {},
+ ) -> APIInput:
+ """Construct prompt given the trajectory"""
+ state_info: StateInfo = trajectory[-1] # type: ignore[assignment]
+
+ obs = state_info["observation"][self.obs_modality]
+ max_obs_length = self.lm_config.gen_config["max_obs_length"]
+ if max_obs_length:
+ if self.lm_config.provider == "google":
+ print("NOTE: This is a Gemini model, so we use characters instead of tokens for max_obs_length.")
+ obs = obs[:max_obs_length]
+ else:
+ try:
+ obs = self.tokenizer.decode(self.tokenizer.encode(obs)[:max_obs_length]) # type: ignore[arg-type]
+ except:
+ print("NOTE: There is no available tokenizer, so we use characters instead of tokens for max_obs_length.")
+ obs = obs[:max_obs_length]
+
+ turn_num = len(meta_data["action_history"])
+ if turn_num == 1:
+ previous_action_str = []
+ else:
+ previous_action_str = meta_data["action_history"][1:]
+
+ index = turn_num - 1
+ history = ""
+ for i in range(index - 1, -1, -1):
+ if i == 0:
+ history = f"Round {i}\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n{intent}\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n{previous_action_str[i]}\n\n" + history
+ else:
+ history = f"Round {i}\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n** Simplified html **\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n{previous_action_str[i]}\n\n" + history
+ if len(history) + len(obs) > (16384 - 512):
+ obs = obs[:(16384 - 512)-len(history)]
+ current_turn = f"Round {index}\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n{obs}\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
+ prompt = f"Task Instruction: {intent}\n\n{history}{current_turn}"
+
+ return prompt
+
+ def extract_action(self, response: str) -> str:
+ return response
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/run.py b/VAB-WebArena-Lite/new/run.py
index e406e2c..6f16161 100644
--- a/VAB-WebArena-Lite/new/run.py
+++ b/VAB-WebArena-Lite/new/run.py
@@ -13,11 +13,13 @@ import tempfile
import time
from pathlib import Path
from typing import List
+import cv2
+import shutil
import openai
import requests
import torch
-from PIL import Image
+from PIL import Image, ImageDraw, ImageFont
from agent import (
PromptAgent,
@@ -62,6 +64,34 @@ formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
console_handler.setFormatter(formatter)
file_handler.setFormatter(formatter)
+def text_wrap(text, font, max_width):
+ lines = []
+ paragraphs = text.split('\n') # 按照 \n 分割文本为段落
+ for paragraph in paragraphs:
+ words = paragraph.split(' ')
+ line = ''
+ for word in words:
+ # 临时行
+ test_line = f"{line} {word}".strip()
+ # 获取临时行的宽度
+ test_line_bbox = font.getbbox(test_line)
+ test_line_width = test_line_bbox[2] - test_line_bbox[0]
+ if test_line_width <= max_width:
+ # 如果临时行的宽度不超过图片宽度,继续添加单词
+ line = test_line
+ else:
+ # 如果超过了最大宽度,保存当前行,开始新的一行
+ lines.append(line)
+ line = word
+ # 添加每段的最后一行
+ if line:
+ lines.append(line)
+ # 每个段落后添加一个空行,以保留段落的换行
+ lines.append('')
+ # 移除最后一个空行(不需要额外的空行)
+ if lines[-1] == '':
+ lines.pop()
+ return lines
def config() -> argparse.Namespace:
parser = argparse.ArgumentParser(
@@ -88,6 +118,7 @@ def config() -> argparse.Namespace:
"html",
"image",
"image_som",
+ "webrl",
],
default="accessibility_tree",
help="Observation type",
@@ -176,6 +207,9 @@ def config() -> argparse.Namespace:
# logging related
parser.add_argument("--result_dir", type=str, default="")
+
+ # if use self-deployed model
+ parser.add_argument("--planner_ip", type=str, default=None)
args = parser.parse_args()
# check the whether the action space is compatible with the observation space
@@ -196,7 +230,7 @@ def config() -> argparse.Namespace:
def early_stop(
- trajectory: Trajectory, max_steps: int, thresholds: dict[str, int]
+ trajectory: Trajectory, max_steps: int, thresholds: dict[str, int], actions=None
) -> tuple[bool, str]:
"""Check whether need to stop early"""
@@ -228,28 +262,38 @@ def early_stop(
if len(action_seq) == 0:
return False, ""
- last_action: Action = action_seq[-1]
+ if actions is None:
+ last_action: Action = action_seq[-1]
+ if last_action["action_type"] != ActionTypes.TYPE:
+ if len(last_k_actions) >= k:
+ if all(
+ [
+ is_equivalent(action, last_action)
+ for action in last_k_actions
+ ]
+ ):
+ return True, f"Same action for {k} times"
+ else:
+ # check the action sequence
+ if (
+ sum([is_equivalent(action, last_action) for action in action_seq])
+ >= k
+ ):
+ return True, f"Same typing action for {k} times"
+ return False, ""
- if last_action["action_type"] != ActionTypes.TYPE:
+ else:
+ last_k_actions = actions[-k:]
+ last_action = actions[-1]
if len(last_k_actions) >= k:
if all(
[
- is_equivalent(action, last_action)
+ action == last_action
for action in last_k_actions
]
):
return True, f"Same action for {k} times"
-
- else:
- # check the action sequence
- if (
- sum([is_equivalent(action, last_action) for action in action_seq])
- >= k
- ):
- return True, f"Same typing action for {k} times"
-
- return False, ""
-
+ return False, ""
def update_action_history(path: str, task_id: int, actions: List[str], score: float=-0.1):
obj = {
@@ -387,18 +431,22 @@ def test(
obs, info = env.reset(options={"config_file": config_file})
state_info: StateInfo = {"observation": obs, "info": info}
trajectory.append(state_info)
-
meta_data = {"action_history": ["None"]}
out_path = os.path.join(args.result_dir, "actions", f"{task_id}.json")
actions = []
+ os.makedirs(os.path.join(args.result_dir, 'screehshots'), exist_ok=True)
+ if os.path.exists(os.path.join(args.result_dir, 'screehshots', f"{task_id}")):
+ shutil.rmtree(os.path.join(args.result_dir, 'screehshots', f"{task_id}"))
+ os.makedirs(os.path.join(args.result_dir, 'screehshots', f"{task_id}"))
+
while True:
update_action_history(out_path, task_id, actions=actions)
-
+ # If no actions variable is passed, the behavior of early_stop is the same as the original one.
early_stop_flag, stop_info = early_stop(
- trajectory, max_steps, early_stop_thresholds
+ trajectory, max_steps, early_stop_thresholds, actions
)
-
+
if early_stop_flag:
action = create_stop_action(f"Early stop: {stop_info}")
else:
@@ -407,14 +455,14 @@ def test(
trajectory,
intent,
images=images,
- meta_data=meta_data,
+ meta_data=meta_data
)
except ValueError as e:
# get the error message
action = create_stop_action(f"ERROR: {str(e)}")
-
+
trajectory.append(action)
-
+
action_str = get_action_description(
action,
state_info["info"]["observation_metadata"],
@@ -426,13 +474,50 @@ def test(
render_helper.render(
action, state_info, meta_data, args.render_screenshot
)
+
+ current_screenshot = os.path.join(args.result_dir, 'screehshots', f"{task_id}", f"{len(actions)}.png")
+ _ = env.page.viewport_size
+ env.page.screenshot(path="/dev/null")
+ env.page.screenshot(path=current_screenshot)
+ element_id = action["element_id"]
+ if element_id != "":
+ element = env.page.query_selector(f"[data-label-id='{element_id}']")
+ if element:
+ bbox = element.bounding_box()
+ bbox = [int(bbox['x']), int(bbox['y']), int(bbox['width']),int(bbox['height'])]
+ image = cv2.imread(current_screenshot)
+ cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[0] + bbox[2], bbox[1] + bbox[3]), (0, 255, 0), 2)
+ cv2.circle(image, (int(bbox[0] + bbox[2] / 2), int(bbox[1] + bbox[3] / 2)), radius=0, color=(0, 255, 0), thickness=2)
+ cv2.imwrite(current_screenshot, image)
+ font_path = "/workspace/qzh/Arial.ttf" # 使用FreeSans字体
+ font_size = 30 # 你可以调整这个值来增大或缩小字体
+ image = Image.open(current_screenshot)
+ font = ImageFont.truetype(font_path, font_size)
+ draw = ImageDraw.Draw(image)
+ image_width = image.width
+ wrapped_text = text_wrap(action_str, font, image_width)
+ line_height = font.getbbox('hg')[3] - font.getbbox('hg')[1]
+ text_height = line_height * len(wrapped_text)
+ new_image_height = image.height + text_height + 20 # 20 is extra white space
+ new_image = Image.new('RGB', (image.width, new_image_height), color=(255, 255, 255)) # white background
+ draw_new = ImageDraw.Draw(new_image)
+ y_text = 10 # Initial position of text from top
+ for line in wrapped_text:
+ text_bbox = draw_new.textbbox((0, 0), line, font=font)
+ text_width = text_bbox[2] - text_bbox[0]
+ text_position = ((image_width - text_width) // 2, y_text) # Center text horizontally
+ draw_new.text(text_position, line, font=font, fill=(0, 0, 0)) # black text
+ y_text += line_height # move to next line
+ new_image.paste(image, (0, text_height + 20))
+ new_image.save(current_screenshot)
+
meta_data["action_history"].append(action_str)
actions.append(action_str)
- print(action_str)
+ print('Action String: ', action_str)
if action["action_type"] == ActionTypes.STOP:
break
-
+
obs, _, terminated, _, info = env.step(action)
state_info = {"observation": obs, "info": info}
trajectory.append(state_info)
@@ -540,7 +625,7 @@ if __name__ == "__main__":
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = config()
- args.sleep_after_execution = 2.5
+ args.sleep_after_execution = 3.0
prepare(args)
test_config_base_dir = args.test_config_base_dir
diff --git a/VAB-WebArena-Lite/new/score.py b/VAB-WebArena-Lite/new/score.py
new file mode 100644
index 0000000..95235e5
--- /dev/null
+++ b/VAB-WebArena-Lite/new/score.py
@@ -0,0 +1,86 @@
+import os, json, sys, copy
+
+USE_TASKS = [i for i in range(165)]
+
+def get_result(res_dict, src="all"):
+ if len(res_dict) == 0:
+ return ''
+
+ success_id = [k for k, v in res_dict.items() if v >= 1.0]
+ score = len(success_id)
+ finish_count = len(res_dict)
+ pacc, acc = score / finish_count * 100, score / TASKS * 100
+
+ print(sorted(success_id))
+
+ meta = """
+--------
+src file: {}
+successed: {:3} / {:4} (812)
+partial accuracy: {:7}
+overall accuracy: {:7}
+--------
+""".format(src, int(score), finish_count, round(pacc, 2), round(acc, 2))
+
+ print(meta)
+
+def export_result(res_dict, src=".", note=["1.0", "0.0"], show_all=False):
+ out_string = ""
+ for id in USE_TASKS:
+ # with open(f"Pipeline/config_files/{id}.json", "r") as f:
+ # jd = json.load(f)
+
+ # if "map" in jd["sites"]:
+ # continue
+ if id in res_dict:
+ if res_dict[id] >= 1.0:
+ out_string += note[0]
+ else:
+ out_string += note[1]
+ elif show_all:
+ out_string += note[1]
+ out_string += "\n"
+
+ with open(os.path.join(src, 'export.txt'), 'w') as f:
+ f.write(out_string)
+
+TASKS = 165
+
+files = sys.argv[1]
+file_list = files.split(',')
+
+all_result = {}
+
+for src in file_list:
+ path = os.path.join(src, 'actions')
+
+ result = {}
+ finished = os.listdir(path)
+
+ for file in finished:
+ if not file.endswith('.json'):
+ continue
+ with open(os.path.join(path, file), 'r') as f:
+ data = json.load(f)
+
+ if not isinstance(data, dict):
+ continue
+
+ task_id = data.get('task_id', 1000)
+ # if task_id >= TASKS:
+ # continue
+
+ task_score = data.get('score', 0)
+ if task_score < 0:
+ continue
+
+ result[task_id] = task_score
+ if task_id not in all_result or task_score > all_result[task_id]:
+ all_result[task_id] = task_score
+
+ get_result(result, src)
+ export_result(result, src=src)
+
+if len(file_list) > 1:
+ get_result(all_result)
+export_result(all_result, show_all=True)
diff --git a/VAB-WebArena-Lite/new/test_webarena_lite.raw.json b/VAB-WebArena-Lite/new/test_webarena_lite.raw.json
index c77225a..adb1822 100644
--- a/VAB-WebArena-Lite/new/test_webarena_lite.raw.json
+++ b/VAB-WebArena-Lite/new/test_webarena_lite.raw.json
@@ -187,7 +187,7 @@
"string_match"
],
"reference_answers": {
- "must_include": [
+ "fuzzy_match": [
"0"
]
},
@@ -774,8 +774,8 @@
"string_match"
],
"reference_answers": {
- "fuzzy_match": [
- "914km"
+ "must_include": [
+ "914"
]
},
"reference_url": "",
@@ -1139,7 +1139,7 @@
"string_match"
],
"reference_answers": {
- "must_include": [
+ "fuzzy_match": [
"0"
]
},
@@ -1679,7 +1679,7 @@
"string_match"
],
"reference_answers": {
- "exact_match": "N/A"
+ "fuzzy_match": ["N/A"]
},
"reference_url": "",
"program_html": [],
@@ -2605,7 +2605,7 @@
"string_match"
],
"reference_answers": {
- "exact_match": "yjlou"
+ "must_include": "yjlou"
},
"reference_url": "",
"program_html": [],
@@ -3712,7 +3712,7 @@
"string_match"
],
"reference_answers": {
- "exact_match": "N/A"
+ "fuzzy_match": "N/A"
},
"reference_url": "",
"program_html": [],
diff --git a/VAB-WebArena-Lite/new/tokenizers.py b/VAB-WebArena-Lite/new/tokenizers.py
index 2234040..b6e8e74 100644
--- a/VAB-WebArena-Lite/new/tokenizers.py
+++ b/VAB-WebArena-Lite/new/tokenizers.py
@@ -1,15 +1,18 @@
from typing import Any
import tiktoken
-from transformers import LlamaTokenizer # type: ignore
+from transformers import LlamaTokenizer, AutoTokenizer # type: ignore
class Tokenizer(object):
def __init__(self, provider: str, model_name: str) -> None:
if provider == "openai":
- self.tokenizer = tiktoken.encoding_for_model(model_name)
+ try:
+ self.tokenizer = tiktoken.encoding_for_model(model_name)
+ except: # The provider is in openai format but the model is a finetuned model
+ self.tokenizer = None
elif provider == "huggingface":
- self.tokenizer = LlamaTokenizer.from_pretrained(model_name)
+ self.tokenizer = LlamaTokenizer.from_pretrained(model_name, trust_remote_code=True)
# turn off adding special tokens automatically
self.tokenizer.add_special_tokens = False # type: ignore[attr-defined]
self.tokenizer.add_bos_token = False # type: ignore[attr-defined]
diff --git a/VAB-WebArena-Lite/new/utils.py b/VAB-WebArena-Lite/new/utils.py
index 4558585..2a8b87c 100644
--- a/VAB-WebArena-Lite/new/utils.py
+++ b/VAB-WebArena-Lite/new/utils.py
@@ -21,6 +21,8 @@ APIInput = str | list[Any] | dict[str, Any]
def call_llm(
lm_config: lm_config.LMConfig,
prompt: APIInput,
+ api_key = None,
+ base_url = None
) -> str:
response: str
if lm_config.provider == "openai":
@@ -39,11 +41,13 @@ def call_llm(
assert isinstance(prompt, str)
response = generate_from_openai_completion(
prompt=prompt,
- engine=lm_config.model,
+ model=lm_config.model,
temperature=lm_config.gen_config["temperature"],
max_tokens=lm_config.gen_config["max_tokens"],
top_p=lm_config.gen_config["top_p"],
stop_token=lm_config.gen_config["stop_token"],
+ api_key=api_key,
+ base_url=base_url
)
else:
raise ValueError(
diff --git a/VAB-WebArena-Lite/new/wa_parallel_run_webrl.sh b/VAB-WebArena-Lite/new/wa_parallel_run_webrl.sh
new file mode 100644
index 0000000..f8c4784
--- /dev/null
+++ b/VAB-WebArena-Lite/new/wa_parallel_run_webrl.sh
@@ -0,0 +1,97 @@
+#!/bin/bash
+DATASET='webarena' # TODO: select from ['webarena', 'visualwebarena']
+result_dir='' # TODO: set your result_dir
+provider='openai' # TODO: select from ['openai', 'finetune', ...]
+model='' # TODO: assign model name, which is used for action generation
+planner_ip='' # TODO: ip address of the model you are deploying (only if you are deploying your own model using e.g. vllm)
+instruction_path='agent/prompts/jsons/p_webrl.json' # e.g., agent/prompts/jsons/p_cot_id_actree_2s.json
+test_config_base_dir='config_files/wa/test_webarena_lite' # e.g., config_files/wa/test_webarena_lite
+temperature=0.0
+
+SERVER='' # TODO: your server address
+MAP_SERVER='' # TODO: the server address for MAP tasks
+OPENAI_API_KEY='' # TODO: if you test OpenAI APIs
+OPENAI_ORGANIZATION=''
+CONDA_ENV_NAME='' # TODO: the name of your conda environment for testing WebArena
+
+ENV_VARIABLES="export DATASET=${DATASET}; export SHOPPING='http://${SERVER}:7770';export SHOPPING_ADMIN='http://${SERVER}:7780/admin';export REDDIT='http://${SERVER}:9999';export GITLAB='http://${SERVER}:8023';export MAP='http://${MAP_SERVER}:3000';export WIKIPEDIA='http://${SERVER}:8888/wikipedia_en_all_maxi_2022-05/A/User:The_other_Kiwix_guy/Landing';export HOMEPAGE='http://${SERVER}:4399';export OPENAI_API_KEY=${OPENAI_API_KEY};export OPENAI_ORGANIZATION=${OPENAI_ORGANIZATION}"
+
+# get the number of tmux panes
+num_panes=$(tmux list-panes | wc -l)
+
+# calculate how many panes need to be created
+let "panes_to_create = 7 - num_panes"
+
+# array of tmux commands to create each pane
+tmux_commands=(
+ 'tmux split-window -h'
+ 'tmux split-window -v'
+ 'tmux select-pane -t 0; tmux split-window -v'
+ 'tmux split-window -v'
+ 'tmux select-pane -t 3; tmux split-window -v'
+ 'tmux select-pane -t 5; tmux split-window -v'
+)
+
+# create panes up to 7
+for ((i=0; i<$panes_to_create; i++)); do
+ eval ${tmux_commands[$i]}
+done
+
+#!/bin/bash
+
+# Function to run a job
+run_job() {
+ tmux select-pane -t $1
+ COMMAND="python run.py \
+ --instruction_path ${instruction_path} \
+ --test_start_idx $2 \
+ --test_end_idx $3 \
+ --result_dir ${result_dir} \
+ --test_config_base_dir ${test_config_base_dir} \
+ --provider ${provider} \
+ --model ${model} \
+ --mode completion \
+ --planner_ip ${planner_ip} \
+ --stop_token \"<|eot_id|>\" \
+ --temperature ${temperature} \
+ --max_obs_length 0 \
+ --max_tokens 2048 \
+ --viewport_width 1280 \
+ --viewport_height 720 \
+ --parsing_failure_th 5 \
+ --repeating_action_failure_th 5 \
+ --action_set_tag webrl_id --observation_type webrl"
+ tmux send-keys "tmux set mouse on; conda activate ${CONDA_ENV_NAME}; ${ENV_VARIABLES}; until ${COMMAND}; do echo 'crashed' >&2; sleep 1; done" C-m
+ sleep 3
+}
+
+TOLERANCE=2
+run_batch() {
+ args=("$@") # save all arguments in an array
+ num_jobs=${#args[@]} # get number of arguments
+
+ for ((i=1; i<$num_jobs; i++)); do
+ run_job $i ${args[i-1]} ${args[i]}
+ done
+
+ # Wait for all jobs to finish
+ while tmux list-panes -F "#{pane_pid} #{pane_current_command}" | grep -q python; do
+ sleep 100 # wait for 10 seconds before checking again
+ done
+
+ # Run checker
+ while ! python scripts/check_error_runs.py ${result_dir} --delete_errors --tolerance ${TOLERANCE}; do
+ echo "Check failed, rerunning jobs..."
+ for ((i=1; i<$num_jobs; i++)); do
+ run_job $i ${args[i-1]} ${args[i]}
+ done
+
+ # Wait for all jobs to finish
+ while tmux list-panes -F "#{pane_pid} #{pane_current_command}" | grep -q python; do
+ sleep 100 # wait for 10 seconds before checking again
+ done
+ done
+
+}
+run_batch 0 28 56 84 112 140 165
+
diff --git a/VAB-WebArena-Lite/replace.sh b/VAB-WebArena-Lite/replace.sh
index add244c..7a674f0 100644
--- a/VAB-WebArena-Lite/replace.sh
+++ b/VAB-WebArena-Lite/replace.sh
@@ -14,6 +14,14 @@ cp -f new/generate_test_data.py visualwebarena/scripts/generate_test_data.py
cp -f new/run.py visualwebarena/run.py
cp -f new/agent.py visualwebarena/agent/agent.py
cp -f new/prompt_constructor.py visualwebarena/agent/prompts/prompt_constructor.py
+cp -f new/p_webrl.json visualwebarena/agent/prompts/jsons/p_webrl.json
+
+# browser_env
+cp -f new/actions.py visualwebarena/browser_env/actions.py
+cp -f new/envs.py visualwebarena/browser_env/envs.py
+cp -f new/helper_functions_browser.py visualwebarena/browser_env/helper_functions.py
+cp -f new/processors.py visualwebarena/browser_env/processors.py
+cp -rf new/html_tools visualwebarena/browser_env/
# llms
cp -f new/utils.py visualwebarena/llms/utils.py
@@ -22,15 +30,19 @@ cp -f new/lm_config.py visualwebarena/llms/lm_config.py
cp -f new/tokenizers.py visualwebarena/llms/tokenizers.py
cp -f new/api_utils.py visualwebarena/llms/providers/api_utils.py
cp -f new/openai_utils.py visualwebarena/llms/providers/openai_utils.py
+cp -f new/utils.py visualwebarena/llms/utils.py
# eval
cp -f new/evaluators.py visualwebarena/evaluation_harness/evaluators.py
-cp -f new/helper_functions.py visualwebarena/evaluation_harness/helper_functions.py
+cp -f new/helper_functions_eval.py visualwebarena/evaluation_harness/helper_functions.py
# misc
cp -f README.md visualwebarena/README.md
cp -f new/wa_parallel_run.sh visualwebarena/wa_parallel_run.sh
+cp -f new/score.py visualwebarena/score.py
+cp -f new/wa_parallel_run_webrl.sh visualwebarena/wa_parallel_run_webrl.sh
+
# 3. remove temporary files
mv visualwebarena/* .
rm -rf new