From 521d7e999af0b88034411bfd534a8a82db67cd2f Mon Sep 17 00:00:00 2001
From: QZH-777 <1961710177@qq.com>
Date: Thu, 14 Nov 2024 15:51:41 +0800
Subject: [PATCH 1/4] add webrl mode
---
VAB-WebArena-Lite/new/actions.py | 2006 +++++++++++++++++
VAB-WebArena-Lite/new/agent.py | 15 +-
VAB-WebArena-Lite/new/envs.py | 319 +++
VAB-WebArena-Lite/new/evaluators.py | 4 +
.../new/helper_functions_browser.py | 239 ++
..._functions.py => helper_functions_eval.py} | 0
VAB-WebArena-Lite/new/html_tools/__init__.py | 7 +
.../new/html_tools/configs/__init__.py | 3 +
.../new/html_tools/configs/config.py | 56 +
.../new/html_tools/configs/html_prompt.py | 22 +
VAB-WebArena-Lite/new/html_tools/fetch.py | 108 +
.../new/html_tools/html_parser.py | 447 ++++
.../new/html_tools/identifier.py | 64 +
VAB-WebArena-Lite/new/html_tools/prompt.py | 97 +
.../new/html_tools/scripts/__init__.py | 43 +
.../html_tools/scripts/clickable_checker.js | 148 ++
.../new/html_tools/scripts/element_info.js | 34 +
.../new/html_tools/scripts/label.js | 74 +
.../new/html_tools/scripts/label_marker.js | 65 +
.../new/html_tools/scripts/prepare.js | 83 +
VAB-WebArena-Lite/new/html_tools/utils.py | 101 +
VAB-WebArena-Lite/new/openai_utils.py | 18 +-
VAB-WebArena-Lite/new/p_webrl.json | 13 +
VAB-WebArena-Lite/new/processors.py | 1351 +++++++++++
VAB-WebArena-Lite/new/prompt_constructor.py | 56 +
VAB-WebArena-Lite/new/run.py | 137 +-
VAB-WebArena-Lite/new/score.py | 86 +
.../new/test_webarena_lite.raw.json | 14 +-
VAB-WebArena-Lite/new/tokenizers.py | 9 +-
VAB-WebArena-Lite/new/utils.py | 6 +-
.../new/wa_parallel_run_webrl.sh | 97 +
VAB-WebArena-Lite/replace.sh | 14 +-
32 files changed, 5688 insertions(+), 48 deletions(-)
create mode 100644 VAB-WebArena-Lite/new/actions.py
create mode 100644 VAB-WebArena-Lite/new/envs.py
create mode 100644 VAB-WebArena-Lite/new/helper_functions_browser.py
rename VAB-WebArena-Lite/new/{helper_functions.py => helper_functions_eval.py} (100%)
create mode 100755 VAB-WebArena-Lite/new/html_tools/__init__.py
create mode 100755 VAB-WebArena-Lite/new/html_tools/configs/__init__.py
create mode 100755 VAB-WebArena-Lite/new/html_tools/configs/config.py
create mode 100755 VAB-WebArena-Lite/new/html_tools/configs/html_prompt.py
create mode 100755 VAB-WebArena-Lite/new/html_tools/fetch.py
create mode 100755 VAB-WebArena-Lite/new/html_tools/html_parser.py
create mode 100755 VAB-WebArena-Lite/new/html_tools/identifier.py
create mode 100755 VAB-WebArena-Lite/new/html_tools/prompt.py
create mode 100755 VAB-WebArena-Lite/new/html_tools/scripts/__init__.py
create mode 100755 VAB-WebArena-Lite/new/html_tools/scripts/clickable_checker.js
create mode 100755 VAB-WebArena-Lite/new/html_tools/scripts/element_info.js
create mode 100755 VAB-WebArena-Lite/new/html_tools/scripts/label.js
create mode 100755 VAB-WebArena-Lite/new/html_tools/scripts/label_marker.js
create mode 100755 VAB-WebArena-Lite/new/html_tools/scripts/prepare.js
create mode 100755 VAB-WebArena-Lite/new/html_tools/utils.py
create mode 100644 VAB-WebArena-Lite/new/p_webrl.json
create mode 100644 VAB-WebArena-Lite/new/processors.py
create mode 100644 VAB-WebArena-Lite/new/score.py
create mode 100644 VAB-WebArena-Lite/new/wa_parallel_run_webrl.sh
diff --git a/VAB-WebArena-Lite/new/actions.py b/VAB-WebArena-Lite/new/actions.py
new file mode 100644
index 0000000..eedb327
--- /dev/null
+++ b/VAB-WebArena-Lite/new/actions.py
@@ -0,0 +1,2006 @@
+"""
+Browser Env action space.
+Inspited by Farama-Foundation/miniwob-plusplus
+"""
+import ast
+import random
+import re
+import string
+from enum import IntEnum
+from itertools import chain
+from typing import Any, TypedDict, Union, cast
+import time
+
+import numpy as np
+import numpy.typing as npt
+from beartype import beartype
+from beartype.door import is_bearable
+from gymnasium import spaces
+from playwright._impl._api_structures import ViewportSize
+from playwright.async_api import BrowserContext as ABrowserContext
+from playwright.async_api import Locator as ALocator
+from playwright.async_api import Page as APage
+from playwright.sync_api import BrowserContext, Locator, Page
+
+from browser_env.constants import (
+ ASCII_CHARSET,
+ FREQ_UNICODE_CHARSET,
+ MAX_ANSWER_LENGTH,
+ MAX_ELEMENT_ID,
+ MAX_ELEMENT_INDEX_IN_VIEWPORT,
+ MAX_PAGE_NUMBER,
+ MAX_VANILLA_STR_LENGTH,
+ PLAYWRIGHT_ACTIONS,
+ PLAYWRIGHT_LOCATORS,
+ ROLES,
+ SPECIAL_KEY_MAPPINGS,
+ SPECIAL_KEYS,
+ SPECIAL_LOCATORS,
+ TEXT_MAX_LENGTH,
+ TYPING_MAX_LENGTH,
+ URL_MAX_LENGTH,
+ RolesType,
+)
+from browser_env.processors import ObservationProcessor
+
+
+class ParsedPlaywrightCode(TypedDict):
+ function_name: str
+ arguments: list[str]
+ keywords: dict[str, Any]
+
+
+from browser_env.processors import (
+ ObservationProcessor,
+ TextObervationProcessor,
+)
+
+
+@beartype
+def is_in_viewport(
+ element: Locator, viewport: ViewportSize, threshold: float = 0.3
+) -> bool:
+ """Given a playwright locator, check if it is in the viewport"""
+ box = element.bounding_box()
+ assert box is not None
+ boxx0 = box["x"]
+ boxx1 = box["x"] + box["width"]
+ boxy0 = box["y"]
+ boxy1 = box["y"] + box["height"]
+ viewportx0, viewporty0 = 0, 0
+ viewportx1, viewporty1 = viewport["width"], viewport["height"]
+ inter = max(0, min(boxx1, viewportx1) - max(boxx0, viewportx0)) * max(
+ 0, min(boxy1, viewporty1) - max(boxy0, viewporty0)
+ )
+ ratio = inter / (box["width"] * box["height"])
+ return ratio > threshold
+
+
+@beartype
+async def async_is_in_viewport(
+ element: ALocator, viewport: ViewportSize, threshold: float = 0.3
+) -> bool:
+ box = await element.bounding_box()
+ assert box is not None
+ boxx0 = box["x"]
+ boxx1 = box["x"] + box["width"]
+ boxy0 = box["y"]
+ boxy1 = box["y"] + box["height"]
+ viewportx0, viewporty0 = 0, 0
+ viewportx1, viewporty1 = viewport["width"], viewport["height"]
+ inter = max(0, min(boxx1, viewportx1) - max(boxx0, viewportx0)) * max(
+ 0, min(boxy1, viewporty1) - max(boxy0, viewporty0)
+ )
+ ratio = inter / (box["width"] * box["height"])
+ return ratio > threshold
+
+
+class Action(TypedDict):
+ action_type: int
+ coords: npt.NDArray[np.float32]
+ element_role: int
+ element_name: str
+ text: list[int]
+ page_number: int
+ url: str
+ nth: int
+ element_id: str
+ direction: str
+ key_comb: str
+ pw_code: str
+ answer: str
+ raw_prediction: str # raw prediction from the model
+
+
+@beartype
+def action2str(
+ action: Action, action_set_tag: str, semantic_element: str = ""
+) -> str:
+ """Return the string representation of an action
+
+ sementic_element: the semantic information of the element
+ such as a line in an accessibility tree
+ """
+ if action_set_tag in [
+ "id_accessibility_tree",
+ "id_accessibility_tree_with_captioner",
+ ]:
+ element_id = action["element_id"]
+ match action["action_type"]:
+ case ActionTypes.CLICK:
+ # [ID=X] xxxxx
+ action_str = f"click [{element_id}] where [{element_id}] is {semantic_element}"
+ case ActionTypes.TYPE:
+ text = "".join([_id2key[i] for i in action["text"]])
+ action_str = f"type [{element_id}] [{text}] where [{element_id}] is {semantic_element}"
+ case ActionTypes.HOVER:
+ action_str = f"hover [{element_id}] where [{element_id}] is {semantic_element}"
+ case ActionTypes.SCROLL:
+ action_str = f"scroll [{action['direction']}]"
+ case ActionTypes.KEY_PRESS:
+ action_str = f"press [{action['key_comb']}]"
+ case ActionTypes.GOTO_URL:
+ action_str = f"goto [{action['url']}]"
+ case ActionTypes.NEW_TAB:
+ action_str = "new_tab"
+ case ActionTypes.PAGE_CLOSE:
+ action_str = "close_tab"
+ case ActionTypes.GO_BACK:
+ action_str = "go_back"
+ case ActionTypes.GO_FORWARD:
+ action_str = "go_forward"
+ case ActionTypes.PAGE_FOCUS:
+ action_str = f"page_focus [{action['page_number']}]"
+ case ActionTypes.CLEAR:
+ action_str = f"clear [{element_id}] where [{element_id}] is {semantic_element}"
+ case ActionTypes.STOP:
+ action_str = f"stop [{action['answer']}]"
+ case ActionTypes.NONE:
+ action_str = "none"
+ case _:
+ raise ValueError(
+ f"Unknown action type {action['action_type']}"
+ )
+ elif action_set_tag == "som":
+ element_id = action["element_id"]
+ match action["action_type"]:
+ case ActionTypes.CLICK:
+ # [ID=X] xxxxx
+ action_str = f"click [{element_id}] where [{element_id}]"
+ case ActionTypes.CLEAR:
+ action_str = f"clear [{element_id}] where [{element_id}] is {semantic_element}"
+ case ActionTypes.TYPE:
+ text = "".join([_id2key[i] for i in action["text"]])
+ action_str = (
+ f"type [{element_id}] [{text}] where [{element_id}]"
+ )
+ case ActionTypes.HOVER:
+ action_str = f"hover [{element_id}] where [{element_id}]"
+ case ActionTypes.SCROLL:
+ action_str = f"scroll [{action['direction']}]"
+ case ActionTypes.KEY_PRESS:
+ action_str = f"press [{action['key_comb']}]"
+ case ActionTypes.GOTO_URL:
+ action_str = f"goto [{action['url']}]"
+ case ActionTypes.NEW_TAB:
+ action_str = "new_tab"
+ case ActionTypes.PAGE_CLOSE:
+ action_str = "close_tab"
+ case ActionTypes.GO_BACK:
+ action_str = "go_back"
+ case ActionTypes.GO_FORWARD:
+ action_str = "go_forward"
+ case ActionTypes.PAGE_FOCUS:
+ action_str = f"page_focus [{action['page_number']}]"
+ case ActionTypes.STOP:
+ action_str = f"stop [{action['answer']}]"
+ case ActionTypes.NONE:
+ action_str = "none"
+ case _:
+ raise ValueError(
+ f"Unknown action type {action['action_type']}"
+ )
+ else:
+ raise NotImplementedError(f"Unknown action set tag {action_set_tag}")
+
+ return action_str
+
+
+def action2create_function(action: Action) -> str:
+ match (action["action_type"]):
+ case ActionTypes.NONE:
+ return "create_none_action()"
+ # mouse wheel and keyboard action
+ case ActionTypes.SCROLL:
+ direction = "up" if "up" in action["direction"] else "down"
+ return f"create_scroll_action({repr(direction)})"
+ case ActionTypes.KEY_PRESS:
+ return f"create_key_press_action({repr(action['key_comb'])})"
+ # inter-page actions
+ case ActionTypes.PAGE_FOCUS:
+ return f"create_page_focus_action({action['page_number']})"
+ case ActionTypes.NEW_TAB:
+ return "create_new_tab_action()"
+ case ActionTypes.GO_BACK:
+ return "create_go_back_action()"
+ case ActionTypes.GO_FORWARD:
+ return "create_go_forward_action()"
+ case ActionTypes.GOTO_URL:
+ return f"create_goto_url_action({repr(action['url'])})"
+ case ActionTypes.PAGE_CLOSE:
+ return "create_page_close_action()"
+
+ # low-level keyboard and mouse actions
+ case ActionTypes.MOUSE_CLICK:
+ return f"create_mouse_click_action({action['coords'][0]}, {action['coords'][1]})"
+ case ActionTypes.MOUSE_HOVER:
+ return f"create_mouse_hover_action({action['coords'][0]}, {action['coords'][1]})"
+ case ActionTypes.KEYBOARD_TYPE:
+ return f"create_keyboard_type_action({list(map(lambda x: _id2key[x], action['text']))})"
+
+ # mid-level keyboard and mouse actions
+ case ActionTypes.CLICK:
+ args = []
+ args.append(f"element_id={repr(action['element_id'])}")
+ args.append(
+ f"element_role={repr(_id2role[action['element_role']])}"
+ )
+ args.append(f"element_name={repr(action['element_name'])}")
+ args.append(f"pw_code={repr(action['pw_code'])}")
+ args_str = ", ".join(args)
+ return f"create_click_action({args_str})"
+ case ActionTypes.CLEAR:
+ args = []
+ args.append(f"element_id={repr(action['element_id'])}")
+ args.append(
+ f"element_role={repr(_id2role[action['element_role']])}"
+ )
+ args.append(f"element_name={repr(action['element_name'])}")
+ args.append(f"pw_code={repr(action['pw_code'])}")
+ args_str = ", ".join(args)
+ return f"create_clear_action({args_str})"
+ case ActionTypes.HOVER:
+ args = []
+ args.append(f"element_id={repr(action['element_id'])}")
+ args.append(
+ f"element_role={repr(_id2role[action['element_role']])}"
+ )
+ args.append(f"element_name={repr(action['element_name'])}")
+ args.append(f"pw_code={repr(action['pw_code'])}")
+ args_str = ", ".join(args)
+ return f"create_hover_action({args_str})"
+ case ActionTypes.TYPE:
+ args = []
+ text = "".join(map(lambda x: _id2key[x], action["text"]))
+ args.append(f"text={repr(text)}")
+ args.append(f"element_id={repr(action['element_id'])}")
+ args.append(
+ f"element_role={repr(_id2role[action['element_role']])}"
+ )
+ args.append(f"element_name={repr(action['element_name'])}")
+ args.append(f"pw_code={repr(action['pw_code'])}")
+ args_str = ", ".join(args)
+ return f"create_type_action({args_str})"
+
+ # high-level actions, only support locators from playwright
+ case ActionTypes.CHECK:
+ return f"create_check_action(pw_code={repr(action['pw_code'])})"
+ case ActionTypes.SELECT_OPTION:
+ return f"create_select_option_action(pw_code={repr(action['pw_code'])})"
+ case ActionTypes.STOP:
+ return f'create_stop_action({repr(action["answer"])})'
+
+ raise ValueError(f"Invalid action type: {action['action_type']}")
+
+
+class ActionTypes(IntEnum):
+ """Valid action types for browser env."""
+
+ NONE = 0
+ # mouse wheel and keyboard, universal across all action spaces
+ SCROLL = 1
+ KEY_PRESS = 2
+
+ # low level mouse and keyboard actions
+ MOUSE_CLICK = 3
+ KEYBOARD_TYPE = 4
+ MOUSE_HOVER = 5
+
+ # mid level mouse and keyboard actions
+ CLICK = 6
+ TYPE = 7
+ HOVER = 8
+
+ # page level actions, universal across all action spaces
+ PAGE_FOCUS = 9
+ NEW_TAB = 10
+ GO_BACK = 11
+ GO_FORWARD = 12
+ GOTO_URL = 13
+ PAGE_CLOSE = 14
+
+ # high-leval actions that playwright support
+ CHECK = 15
+ SELECT_OPTION = 16
+
+ STOP = 17
+ CLEAR = 18
+
+ # webrl actions
+ SEARCH = 19
+ SELECT_DROPDOWN_OPTION = 20
+
+ def __str__(self) -> str:
+ return f"ACTION_TYPES.{self.name}"
+
+
+@beartype
+def is_equivalent(a: Action, b: Action) -> bool:
+ """Return True if two actions are equal."""
+ if a["action_type"] != b["action_type"]:
+ return False
+ match (a["action_type"]):
+ case ActionTypes.NONE:
+ return True
+ case ActionTypes.SCROLL:
+ da = "up" if "up" in a["direction"] else "down"
+ db = "up" if "up" in b["direction"] else "down"
+ return da == db
+ case ActionTypes.KEY_PRESS:
+ return a["key_comb"] == b["key_comb"]
+ case ActionTypes.MOUSE_CLICK | ActionTypes.MOUSE_HOVER:
+ return np.allclose(a["coords"], b["coords"])
+ case ActionTypes.KEYBOARD_TYPE:
+ return a["text"] == b["text"]
+ case ActionTypes.CLICK | ActionTypes.HOVER | ActionTypes.TYPE: # TODO: can be further optimized
+ if a["element_id"] and b["element_id"]:
+ return a["element_id"] == b["element_id"]
+ elif a["element_role"] and b["element_role"]:
+ return (
+ a["element_role"] == b["element_role"]
+ and a["element_name"] == b["element_name"]
+ )
+ elif a["pw_code"] and b["pw_code"]:
+ return a["pw_code"] == b["pw_code"]
+ else:
+ return False
+ case ActionTypes.PAGE_FOCUS:
+ return a["page_number"] == b["page_number"]
+ case ActionTypes.NEW_TAB:
+ return True
+ case ActionTypes.GO_BACK:
+ return True
+ case ActionTypes.GO_FORWARD:
+ return True
+ case ActionTypes.GOTO_URL:
+ return a["url"] == b["url"]
+ case ActionTypes.PAGE_CLOSE:
+ return True
+ case ActionTypes.CHECK | ActionTypes.SELECT_OPTION:
+ return a["pw_code"] == b["pw_code"]
+ case ActionTypes.STOP:
+ return a["answer"] == b["answer"]
+ case _:
+ raise ValueError(f"Unknown action type: {a['action_type']}")
+
+
+_key2id: dict[str, int] = {
+ key: i
+ for i, key in enumerate(
+ chain(SPECIAL_KEYS, ASCII_CHARSET, FREQ_UNICODE_CHARSET, ["\n"])
+ )
+}
+_id2key: list[str] = sorted(_key2id, key=_key2id.get) # type: ignore[arg-type]
+_role2id: dict[RolesType, int] = {
+ cast(RolesType, role): i
+ for i, role in enumerate(chain(ROLES, SPECIAL_LOCATORS))
+}
+_id2role: list[RolesType] = sorted(_role2id, key=_role2id.get) # type: ignore[arg-type]
+
+
+@beartype
+def _keys2ids(keys: list[int | str] | str) -> list[int]:
+ return list(
+ map(
+ lambda key: _key2id.get(str(key), _key2id.get(key, " "))
+ if is_bearable(key, str)
+ else int(key),
+ keys,
+ )
+ )
+
+
+def get_action_space() -> spaces.Dict:
+ """Return the space of serialized actions."""
+ space = spaces.Dict(
+ {
+ "action_type": spaces.Discrete(len(ActionTypes)),
+ # coords (left, top) is used for COORD_CLICK
+ "coords": spaces.Box(
+ np.array([0.0, 0.0], dtype=np.float32),
+ np.array([1.0, 1.0], dtype=np.float32),
+ ),
+ # element role is used for FOCUS_AND_CLICK and FOCUS_AND_TYPE
+ "element_role": spaces.Discrete(
+ len(ROLES) + len(SPECIAL_LOCATORS)
+ ),
+ # element name is used with element role
+ "element_name": spaces.Text(TEXT_MAX_LENGTH),
+ "element_id": spaces.Text(TEXT_MAX_LENGTH),
+ # text is only used for TYPE and FOCUS_AND_TYPE
+ "text": spaces.MultiDiscrete(
+ [
+ len(ASCII_CHARSET)
+ + len(SPECIAL_KEYS)
+ + len(FREQ_UNICODE_CHARSET)
+ ]
+ * TYPING_MAX_LENGTH
+ ),
+ "page_number": spaces.Discrete(MAX_PAGE_NUMBER),
+ "url": spaces.Text(URL_MAX_LENGTH),
+ "nth": spaces.Discrete(MAX_ELEMENT_INDEX_IN_VIEWPORT),
+ "key_comb": spaces.Text(MAX_VANILLA_STR_LENGTH),
+ "direction": spaces.Text(MAX_VANILLA_STR_LENGTH),
+ "pw_code": spaces.Text(MAX_VANILLA_STR_LENGTH),
+ "answer": spaces.Text(MAX_ANSWER_LENGTH),
+ }
+ )
+ return space
+
+
+def create_random_action() -> Action:
+ """Return a random action."""
+ return {
+ "action_type": np.random.randint(len(ActionTypes)),
+ "coords": np.random.rand(2).astype(np.float32),
+ "element_role": np.random.randint(len(ROLES) + len(SPECIAL_LOCATORS)),
+ "element_name": "".join(
+ random.choices(ASCII_CHARSET, k=np.random.randint(TEXT_MAX_LENGTH))
+ ),
+ "text": list(
+ random.choices(
+ list(range(len(ASCII_CHARSET))),
+ k=np.random.randint(TYPING_MAX_LENGTH),
+ )
+ ),
+ "page_number": np.random.randint(MAX_PAGE_NUMBER),
+ "url": "".join(
+ random.choices(ASCII_CHARSET, k=np.random.randint(URL_MAX_LENGTH))
+ ),
+ "nth": np.random.randint(MAX_ELEMENT_INDEX_IN_VIEWPORT),
+ "element_id": str(np.random.randint(MAX_ELEMENT_ID)),
+ "key_comb": "+".join(
+ random.choices(SPECIAL_KEYS, k=np.random.randint(3))
+ ),
+ "direction": random.choice(["up", "down"]),
+ "pw_code": "".join(
+ random.choices(
+ string.ascii_uppercase + string.digits,
+ k=np.random.randint(MAX_VANILLA_STR_LENGTH),
+ )
+ ),
+ "answer": str(np.random.randint(MAX_ANSWER_LENGTH)),
+ "raw_prediction": str(np.random.randint(MAX_ANSWER_LENGTH)),
+ }
+
+
+@beartype
+def create_none_action() -> Action:
+ """Return a valid action object that does nothing."""
+ return {
+ "action_type": ActionTypes.NONE,
+ "coords": np.zeros(2, dtype=np.float32),
+ "element_role": 0,
+ "element_name": "",
+ "text": [],
+ "page_number": 0,
+ "url": "",
+ "nth": 0,
+ "pw_code": "", # str that requires further processing
+ "element_id": "",
+ "key_comb": "",
+ "direction": "",
+ "answer": "",
+ "raw_prediction": "",
+ }
+
+
+@beartype
+def create_stop_action(answer: str) -> Action:
+ action = create_none_action()
+ action.update({"action_type": ActionTypes.STOP, "answer": answer})
+ return action
+
+
+@beartype
+def create_scroll_action(direction: str) -> Action:
+ """Return the playwright action"""
+ assert direction in ["up", "down"]
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.SCROLL,
+ "direction": direction,
+ }
+ )
+ return action
+
+
+@beartype
+def create_mouse_hover_action(
+ left: float | None = None, top: float | None = None
+) -> Action:
+ """Return a valid action object with type COORD_CLICK."""
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.MOUSE_HOVER,
+ "coords": np.array([left, top], dtype=np.float32),
+ }
+ )
+ return action
+
+
+@beartype
+def create_key_press_action(key_comb: str) -> Action:
+ """Return the key press action"""
+
+ def map_keys(key_comb: str) -> str:
+ keys = key_comb.split("+")
+ mapped_keys = []
+ for key in keys:
+ mapped_key = SPECIAL_KEY_MAPPINGS.get(key.lower(), key)
+ mapped_keys.append(mapped_key)
+ return "+".join(mapped_keys)
+
+ action = create_none_action()
+ mapped_key_comb = map_keys(key_comb)
+ action.update(
+ {
+ "action_type": ActionTypes.KEY_PRESS,
+ "key_comb": mapped_key_comb,
+ }
+ )
+ return action
+
+
+@beartype
+def create_page_focus_action(page_number: int) -> Action:
+ """Return a valid action object with type PAGE_FOCUS."""
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.PAGE_FOCUS,
+ "page_number": page_number,
+ }
+ )
+ return action
+
+
+@beartype
+def create_new_tab_action() -> Action:
+ """Return a valid action object with type NEW_TAB."""
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.NEW_TAB,
+ }
+ )
+ return action
+
+
+@beartype
+def create_go_back_action() -> Action:
+ """Return a valid action object with type GO_BACK."""
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.GO_BACK,
+ }
+ )
+ return action
+
+
+@beartype
+def create_go_forward_action() -> Action:
+ """Return a valid action object with type GO_FORWARD."""
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.GO_FORWARD,
+ }
+ )
+ return action
+
+
+@beartype
+def create_goto_url_action(url: str) -> Action:
+ """Return a valid action object with type GOTO_URL."""
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.GOTO_URL,
+ "url": url,
+ }
+ )
+ return action
+
+
+@beartype
+def create_page_close_action() -> Action:
+ """Return a valid action object with type PAGE_CLOSE."""
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.PAGE_CLOSE,
+ }
+ )
+ return action
+
+
+@beartype
+def create_mouse_click_action(
+ left: float | None = None, top: float | None = None
+) -> Action:
+ """Return a valid action object with type COORD_CLICK."""
+ action = create_none_action()
+ if left and top:
+ action.update(
+ {
+ "action_type": ActionTypes.MOUSE_CLICK,
+ "coords": np.array([left, top], dtype=np.float32),
+ }
+ )
+ elif (not left) and (not top):
+ action.update(
+ {
+ "action_type": ActionTypes.CLICK,
+ }
+ )
+ else:
+ raise ValueError("left and top must be both None or both not None")
+ return action
+
+
+@beartype
+def create_clear_action(
+ element_id: str = "",
+ element_role: RolesType = "link",
+ element_name: str = "",
+ pw_code: str = "",
+ nth: int = 0,
+) -> Action:
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.CLEAR,
+ "element_id": element_id,
+ "element_role": _role2id[element_role],
+ "element_name": element_name,
+ "nth": nth,
+ "pw_code": pw_code,
+ }
+ )
+ return action
+
+
+@beartype
+def create_keyboard_type_action(keys: list[int | str] | str) -> Action:
+ """Return a valid action object with type TYPE."""
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.KEYBOARD_TYPE,
+ "text": _keys2ids(keys),
+ }
+ )
+ return action
+
+
+@beartype
+def create_click_action(
+ element_id: str = "",
+ element_role: RolesType = "link",
+ element_name: str = "",
+ pw_code: str = "",
+ nth: int = 0,
+) -> Action:
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.CLICK,
+ "element_id": element_id,
+ "element_role": _role2id[element_role],
+ "element_name": element_name,
+ "nth": nth,
+ "pw_code": pw_code,
+ }
+ )
+ return action
+
+
+@beartype
+def create_hover_action(
+ element_id: str = "",
+ element_role: RolesType = "link",
+ element_name: str = "",
+ pw_code: str = "",
+ nth: int = 0,
+) -> Action:
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.HOVER,
+ "element_id": element_id,
+ "element_role": _role2id[element_role],
+ "element_name": element_name,
+ "nth": nth,
+ "pw_code": pw_code,
+ }
+ )
+ return action
+
+
+@beartype
+def create_type_action(
+ text: str,
+ element_id: str = "",
+ element_role: RolesType = "link",
+ element_name: str = "",
+ pw_code: str = "",
+ nth: int = 0,
+) -> Action:
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.TYPE,
+ "element_id": element_id,
+ "element_role": _role2id[element_role],
+ "element_name": element_name,
+ "nth": nth,
+ "text": _keys2ids(text),
+ "pw_code": pw_code,
+ }
+ )
+ return action
+
+
+@beartype
+def create_type_action_webrl(
+ text: str,
+ element_id: str = "",
+ element_role: RolesType = "link",
+ element_name: str = "",
+ pw_code: str = "",
+ nth: int = 0,
+) -> Action:
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.TYPE,
+ "element_id": element_id,
+ "element_role": _role2id[element_role],
+ "element_name": element_name,
+ "nth": nth,
+ "text": text,
+ "pw_code": pw_code,
+ }
+ )
+ return action
+
+
+@beartype
+def create_search_action(
+ text: str,
+ element_id: str = "",
+ element_role: RolesType = "link",
+ element_name: str = "",
+ pw_code: str = "",
+ nth: int = 0,
+) -> Action:
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.SEARCH,
+ "element_id": element_id,
+ "element_role": _role2id[element_role],
+ "element_name": element_name,
+ "nth": nth,
+ "text": text,
+ "pw_code": pw_code,
+ }
+ )
+ return action
+
+
+@beartype
+def create_select_dropdown_option_action(
+ argument: str,
+ element_id: str = "",
+ element_role: RolesType = "link",
+ element_name: str = "",
+ pw_code: str = "",
+ nth: int = 0,
+) -> Action:
+ """Return a valid action object with type SELECT_DROPDOWN_OPTION."""
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.SELECT_DROPDOWN_OPTION,
+ "element_id": element_id,
+ "element_role": _role2id[element_role],
+ "element_name": element_name,
+ "nth": nth,
+ "argument": argument,
+ "pw_code": pw_code,
+ }
+ )
+ return action
+
+
+@beartype
+def create_check_action(pw_code: str) -> Action:
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.CHECK,
+ "pw_code": pw_code,
+ }
+ )
+ return action
+
+
+@beartype
+def create_select_option_action(
+ pw_code: str,
+) -> Action:
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.SELECT_OPTION,
+ "pw_code": pw_code,
+ }
+ )
+ return action
+
+
+@beartype
+def create_focus_action(
+ element_role: RolesType, element_name: str = "", nth: int = 0
+) -> Action:
+ """Return a valid action object with type CLICK.
+
+ Keep compatible with the old version."""
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.CLICK,
+ "element_role": _role2id[element_role],
+ "element_name": element_name,
+ "nth": nth,
+ }
+ )
+ return action
+
+
+@beartype
+def create_focus_and_click_action(
+ element_role: RolesType, element_name: str = "", nth: int = 0
+) -> Action:
+ """Return a valid action object with type CLICK.
+
+ Keep compatible with the old version."""
+
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.CLICK,
+ "element_role": _role2id[element_role],
+ "element_name": element_name,
+ "nth": nth,
+ }
+ )
+ return action
+
+
+@beartype
+def create_focus_and_type_action(
+ keys: list[int | str] | str,
+ element_role: RolesType,
+ element_name: str = "",
+ nth: int = 0,
+) -> Action:
+ """Return a valid action object with type TYPE.
+
+ Keep compatible with the old version."""
+ action = create_none_action()
+ action.update(
+ {
+ "action_type": ActionTypes.TYPE,
+ "element_role": _role2id[element_role],
+ "element_name": element_name,
+ "text": _keys2ids(keys),
+ "nth": nth,
+ }
+ )
+ return action
+
+
+@beartype
+def execute_scroll(direction: str, page: Page) -> None:
+ # perform the action
+ # code from natbot
+ if direction == "up":
+ page.evaluate(
+ "(document.scrollingElement || document.body).scrollTop = (document.scrollingElement || document.body).scrollTop - window.innerHeight;"
+ )
+ elif direction == "down":
+ page.evaluate(
+ "(document.scrollingElement || document.body).scrollTop = (document.scrollingElement || document.body).scrollTop + window.innerHeight;"
+ )
+
+@beartype
+def execute_scroll_webrl(direction: str, page: Page) -> None:
+ # perform the action which move 2/3 of the height of the page at a time
+ if direction == "up":
+ page.mouse.wheel(0, -page.viewport_size['height'] * 2.0 / 3)
+ elif direction == "down":
+ page.mouse.wheel(0, page.viewport_size['height'] * 2.0 / 3)
+
+@beartype
+async def aexecute_scroll(direction: str, page: APage) -> None:
+ # perform the action
+ # code from natbot
+ if direction == "up":
+ await page.evaluate(
+ "(document.scrollingElement || document.body).scrollTop = (document.scrollingElement || document.body).scrollTop - window.innerHeight;"
+ )
+ elif direction == "down":
+ await page.evaluate(
+ "(document.scrollingElement || document.body).scrollTop = (document.scrollingElement || document.body).scrollTop + window.innerHeight;"
+ )
+
+
+@beartype
+def execute_key_press(key: str, page: Page) -> None:
+ """Press a key."""
+ if "Meta" in key and "Mac" not in page.evaluate("navigator.platform"):
+ key = key.replace("Meta", "Control")
+ page.keyboard.press(key)
+
+
+@beartype
+async def aexecute_key_press(key: str, page: APage) -> None:
+ """Press a key."""
+ if "Meta" in key and "Mac" not in await page.evaluate(
+ "navigator.platform"
+ ):
+ key = key.replace("Meta", "Control")
+ await page.keyboard.press(key)
+
+
+@beartype
+def execute_mouse_hover(left: float, top: float, page: Page) -> None:
+ """Click at coordinates (left, top)."""
+ viewport_size = page.viewport_size
+ assert viewport_size
+ page.mouse.move(
+ left * viewport_size["width"], top * viewport_size["height"]
+ )
+
+
+@beartype
+async def aexecute_mouse_hover(left: float, top: float, page: APage) -> None:
+ """Click at coordinates (left, top)."""
+ viewport_size = page.viewport_size
+ assert viewport_size
+ await page.mouse.move(
+ left * viewport_size["width"], top * viewport_size["height"]
+ )
+
+
+def execute_mouse_click(left: float, top: float, page: Page) -> None:
+ """Click at coordinates (left, top)."""
+ viewport_size = page.viewport_size
+ assert viewport_size
+ page.mouse.click(
+ left * viewport_size["width"], top * viewport_size["height"]
+ )
+
+
+@beartype
+async def aexecute_mouse_click(left: float, top: float, page: APage) -> None:
+ """Click at coordinates (left, top)."""
+ viewport_size = page.viewport_size
+ assert viewport_size
+ await page.mouse.click(
+ left * viewport_size["width"], top * viewport_size["height"]
+ )
+
+
+@beartype
+def execute_keyboard_type(text: str, page: Page) -> None:
+ """Fill the focused element with text."""
+ page.keyboard.type(text)
+
+
+@beartype
+async def aexecute_keyboard_type(text: str, page: APage) -> None:
+ """Fill the focused element with text."""
+ await page.keyboard.type(text)
+
+
+@beartype
+def execute_click_current(page: Page) -> None:
+ """Click at the current mouse position."""
+ locators = page.locator("*:focus")
+ if not locators.count():
+ for frame in page.frames[1:]:
+ locators = frame.locator("*:focus")
+ if locators.count():
+ break
+ locators.click()
+
+
+@beartype
+async def aexecute_click_current(page: APage) -> None:
+ """Click at the current mouse position."""
+ locators = page.locator("*:focus")
+ locator_count = await locators.count()
+ if not locator_count:
+ for frame in page.frames[1:]:
+ locators = frame.locator("*:focus")
+ locator_count = await locators.count()
+ if locator_count:
+ break
+ await locators.click()
+ await page.wait_for_load_state("load")
+
+
+@beartype
+def execute_type(keys: list[int], page: Page) -> None:
+ """Send keystrokes to the focused element."""
+ text = "".join([_id2key[key] for key in keys])
+ page.keyboard.type(text)
+
+
+@beartype
+async def aexecute_type(keys: list[int], page: APage) -> None:
+ """Send keystrokes to the focused element."""
+ text = "".join([_id2key[key] for key in keys])
+ await page.keyboard.type(text)
+
+
+@beartype
+def execute_focus(
+ element_role: int, element_name: str, nth: int, page: Page
+) -> None:
+ """Click the specified DOM element."""
+ element_role_str = _id2role[element_role]
+ if page.viewport_size is None:
+ raise ValueError("Viewport size is not set for the current page")
+ element_location_list: list[tuple[Locator, float, float]] = []
+ for frame in page.frames:
+ match element_role_str:
+ case "alt_text":
+ locators = frame.get_by_alt_text(element_name)
+ case "label":
+ locators = frame.get_by_label(element_name)
+ case "placeholder":
+ locators = frame.get_by_placeholder(element_name)
+ case _:
+ locators = frame.get_by_role(
+ role=element_role_str, name=element_name
+ )
+ for locator_idx in range(locators.count()):
+ locator = locators.nth(locator_idx)
+ if is_in_viewport(locator, page.viewport_size):
+ bounding_box = locator.bounding_box()
+ assert bounding_box
+ element_location_list.append(
+ (locator, bounding_box["x"], bounding_box["y"])
+ )
+ if len(element_location_list) <= nth:
+ raise ValueError(
+ f"There are only {len(element_location_list)} elements found in viewport, but {nth + 1} is requested"
+ )
+ element_location_list.sort(key=lambda x: (x[2], x[1])) # row major order
+ element_location_list[nth][0].focus()
+
+
+@beartype
+async def aexecute_focus(
+ element_role: int, element_name: str, nth: int, page: APage
+) -> None:
+ """Click the specified DOM element."""
+ element_role_str = _id2role[element_role]
+ if page.viewport_size is None:
+ raise ValueError("Viewport size is not set for the current page")
+ element_location_list: list[tuple[ALocator, float, float]] = []
+ for frame in page.frames:
+ match element_role_str:
+ case "alt_text":
+ locators = frame.get_by_alt_text(element_name)
+ case "label":
+ locators = frame.get_by_label(element_name)
+ case "placeholder":
+ locators = frame.get_by_placeholder(element_name)
+ case _:
+ locators = frame.get_by_role(
+ role=element_role_str, name=element_name
+ )
+ for locator_idx in range(await locators.count()):
+ locator = locators.nth(locator_idx)
+ if await async_is_in_viewport(locator, page.viewport_size):
+ bounding_box = await locator.bounding_box()
+ assert bounding_box
+ element_location_list.append(
+ (locator, bounding_box["x"], bounding_box["y"])
+ )
+ if len(element_location_list) <= nth:
+ raise ValueError(
+ f"There are only {len(element_location_list)} elements found in viewport, but {nth + 1} is requested"
+ )
+ element_location_list.sort(key=lambda x: (x[2], x[1])) # row major order
+ await element_location_list[nth][0].focus()
+
+
+@beartype
+def locate(locator_calls: list[ParsedPlaywrightCode], page: Page) -> Locator:
+ locator = page
+ for call in locator_calls:
+ function_name = call["function_name"]
+ arguments = call["arguments"]
+ keywords = call["keywords"]
+ locator = getattr(locator, function_name)(*arguments, **keywords)
+ return locator # type: ignore[return-value]
+
+
+@beartype
+async def alocate(
+ locator_calls: list[ParsedPlaywrightCode], page: APage
+) -> ALocator:
+ locator = page
+ for call in locator_calls:
+ function_name = call["function_name"]
+ arguments = call["arguments"]
+ keywords = call["keywords"]
+ locator = await getattr(locator, function_name)(*arguments, **keywords)
+ return locator # type: ignore[return-value]
+
+
+@beartype
+def execute_playwright_click(
+ locator_code: list[ParsedPlaywrightCode],
+ page: Page,
+ pw_action_args: list[str] = [],
+ pw_action_kwargs: dict[str, Any] = {},
+) -> None:
+ locator = locate(locator_code, page)
+
+ # perform the action
+ locator.click(*pw_action_args, **pw_action_kwargs)
+
+
+@beartype
+async def aexecute_playwright_click(
+ locator_code: list[ParsedPlaywrightCode],
+ page: APage,
+ pw_action_args: list[str] = [],
+ pw_action_kwargs: dict[str, Any] = {},
+) -> None:
+ locator = await alocate(locator_code, page)
+
+ # perform the action
+ await locator.click(*pw_action_args, **pw_action_kwargs)
+
+
+@beartype
+def execute_playwright_hover(
+ locator_code: list[ParsedPlaywrightCode], page: Page
+) -> None:
+ locator = locate(locator_code, page)
+
+ # perform the action
+ locator.hover()
+
+
+@beartype
+async def aexecute_playwright_hover(
+ locator_code: list[ParsedPlaywrightCode], page: APage
+) -> None:
+ locator = await alocate(locator_code, page)
+
+ # perform the action
+ await locator.hover()
+
+
+@beartype
+def execute_playwright_type(
+ text: str,
+ locator_code: list[ParsedPlaywrightCode],
+ page: Page,
+ pw_action_args: list[str] = [],
+ pw_action_kwargs: dict[str, Any] = {},
+) -> None:
+ locator = locate(locator_code, page)
+ # perform the action
+ pw_action_args = [text] + pw_action_args # text is the first argument
+ locator.type(*pw_action_args, **pw_action_kwargs)
+
+
+@beartype
+async def aexecute_playwright_type(
+ text: str,
+ locator_code: list[ParsedPlaywrightCode],
+ page: APage,
+ pw_action_args: list[str] = [],
+ pw_action_kwargs: dict[str, Any] = {},
+) -> None:
+ locator = await alocate(locator_code, page)
+ # perform the action
+ pw_action_args = [text] + pw_action_args # text is the first argument
+ await locator.type(*pw_action_args, **pw_action_kwargs)
+
+
+@beartype
+def execute_playwright_select_option(
+ locator_code: list[ParsedPlaywrightCode],
+ page: Page,
+ pw_action_args: list[str] = [],
+ pw_action_kwargs: dict[str, Any] = {},
+) -> None:
+ locator = locate(locator_code, page)
+ # perform the action
+ locator.select_option(*pw_action_args, **pw_action_kwargs)
+
+
+@beartype
+async def aexecute_playwright_select_option(
+ locator_code: list[ParsedPlaywrightCode],
+ page: APage,
+ pw_action_args: list[str] = [],
+ pw_action_kwargs: dict[str, Any] = {},
+) -> None:
+ locator = await alocate(locator_code, page)
+ # perform the action
+ await locator.select_option(*pw_action_args, **pw_action_kwargs)
+
+
+@beartype
+def execute_playwright_check(
+ locator_code: list[ParsedPlaywrightCode], page: Page
+) -> None:
+ locator = locate(locator_code, page)
+ # perform the action
+ locator.check()
+
+
+@beartype
+async def aexecute_playwright_check(
+ locator_code: list[ParsedPlaywrightCode], page: APage
+) -> None:
+ locator = await alocate(locator_code, page)
+ # perform the action
+ await locator.check()
+
+
+@beartype
+def execute_action_webrl(
+ action: Action,
+ page: Page,
+ browser_ctx: BrowserContext,
+ obseration_processor: ObservationProcessor,
+ sleep_after_execution: float = 0.0,
+) -> Page:
+ """Execute the action on the ChromeDriver."""
+ action_type = action["action_type"]
+ num_tabs_before = len(browser_ctx.pages)
+ match action_type:
+ case ActionTypes.NONE:
+ pass
+ case ActionTypes.SCROLL:
+ direction = "up" if "up" in action["direction"] else "down"
+ execute_scroll_webrl(direction, page)
+ case ActionTypes.KEY_PRESS:
+ keys = action["key_comb"]
+ execute_key_press(keys, page)
+ case ActionTypes.MOUSE_CLICK:
+ execute_mouse_click(action["coords"][0], action["coords"][1], page)
+ case ActionTypes.CLICK:
+ # check each kind of locator in order
+ # TODO[shuyanzh]: order is temp now
+ element_id = action["element_id"]
+ element_center = obseration_processor.get_element_center(element_id, page) # type: ignore[attr-defined]
+ execute_mouse_click(element_center[0], element_center[1], page)
+ case ActionTypes.HOVER:
+ element_id = action["element_id"]
+ element_center = obseration_processor.get_element_center(element_id) # type: ignore[attr-defined]
+ execute_mouse_hover(element_center[0], element_center[1], page)
+ case ActionTypes.TYPE:
+ element_id = action["element_id"]
+ element_center = obseration_processor.get_element_center(element_id) # type: ignore[attr-defined]
+ execute_mouse_click(element_center[0], element_center[1], page)
+ execute_key_press("Meta+A", page)
+ execute_key_press('Backspace', page)
+ # execute_mouse_click(element_center[0], element_center[1], page)
+ text = _keys2ids(action["text"])
+ execute_type(text, page)
+ case ActionTypes.SEARCH:
+ element_id = action["element_id"]
+ element_center = obseration_processor.get_element_center(element_id) # type: ignore[attr-defined]
+ execute_mouse_click(element_center[0], element_center[1], page)
+ execute_key_press("Meta+A", page)
+ execute_key_press('Backspace', page)
+ # execute_mouse_click(element_center[0], element_center[1], page)
+ text = _keys2ids(action["text"])
+ execute_type(text, page)
+ time.sleep(2)
+ execute_key_press("Enter", page)
+ case ActionTypes.GO_BACK:
+ page.go_back()
+ case ActionTypes.GO_FORWARD:
+ page.go_forward()
+ case ActionTypes.GOTO_URL:
+ page.goto(action["url"])
+ case ActionTypes.SELECT_DROPDOWN_OPTION:
+ # Click
+ element_id = action["element_id"]
+ argument = action["argument"]
+ element_center = obseration_processor.get_element_center(element_id, page) # type: ignore[attr-defined]
+ execute_mouse_click(element_center[0], element_center[1], page)
+ # get element
+ device_pixel_ratio = page.evaluate("window.devicePixelRatio")
+ center_x, center_y = element_center[0] * page.viewport_size["width"], element_center[1] * page.viewport_size["height"]
+ last_turn_element = page.evaluate_handle(f"""() => document.elementFromPoint({center_x / device_pixel_ratio}, {center_y / device_pixel_ratio})""")
+ # get select element options
+ select_element_options = [{"value": option.get_attribute('value'), "text": option.text_content().strip(' \n')} for option in
+ last_turn_element.query_selector_all("option")]
+ selector_option_dict = dict((o["text"].lower(), o["value"]) for o in select_element_options)
+ value = None
+ for key in selector_option_dict.keys():
+ if argument.lower() in key.lower():
+ value = selector_option_dict[key]
+ break
+ if value is not None:
+ last_turn_element.select_option(value=value)
+ case _:
+ raise ValueError(f"Unknown action type: {action_type}")
+
+ page.wait_for_timeout(int(sleep_after_execution * 1000))
+ num_tabs_now = len(browser_ctx.pages)
+ # if a new tab is opened by clicking, switch to the new tab
+ if num_tabs_now > num_tabs_before:
+ page = browser_ctx.pages[-1]
+ page.bring_to_front()
+
+ return page
+
+@beartype
+def execute_action(
+ action: Action,
+ page: Page,
+ browser_ctx: BrowserContext,
+ obseration_processor: ObservationProcessor,
+ sleep_after_execution: float = 0.0,
+) -> Page:
+ """Execute the action on the ChromeDriver."""
+ action_type = action["action_type"]
+ num_tabs_before = len(browser_ctx.pages)
+ match action_type:
+ case ActionTypes.NONE:
+ pass
+
+ case ActionTypes.SCROLL:
+ direction = "up" if "up" in action["direction"] else "down"
+ execute_scroll(direction, page)
+ case ActionTypes.KEY_PRESS:
+ keys = action["key_comb"]
+ execute_key_press(keys, page)
+
+ case ActionTypes.MOUSE_CLICK:
+ execute_mouse_click(action["coords"][0], action["coords"][1], page)
+ case ActionTypes.CLEAR:
+ element_id = action["element_id"]
+ element_center = obseration_processor.get_element_center(element_id) # type: ignore[attr-defined]
+ execute_mouse_click(element_center[0], element_center[1], page)
+ execute_key_press("Meta+A", page)
+ execute_key_press('Backspace', page)
+ case ActionTypes.MOUSE_HOVER:
+ execute_mouse_hover(action["coords"][0], action["coords"][1], page)
+ case ActionTypes.KEYBOARD_TYPE:
+ execute_type(action["text"], page)
+ case ActionTypes.CLICK:
+ # check each kind of locator in order
+ # TODO[shuyanzh]: order is temp now
+ if action["element_id"]:
+ element_id = action["element_id"]
+ element_center = obseration_processor.get_element_center(element_id) # type: ignore[attr-defined]
+ execute_mouse_click(element_center[0], element_center[1], page)
+ elif action["element_role"] and action["element_name"]:
+ element_role = int(action["element_role"])
+ element_name = action["element_name"]
+ nth = action["nth"]
+ execute_focus(element_role, element_name, nth, page)
+ execute_click_current(page)
+ elif action["pw_code"]:
+ parsed_code = parse_playwright_code(action["pw_code"])
+ locator_code = parsed_code[:-1]
+ # [shuyanzh], don't support action args and kwargs now
+ execute_playwright_click(locator_code=locator_code, page=page)
+ else:
+ raise ValueError("No proper locator found for click action")
+ case ActionTypes.HOVER:
+ if action["element_id"]:
+ element_id = action["element_id"]
+ element_center = obseration_processor.get_element_center(element_id) # type: ignore[attr-defined]
+ execute_mouse_hover(element_center[0], element_center[1], page)
+ elif action["element_role"] and action["element_name"]:
+ element_role = int(action["element_role"])
+ element_name = action["element_name"]
+ nth = action["nth"]
+ execute_focus(element_role, element_name, nth, page)
+ elif action["pw_code"]:
+ parsed_code = parse_playwright_code(action["pw_code"])
+ locator_code = parsed_code[:-1]
+ # [shuyanzh], don't support action args and kwargs now
+ execute_playwright_hover(locator_code=locator_code, page=page)
+ else:
+ raise NotImplementedError(
+ "No proper locator found for hover action"
+ )
+ case ActionTypes.TYPE:
+ if action["element_id"]:
+ element_id = action["element_id"]
+ element_center = obseration_processor.get_element_center(element_id) # type: ignore[attr-defined]
+ execute_mouse_click(element_center[0], element_center[1], page)
+ execute_type(action["text"], page)
+ elif action["element_role"] and action["element_name"]:
+ element_role = int(action["element_role"])
+ element_name = action["element_name"]
+ nth = action["nth"]
+ execute_focus(element_role, element_name, nth, page)
+ execute_type(action["text"], page)
+ elif action["pw_code"]:
+ parsed_code = parse_playwright_code(action["pw_code"])
+ locator_code = parsed_code[:-1]
+ text = parsed_code[-1]["arguments"][0]
+ # [shuyanzh], don't support action args and kwargs now
+ execute_playwright_type(
+ text=text, locator_code=locator_code, page=page
+ )
+ else:
+ raise NotImplementedError(
+ "No proper locator found for type action"
+ )
+
+ case ActionTypes.PAGE_FOCUS:
+ page = browser_ctx.pages[action["page_number"]]
+ page.bring_to_front()
+ case ActionTypes.NEW_TAB:
+ page = browser_ctx.new_page()
+ case ActionTypes.GO_BACK:
+ page.go_back()
+ case ActionTypes.GO_FORWARD:
+ page.go_forward()
+ case ActionTypes.GOTO_URL:
+ page.goto(action["url"])
+ case ActionTypes.PAGE_CLOSE:
+ page.close()
+ if len(browser_ctx.pages) > 0:
+ page = browser_ctx.pages[-1]
+ else:
+ page = browser_ctx.new_page()
+
+ case ActionTypes.SELECT_OPTION:
+ if action["pw_code"]:
+ parsed_code = parse_playwright_code(action["pw_code"])
+ locator_code = parsed_code[:-1]
+ execute_playwright_select_option(locator_code, page)
+ else:
+ raise NotImplementedError(
+ "No proper locator found for select option action"
+ )
+ case ActionTypes.CHECK:
+ if action["pw_code"]:
+ parsed_code = parse_playwright_code(action["pw_code"])
+ locator_code = parsed_code[:-1]
+ execute_playwright_check(locator_code, page)
+ else:
+ raise NotImplementedError(
+ "No proper locator found for select option action"
+ )
+
+ case _:
+ raise ValueError(f"Unknown action type: {action_type}")
+
+ page.wait_for_timeout(int(sleep_after_execution * 1000))
+ num_tabs_now = len(browser_ctx.pages)
+ # if a new tab is opened by clicking, switch to the new tab
+ if num_tabs_now > num_tabs_before:
+ page = browser_ctx.pages[-1]
+ page.bring_to_front()
+
+ return page
+
+
+
+@beartype
+async def aexecute_action(
+ action: Action, page: APage, browser_ctx: ABrowserContext
+) -> APage:
+ """Execute the async action on the ChromeDriver."""
+ action_type = action["action_type"]
+ match action_type:
+ case ActionTypes.NONE:
+ pass
+ case ActionTypes.SCROLL:
+ direction = "up" if "up" in action["direction"] else "down"
+ await aexecute_scroll(direction, page)
+ case ActionTypes.KEY_PRESS:
+ keys = action["key_comb"]
+ await aexecute_key_press(keys, page)
+
+ case ActionTypes.MOUSE_CLICK:
+ await aexecute_mouse_click(
+ action["coords"][0], action["coords"][1], page
+ )
+ case ActionTypes.CLEAR:
+ element_id = action["element_id"]
+ element_center = obseration_processor.get_element_center(element_id) # type: ignore[attr-defined]
+ await execute_mouse_click(element_center[0], element_center[1], page)
+ await execute_key_press("Meta+A", page)
+ await execute_key_press('Backspace', page)
+ case ActionTypes.MOUSE_HOVER:
+ await aexecute_mouse_hover(
+ action["coords"][0], action["coords"][1], page
+ )
+ case ActionTypes.KEYBOARD_TYPE:
+ await aexecute_type(action["text"], page)
+
+ case ActionTypes.CLICK:
+ # check each kind of locator in order
+ # TODO[shuyanzh]: order is temp now
+ if action["element_id"]:
+ raise NotImplementedError
+ elif action["element_role"] and action["element_name"]:
+ element_role = int(action["element_role"])
+ element_name = action["element_name"]
+ nth = action["nth"]
+ await aexecute_focus(element_role, element_name, nth, page)
+ await aexecute_click_current(page)
+ elif action["pw_code"]:
+ parsed_code = parse_playwright_code(action["pw_code"])
+ locator_code = parsed_code[:-1]
+ # [shuyanzh], don't support action args and kwargs now
+ await aexecute_playwright_click(
+ locator_code=locator_code, page=page
+ )
+ else:
+ raise ValueError("No proper locator found for click action")
+ case ActionTypes.HOVER:
+ if action["element_id"]:
+ raise NotImplementedError
+ elif action["element_role"] and action["element_name"]:
+ element_role = int(action["element_role"])
+ element_name = action["element_name"]
+ nth = action["nth"]
+ await aexecute_focus(element_role, element_name, nth, page)
+ elif action["pw_code"]:
+ parsed_code = parse_playwright_code(action["pw_code"])
+ locator_code = parsed_code[:-1]
+ # [shuyanzh], don't support action args and kwargs now
+ await aexecute_playwright_hover(
+ locator_code=locator_code, page=page
+ )
+ else:
+ raise NotImplementedError(
+ "No proper locator found for hover action"
+ )
+ case ActionTypes.TYPE:
+ if action["element_id"]:
+ raise NotImplementedError
+ elif action["element_role"] and action["element_name"]:
+ element_role = int(action["element_role"])
+ element_name = action["element_name"]
+ nth = action["nth"]
+ await aexecute_focus(element_role, element_name, nth, page)
+ await aexecute_type(action["text"], page)
+ elif action["pw_code"]:
+ parsed_code = parse_playwright_code(action["pw_code"])
+ locator_code = parsed_code[:-1]
+ text = parsed_code[-1]["arguments"][0]
+ # [shuyanzh], don't support action args and kwargs now
+ await aexecute_playwright_type(
+ text=text, locator_code=locator_code, page=page
+ )
+ else:
+ raise NotImplementedError(
+ "No proper locator found for type action"
+ )
+
+ case ActionTypes.PAGE_FOCUS:
+ page = browser_ctx.pages[action["page_number"]]
+ await page.bring_to_front()
+ case ActionTypes.NEW_TAB:
+ page = await browser_ctx.new_page()
+ case ActionTypes.GO_BACK:
+ await page.go_back()
+ case ActionTypes.GO_FORWARD:
+ await page.go_forward()
+ case ActionTypes.GOTO_URL:
+ await page.goto(action["url"])
+ case ActionTypes.PAGE_CLOSE:
+ await page.close()
+ if len(browser_ctx.pages) > 0:
+ page = browser_ctx.pages[-1]
+ else:
+ page = await browser_ctx.new_page()
+
+ case ActionTypes.SELECT_OPTION:
+ if action["pw_code"]:
+ parsed_code = parse_playwright_code(action["pw_code"])
+ locator_code = parsed_code[:-1]
+ await aexecute_playwright_select_option(locator_code, page)
+ else:
+ raise NotImplementedError(
+ "No proper locator found for select option action"
+ )
+ case ActionTypes.CHECK:
+ if action["pw_code"]:
+ parsed_code = parse_playwright_code(action["pw_code"])
+ locator_code = parsed_code[:-1]
+ await aexecute_playwright_check(locator_code, page)
+ else:
+ raise NotImplementedError(
+ "No proper locator found for select option action"
+ )
+
+ case _:
+ raise ValueError(f"Unknown action type: {action_type}")
+
+ return page
+
+
+@beartype
+def parse_playwright_code(code: str) -> list[ParsedPlaywrightCode]:
+ # extract function calls
+ if not code.startswith("page."):
+ raise ValueError(
+ f'Playwright action must start with "page.", but got {code}'
+ )
+
+ regex = r"\.(?![^\(\)]*\))"
+ chain = re.split(regex, code)[1:]
+
+ parsed_chain = []
+
+ for item in chain:
+ tree = ast.parse(item)
+ funcs = []
+ for node in ast.walk(tree):
+ if isinstance(node, ast.Call):
+ function_name = node.func.id # type: ignore[attr-defined]
+ arguments = [
+ ast.literal_eval(arg) if isinstance(arg, ast.Str) else arg
+ for arg in node.args
+ ]
+ keywords = {
+ str(kw.arg): ast.literal_eval(kw.value)
+ for kw in node.keywords
+ }
+ funcs.append(
+ ParsedPlaywrightCode(
+ {
+ "function_name": function_name,
+ "arguments": arguments,
+ "keywords": keywords,
+ }
+ )
+ )
+
+ if len(funcs) != 1:
+ raise ValueError(f"Fail to parse {item} in {code}")
+
+ if (
+ funcs[0]["function_name"]
+ not in PLAYWRIGHT_LOCATORS + PLAYWRIGHT_ACTIONS
+ ):
+ raise ValueError(
+ f"Invalid playwright code {item}, ",
+ f"the function needs to be one of {PLAYWRIGHT_LOCATORS + PLAYWRIGHT_ACTIONS}",
+ )
+
+ parsed_chain.append(funcs[0])
+
+ last_action = parsed_chain[-1]
+ if last_action["function_name"] not in PLAYWRIGHT_ACTIONS:
+ raise ValueError(
+ f"Invalid playwright action {last_action},",
+ f"the action needs to be one of {PLAYWRIGHT_ACTIONS}",
+ )
+
+ return parsed_chain
+
+
+@beartype
+class ActionParsingError(Exception):
+ def __init__(self, message: str) -> None:
+ self.message = message
+ super().__init__(self.message)
+
+
+@beartype
+def create_playwright_action(playwright_code: str) -> Action:
+ """Main function to return individual playwright action"""
+ # get the last action
+ regex = r"\.(?![^\(\)]*\))"
+ action = re.split(regex, playwright_code)[-1].split("(")[0]
+ match action:
+ case "press":
+ p = r'press\((?:"|\')(.+?)(?:"|\')\)'
+ match = re.search(p, playwright_code)
+ if not match:
+ raise ActionParsingError(
+ f"Invalid press action, required to be page.press(KEY_COMB_STR)"
+ )
+ key_comb = match.group(1)
+ return create_key_press_action(key_comb=key_comb)
+ case "scroll":
+ direction = "up" if "up" in playwright_code else "down"
+ return create_scroll_action(direction=direction)
+ case "click":
+ return create_click_action(pw_code=playwright_code)
+ case "clear":
+ return create_clear_action(pw_code=playwright_code)
+ case "hover":
+ return create_hover_action(pw_code=playwright_code)
+ case "type" | "fill":
+ p = r'type|fill\((?:"|\')(.+?)(?:"|\')\)'
+ match = re.search(p, playwright_code)
+ if not match:
+ raise ActionParsingError(
+ f"Invalid type/fill action, required to be page.type(TEXT)"
+ )
+ text = match.group(1)
+ return create_type_action(text=text, pw_code=playwright_code)
+ case "select_option":
+ return create_select_option_action(pw_code=playwright_code)
+ case "check":
+ return create_check_action(pw_code=playwright_code)
+ case "goto":
+ p = r'goto\((?:"|\')(.+?)(?:"|\')\)'
+ match = re.search(p, playwright_code)
+ if not match:
+ raise ActionParsingError(
+ f"Invalid goto action, required to be page.goto(URL_STR)"
+ )
+ url = match.group(1)
+ return create_goto_url_action(url)
+ case "page_focus":
+ # get the page number
+ p = r"page_focus\((\d+)\)"
+ match = re.search(p, playwright_code)
+ if not match:
+ raise ActionParsingError("page focus requires a page number")
+ page_num = int(match.group(1))
+ return create_page_focus_action(page_num)
+ case "new_tab":
+ return create_new_tab_action()
+ case "go_back":
+ return create_go_back_action()
+ case "go_forward":
+ return create_go_forward_action()
+ case "page_close":
+ return create_page_close_action()
+ case "stop": # page.stop(answer)
+ p = r'stop\(?"(.+)?"\)'
+ match = re.search(p, playwright_code)
+ if not match:
+ answer = ""
+ else:
+ answer = match.group(1)
+ return create_stop_action(answer)
+
+ raise ActionParsingError(f"Unknown playwright action {action}")
+
+
+@beartype
+def create_id_based_action(action_str: str) -> Action:
+ """Main function to return individual id based action"""
+ action_str = action_str.strip()
+ if "[" in action_str:
+ action = action_str.split("[")[0].strip()
+ else:
+ actions = action_str.split()
+ if actions:
+ action = actions[0].strip()
+ else:
+ raise ActionParsingError(f"No action specified: {action_str}")
+ match action:
+ case "click":
+ match = re.search(r"click ?\[(\d+)\]", action_str)
+ if not match:
+ raise ActionParsingError(f"Invalid click action {action_str}")
+ element_id = match.group(1)
+ return create_click_action(element_id=element_id)
+ case "clear":
+ match = re.search(r"clear ?\[(\d+)\]", action_str)
+ if not match:
+ raise ActionParsingError(f"Invalid clear action {action_str}")
+ element_id = match.group(1)
+ return create_clear_action(element_id=element_id)
+ case "hover":
+ match = re.search(r"hover ?\[(\d+)\]", action_str)
+ if not match:
+ raise ActionParsingError(f"Invalid hover action {action_str}")
+ element_id = match.group(1)
+ return create_hover_action(element_id=element_id)
+ case "type":
+ # add default enter flag
+ if not (action_str.endswith("[0]") or action_str.endswith("[1]")):
+ action_str += " [1]"
+
+ match = re.search(
+ r"type ?\[(\d+)\] ?\[(.+)\] ?\[(\d+)\]", action_str
+ )
+ if not match:
+ raise ActionParsingError(f"Invalid type action {action_str}")
+ element_id, text, enter_flag = (
+ match.group(1),
+ match.group(2),
+ match.group(3),
+ )
+ if enter_flag == "1":
+ text += "\n"
+ return create_type_action(text=text, element_id=element_id)
+ case "press":
+ match = re.search(r"press ?\[(.+)\]", action_str)
+ if not match:
+ raise ActionParsingError(f"Invalid press action {action_str}")
+ key_comb = match.group(1)
+ return create_key_press_action(key_comb=key_comb)
+ case "scroll":
+ # up or down
+ match = re.search(r"scroll ?\[?(up|down)\]?", action_str)
+ if not match:
+ raise ActionParsingError(f"Invalid scroll action {action_str}")
+ direction = match.group(1)
+ return create_scroll_action(direction=direction)
+ case "goto":
+ match = re.search(r"goto ?\[(.+)\]", action_str)
+ if not match:
+ raise ActionParsingError(f"Invalid goto action {action_str}")
+ url = match.group(1)
+ return create_goto_url_action(url=url)
+ case "new_tab":
+ return create_new_tab_action()
+ case "go_back":
+ return create_go_back_action()
+ case "go_forward":
+ return create_go_forward_action()
+ case "tab_focus":
+ match = re.search(r"tab_focus ?\[(\d+)\]", action_str)
+ if not match:
+ raise ActionParsingError(
+ f"Invalid tab_focus action {action_str}"
+ )
+ page_number = int(match.group(1))
+ return create_page_focus_action(page_number)
+ case "close_tab":
+ return create_page_close_action()
+ case "stop": # stop answer
+ match = re.search(r"stop ?\[(.+)\]", action_str)
+ if not match: # some tasks don't require an answer
+ answer = ""
+ else:
+ answer = match.group(1)
+ return create_stop_action(answer)
+
+ raise ActionParsingError(f"Invalid action {action_str}")
+
+
+@beartype
+def create_webrl_id_based_action(action_str: str) -> Action:
+ """Main function to return individual webrl id based action"""
+ import ast
+ def remove_comments(code):
+ # 按行分割代码
+ for key in ['exit(','do(','go_backward(']:
+ if key in code:
+ return key + code.split(key)[-1]
+ lines = code.split('\n')
+ for i, line in enumerate(lines):
+ if line.strip().startswith('#'):
+ # 跳过注释行
+ continue
+ else:
+ # 返回非注释行及其后面的部分
+ return '\n'.join(lines[i:])
+ return ''
+
+ def parse_function_call(expression):
+ expression = remove_comments(expression)
+ # 将字符串解析为 AST
+ expression = expression.strip()
+ tree = ast.parse(expression, mode='eval')
+ # 提取函数名称
+ func_call = tree.body
+ if not isinstance(func_call, ast.Call):
+ return {
+ "operation": expression,
+ }
+ func_name = func_call.func.id
+ result = {
+ "operation": func_name,
+ }
+ # 提取参数
+ args = func_call.args
+ kwargs = func_call.keywords
+ for kw in kwargs:
+ if func_name == "do" and kw.arg == "action":
+ result["action"] = ast.literal_eval(kw.value)
+ # elif func_name == "do" and kw.arg == "argument":
+ # result["argument"] = ast.literal_eval(kw.value)
+ else:
+ if "kwargs" not in result:
+ result["kwargs"] = {}
+ if kw.arg == "element":
+ try:
+ # 解析元素的内部函数
+ inner_func = kw.value
+ if isinstance(inner_func, ast.Call) and inner_func.func.id == 'find_element_by_instruction':
+ for inner_kw in inner_func.keywords:
+ if inner_kw.arg == "instruction":
+ result["kwargs"]["instruction"] = ast.literal_eval(inner_kw.value)
+ else:
+ result["kwargs"][kw.arg] = ast.literal_eval(inner_func)
+ except Exception:
+ result["kwargs"][kw.arg] = ast.literal_eval(kw.value)
+ else:
+ result["kwargs"][kw.arg] = ast.literal_eval(kw.value)
+ return result
+
+ action_str = action_str.strip()
+ try:
+ action = parse_function_call(action_str)
+ except Exception as e:
+ raise ActionParsingError(f"No action specified: {action_str}")
+ operation = action["operation"]
+ match operation:
+ case "do":
+ action_type = action["action"].lower()
+ match action_type:
+ case "press enter":
+ return create_key_press_action(key_comb='enter')
+ case "scroll up":
+ return create_scroll_action(direction='up')
+ case "scroll down":
+ return create_scroll_action(direction='down')
+ case "click":
+ element_id = action["kwargs"]["element"]
+ return create_click_action(element_id=element_id)
+ case "type":
+ element_id = action["kwargs"]["element"]
+ text = action["kwargs"]["argument"]
+ return create_type_action_webrl(text=text, element_id=element_id)
+ case "hover":
+ element_id = action["kwargs"]["element"]
+ return create_hover_action(element_id=element_id)
+ case "select dropdown option":
+ element_id = action["kwargs"]["element"]
+ argument = action["kwargs"]["argument"]
+ return create_select_dropdown_option_action(argument=argument, element_id=element_id)
+ case "go forward":
+ return create_go_forward_action()
+ case "go backward":
+ return create_go_back_action()
+ case "search":
+ element_id = action["kwargs"]["element"]
+ text = action["kwargs"]["argument"]
+ return create_search_action(text=text, element_id=element_id)
+ case "exit": # stop answer
+ answer = action['kwargs']['message']
+ return create_stop_action(answer)
+
+ raise ActionParsingError(f"Invalid action {action_str}")
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/agent.py b/VAB-WebArena-Lite/new/agent.py
index 461ba7e..96caf17 100644
--- a/VAB-WebArena-Lite/new/agent.py
+++ b/VAB-WebArena-Lite/new/agent.py
@@ -14,6 +14,7 @@ from browser_env.actions import (
create_id_based_action,
create_none_action,
create_playwright_action,
+ create_webrl_id_based_action
)
from browser_env.utils import Observation, StateInfo
from llms import (
@@ -108,13 +109,15 @@ class PromptAgent(Agent):
lm_config: lm_config.LMConfig,
prompt_constructor: PromptConstructor,
captioning_fn = None,
+ planner_ip = None
) -> None:
super().__init__()
self.lm_config = lm_config
self.prompt_constructor = prompt_constructor
self.action_set_tag = action_set_tag
self.captioning_fn = captioning_fn
-
+ self.planner_ip = planner_ip
+
# Check if the model is multimodal.
if ("gemini" in lm_config.model or "gpt-4" in lm_config.model and "vision" in lm_config.model or lm_config.provider in ["api", "finetune"]) and type(prompt_constructor) == MultimodalCoTPromptConstructor:
self.multimodal_inputs = True
@@ -165,7 +168,10 @@ class PromptAgent(Agent):
lm_config = self.lm_config
n = 0
while True:
- response = call_llm(lm_config, prompt)
+ if self.planner_ip is not None and self.planner_ip != "":
+ response = call_llm(lm_config, prompt, 'EMPTY', self.planner_ip)
+ else:
+ response = call_llm(lm_config, prompt)
force_prefix = self.prompt_constructor.instruction[
"meta_data"
].get("force_prefix", "")
@@ -183,6 +189,8 @@ class PromptAgent(Agent):
action = create_playwright_action(parsed_response)
elif self.action_set_tag == "som":
action = create_id_based_action(parsed_response)
+ elif self.action_set_tag == 'webrl_id':
+ action = create_webrl_id_based_action(parsed_response)
else:
raise ValueError(
f"Unknown action type {self.action_set_tag}"
@@ -218,7 +226,8 @@ def construct_agent(args: argparse.Namespace, captioning_fn=None) -> Agent:
action_set_tag=args.action_set_tag,
lm_config=llm_config,
prompt_constructor=prompt_constructor,
- captioning_fn=captioning_fn
+ captioning_fn=captioning_fn,
+ planner_ip=args.planner_ip
)
else:
raise NotImplementedError(
diff --git a/VAB-WebArena-Lite/new/envs.py b/VAB-WebArena-Lite/new/envs.py
new file mode 100644
index 0000000..2dfe7dc
--- /dev/null
+++ b/VAB-WebArena-Lite/new/envs.py
@@ -0,0 +1,319 @@
+import json
+import os
+import re
+import subprocess
+import time
+from collections import defaultdict
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Any, Union
+
+import numpy as np
+import numpy.typing as npt
+import requests
+from beartype import beartype
+from gymnasium import Env
+from gymnasium.spaces import Box, Text
+from playwright.sync_api import (
+ CDPSession,
+ Page,
+ Playwright,
+ ViewportSize,
+ expect,
+ sync_playwright,
+)
+
+DATASET = os.environ["DATASET"]
+if DATASET == "visualwebarena":
+ from browser_env.env_config import (
+ CLASSIFIEDS,
+ CLASSIFIEDS_RESET_TOKEN,
+ )
+
+from .actions import Action, execute_action, get_action_space, execute_action_webrl
+from .processors import ObservationHandler, ObservationMetadata
+from .utils import (
+ AccessibilityTree,
+ DetachedPage,
+ Observation,
+ png_bytes_to_numpy,
+)
+
+
+@dataclass
+class PlaywrightScript:
+ function: str # goto, get_by_role
+ destination: str # https://www.google.com/, combobox
+ name: str | None = None # Search, Avatar 2009
+ operation: str | None = None # click, fill, press
+ value: str | None = None # avatar movie, Enter
+
+
+def parse_action(action: str) -> PlaywrightScript:
+ splitted = action.strip().split(" ")
+ assert len(splitted) >= 2
+ match splitted[:2]:
+ case ["goto", url]:
+ assert len(splitted) == 2
+ return PlaywrightScript("goto", url)
+ case ["get_by_role", destination]:
+ assert len(splitted) >= 4
+ match splitted[2:]:
+ case [name, operation]:
+ return PlaywrightScript(
+ "get_by_role", destination, name, operation
+ )
+ case [name, operation, value]:
+ return PlaywrightScript(
+ "get_by_role", destination, name, operation, value
+ )
+ case _:
+ raise ValueError("Invalid action")
+ case _:
+ raise ValueError(f"Invalid action {action}")
+
+
+class ScriptBrowserEnv(Env[dict[str, Observation], Action]):
+ """
+ The goal of this environment is to produce a prototype of a browser environment.
+ In the end, we want to support a fully configurable browser environment with wide
+ range of action spaces and observation spaces, both structured and unstructured.
+ But in this prototype, we just support action space specified by Playwright script,
+ and observation space is the html content of the page.
+ """
+
+ @beartype
+ def __init__(
+ self,
+ max_page_length: int = 8192,
+ headless: bool = True,
+ slow_mo: int = 0,
+ observation_type: str = "html",
+ current_viewport_only: bool = False,
+ viewport_size: ViewportSize = {"width": 1280, "height": 720},
+ save_trace_enabled: bool = False,
+ sleep_after_execution: float = 0.0,
+ captioning_fn=None,
+ ):
+ # TODO: make Space[Action] = ActionSpace
+ self.action_space = get_action_space() # type: ignore[assignment]
+ self.headless = headless
+ self.slow_mo = slow_mo
+ self.current_viewport_only = current_viewport_only
+ self.reset_finished = False
+ self.viewport_size = viewport_size
+ self.save_trace_enabled = save_trace_enabled
+ self.sleep_after_execution = sleep_after_execution
+
+ match observation_type:
+ case "html" | "accessibility_tree" | "accessibility_tree_with_captioner" | "webrl":
+ self.text_observation_type = observation_type
+ self.image_observation_type = ""
+ self.main_observation_type = "text"
+ case "image":
+ self.image_observation_type = observation_type
+ self.text_observation_type = "" # type: ignore[assignment]
+ self.main_observation_type = "image"
+ case "image_som":
+ self.image_observation_type = observation_type
+ self.text_observation_type = observation_type # type: ignore[assignment]
+ self.main_observation_type = "image"
+ case _:
+ raise ValueError(
+ f"Unsupported observation type: {observation_type}"
+ )
+
+ self.observation_handler = ObservationHandler(
+ self.main_observation_type,
+ self.text_observation_type,
+ self.image_observation_type,
+ self.current_viewport_only,
+ self.viewport_size,
+ captioning_fn,
+ )
+
+ self.observation_space = (
+ self.observation_handler.get_observation_space()
+ )
+
+ @beartype
+ def setup(self, config_file: Path | None = None) -> None:
+ self.context_manager = sync_playwright()
+ self.playwright = self.context_manager.__enter__()
+ self.browser = self.playwright.chromium.launch(
+ headless=self.headless, slow_mo=self.slow_mo
+ )
+
+ if config_file:
+ with open(config_file, "r") as f:
+ instance_config = json.load(f)
+ else:
+ instance_config = {}
+
+ # Reset site if needed. Currently only supported for Classifieds.
+ # TODO(jykoh): Add reset functionality for Shopping/Reddit.
+ if instance_config.get("require_reset", False):
+ if "classifieds" in instance_config["sites"]:
+ # Send POST request to __CLASSIFIEDS__/index.php?page=reset with token=CLASSIFIEDS_TOKEN
+ response = requests.post(
+ f"{CLASSIFIEDS}/index.php?page=reset",
+ data={"token": CLASSIFIEDS_RESET_TOKEN},
+ )
+
+ # Check if the request was successful
+ if response.status_code == 200:
+ print("Reset Classifieds site.")
+ else:
+ print(
+ "Failed to reset Classifieds site:",
+ response.status_code,
+ )
+ else:
+ print(
+ "WARNING: Reset is not supported for this site. Please manually reset the site."
+ )
+
+ storage_state = instance_config.get("storage_state", None)
+ start_url = instance_config.get("start_url", None)
+ geolocation = instance_config.get("geolocation", None)
+
+ # Use custom viewport size if specified in the config, otherwise use the default.
+ viewport_size = self.viewport_size.copy()
+ viewport_size.update(instance_config.get("viewport_size", {}))
+ self.observation_handler.viewport_size = viewport_size
+
+ self.context = self.browser.new_context(
+ viewport=viewport_size,
+ storage_state=storage_state,
+ geolocation=geolocation,
+ device_scale_factor=1,
+ )
+ if self.save_trace_enabled:
+ self.context.tracing.start(screenshots=True, snapshots=True)
+
+ if start_url:
+ start_urls = start_url.split(" |AND| ")
+ for url in start_urls:
+ page = self.context.new_page()
+ if self.text_observation_type in [
+ "accessibility_tree",
+ "accessibility_tree_with_captioner",
+ ]:
+ client = page.context.new_cdp_session(page)
+ client.send("Accessibility.enable")
+ client.detach()
+ page.goto(url)
+ # set the first page as the current page
+ self.page = self.context.pages[0]
+ self.page.bring_to_front()
+ else:
+ self.page = self.context.new_page()
+ if self.text_observation_type in [
+ "accessibility_tree",
+ "accessibility_tree_with_captioner",
+ ]:
+ client = self.page.context.new_cdp_session(self.page)
+ client.send("Accessibility.enable")
+ client.detach()
+
+ def _get_obs(self) -> dict[str, Observation]:
+ obs = self.observation_handler.get_observation(self.page)
+ return obs
+
+ def _get_obs_metadata(self) -> dict[str, ObservationMetadata]:
+ metadata = self.observation_handler.get_observation_metadata()
+ return metadata
+
+ @beartype
+ def reset(
+ self,
+ *,
+ seed: int | None = None,
+ options: dict[str, str] | None = None,
+ ) -> tuple[dict[str, Observation], dict[str, Any]]:
+ """
+ Reset the environment.
+ :param options: options for the environment. The current supported options are:
+ - "storage_state": the storage state of the browser. It is a file path to a json file.
+ """
+ super().reset(seed=seed, options=options)
+ if self.reset_finished:
+ self.context_manager.__exit__()
+
+ if options is not None and "config_file" in options:
+ config_file = Path(options["config_file"])
+ if config_file.exists():
+ self.setup(config_file=config_file)
+ else:
+ raise ValueError(f"Config file {config_file} does not exist.")
+ else:
+ self.setup()
+ self.reset_finished = True
+ timeout_in_ms = 120000
+ self.page.set_default_timeout(timeout_in_ms)
+ self.page.set_default_navigation_timeout(timeout_in_ms)
+ self.page.wait_for_timeout(int(self.sleep_after_execution * 1000))
+
+ observation = self._get_obs()
+ observation_metadata = self._get_obs_metadata()
+ info = {
+ "page": DetachedPage(self.page.url, ""),
+ "fail_error": "",
+ "observation_metadata": observation_metadata,
+ }
+
+ return (observation, info)
+
+ def save_trace(self, trace_path: str | Path) -> None:
+ if self.save_trace_enabled:
+ self.context.tracing.stop(path=trace_path)
+
+ def close(self) -> None:
+ if self.reset_finished:
+ self.context_manager.__exit__()
+
+ def step(
+ self, action: Action
+ ) -> tuple[dict[str, Observation], float, bool, bool, dict[str, Any]]:
+ if not self.reset_finished:
+ raise RuntimeError("Call reset first before calling step.")
+
+ success = False
+ fail_error = ""
+ try:
+ if self.text_observation_type == 'webrl':
+ self.page = execute_action_webrl(
+ action,
+ self.page,
+ self.context,
+ self.observation_handler.action_processor,
+ self.sleep_after_execution,
+ )
+ else:
+ self.page = execute_action(
+ action,
+ self.page,
+ self.context,
+ self.observation_handler.action_processor,
+ self.sleep_after_execution,
+ )
+ success = True
+ except Exception as e:
+ fail_error = str(e)
+
+ observation = self._get_obs()
+ observation_metadata = self._get_obs_metadata()
+
+ info = {
+ "page": DetachedPage(self.page.url, self.page.content()),
+ "fail_error": fail_error,
+ "observation_metadata": observation_metadata,
+ }
+ msg = (
+ observation,
+ float(success), # reward
+ False, # terminated
+ False, # truncated
+ info,
+ )
+ return msg
diff --git a/VAB-WebArena-Lite/new/evaluators.py b/VAB-WebArena-Lite/new/evaluators.py
index 85b5363..aac511b 100644
--- a/VAB-WebArena-Lite/new/evaluators.py
+++ b/VAB-WebArena-Lite/new/evaluators.py
@@ -172,6 +172,10 @@ class StringEvaluator(Evaluator):
# prevent false positive (e.g, 0)
if len(word_tokenize(clean_ref)) == 1:
tok_pred = word_tokenize(clean_pred)
+ for token in tok_pred:
+ if '/' in token:
+ sub_tokens = token.split('/')
+ tok_pred.extend(sub_tokens)
return float(clean_ref in tok_pred)
else:
return float(clean_ref in clean_pred)
diff --git a/VAB-WebArena-Lite/new/helper_functions_browser.py b/VAB-WebArena-Lite/new/helper_functions_browser.py
new file mode 100644
index 0000000..cb902ba
--- /dev/null
+++ b/VAB-WebArena-Lite/new/helper_functions_browser.py
@@ -0,0 +1,239 @@
+import base64
+import io
+import json
+import re
+from pathlib import Path
+from typing import Any
+
+from PIL import Image
+
+from agent.prompts import *
+from browser_env import (
+ Action,
+ ActionTypes,
+ ObservationMetadata,
+ StateInfo,
+ action2str,
+)
+
+HTML_TEMPLATE = """
+
+
+
+
+
+
+ {body}
+
+
+"""
+
+
+def get_render_action(
+ action: Action,
+ observation_metadata: dict[str, ObservationMetadata],
+ action_set_tag: str,
+) -> str:
+ """Parse the predicted actions for rendering purpose. More comprehensive information"""
+ match action_set_tag:
+ case "id_accessibility_tree":
+ text_meta_data = observation_metadata["text"]
+ if action["element_id"] in text_meta_data["obs_nodes_info"]:
+ node_content = text_meta_data["obs_nodes_info"][
+ action["element_id"]
+ ]["text"]
+ else:
+ node_content = "No match found"
+
+ action_str = f"{action['raw_prediction']}
"
+ action_str += f""
+ action_str += f"{action2str(action, action_set_tag, node_content)}
"
+
+ case "som":
+ text_meta_data = observation_metadata["text"]
+ if action["element_id"] in text_meta_data["obs_nodes_info"]:
+ node_content = text_meta_data["obs_nodes_info"][
+ action["element_id"]
+ ]
+ else:
+ node_content = "No match found"
+ action_str = f"{action['raw_prediction']}
"
+ action_str += f""
+ action_str += f"{action2str(action, action_set_tag, node_content)}
"
+
+ case "playwright":
+ action_str = action["pw_code"]
+ case "webrl_id":
+ action_str = action["raw_prediction"]
+ case _:
+ raise ValueError(
+ f"Unknown action type {action['action_type'], action_set_tag}"
+ )
+ return action_str
+
+
+def get_action_description(
+ action: Action,
+ observation_metadata: dict[str, ObservationMetadata],
+ action_set_tag: str,
+ prompt_constructor: PromptConstructor | None,
+) -> str:
+ """Generate the text version of the predicted actions to store in action history for prompt use.
+ May contain hint information to recover from the failures"""
+
+ match action_set_tag:
+ case "id_accessibility_tree":
+ text_meta_data = observation_metadata["text"]
+ if action["action_type"] in [
+ ActionTypes.CLICK,
+ ActionTypes.HOVER,
+ ActionTypes.TYPE,
+ ]:
+ action_name = str(action["action_type"]).split(".")[1].lower()
+ if action["element_id"] in text_meta_data["obs_nodes_info"]:
+ node_content = text_meta_data["obs_nodes_info"][
+ action["element_id"]
+ ]["text"]
+ node_content = " ".join(node_content.split()[1:])
+ action_str = action2str(
+ action, action_set_tag, node_content
+ )
+ else:
+ action_str = f"Attempt to perfom \"{action_name}\" on element \"[{action['element_id']}]\" but no matching element found. Please check the observation more carefully."
+ else:
+ if (
+ action["action_type"] == ActionTypes.NONE
+ and prompt_constructor is not None
+ ):
+ action_splitter = prompt_constructor.instruction[
+ "meta_data"
+ ]["action_splitter"]
+ action_str = f'The previous prediction you issued was "{action["raw_prediction"]}". However, the format was incorrect. Ensure that the action is wrapped inside a pair of {action_splitter} and enclose arguments within [] as follows: {action_splitter}action [arg] ...{action_splitter}.'
+ else:
+ action_str = action2str(action, action_set_tag, "")
+
+ case "som":
+ text_meta_data = observation_metadata["image"]
+ if action["action_type"] in [
+ ActionTypes.CLICK,
+ ActionTypes.HOVER,
+ ActionTypes.TYPE,
+ ]:
+ action_name = str(action["action_type"]).split(".")[1].lower()
+ if action["element_id"] in text_meta_data["obs_nodes_info"]:
+ action_str = action2str(action, action_set_tag, "")
+ else:
+ print(
+ 'action["element_id"], text_meta_data["obs_nodes_info"]',
+ action["element_id"],
+ text_meta_data["obs_nodes_info"],
+ )
+ action_str = f"Attempt to perfom \"{action_name}\" on element \"[{action['element_id']}]\" but no matching element found. Please check the observation more carefully."
+ else:
+ if (
+ action["action_type"] == ActionTypes.NONE
+ and prompt_constructor is not None
+ ):
+ action_splitter = prompt_constructor.instruction[
+ "meta_data"
+ ]["action_splitter"]
+ action_str = f'The previous prediction you issued was "{action["raw_prediction"]}". However, the format was incorrect. Ensure that the action is wrapped inside a pair of {action_splitter} and enclose arguments within [] as follows: {action_splitter}action [arg] ...{action_splitter}.'
+ else:
+ action_str = action2str(action, action_set_tag, "")
+
+ case "playwright":
+ action_str = action["pw_code"]
+
+ case "webrl_id":
+ action_str = action["raw_prediction"]
+
+ case _:
+ raise ValueError(f"Unknown action type {action['action_type']}")
+
+ return action_str
+
+
+class RenderHelper(object):
+ """Helper class to render text and image observations and meta data in the trajectory"""
+
+ def __init__(
+ self, config_file: str, result_dir: str, action_set_tag: str
+ ) -> None:
+ with open(config_file, "r") as f:
+ _config = json.load(f)
+ _config_str = ""
+ for k, v in _config.items():
+ _config_str += f"{k}: {v}\n"
+ _config_str = f"{_config_str}
\n"
+ task_id = _config["task_id"]
+
+ self.action_set_tag = action_set_tag
+
+ self.render_file = open(
+ Path(result_dir) / f"render_{task_id}.html", "a+", encoding="utf-8"
+ )
+ self.render_file.truncate(0)
+ # write init template
+ self.render_file.write(HTML_TEMPLATE.format(body=f"{_config_str}"))
+ self.render_file.read()
+ self.render_file.flush()
+
+ def render(
+ self,
+ action: Action,
+ state_info: StateInfo,
+ meta_data: dict[str, Any],
+ render_screenshot: bool = False,
+ ) -> None:
+ """Render the trajectory"""
+ # text observation
+ observation = state_info["observation"]
+ text_obs = observation["text"]
+ info = state_info["info"]
+ new_content = f"New Page
\n"
+ new_content += f"\n"
+ new_content += f"{text_obs}
\n"
+
+ if render_screenshot:
+ # image observation
+ img_obs = observation["image"]
+ image = Image.fromarray(img_obs)
+ byte_io = io.BytesIO()
+ image.save(byte_io, format="PNG")
+ byte_io.seek(0)
+ image_bytes = base64.b64encode(byte_io.read())
+ image_str = image_bytes.decode("utf-8")
+ new_content += f"

\n"
+
+ # meta data
+ new_content += f"
{meta_data['action_history'][-1]}
\n"
+
+ # action
+ action_str = get_render_action(
+ action,
+ info["observation_metadata"],
+ action_set_tag=self.action_set_tag,
+ )
+ # with yellow background
+ action_str = f"
{action_str}
"
+ new_content += f"{action_str}\n"
+
+ # add new content
+ self.render_file.seek(0)
+ html = self.render_file.read()
+ html_body = re.findall(r"(.*?)", html, re.DOTALL)[0]
+ html_body += new_content
+
+ html = HTML_TEMPLATE.format(body=html_body)
+ self.render_file.seek(0)
+ self.render_file.truncate()
+ self.render_file.write(html)
+ self.render_file.flush()
+
+ def close(self) -> None:
+ self.render_file.close()
diff --git a/VAB-WebArena-Lite/new/helper_functions.py b/VAB-WebArena-Lite/new/helper_functions_eval.py
similarity index 100%
rename from VAB-WebArena-Lite/new/helper_functions.py
rename to VAB-WebArena-Lite/new/helper_functions_eval.py
diff --git a/VAB-WebArena-Lite/new/html_tools/__init__.py b/VAB-WebArena-Lite/new/html_tools/__init__.py
new file mode 100755
index 0000000..6aac581
--- /dev/null
+++ b/VAB-WebArena-Lite/new/html_tools/__init__.py
@@ -0,0 +1,7 @@
+from .identifier import IdentifierTool
+from .prompt import HtmlPrompt
+from .html_parser import HtmlParser
+
+from .utils import print_html_object
+from .configs import basic_attrs, mind2web_keep_attrs
+from .fetch import get_parsed_html
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/html_tools/configs/__init__.py b/VAB-WebArena-Lite/new/html_tools/configs/__init__.py
new file mode 100755
index 0000000..b900a67
--- /dev/null
+++ b/VAB-WebArena-Lite/new/html_tools/configs/__init__.py
@@ -0,0 +1,3 @@
+from .html_prompt import prompts
+from .config import basic_attrs, mind2web_keep_attrs, miniwob_attrs
+from .config import config_meta
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/html_tools/configs/config.py b/VAB-WebArena-Lite/new/html_tools/configs/config.py
new file mode 100755
index 0000000..d7da96c
--- /dev/null
+++ b/VAB-WebArena-Lite/new/html_tools/configs/config.py
@@ -0,0 +1,56 @@
+basic_attrs = [
+ 'title',
+ 'value',
+ 'type',
+ 'placeholder',
+ 'selected',
+ 'data-value',
+ 'data-text',
+ 'data-testid',
+ 'data-label',
+ 'data-bbox',
+ 'data-status'
+]
+
+mind2web_keep_attrs = [
+ 'alt',
+ 'aria_description',
+ 'aria_label',
+ 'aria_role',
+ 'input_checked',
+ 'input_value',
+ 'label',
+ 'name',
+ 'option_selected',
+ 'placeholder',
+ 'role',
+ 'text_value',
+ 'title',
+ 'type',
+ 'value',
+]
+
+miniwob_attrs = [
+ 'id',
+ 'type',
+ 'value',
+]
+
+config_meta = """
+======= Configs =======
+Columns:
+ - id: {id_attr}
+ - label: {label_attr}
+Position: {use_position}
+ - window: {window_size}
+ - rect_dict: {rect}
+Keep:
+ - parents: {parent_chain}
+ - attrs: {keep_attrs}
+ - elems: {keep_elem}
+ - obs_elem: {obs_elem}
+Generator:
+ - prompt: {prompt_name}
+ - label: {identifier_name}
+========================
+"""
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/html_tools/configs/html_prompt.py b/VAB-WebArena-Lite/new/html_tools/configs/html_prompt.py
new file mode 100755
index 0000000..904d021
--- /dev/null
+++ b/VAB-WebArena-Lite/new/html_tools/configs/html_prompt.py
@@ -0,0 +1,22 @@
+refine_prompt = {
+ 'dom': '<{tag}{label}|{attr}{content}{subtree} >',
+ 'label': '[{label}]',
+ 'attr': '{attr}',
+ 'attr_splitter': '; ',
+ 'subtree_splitter': ' ',
+}
+
+xml_prompt = {
+ 'dom': '<{tag}{label}{attr}>{content}{subtree} {tag}>',
+ 'label': ' id="{label}"',
+ 'attr': '{key}="{attr}"',
+ 'attr_splitter': ' ',
+ 'subtree_splitter': ' ',
+}
+
+prompts = {
+ 'refine': refine_prompt,
+ 'xml': xml_prompt,
+ 'new_data': refine_prompt,
+}
+
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/html_tools/fetch.py b/VAB-WebArena-Lite/new/html_tools/fetch.py
new file mode 100755
index 0000000..25d5637
--- /dev/null
+++ b/VAB-WebArena-Lite/new/html_tools/fetch.py
@@ -0,0 +1,108 @@
+import os
+import json
+import base64
+from .html_parser import HtmlParser
+from .configs import basic_attrs
+from .scripts import *
+
+def get_window(page):
+ x = page.evaluate("window.scrollX")
+ y = page.evaluate("window.scrollY")
+ w = page.evaluate("window.innerWidth")
+ h = page.evaluate("window.innerHeight")
+ return (x, y, w, h)
+
+def modify_page(page):
+ page.wait_for_timeout(500)
+
+ try:
+ page.evaluate(remove_id_script)
+ except:
+ pass
+
+ packet = {
+ "raw_html": page.evaluate("document.documentElement.outerHTML"),
+ "window": get_window(page)
+ }
+
+ page.evaluate(prepare_script)
+ page.wait_for_timeout(100)
+
+ img_bytes = page.screenshot(path="debug_info/screenshot_raw.png")
+ raw_image = base64.b64encode(img_bytes).decode()
+
+ page.evaluate(clickable_checker_script)
+ page.wait_for_timeout(50)
+
+ # get all clickable elements
+ start_id = 0
+ items, start_id = page.evaluate(label_script, {
+ "selector": ".possible-clickable-element",
+ "startIndex": start_id
+ })
+ page.wait_for_timeout(50)
+
+ # mark our own labels and get the images
+ items = page.evaluate(label_marker_script, items)
+ page.wait_for_timeout(100)
+ img_bytes = page.screenshot(path="debug_info/marked.png")
+ marked_image = base64.b64encode(img_bytes).decode()
+
+ # remove markers on the page
+ page.evaluate(remove_label_mark_script)
+
+ packet.update({
+ "raw_image": raw_image,
+ "marked_image": marked_image,
+ "modified_html": page.evaluate("document.documentElement.outerHTML")
+ })
+
+ # element_info, include "all_elements" and "clickable_elements"
+ element_info = page.evaluate(element_info_script)
+ page.wait_for_timeout(100)
+ packet.update(element_info)
+ return packet
+
+def save_debug_info(packet):
+ with open("debug_info/raw.html", "w") as f:
+ f.write(packet["modified_html"])
+ with open("debug_info/parsed.html", "w") as f:
+ f.write(packet["html"])
+ with open("debug_info/all_element.json", "w") as f:
+ f.write(json.dumps(packet["all_elements"]))
+
+def get_parsed_html(page):
+ if not os.path.exists("debug_info"):
+ os.makedirs("debug_info")
+
+ print("parsing html...")
+
+ packet = modify_page(page)
+ raw_html = packet["modified_html"]
+
+ args = {
+ "use_position": True,
+ "rect_dict": {},
+ "window_size": packet["window"],
+ "id-attr": "data-backend-node-id",
+ "label_attr": "data-label-id",
+ "label_generator": "order",
+ "regenerate_label": False,
+ "attr_list": basic_attrs,
+ "prompt": "xml",
+ "dataset": "pipeline"
+ }
+
+ hp = HtmlParser(raw_html, args)
+ res = hp.parse_tree()
+ page_html = res.get("html", "")
+
+ packet["html"] = page_html
+
+ # for debug
+ save_debug_info(packet)
+
+ print("parsing finished.")
+
+ return packet
+
diff --git a/VAB-WebArena-Lite/new/html_tools/html_parser.py b/VAB-WebArena-Lite/new/html_tools/html_parser.py
new file mode 100755
index 0000000..080342f
--- /dev/null
+++ b/VAB-WebArena-Lite/new/html_tools/html_parser.py
@@ -0,0 +1,447 @@
+from lxml import html
+import time, copy, random
+import json, re, os
+
+from .identifier import IdentifierTool
+from .prompt import HtmlPrompt
+from .configs import config_meta
+from .utils import get_xpath_top_down, rect2tuple
+
+class HtmlParser():
+ def __init__(self, ctx: str, args: dict[str]={}) -> None:
+ stt = time.time()
+ self.dom_tree = self.ctx2tree(ctx)
+ # tool related
+ self.bids2label = {}
+ self.bids2xpath = {}
+ self.used_labels = {}
+
+ # parse args
+ self.parse_args(args)
+ self.init_time = time.time() - stt
+
+ def parse_args(self, args: dict[str]={}) -> None:
+ def attr_check(attr, type_model='str'):
+ if attr is None:
+ return False
+ attr_type = type(attr)
+ if attr_type != type(type_model):
+ return False
+ if attr_type == type('str') and len(attr) == 0:
+ return False
+ return True
+
+ args = {} if args is None else args
+
+ # [Position] use_pos: False -> use full page, otherwise use window_size
+ dataset = args.get('dataset', '')
+ use_position = args.get('use_position', False)
+ window_size = args.get('window_size', None)
+ rect = args.get('rect_dict', None)
+ if use_position:
+ if not attr_check(window_size, ()):
+ raise ValueError('window_size must be set when use_position is True')
+ if not attr_check(rect, {}):
+ raise ValueError('rect_dict must be set when use_position is True')
+
+ if not attr_check(rect, {}):
+ rect = {}
+
+ # [Label] for vimium is temp_clickable_label, otherwise keep all of it
+ label_attr = args.get('label_attr', '')
+ get_new_label = args.get('regenerate_label', False)
+ label_method = args.get('label_generator', None)
+ regen_label = not attr_check(label_method)
+
+ # [id] for mind2web is backend_node_id, for normal website use our method
+ id_attr = args.get('id_attr', '')
+ regen_id = not attr_check(id_attr)
+
+ if regen_id:
+ id_attr = 'temp_id'
+
+ # [attributes]
+ keep_attrs = args.get('attr_list', [])
+ if not attr_check(keep_attrs, []):
+ keep_attrs = []
+
+ # [Tags] for clickable elem, keep: must keep, obs: keep if follow specific rule
+ parent_chain = args.get('parent_chain', False)
+ keep_elem = args.get('keep_elem', [])
+ obs_elem = args.get('obs_elem', [])
+
+ # sanity check
+ self.set_args(use_position, window_size, rect, label_attr, id_attr, keep_attrs, keep_elem, obs_elem, parent_chain, get_new_label, dataset)
+
+ # [Prompt]
+ prompt = args.get('prompt', None)
+ self.prompt = HtmlPrompt(prompt)
+
+ # traverse and get special data
+ if regen_id or regen_label:
+ self.mark_id()
+
+ if get_new_label:
+ self.used_labels = {}
+
+ self.identifier = IdentifierTool(label_method, self.used_labels)
+
+ def set_args(self, use_position: bool=False, window_size: tuple=(), rect_dict: dict[str]={}, label_attr: str='',
+ id_attr: str='', keep_attrs: list[str]=[], keep_elem: list[str]=[], obs_elem: list[str]=[],
+ parent_chain: bool=False, get_new_label: bool=False, dataset: str='') -> None:
+
+ self.use_position = use_position
+ self.window_size = window_size
+ self.rect = rect_dict
+ self.label_attr = label_attr
+ self.id_attr = id_attr
+ self.keep_attrs = keep_attrs
+ self.keep = keep_elem
+ self.obs = obs_elem
+ self.parent_chain = parent_chain
+ self.get_new_label = get_new_label
+ self.dataset = dataset
+
+ def get_config(self):
+ config = {
+ 'id_attr': self.id_attr,
+ 'keep_attrs': self.keep_attrs[:5],
+ 'label_attr': self.label_attr,
+ 'use_position': self.use_position,
+ 'window_size': self.window_size,
+ 'rect': dict(list(self.rect.items())[:3]),
+ 'keep_elem': self.keep[:5],
+ 'obs_elem': self.obs[:5],
+ 'parent_chain': self.parent_chain,
+ 'prompt_name': self.prompt.name,
+ 'identifier_name': self.identifier.name
+ }
+
+ return config, config_meta.format(**config)
+
+ def update_rect_dict(self, rect_dict: dict[str]={}) -> None:
+ self.rect = rect_dict
+
+ @staticmethod
+ def ctx2tree(ctx: str) -> html.HtmlElement:
+ # remove useless tags, eg. style and script
+ ctx = re.sub('', '', ctx)
+ ctx = re.sub('', '', ctx)
+ ctx = re.sub('', '', ctx)
+ ctx = '' if ctx is None else re.sub(r'\s+', ' ', ctx).strip()
+ dom_tree = html.fromstring(ctx.encode('utf-8'))
+ match = re.search('
html.HtmlElement:
+ node = tree.xpath('//*')[0]
+ while True:
+ parent = node.getparent()
+ if parent is None:
+ break
+ node = parent
+ return node
+
+ def get_node_by_bid(self, tree: html.HtmlElement, bid: str) -> html.HtmlElement:
+ nodes = tree.xpath(f'//*[@{self.id_attr}="{bid}"]')
+ if len(nodes) == 0:
+ return None
+ return nodes[0]
+
+ def id_label_converter(self, label: str) -> str:
+ return self.bids2label.get(label, '')
+
+ def id_xpath_converter(self, label: str) -> str:
+ return self.bids2xpath.get(label, '')
+
+ def mark_id(self) -> None:
+ root = self.get_root(self.dom_tree)
+ _, i2xpath, used_labels = get_xpath_top_down(root, self.id_attr, self.label_attr)
+ self.used_labels = used_labels
+ self.bids2xpath = i2xpath
+
+ def parse(self, root: html.HtmlElement, keep: list[str], obs: list[str], parent_chain: bool=False, get_new_label: bool=False) -> dict[str]:
+ def get_text(str: str) -> str:
+ return '' if str is None else str.strip()[:500]
+
+ def check_attr(attr: str, node: html.HtmlElement) -> bool:
+ tag = node.tag
+ if (
+ ( attr == 'role' and node.attrib.get(attr, '') in ['presentation', 'none', 'link'] )
+ or ( attr == 'type' and node.attrib.get(attr, '') == 'hidden' )
+ or ( attr == 'aria-hidden' and node.attrib.get(attr, '') == 'true')
+ # or ( attr == 'value' and tag in ['option'] )
+ ):
+ return False
+ return True
+
+ def is_visible(node: html.HtmlElement, bid: str) -> bool:
+ if self.dataset == 'mind2web':
+ bound = node.attrib.get('bounding_box_rect', None)
+ self.rect[bid] = rect2tuple(bound)
+
+ if self.dataset == 'pipeline':
+ bound = node.attrib.get('data-bbox', None)
+ self.rect[bid] = rect2tuple(bound)
+
+ if node.attrib.get('aria-hidden', 'false') == 'true':
+ return False
+
+ if not self.use_position:
+ return True
+
+ rect = self.rect.get(bid, None)
+ if rect is None:
+ return False
+
+ if self.window_size is None:
+ return True
+
+ # get window size
+ wx, wy, ww, wh = self.window_size
+ wx, wy = 0, 0
+ x, y, w, h = rect
+ if x + w < wx or x > wx + ww or y + h < wy or y > wy + wh:
+ return False
+
+ return True
+
+ def _dfs(node: html.HtmlElement, keep: list[str]=[], obs: list[str]=[],
+ parent_chain: bool=False, get_new_label: bool=False, par_keep: bool=False) -> (str, dict[str]):
+ # basic information
+ bid = node.attrib.get(self.id_attr, '')
+ tag = node.tag
+ label = node.attrib.get(self.label_attr, '')
+
+ # element which is keeped equivalent to visible
+ visible = is_visible(node, bid)
+ in_keep_list = bid in keep
+ in_obs_list = (bid in obs or len(label) > 0) and visible
+ par_keep = par_keep and tag == "option"
+ keep_element = in_keep_list or in_obs_list or visible or par_keep
+
+ if label:
+ keep_element = True
+
+ # mark label
+ bids2label, labeled_elems = {}, []
+ have_label = False
+ if in_keep_list or in_obs_list:
+ if label is None or len(label) == 0 or get_new_label:
+ label = self.identifier.generate()
+ node.attrib[self.label_attr] = label
+ bids2label[bid] = label
+ bids2label[label] = bid
+ have_label = True
+
+ # get text or alt_text of current element
+ text = get_text(node.text)
+
+ classes = {}
+ # keep attributes if needed
+ keep_all_attrs = len(self.keep_attrs) == 0
+ keep_attrs = node.attrib.keys() if keep_all_attrs else self.keep_attrs
+
+ # traverse attributes
+ for attr in keep_attrs:
+ if attr not in node.attrib or not check_attr(attr, node):
+ continue
+ if attr in [self.id_attr, self.label_attr]:
+ continue
+ val = get_text(node.attrib[attr])
+ if len(val) > 0 or keep_all_attrs:
+ classes[attr] = val
+
+ have_text = len(text) > 0 or len(classes) - (1 if 'data-bbox' in classes else 0) > 0
+ par_keep = keep_element and tag == 'select'
+
+ parts = []
+ clickable_count = 0
+ children = node.getchildren()
+ for child in children:
+ cres, cmsg = _dfs(child, keep, obs, parent_chain, get_new_label, par_keep)
+ clickable_count += 1 if cmsg.get('have_clickable', False) else 0
+ bids2label.update(cmsg.get('bids2label', {}))
+ labeled_elems.extend(cmsg.get('label_element', []))
+ if len(cres) != 0:
+ parts.append(cres)
+
+ dom = self.prompt.subtree_constructor(parts)
+
+ # remove
if all children are text
+ keep_as_all_text = (dom.count('<') == dom.count(' 0
+ if keep_as_all_text:
+ matches = re.findall(r']+) >', dom)
+ dom = self.prompt.subtree_constructor(matches)
+
+ keep_element = keep_element and (clickable_count > 1 or have_text or have_label or keep_as_all_text)
+ keep_as_parent = len(dom) > 0 and parent_chain
+ if in_keep_list or keep_element or keep_as_parent:
+ dom = self.prompt.prompt_constructor(tag, label, text, dom, classes)
+
+ if have_label:
+ labeled_elems.append(bid)
+
+ control_msg = {
+ 'have_clickable': bool(clickable_count or have_text),
+ 'bids2label': bids2label,
+ 'label_element': labeled_elems,
+ }
+
+ return dom, control_msg
+
+ dom, cmsg = _dfs(root, keep, obs, parent_chain, get_new_label)
+ return dom, cmsg
+
+ def parse_tree(self) -> dict[str]:
+ # start from here
+ stt = time.time()
+ root = self.get_root(self.dom_tree)
+ dom, cmsg = self.parse(root, self.keep, self.obs, self.parent_chain, self.get_new_label)
+ self.bids2label = cmsg.get('bids2label', {})
+ self.keep = list(set(self.keep + cmsg.get('label_element', [])))
+
+ obj = {
+ 'html': dom,
+ 'parse_time': time.time() - stt
+ }
+
+ return obj
+
+ # From mind2web, https://github.com/OSU-NLP-Group/Mind2Web/blob/main/src/data_utils/dom_utils.py
+ def get_keep_elements(self, tree: html.HtmlElement, keep: list[str], max_depth: int, max_children: int,
+ max_sibling: int, dfs_count: int=1, keep_parent: bool=False) -> list[str]:
+ def get_anscendants(node: html.HtmlElement, max_depth: int, current_depth: int=0) -> list[str]:
+ if current_depth > max_depth:
+ return []
+
+ anscendants = []
+ parent = node.getparent()
+ if parent is not None:
+ anscendants.append(parent)
+ anscendants.extend(get_anscendants(parent, max_depth, current_depth + 1))
+
+ return anscendants
+
+ def get_descendants(node: html.HtmlElement, max_depth: int, current_depth: int=0) -> list[str]:
+ if current_depth > max_depth:
+ return []
+
+ descendants = []
+ for child in node:
+ descendants.append(child)
+ descendants.extend(get_descendants(child, max_depth, current_depth + 1))
+
+ return descendants
+
+ to_keep = set(copy.deepcopy(keep))
+ nodes_to_keep = set()
+
+ for _ in range(max(1, dfs_count)):
+ for bid in to_keep:
+ candidate_node = self.get_node_by_bid(tree, bid)
+ if candidate_node is None:
+ continue
+
+ nodes_to_keep.add(candidate_node.attrib[self.id_attr])
+ # get all ancestors or with max depth
+ nodes_to_keep.update([x.attrib.get(self.id_attr, '') for x in get_anscendants(candidate_node, max_depth)])
+
+ # get descendants with max depth
+ nodes_to_keep.update([x.attrib.get(self.id_attr, '') for x in get_descendants(candidate_node, max_depth)][:max_children])
+ # get siblings within range
+ parent = candidate_node.getparent()
+ if parent is None:
+ continue
+
+ siblings = [x for x in parent.getchildren() if x.tag != 'text']
+ if candidate_node not in siblings:
+ continue
+
+ idx_in_sibling = siblings.index(candidate_node)
+ nodes_to_keep.update([x.attrib.get(self.id_attr, '')
+ for x in siblings[max(0, idx_in_sibling - max_sibling) : idx_in_sibling + max_sibling + 1]])
+
+ max_children = int(max_children * 0.5)
+ max_depth = int(max_depth * 0.5)
+ max_sibling = int(max_sibling * 0.7)
+
+ to_keep = copy.deepcopy(nodes_to_keep)
+
+ if keep_parent:
+ for bid in keep:
+ candidate_node = self.get_node_by_bid(tree, bid)
+ if candidate_node is None:
+ continue
+ nodes_to_keep.update([x.attrib.get(self.id_attr, '') for x in candidate_node.xpath("ancestor::*")])
+
+ return list(nodes_to_keep)
+
+ def prune(self, tree: html.HtmlElement, nodes_to_keep: list[str]) -> html.HtmlElement:
+ # remove nodes not in nodes_to_keep
+ for node in tree.xpath('//*')[::-1]:
+ if node.tag != 'text':
+ is_keep = node.attrib.get(self.id_attr, '') in nodes_to_keep
+ is_candidate = node.attrib.get(self.id_attr, '') in self.keep
+ else:
+ is_keep = (node.getparent().attrib.get(self.id_attr, '') in nodes_to_keep)
+ is_candidate = (node.getparent().attrib.get(self.id_attr, '') in self.keep)
+
+ if not is_keep and node.getparent() is not None:
+ # insert all children into parent
+ for child in node.getchildren():
+ node.addprevious(child)
+ node.getparent().remove(node)
+ else:
+ # if not is_candidate or node.tag == 'text':
+ # node.attrib.pop(self.id_attr, None)
+ if (
+ len(node.attrib) == 0
+ and not any([x.tag == 'text' for x in node.getchildren()])
+ and node.getparent() is not None
+ and node.tag != "text"
+ and len(node.getchildren()) <= 1
+ ):
+ # insert all children into parent
+ for child in node.getchildren():
+ node.addprevious(child)
+ node.getparent().remove(node)
+
+ return tree
+
+ def prune_tree(self, dfs_count: int=1, max_depth: int=3, max_children: int=30,
+ max_sibling: int=3, keep_parent: bool=False) -> None:
+ # clone the tree
+ new_tree = copy.deepcopy(self.dom_tree)
+ nodes_to_keep = self.get_keep_elements(new_tree, self.keep, max_depth, max_children, max_sibling, dfs_count, keep_parent)
+ new_tree = self.prune(new_tree, nodes_to_keep)
+
+ self.dom_tree = new_tree
+
+ def get_segment(self, bid: str) -> str:
+ # clone the tree
+ new_tree = copy.deepcopy(self.dom_tree)
+ nodes_to_keep = self.get_keep_elements(new_tree, [bid], 0, 2, 1)
+ new_tree = self.prune(new_tree, nodes_to_keep)
+ dom, _ = self.parse(new_tree, self.keep, [], False)
+ return dom
+
+ def get_rect_data(self, bids: list[str]) -> list[dict[str]]:
+ res = []
+ for bid in bids:
+ label = self.bids2label.get(bid, '')
+ rect = self.rect.get(bid, None)
+ res.append({
+ 'bid': bid,
+ 'label': label,
+ 'rect': rect
+ })
+ return res
+
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/html_tools/identifier.py b/VAB-WebArena-Lite/new/html_tools/identifier.py
new file mode 100755
index 0000000..793ebed
--- /dev/null
+++ b/VAB-WebArena-Lite/new/html_tools/identifier.py
@@ -0,0 +1,64 @@
+import secrets
+
+class IdentifierTool:
+ def __init__(self, method: str='order', existing_labels: dict[str]={}) -> None:
+ self.methods = {
+ 'order': self.get_identifier_in_order,
+ 'random': self.get_random_identifier,
+ }
+
+ if method is None:
+ method = 'order'
+
+ self.func = self.methods.get(method, None)
+ self.name = method
+ if self.func is None:
+ raise ValueError(f'Invalid method for identifier: {method}')
+
+ self.reset(existing_labels)
+
+ def reset(self, exists: dict[str]={}) -> None:
+ self.identifier = -1
+ self.exists = {} if exists is None else exists
+
+ def get_identifier_in_order(self) -> str:
+ def id2str(id: int) -> str:
+ if id < 26:
+ return chr(id + 65)
+ id -= 26
+ c0 = id // 676
+ c1 = (id // 26) % 26
+ c2 = id % 26
+ label = f'{chr(c1 + 65)}{chr(c2 + 65)}'
+ return label if c0 == 0 else f'{chr(c0 + 64)}{label}'
+
+ self.identifier += 1
+ label = id2str(self.identifier)
+
+ while label in self.exists:
+ self.identifier += 1
+ label = id2str(self.identifier)
+
+ self.exists[label] = True
+ return label
+
+ def get_random_identifier(self) -> str:
+ secret_generator = secrets.SystemRandom()
+
+ def get_random_label(n: int=2) -> str:
+ tmp = ''
+ for _ in range(n):
+ tmp += chr(secret_generator.randint(65, 90))
+ return tmp
+
+ wc = 3 if len(self.exists) > 280 else 2
+
+ label = get_random_label(wc)
+ while label in self.exists:
+ label = get_random_label(wc)
+
+ self.exists[label] = True
+ return label
+
+ def generate(self):
+ return self.func()
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/html_tools/prompt.py b/VAB-WebArena-Lite/new/html_tools/prompt.py
new file mode 100755
index 0000000..38d6b94
--- /dev/null
+++ b/VAB-WebArena-Lite/new/html_tools/prompt.py
@@ -0,0 +1,97 @@
+from .configs import prompts
+
+class HtmlPrompt:
+ def __init__(self, prompt: str='') -> None:
+ prompt = self.extract(prompt, 'xml')
+ if prompt not in prompts:
+ raise Exception('Unknown prompt: ' + prompt)
+
+ constructors = {
+ 'refine': self.normal_prompt_constructor,
+ 'xml': self.normal_prompt_constructor,
+ 'new_data': self.new_data_prompt_constructor,
+ }
+
+ self.name = prompt
+ self.prompt = prompts[prompt]
+ self.constructor = constructors[prompt]
+
+ @staticmethod
+ def extract(data, default=''):
+ return data if data is not None else default
+
+ def subtree_constructor(self, subtree: list[str]=[]) -> str:
+ return self.prompt['subtree_splitter'].join(subtree)
+
+ def normal_prompt_constructor(self, tag: str='', label: str='', content: str='', subtree_str: str='', class_dict: dict[str]={}) -> str:
+ def add_prefix(data, prefix):
+ return prefix + data if len(data) > 0 else ''
+
+ tag = self.extract(tag)
+ label = self.extract(label)
+ content = self.extract(content)
+ subtree_str = self.extract(subtree_str, '')
+ class_dict = self.extract(class_dict, {})
+
+ label_str = ''
+ if len(label) > 0:
+ label_str = self.prompt['label'].format(label=label)
+
+ classes = []
+ values = set()
+ for key, val in class_dict.items():
+ if val in values:
+ continue
+ values.add(val)
+ classes.append(self.prompt['attr'].format(key=key, attr=val))
+ classes_str = self.prompt['attr_splitter'].join(classes)
+
+ content_splitter = ' ' if len(classes_str) == 0 else self.prompt['attr_splitter']
+ classes_str = add_prefix(classes_str, ' ')
+ content_str = add_prefix(content, content_splitter)
+ subtree_str = add_prefix(subtree_str, ' ')
+
+ return self.prompt['dom'].format(tag=tag, label=label_str, attr=classes_str, content=content_str, subtree=subtree_str)
+
+ def new_data_prompt_constructor(self, tag: str='', label: str='', content: str='', subtree_str: str='', class_dict: dict[str]={}) -> str:
+ def add_prefix(data, prefix):
+ return prefix + data if len(data) > 0 else ''
+
+ tag = self.extract(tag)
+ label = self.extract(label)
+ content = self.extract(content)
+ subtree_str = self.extract(subtree_str, '')
+ class_dict = self.extract(class_dict, {})
+
+ label_str = ''
+ if len(label) > 0:
+ label_str = self.prompt['label'].format(label=label)
+
+ classes = []
+ values = set()
+
+ message = []
+ for key, val in class_dict.items():
+ if val == '':
+ message.append(key)
+ continue
+ if val in values:
+ continue
+ values.add(val)
+ classes.append(self.prompt['attr'].format(key=key, attr=val))
+
+ if len(message) > 0:
+ message_str = ' '.join(message)
+ classes.append(self.prompt['attr'].format(key='message', attr=message_str))
+
+ classes_str = self.prompt['attr_splitter'].join(classes)
+
+ content_splitter = ' ' if len(classes_str) == 0 else self.prompt['attr_splitter']
+ classes_str = add_prefix(classes_str, ' ')
+ content_str = add_prefix(content, content_splitter)
+ subtree_str = add_prefix(subtree_str, ' ')
+
+ return self.prompt['dom'].format(tag=tag, label=label_str, attr=classes_str, content=content_str, subtree=subtree_str)
+
+ def prompt_constructor(self, tag: str='', label: str='', content: str='', subtree_str: str='', class_dict: dict[str]={}) -> str:
+ return self.constructor(tag, label, content, subtree_str, class_dict)
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/html_tools/scripts/__init__.py b/VAB-WebArena-Lite/new/html_tools/scripts/__init__.py
new file mode 100755
index 0000000..0d36c9a
--- /dev/null
+++ b/VAB-WebArena-Lite/new/html_tools/scripts/__init__.py
@@ -0,0 +1,43 @@
+import os
+from pathlib import Path
+rootdir = Path(__file__).parent
+
+with open(os.path.join(rootdir,'prepare.js'), 'r') as f:
+ prepare_script = f.read()
+
+with open(os.path.join(rootdir, 'clickable_checker.js'), 'r') as f:
+ clickable_checker_script = f.read()
+
+with open(os.path.join(rootdir, 'label.js'), 'r') as f:
+ label_script = f.read()
+
+with open(os.path.join(rootdir, 'element_info.js'), 'r') as f:
+ element_info_script = f.read()
+
+# draw label on page
+with open(os.path.join(rootdir, 'label_marker.js'), 'r') as f:
+ label_marker_script = f.read()
+
+# remove label draw on page
+remove_label_mark_script = """
+ () => {
+ document.querySelectorAll(".our-dom-marker").forEach(item => {
+ document.body.removeChild(item);
+ });
+ }
+"""
+
+remove_id_script = """
+ () => {
+ Array.from(document.getElementsByClassName('possible-clickable-element')).forEach((element) => {
+ element.classList.remove('possible-clickable-element');
+ element.removeAttribute('data-value');
+ element.removeAttribute('data-text');
+ element.removeAttribute('data-label');
+ element.removeAttribute('data-bbox');
+ element.removeAttribute('data-status');
+ element.removeAttribute('data-backend-node-id');
+ element.removeAttribute('data-label-id');
+ });
+ }
+"""
diff --git a/VAB-WebArena-Lite/new/html_tools/scripts/clickable_checker.js b/VAB-WebArena-Lite/new/html_tools/scripts/clickable_checker.js
new file mode 100755
index 0000000..42267f0
--- /dev/null
+++ b/VAB-WebArena-Lite/new/html_tools/scripts/clickable_checker.js
@@ -0,0 +1,148 @@
+() => {
+ var items = Array.prototype.slice.call(
+ document.querySelectorAll('*')
+ ).map(function(element) {
+ var vw = Math.max(document.documentElement.clientWidth || 0, window.innerWidth || 0);
+ var vh = Math.max(document.documentElement.clientHeight || 0, window.innerHeight || 0);
+
+ var rects = [...element.getClientRects()].filter(bb => {
+ var center_x = bb.left + bb.width / 2;
+ var center_y = bb.top + bb.height / 2;
+ var elAtCenter = document.elementFromPoint(center_x, center_y);
+
+ if (!elAtCenter) return false;
+ return elAtCenter === element || element.contains(elAtCenter)
+ }).map(bb => {
+ const rect = {
+ left: Math.max(0, bb.left),
+ top: Math.max(0, bb.top),
+ right: Math.min(vw, bb.right),
+ bottom: Math.min(vh, bb.bottom)
+ };
+ return {
+ ...rect,
+ width: rect.right - rect.left,
+ height: rect.bottom - rect.top
+ }
+ });
+ // var rects = [];
+ var area = rects.reduce((acc, rect) => acc + rect.width * rect.height, 0);
+
+ const tagName = element.tagName.toLowerCase?.() || "";
+ let isClickable = ((element.onclick != null) || window.getComputedStyle(element).cursor == "pointer");
+
+ // Insert area elements that provide click functionality to an img.
+ if (tagName === "img") {
+ let mapName = element.getAttribute("usemap");
+ if (mapName) {
+ const imgClientRects = element.getClientRects();
+ mapName = mapName.replace(/^#/, "").replace('"', '\\"');
+ const map = document.querySelector(`map[name=\"${mapName}\"]`);
+ if (map && (imgClientRects.length > 0)) isClickable = true;
+ }
+ }
+
+ if (!isClickable) {
+ const role = element.getAttribute("role");
+ const clickableRoles = [
+ "button",
+ "tab",
+ "link",
+ "checkbox",
+ "menuitem",
+ "menuitemcheckbox",
+ "menuitemradio",
+ "radio",
+ ];
+ if (role != null && clickableRoles.includes(role.toLowerCase())) {
+ isClickable = true;
+ } else {
+ const contentEditable = element.getAttribute("contentEditable");
+ if (
+ contentEditable != null &&
+ ["", "contenteditable", "true"].includes(contentEditable.toLowerCase())
+ ) {
+ isClickable = true;
+ }
+ }
+ }
+
+ // Check for jsaction event listeners on the element.
+ if (!isClickable && element.hasAttribute("jsaction")) {
+ const jsactionRules = element.getAttribute("jsaction").split(";");
+ for (let jsactionRule of jsactionRules) {
+ const ruleSplit = jsactionRule.trim().split(":");
+ if ((ruleSplit.length >= 1) && (ruleSplit.length <= 2)) {
+ const [eventType, namespace, actionName] = ruleSplit.length === 1
+ ? ["click", ...ruleSplit[0].trim().split("."), "_"]
+ : [ruleSplit[0], ...ruleSplit[1].trim().split("."), "_"];
+ if (!isClickable) {
+ isClickable = (eventType === "click") && (namespace !== "none") && (actionName !== "_");
+ }
+ }
+ }
+ }
+
+ if (!isClickable) {
+ const clickableTags = [
+ "input",
+ "textarea",
+ "select",
+ "button",
+ "a",
+ "iframe",
+ "video",
+ "object",
+ "embed",
+ "details"
+ ];
+ isClickable = clickableTags.includes(tagName);
+ }
+
+ if (!isClickable) {
+ if (tagName === "label")
+ isClickable = (element.control != null) && !element.control.disabled;
+ else if (tagName === "img")
+ isClickable = ["zoom-in", "zoom-out"].includes(element.style.cursor);
+ }
+
+ // An element with a class name containing the text "button" might be clickable. However, real
+ // clickables are often wrapped in elements with such class names. So, when we find clickables
+ // based only on their class name, we mark them as unreliable.
+ const className = element.getAttribute("class");
+ if (!isClickable && className && className.toLowerCase().includes("button")) {
+ isClickable = true;
+ }
+
+ // Elements with tabindex are sometimes useful, but usually not. We can treat them as second
+ // class citizens when it improves UX, so take special note of them.
+ const tabIndexValue = element.getAttribute("tabindex");
+ const tabIndex = tabIndexValue ? parseInt(tabIndexValue) : -1;
+ if (!isClickable && !(tabIndex < 0) && !isNaN(tabIndex)) {
+ isClickable = true;
+ }
+
+ const idValue = element.getAttribute("id");
+ const id = idValue ? idValue.toLowerCase() : "";
+ if (isClickable && area == 0) {
+ const textValue = element.textContent.trim().replace(/\s{2,}/g, ' ');
+ clickable_msg = `${tagName}[id=${id}] ${isClickable} (${area}) ${textValue}`
+ }
+
+ return {
+ element: element,
+ include: isClickable,
+ area,
+ rects,
+ text: element.textContent.trim().replace(/\s{2,}/g, ' ')
+ };
+ }).filter(item =>
+ item.include && (item.area >= 1)
+ );
+
+ items = items.filter(x => !items.some(y => x.element.contains(y.element) && !(x == y)))
+
+ items.forEach(item => {
+ item.element.classList.add('possible-clickable-element');
+ });
+}
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/html_tools/scripts/element_info.js b/VAB-WebArena-Lite/new/html_tools/scripts/element_info.js
new file mode 100755
index 0000000..c396220
--- /dev/null
+++ b/VAB-WebArena-Lite/new/html_tools/scripts/element_info.js
@@ -0,0 +1,34 @@
+() => {
+ function getElementInfo(element) {
+ return {
+ "bid": element.getAttribute("data-backend-node-id") || "",
+ "label": element.getAttribute("data-label-id") || "",
+ "tag": element.tagName.toLowerCase?.() || "",
+ "area": JSON.parse("[" + (element.getAttribute("data-bbox") || "") + "]"),
+ "text": element.innerText?.trim().replace(/\s{2,}/g, " ") || "",
+ "id": element.getAttribute("id") || "",
+ "role": element.getAttribute("role") || "",
+ "aria-label": element.getAttribute("aria-label") || "",
+ "href": element.getAttribute("href") || "",
+ };
+ }
+
+ var all_items = Array.prototype.slice.call(
+ document.querySelectorAll("*")
+ ).map((element) => {
+ return getElementInfo(element);
+ });
+
+ var clickable_items = Array.prototype.slice.call(
+ document.querySelectorAll("*")
+ ).filter(
+ element => element.getAttribute("data-label-id")
+ ).map((element) => {
+ return getElementInfo(element);
+ });
+
+ return {
+ all_elements: all_items,
+ clickable_elements: clickable_items
+ };
+}
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/html_tools/scripts/label.js b/VAB-WebArena-Lite/new/html_tools/scripts/label.js
new file mode 100755
index 0000000..285ca89
--- /dev/null
+++ b/VAB-WebArena-Lite/new/html_tools/scripts/label.js
@@ -0,0 +1,74 @@
+(packet) => {
+ function int2str(index) {
+ var str = "";
+ while (index >= 0) {
+ str = String.fromCharCode(65 + index % 26) + str;
+ index = Math.floor(index / 26) - 1;
+ }
+ return str;
+ };
+
+ selector = packet.selector
+ index = packet.startIndex
+ var items = Array.prototype.slice.call(
+ document.querySelectorAll(selector)
+ );
+
+ var vw = Math.max(document.documentElement.clientWidth || 0, window.innerWidth || 0);
+ var vh = Math.max(document.documentElement.clientHeight || 0, window.innerHeight || 0);
+
+ items = items.filter(
+ x => !items.some(y => x.contains(y) && !(x == y))
+ ).map(element => {
+ var bb = element.getClientRects();
+ var rect = {
+ left: 0,
+ top: 0,
+ right: 0,
+ bottom: 0,
+ width: 0,
+ height: 0
+ };
+ var keep = false;
+ var text = "", id = -1;
+ if (bb.length > 0) {
+ bb = bb[0];
+ rect = {
+ left: Math.max(0, bb.left),
+ top: Math.max(0, bb.top),
+ right: Math.min(vw, bb.right),
+ bottom: Math.min(vh, bb.bottom)
+ };
+ rect = {
+ ...rect,
+ width: rect.right - rect.left,
+ height: rect.bottom - rect.top
+ };
+ if (rect.width > 0 || rect.height > 0) {
+ keep = true;
+ if (index >= 0) {
+ // id = int2str(index++);
+ id = index++;
+ element.setAttribute("data-label-id", id);
+ }
+ var childNodes = element.childNodes;
+
+ for (var i = 0; i < childNodes.length; i++) {
+ if (childNodes[i].nodeType == Node.TEXT_NODE) {
+ text += childNodes[i].textContent;
+ }
+ }
+ }
+ }
+
+ return {
+ keep: true,
+ id,
+ rects: rect,
+ tag: element.tagName.toLowerCase?.() || "",
+ text,//: element.innerText?.trim().replace(/\s{2,}/g, " ") || ""
+ };
+ }).filter(x => x.keep);
+
+ return [items, index];
+}
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/html_tools/scripts/label_marker.js b/VAB-WebArena-Lite/new/html_tools/scripts/label_marker.js
new file mode 100755
index 0000000..6c55af5
--- /dev/null
+++ b/VAB-WebArena-Lite/new/html_tools/scripts/label_marker.js
@@ -0,0 +1,65 @@
+(items) => {
+ function getRandomColor() {
+ var letters = '0123456789ABCDEF';
+ var color = '#';
+ for (var i = 0; i < 6; i++) {
+ color += letters[Math.floor(Math.random() * 16)];
+ }
+ return color;
+ }
+
+ items.filter(
+ item => item.id != ""
+ ).forEach((item) => {
+ const bbox = item.rects;
+ const id_string = `dom-marker-id-${index}`;
+
+ index = item.id;
+
+ outerElement = document.createElement("div");
+ outerElement.classList.add("our-dom-marker");
+ // var borderColor = getRandomColor();
+ var borderColor = "#FFFF00";
+ outerElement.style.outline = `2px dashed ${borderColor}`;
+ outerElement.style.position = "fixed";
+ outerElement.style.left = bbox.left - 2 + "px";
+ outerElement.style.top = bbox.top - 2 + "px";
+ outerElement.style.width = bbox.width + 4 + "px";
+ outerElement.style.height = bbox.height + 4 + "px";
+ outerElement.style.pointerEvents = "none";
+ outerElement.style.boxSizing = "border-box";
+ outerElement.style.zIndex = 2147483647;
+
+ innerElement = document.createElement("div");
+ innerElement.classList.add("our-dom-marker");
+ innerElement.style.outline = `2px dashed #222288`;
+ innerElement.style.position = "fixed";
+ innerElement.style.left = bbox.left + "px";
+ innerElement.style.top = bbox.top + "px";
+ innerElement.style.width = bbox.width + "px";
+ innerElement.style.height = bbox.height + "px";
+ innerElement.style.pointerEvents = "none";
+ innerElement.style.boxSizing = "border-box";
+ innerElement.style.zIndex = 2147483647;
+
+ // Add floating label at the corner
+ var label = document.createElement("span");
+ var topPosition = 25;
+ if (bbox.top < 25) topPosition = bbox.top;
+ label.textContent = index;
+ label.style.position = "absolute";
+ label.style.top = `-${topPosition}px`;
+ label.style.left = "0px";
+ label.style.background = borderColor;
+ label.style.color = "black";
+ label.style.padding = "2px 4px";
+ label.style.fontSize = "16px";
+ label.style.borderRadius = "2px";
+ label.style.fontWeight = "bold";
+ outerElement.appendChild(label);
+
+ document.body.appendChild(outerElement);
+ document.body.appendChild(innerElement);
+ })
+ return items;
+}
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/html_tools/scripts/prepare.js b/VAB-WebArena-Lite/new/html_tools/scripts/prepare.js
new file mode 100755
index 0000000..9e76a62
--- /dev/null
+++ b/VAB-WebArena-Lite/new/html_tools/scripts/prepare.js
@@ -0,0 +1,83 @@
+() => {
+ // mark backend node id
+ var vw = Math.max(document.documentElement.clientWidth || 0, window.innerWidth || 0);
+ var vh = Math.max(document.documentElement.clientHeight || 0, window.innerHeight || 0);
+
+ var backendId = 0;
+ Array.prototype.slice.call(
+ document.querySelectorAll("*")
+ ).forEach((element) => {
+ element.setAttribute("data-backend-node-id", backendId);
+ backendId++;
+
+ var tag = element.tagName.toLowerCase?.() || "";
+ var bb = element.getClientRects();
+ var rect = {
+ left: 0,
+ top: 0,
+ right: 0,
+ bottom: 0,
+ width: 0,
+ height: 0
+ };
+
+ if (bb.length > 0) {
+ bb = bb[0];
+ // rect = {
+ // left: Math.round(Math.max(0, bb.left) * 100) / 100,
+ // top: Math.round(Math.max(0, bb.top) * 100) / 100,
+ // right: Math.round(Math.min(vw, bb.right) * 100) / 100,
+ // bottom: Math.round(Math.min(vh, bb.bottom) * 100) / 100
+ // };
+ rect = {
+ left: (Math.round(bb.left) * 100) / 100,
+ top: (Math.round(bb.top) * 100) / 100,
+ right: (Math.round(bb.right) * 100) / 100,
+ bottom: (Math.round(bb.bottom) * 100) / 100
+ };
+ rect = {
+ ...rect,
+ width: Math.round((rect.right - rect.left) * 100) / 100,
+ height: Math.round((rect.bottom - rect.top) * 100) / 100
+ };
+
+ element.setAttribute("data-bbox", `${rect.left},${rect.top},${rect.width},${rect.height}`);
+ }
+
+ if (element.hasChildNodes()) {
+ let children = Array.prototype.slice.call(element.childNodes);
+ var texts = children.filter(
+ (node) => node.nodeType == Node.TEXT_NODE
+ ).map(
+ (node) => node.textContent.trim().replace(/\s{2,}/g, " ") || ""
+ ).filter(
+ (text) => text.length > 0
+ )
+ element.setAttribute("data-text", texts.join(","));
+ }
+
+ // fix select issue
+ if (tag == "select") {
+ var value = element.value;
+ var text = element.options[element.selectedIndex]?.text || "";
+ element.setAttribute("data-value", value);
+ element.setAttribute("data-text", text);
+ element.options[element.selectedIndex]?.setAttribute("data-status", "selected");
+ }
+
+ if (tag == "input") {
+ var input_type = element.getAttribute("type") || "";
+ if (input_type == "checkbox") {
+ var status = element.checked? "checked" : "not-checked";
+ element.setAttribute("data-status", status);
+ }
+ }
+ });
+
+ // fix input and textarea issue
+ Array.prototype.slice.call(
+ document.querySelectorAll("input, textarea")
+ ).forEach(element => {
+ element.setAttribute("data-value", element.value);
+ });
+}
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/html_tools/utils.py b/VAB-WebArena-Lite/new/html_tools/utils.py
new file mode 100755
index 0000000..9e0be65
--- /dev/null
+++ b/VAB-WebArena-Lite/new/html_tools/utils.py
@@ -0,0 +1,101 @@
+from lxml import html
+def get_xpath_top_down(element: html.HtmlElement, id_column: str='temp_id', label_column: str='temp_clickable_label', path: str='', order: int=0,
+ in_svg: bool=False, temp_id: int=0) -> tuple[int, dict[str, str], dict[str]]:
+ used_labels, i2xpath = {}, {}
+ # path
+ tag = element.tag.lower()
+ in_svg = in_svg or (tag == 'svg')
+
+ if not in_svg and 'id' in element.attrib:
+ node_id = element.attrib['id']
+ path = f'//*[@id="{node_id}"]'
+ else:
+ suffix = f'[{order}]' if order > 0 else ''
+ prefix = f'*[name()="{tag}"]' if in_svg else tag
+ path = path + '/' + prefix + suffix
+
+ # add temp id
+ element.attrib[id_column] = str(temp_id)
+ ori_label = element.attrib.get(label_column, '')
+ if ori_label != '':
+ used_labels[ori_label] = True
+
+ bid = str(temp_id)
+ i2xpath[bid] = path
+ i2xpath[path] = bid
+ i2xpath[f'xpath/{path}'] = bid
+ i2xpath[f'xpath=/{path}'] = bid
+
+ temp_id += 1
+
+ # traverse node
+ children = element.getchildren()
+ tag_dict = {}
+ id_list = []
+ for child in children:
+ ctag = child.tag.lower()
+ if ctag not in tag_dict:
+ tag_dict[ctag] = 0
+ tag_dict[ctag] += 1
+ id_list.append(tag_dict[ctag])
+
+ for cid, child in zip(id_list, children):
+ ctag = child.tag.lower()
+ cod = cid if tag_dict[ctag] > 1 else 0
+ temp_id, i2x, ulabels = get_xpath_top_down(child, id_column, label_column, path, cod, in_svg, temp_id)
+ i2xpath.update(i2x)
+ used_labels.update(ulabels)
+
+ return temp_id, i2xpath, used_labels
+
+def print_html_object(obj: str='') -> str:
+ tab_cnt = 0
+ result, content, sep = '', '', ''
+ last_is_left, last_is_right = False, False
+ for ch in obj:
+ if ch == '<':
+ result += '\n'
+ if len(content.strip()) > 0:
+ result += sep + content.strip() + '\n'
+ result += sep + '<'
+
+ tab_cnt += 1
+ sep = ' ' * tab_cnt
+
+ content = ''
+ last_is_right = False
+ last_is_left = True
+ elif ch == '>':
+ if last_is_left:
+ result += content
+ else:
+ if last_is_right:
+ result += '\n'
+ if len(content.strip()) > 0:
+ result += sep + content.strip() + '\n'
+
+ tab_cnt -= 1
+ sep = ' ' * tab_cnt
+
+ if not last_is_left:
+ result += sep
+
+ result += '>'
+ content = ''
+
+ last_is_right = True
+ last_is_left = False
+ else:
+ content += ch
+
+ return result
+
+def rect2tuple(rect: str) -> tuple[int, int, int, int]:
+ if rect is None or type(rect) != type('str'):
+ return None
+ rect = rect.strip()
+ if rect.count(',') != 3:
+ return None
+ rect = rect.split(',')
+ rect = [float(r) for r in rect]
+ return tuple(rect)
diff --git a/VAB-WebArena-Lite/new/openai_utils.py b/VAB-WebArena-Lite/new/openai_utils.py
index 8374c7a..c9754f2 100644
--- a/VAB-WebArena-Lite/new/openai_utils.py
+++ b/VAB-WebArena-Lite/new/openai_utils.py
@@ -36,7 +36,6 @@ def retry_with_exponential_backoff( # type: ignore
# Initialize variables
num_retries = 0
delay = initial_delay
-
# Loop until a successful response or max_retries is hit or an exception is raised
while True:
try:
@@ -142,27 +141,32 @@ async def agenerate_from_openai_completion(
@retry_with_exponential_backoff
def generate_from_openai_completion(
prompt: str,
- engine: str,
+ model: str,
temperature: float,
max_tokens: int,
top_p: float,
- context_length: int,
stop_token: str | None = None,
+ api_key: str | None = None,
+ base_url: str | None = None
) -> str:
- if "OPENAI_API_KEY" not in os.environ:
+ if "OPENAI_API_KEY" not in os.environ and api_key is None:
raise ValueError(
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
)
-
+ if api_key is not None:
+ client = OpenAI(api_key=api_key, base_url=base_url)
response = client.completions.create(
prompt=prompt,
- engine=engine,
+ model=model,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
stop=[stop_token],
)
- answer: str = response["choices"][0]["text"]
+ try:
+ answer: str = response["choices"][0]["text"]
+ except:
+ answer: str = response.choices[0].text
return answer
diff --git a/VAB-WebArena-Lite/new/p_webrl.json b/VAB-WebArena-Lite/new/p_webrl.json
new file mode 100644
index 0000000..c172455
--- /dev/null
+++ b/VAB-WebArena-Lite/new/p_webrl.json
@@ -0,0 +1,13 @@
+{
+ "intro": "",
+ "examples": [],
+ "template": "",
+ "meta_data": {
+ "observation": "webrl",
+ "action_type": "webrl_id",
+ "keywords": [],
+ "prompt_constructor": "WebRLPromptConstructor",
+ "answer_phrase": "",
+ "action_splitter": ""
+ }
+ }
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/processors.py b/VAB-WebArena-Lite/new/processors.py
new file mode 100644
index 0000000..d3f422d
--- /dev/null
+++ b/VAB-WebArena-Lite/new/processors.py
@@ -0,0 +1,1351 @@
+import json
+import pkgutil
+import re
+from collections import defaultdict
+from dataclasses import dataclass
+from io import BytesIO, StringIO
+from typing import Any, Optional, TypedDict, Union
+from urllib.parse import urljoin, urlparse
+
+import matplotlib.pyplot as plt
+import numpy as np
+import numpy.typing as npt
+import pandas as pd
+import playwright
+import requests
+from gymnasium import spaces
+from PIL import Image, ImageDraw, ImageFont
+from playwright.sync_api import CDPSession, Page, ViewportSize
+from .html_tools.fetch import get_parsed_html
+
+from browser_env.constants import (
+ ASCII_CHARSET,
+ FREQ_UNICODE_CHARSET,
+ IGNORED_ACTREE_PROPERTIES,
+ INJECTED_ATTR_NAME,
+ UTTERANCE_MAX_LENGTH,
+ BID_ATTR,
+ DATA_REGEXP,
+ IN_VIEWPORT_RATIO_THRESHOLD,
+)
+
+from .utils import (
+ AccessibilityTree,
+ AccessibilityTreeNode,
+ BrowserConfig,
+ BrowserInfo,
+ DOMNode,
+ DOMTree,
+ Observation,
+ png_bytes_to_numpy,
+)
+
+
+def remove_unicode(input_string):
+ # Define a regex pattern to match Unicode characters
+ unicode_pattern = re.compile(r"[^\x00-\x7F]+")
+
+ # Use the pattern to replace Unicode characters with an empty string
+ cleaned_string = unicode_pattern.sub("", input_string)
+
+ return cleaned_string
+
+
+class ObservationProcessor:
+ def process(self, page: Page) -> Observation:
+ raise NotImplementedError
+
+
+class ObservationMetadata(TypedDict):
+ obs_nodes_info: dict[str, Any]
+
+
+def create_empty_metadata() -> ObservationMetadata:
+ return {
+ "obs_nodes_info": {},
+ }
+
+
+def extract_data_items_from_aria(string: str) -> tuple[list[str], str]:
+ """
+ Utility function to extract temporary data stored in the "aria-roledescription" attribute of a node
+ """
+
+ match = DATA_REGEXP.fullmatch(string)
+ if not match:
+ return [], string
+
+ groups = match.groups()
+ data_items = groups[:-1]
+ original_aria = groups[-1]
+ return data_items, original_aria
+
+
+class TextObervationProcessor(ObservationProcessor):
+ def __init__(
+ self,
+ observation_type: str,
+ current_viewport_only: bool,
+ viewport_size: ViewportSize,
+ captioning_fn=None,
+ ):
+ self.observation_type = observation_type
+ self.current_viewport_only = current_viewport_only
+ self.viewport_size = viewport_size
+ self.observation_tag = "text"
+ self.meta_data = (
+ create_empty_metadata()
+ ) # use the store meta data of this observation type
+
+ if self.observation_type in [
+ "accessibility_tree_with_captioner",
+ "image_som",
+ ]:
+ self.captioning_fn = captioning_fn
+ # Cache captions.
+ self.url2caption = {}
+
+ def fetch_browser_info(
+ self,
+ page: Page,
+ ) -> BrowserInfo:
+ # extract domtree
+ client = page.context.new_cdp_session(page)
+ tree = client.send(
+ "DOMSnapshot.captureSnapshot",
+ {
+ "computedStyles": [],
+ "includeDOMRects": True,
+ "includePaintOrder": True,
+ },
+ )
+ client.detach()
+
+ # calibrate the bounds, in some cases, the bounds are scaled somehow
+ bounds = tree["documents"][0]["layout"]["bounds"]
+ b = bounds[0]
+ n = b[2] / self.viewport_size["width"]
+ bounds = [[x / n for x in bound] for bound in bounds]
+ tree["documents"][0]["layout"]["bounds"] = bounds
+ # add union bound placeholder
+ tree["documents"][0]["layout"]["unionBounds"] = [None for _ in bounds]
+
+ # extract browser info
+ win_upper_bound = page.evaluate("window.pageYOffset")
+ win_left_bound = page.evaluate("window.pageXOffset")
+ win_width = page.evaluate("window.screen.width")
+ win_height = page.evaluate("window.screen.height")
+ win_right_bound = win_left_bound + win_width
+ win_lower_bound = win_upper_bound + win_height
+ device_pixel_ratio = page.evaluate("window.devicePixelRatio")
+ assert device_pixel_ratio == 1.0, "devicePixelRatio is not 1.0"
+
+ config: BrowserConfig = {
+ "win_upper_bound": win_upper_bound,
+ "win_left_bound": win_left_bound,
+ "win_width": win_width,
+ "win_height": win_height,
+ "win_right_bound": win_right_bound,
+ "win_lower_bound": win_lower_bound,
+ "device_pixel_ratio": device_pixel_ratio,
+ }
+
+ # assert len(tree['documents']) == 1, "More than one document in the DOM tree"
+ info: BrowserInfo = {"DOMTree": tree, "config": config}
+
+ return info
+
+ @staticmethod
+ def get_bounding_client_rect(
+ client: CDPSession, backend_node_id: str
+ ) -> dict[str, Any]:
+ try:
+ remote_object = client.send(
+ "DOM.resolveNode", {"backendNodeId": int(backend_node_id)}
+ )
+ remote_object_id = remote_object["object"]["objectId"]
+ response = client.send(
+ "Runtime.callFunctionOn",
+ {
+ "objectId": remote_object_id,
+ "functionDeclaration": """
+ function() {
+ if (this.nodeType == 3) {
+ var range = document.createRange();
+ range.selectNode(this);
+ var rect = range.getBoundingClientRect().toJSON();
+ range.detach();
+ return rect;
+ } else {
+ return this.getBoundingClientRect().toJSON();
+ }
+ }
+ """,
+ "returnByValue": True,
+ },
+ )
+ return response
+ except Exception as e:
+ return {"result": {"subtype": "error"}}
+
+ @staticmethod
+ def get_element_in_viewport_ratio(
+ elem_left_bound: float,
+ elem_top_bound: float,
+ width: float,
+ height: float,
+ config: BrowserConfig,
+ ) -> float:
+ elem_right_bound = elem_left_bound + width
+ elem_lower_bound = elem_top_bound + height
+
+ win_left_bound = 0
+ win_right_bound = config["win_width"]
+ win_top_bound = 0
+ win_lower_bound = config["win_height"]
+
+ # Compute the overlap in x and y axes
+ overlap_width = max(
+ 0,
+ min(elem_right_bound, win_right_bound)
+ - max(elem_left_bound, win_left_bound),
+ )
+ overlap_height = max(
+ 0,
+ min(elem_lower_bound, win_lower_bound) - max(elem_top_bound, win_top_bound),
+ )
+
+ # Compute the overlap area
+ ratio = overlap_width * overlap_height / width * height
+ return ratio
+
+ def fetch_page_html(
+ self,
+ info: BrowserInfo,
+ page: Page,
+ current_viewport_only: bool,
+ ) -> DOMTree:
+ # adopted from [natbot](https://github.com/nat/natbot)
+ tree = info["DOMTree"]
+ strings = tree["strings"]
+ document = tree["documents"][0]
+ nodes = document["nodes"]
+
+ # make a dom tree that is easier to navigate
+ dom_tree: DOMTree = []
+ graph = defaultdict(list)
+ client = page.context.new_cdp_session(page)
+ for node_idx in range(len(nodes["nodeName"])):
+ cur_node: DOMNode = {
+ "nodeId": "",
+ "nodeType": "",
+ "nodeName": "",
+ "nodeValue": "",
+ "attributes": "",
+ "backendNodeId": "",
+ "parentId": "",
+ "childIds": [],
+ "cursor": 0,
+ "union_bound": None,
+ }
+
+ node_type_idx = nodes["nodeType"][node_idx]
+ node_type = "generic"
+ if node_type_idx >= 0 and node_type_idx < len(strings):
+ node_type = strings[node_type_idx]
+
+ node_name = strings[nodes["nodeName"][node_idx]]
+
+ node_value_idx = nodes["nodeValue"][node_idx]
+ node_value = ""
+ if node_value_idx >= 0 and node_value_idx < len(strings):
+ node_value = " ".join(strings[node_value_idx].split())
+
+ node_attributes = [strings[i] for i in nodes["attributes"][node_idx]]
+ node_attributes_str = ""
+ for i in range(0, len(node_attributes), 2):
+ a = node_attributes[i]
+ b = node_attributes[i + 1]
+ b = " ".join(b.split())
+ node_attributes_str += f'{a}="{b}" '
+ node_attributes_str = node_attributes_str.strip()
+
+ cur_node["nodeId"] = str(node_idx)
+ cur_node["nodeType"] = node_type
+ cur_node["nodeName"] = node_name
+ cur_node["nodeValue"] = node_value
+ cur_node["attributes"] = node_attributes_str
+ cur_node["backendNodeId"] = str(nodes["backendNodeId"][node_idx])
+ cur_node["parentId"] = str(nodes["parentIndex"][node_idx])
+
+ if cur_node["parentId"] != "-1":
+ graph[cur_node["parentId"]].append(str(cur_node["nodeId"]))
+
+ # get the bound
+ if cur_node["parentId"] == "-1":
+ cur_node["union_bound"] = [0.0, 0.0, 10.0, 10.0]
+ else:
+ response = self.get_bounding_client_rect(
+ client, cur_node["backendNodeId"]
+ )
+ if response.get("result", {}).get("subtype", "") == "error":
+ cur_node["union_bound"] = None
+ else:
+ x = response["result"]["value"]["x"]
+ y = response["result"]["value"]["y"]
+ width = response["result"]["value"]["width"]
+ height = response["result"]["value"]["height"]
+ cur_node["union_bound"] = [x, y, width, height]
+
+ dom_tree.append(cur_node)
+
+ client.detach()
+ # add parent children index to the node
+ for parent_id, child_ids in graph.items():
+ dom_tree[int(parent_id)]["childIds"] = child_ids
+
+ # remove the nodes that are not in the current viewport
+ if current_viewport_only:
+
+ def remove_node_in_graph(node: DOMNode) -> None:
+ # update the node information in the accessibility tree
+ node_id = node["nodeId"]
+ parent_id = node["parentId"]
+ child_ids = node["childIds"]
+
+ # update the children of the parent node
+ assert dom_tree[int(parent_id)]["parentId"] != "[REMOVED]"
+ # remove the nodeid from parent
+ index = dom_tree[int(parent_id)]["childIds"].index(node_id)
+ dom_tree[int(parent_id)]["childIds"].pop(index)
+
+ # Insert children_nodeids in the same location
+ for child_id in child_ids:
+ dom_tree[int(parent_id)]["childIds"].insert(index, child_id)
+ index += 1
+
+ # update children node's parent
+ for child_id in child_ids:
+ dom_tree[int(child_id)]["parentId"] = parent_id
+ # mark as removed
+ dom_tree[int(node_id)]["parentId"] = "[REMOVED]"
+
+ config = info["config"]
+ for cursor, node in enumerate(dom_tree):
+ if not node["union_bound"]:
+ remove_node_in_graph(node)
+ continue
+
+ [x, y, width, height] = node["union_bound"]
+
+ # invisible node
+ if width == 0.0 or height == 0.0:
+ remove_node_in_graph(node)
+ continue
+
+ in_viewport_ratio = self.get_element_in_viewport_ratio(
+ elem_left_bound=float(x),
+ elem_top_bound=float(y),
+ width=float(width),
+ height=float(height),
+ config=config,
+ )
+
+ if in_viewport_ratio < IN_VIEWPORT_RATIO_THRESHOLD:
+ remove_node_in_graph(node)
+
+ dom_tree = [
+ node for node in dom_tree if node.get("parentId", "-1") != "[REMOVED]"
+ ]
+
+ return dom_tree
+
+ @staticmethod
+ def parse_html(dom_tree: DOMTree) -> tuple[str, dict[str, Any]]:
+ """Parse the html tree into a string text"""
+
+ obs_nodes_info = {}
+ nodeid_to_cursor = {node["nodeId"]: idx for idx, node in enumerate(dom_tree)}
+
+ def dfs(node_cursor: int, depth: int) -> str:
+ tree_str = ""
+ node = dom_tree[node_cursor]
+ indent = "\t" * depth
+ valid_node = True
+ try:
+ node_str = f"[{node_cursor}] <{node['nodeName']}"
+ if node["attributes"]:
+ node_str += f" {node['attributes']}"
+ node_str += f"> {node['nodeValue']}"
+ valid_node = bool(node["attributes"] or node["nodeValue"])
+
+ if valid_node:
+ obs_nodes_info[str(node_cursor)] = {
+ "backend_id": node["backendNodeId"],
+ "union_bound": node["union_bound"],
+ "text": node_str,
+ }
+ tree_str += f"{indent}{node_str}\n"
+
+ except Exception as e:
+ valid_node = False
+
+ for child_ids in node["childIds"]:
+ child_cursor = nodeid_to_cursor[child_ids]
+ child_depth = depth + 1 if valid_node else depth
+ child_str = dfs(child_cursor, child_depth)
+ tree_str += child_str
+
+ return tree_str
+
+ html = dfs(0, 0)
+ return html, obs_nodes_info
+
+ def fetch_page_accessibility_tree(
+ self,
+ page: Page,
+ info: BrowserInfo,
+ current_viewport_only: bool,
+ ) -> AccessibilityTree:
+ client = page.context.new_cdp_session(page)
+ accessibility_tree: AccessibilityTree = client.send(
+ "Accessibility.getFullAXTree", {}
+ )["nodes"]
+
+ # a few nodes are repeated in the accessibility tree
+ seen_ids = set()
+ _accessibility_tree = []
+ for node in accessibility_tree:
+ if node["nodeId"] not in seen_ids:
+ _accessibility_tree.append(node)
+ seen_ids.add(node["nodeId"])
+ accessibility_tree = _accessibility_tree
+
+ nodeid_to_cursor = {}
+ for cursor, node in enumerate(accessibility_tree):
+ nodeid_to_cursor[node["nodeId"]] = cursor
+ # usually because the node is not visible etc
+ if "backendDOMNodeId" not in node:
+ node["union_bound"] = None
+ continue
+ backend_node_id = str(node["backendDOMNodeId"])
+ if node["role"]["value"] == "RootWebArea":
+ # always inside the viewport
+ node["union_bound"] = [0.0, 0.0, 10.0, 10.0]
+ else:
+ response = self.get_bounding_client_rect(
+ client,
+ backend_node_id
+ )
+ if response.get("result", {}).get("subtype", "") == "error":
+ node["union_bound"] = None
+ else:
+ x = response["result"]["value"]["x"]
+ y = response["result"]["value"]["y"]
+ width = response["result"]["value"]["width"]
+ height = response["result"]["value"]["height"]
+ node["union_bound"] = [x, y, width, height]
+
+ client.detach()
+ # filter nodes that are not in the current viewport
+ if current_viewport_only:
+
+ def remove_node_in_graph(node: AccessibilityTreeNode) -> None:
+ # update the node information in the accessibility tree
+ nodeid = node["nodeId"]
+ node_cursor = nodeid_to_cursor[nodeid]
+ parent_nodeid = node["parentId"]
+ children_nodeids = node["childIds"]
+ parent_cursor = nodeid_to_cursor[parent_nodeid]
+ # update the children of the parent node
+ assert (
+ accessibility_tree[parent_cursor].get("parentId", "Root")
+ is not None
+ )
+ # remove the nodeid from parent's childIds
+ index = accessibility_tree[parent_cursor]["childIds"].index(nodeid)
+ accessibility_tree[parent_cursor]["childIds"].pop(index)
+ # Insert children_nodeids in the same location
+ for child_nodeid in children_nodeids:
+ accessibility_tree[parent_cursor]["childIds"].insert(
+ index, child_nodeid
+ )
+ index += 1
+ # update children node's parent
+ for child_nodeid in children_nodeids:
+ child_cursor = nodeid_to_cursor[child_nodeid]
+ accessibility_tree[child_cursor]["parentId"] = parent_nodeid
+ # mark as removed
+ accessibility_tree[node_cursor]["parentId"] = "[REMOVED]"
+
+ config = info["config"]
+ for node in accessibility_tree:
+ if not node["union_bound"]:
+ remove_node_in_graph(node)
+ continue
+
+ [x, y, width, height] = node["union_bound"]
+
+ # invisible node
+ if width == 0 or height == 0:
+ remove_node_in_graph(node)
+ continue
+
+ in_viewport_ratio = self.get_element_in_viewport_ratio(
+ elem_left_bound=float(x),
+ elem_top_bound=float(y),
+ width=float(width),
+ height=float(height),
+ config=config,
+ )
+
+ if in_viewport_ratio < IN_VIEWPORT_RATIO_THRESHOLD:
+ remove_node_in_graph(node)
+
+ accessibility_tree = [
+ node
+ for node in accessibility_tree
+ if node.get("parentId", "Root") != "[REMOVED]"
+ ]
+
+ return accessibility_tree
+
+ @staticmethod
+ def parse_accessibility_tree(
+ accessibility_tree: AccessibilityTree,
+ ) -> tuple[str, dict[str, Any]]:
+ """Parse the accessibility tree into a string text"""
+ node_id_to_idx = {}
+ for idx, node in enumerate(accessibility_tree):
+ node_id_to_idx[node["nodeId"]] = idx
+
+ obs_nodes_info = {}
+
+ def dfs(idx: int, obs_node_id: str, depth: int) -> str:
+ tree_str = ""
+ node = accessibility_tree[idx]
+ indent = "\t" * depth
+ valid_node = True
+ try:
+ role = node["role"]["value"]
+ name = node["name"]["value"]
+ node_str = f"[{obs_node_id}] {role} {repr(name)}"
+ properties = []
+ for property in node.get("properties", []):
+ try:
+ if property["name"] in IGNORED_ACTREE_PROPERTIES:
+ continue
+ properties.append(
+ f'{property["name"]}: {property["value"]["value"]}'
+ )
+ except KeyError:
+ pass
+
+ if properties:
+ node_str += " " + " ".join(properties)
+
+ # check valid
+ if not node_str.strip():
+ valid_node = False
+
+ # empty generic node
+ if not name.strip():
+ if not properties:
+ if role in [
+ "generic",
+ "img",
+ "list",
+ "strong",
+ "paragraph",
+ "banner",
+ "navigation",
+ "Section",
+ "LabelText",
+ "Legend",
+ "listitem",
+ ]:
+ valid_node = False
+ elif role in ["listitem"]:
+ valid_node = False
+
+ if valid_node:
+ tree_str += f"{indent}{node_str}"
+ obs_nodes_info[obs_node_id] = {
+ "backend_id": node["backendDOMNodeId"],
+ "union_bound": node["union_bound"],
+ "text": node_str,
+ }
+
+ except Exception as e:
+ valid_node = False
+
+ for _, child_node_id in enumerate(node["childIds"]):
+ if child_node_id not in node_id_to_idx:
+ continue
+ # mark this to save some tokens
+ child_depth = depth + 1 if valid_node else depth
+ child_str = dfs(
+ node_id_to_idx[child_node_id], child_node_id, child_depth
+ )
+ if child_str.strip():
+ if tree_str.strip():
+ tree_str += "\n"
+ tree_str += child_str
+
+ return tree_str
+
+ tree_str = dfs(0, accessibility_tree[0]["nodeId"], 0)
+ return tree_str, obs_nodes_info
+
+ @staticmethod
+ def clean_accesibility_tree(tree_str: str) -> str:
+ """further clean accesibility tree"""
+ clean_lines: list[str] = []
+ for line in tree_str.split("\n"):
+ # remove statictext if the content already appears in the previous line
+ if "statictext" in line.lower():
+ prev_lines = clean_lines[-3:]
+ pattern = r"\[\d+\] StaticText (.+)"
+
+ match = re.search(pattern, line, re.DOTALL)
+ if match:
+ static_text = match.group(1)[1:-1] # remove the quotes
+ if static_text and all(
+ static_text not in prev_line for prev_line in prev_lines
+ ):
+ clean_lines.append(line)
+ else:
+ clean_lines.append(line)
+
+ return "\n".join(clean_lines)
+
+ def fetch_image_related(self, page: Page, browser_info: BrowserInfo) -> str:
+ # Check if the current page is an image url
+ if page.url.endswith((".jpg", ".jpeg", ".png")):
+ print("NOTE: We are on an image page!!!")
+ # Load image from current url and run captioning on it.
+ if page.url not in self.url2caption and self.captioning_fn is not None:
+ try:
+ image = Image.open(requests.get(page.url, stream=True).raw)
+ caption = self.captioning_fn([image])[0].strip()
+ self.url2caption[page.url] = remove_unicode(caption)
+ except Exception as e:
+ print("L579 WARNING: ", e)
+ content = self.url2caption.get(page.url, "Image")
+
+ else:
+ if self.captioning_fn is not None:
+ images = page.query_selector_all("img")
+ image_urls = []
+ for image in images:
+ try:
+ image_url = image.get_attribute("src")
+ if not image_url.startswith(("http://", "https://", "www.")):
+ image_url = urljoin(page.url, image_url)
+ if image_url not in self.url2caption:
+ image_urls.append(image_url)
+ except Exception as e:
+ print("L604 WARNING: ", e)
+
+ # Run image captioning on image_url pixels. This is for models which use captioning as a baseline.
+ if len(image_urls) > 0:
+ image_pixels = []
+ valid_urls = []
+ for url in image_urls:
+ if "data:image/svg" in url:
+ continue
+ else:
+ try:
+ image = Image.open(requests.get(url, stream=True).raw)
+ image_pixels.append(image)
+ valid_urls.append(url)
+ except Exception as e:
+ print("L616 WARNING: ", e)
+
+ # Caption images.
+ if image_pixels:
+ # Run in batches of 4.
+ bs = 4
+ captions = []
+ for i in range(0, len(image_pixels), bs):
+ try:
+ captions.extend(
+ self.captioning_fn(image_pixels[i : i + bs])
+ )
+ except Exception as e:
+ print("L628 WARNING: ", e)
+ captions.extend([""] * len(image_pixels[i : i + bs]))
+ assert len(valid_urls) == len(
+ captions
+ ), f"len(images)={len(valid_urls)}, len(captions)={len(captions)}"
+ for image_url, caption in zip(valid_urls, captions):
+ self.url2caption[image_url] = remove_unicode(
+ caption.strip()
+ )
+
+ image_idx = 0
+ for image in images:
+ try:
+ original_alt = image.get_attribute("alt") or ""
+ image_url = image.get_attribute("src")
+ if not image_url.startswith(("http://", "https://", "www.")):
+ image_url = urljoin(page.url, image_url)
+
+ updated_alt = original_alt
+
+ if image_url in self.url2caption:
+ if self.url2caption[image_url] not in updated_alt:
+ updated_alt = f"{updated_alt}, description: {self.url2caption[image_url]}"
+ elif "data:image/svg" not in image_url:
+ print(f"WARNING: {image_url} not in self.url2caption")
+
+ if "url:" not in updated_alt:
+ updated_alt = f"{updated_alt}, url: {image_url}"
+
+ safe_updated_alt = json.dumps(updated_alt)
+ image.evaluate(f"node => node.alt = {safe_updated_alt}")
+ except Exception as e:
+ print("L653 WARNING:", e)
+
+ if self.observation_type == "accessibility_tree_with_captioner":
+ frame_ax_trees = self.fetch_page_accessibility_tree(
+ page,
+ browser_info,
+ current_viewport_only=self.current_viewport_only
+ )
+ content, obs_nodes_info = self.parse_accessibility_tree(frame_ax_trees)
+ content = self.clean_accesibility_tree(content)
+ self.obs_nodes_info = obs_nodes_info
+ self.meta_data["obs_nodes_info"] = obs_nodes_info
+ else:
+ content = "" # Not used for SoM
+
+ return content
+
+ def process(self, page: Page) -> str:
+ # get the tab info
+ open_tabs = page.context.pages
+ try:
+ tab_titles = [tab.title() for tab in open_tabs]
+ current_tab_idx = open_tabs.index(page)
+ for idx in range(len(open_tabs)):
+ if idx == current_tab_idx:
+ tab_titles[idx] = f"Tab {idx} (current): {open_tabs[idx].title()}"
+ else:
+ tab_titles[idx] = f"Tab {idx}: {open_tabs[idx].title()}"
+ tab_title_str = " | ".join(tab_titles)
+ except Exception:
+ tab_title_str = " | ".join([f"Tab {idx}" for idx in range(len(open_tabs))])
+
+ try:
+ browser_info = self.fetch_browser_info(page)
+ except Exception:
+ page.wait_for_load_state("load", timeout=500)
+ browser_info = self.fetch_browser_info(page)
+
+ if self.observation_type == "html":
+ dom_tree = self.fetch_page_html(
+ browser_info,
+ page,
+ self.current_viewport_only,
+ )
+ content, obs_nodes_info = self.parse_html(dom_tree)
+ self.obs_nodes_info = obs_nodes_info
+ self.meta_data["obs_nodes_info"] = obs_nodes_info
+
+ elif self.observation_type == "accessibility_tree":
+ accessibility_tree = self.fetch_page_accessibility_tree(
+ page,
+ browser_info,
+ self.current_viewport_only,
+ )
+ content, obs_nodes_info = self.parse_accessibility_tree(accessibility_tree)
+ content = self.clean_accesibility_tree(content)
+ self.obs_nodes_info = obs_nodes_info
+ self.meta_data["obs_nodes_info"] = obs_nodes_info
+
+ elif self.observation_type in [
+ "accessibility_tree_with_captioner",
+ "image_som",
+ ]:
+ content = self.fetch_image_related(
+ page,
+ browser_info,
+ )
+
+ elif self.observation_type == "":
+ content = ""
+
+ else:
+ raise ValueError(f"Invalid observation type: {self.observation_type}")
+
+ self.browser_config = browser_info["config"]
+ content = f"{tab_title_str}\n\n{content}"
+
+ return content
+
+ def get_element_center(self, element_id: str) -> tuple[float, float]:
+ node_info = self.obs_nodes_info[element_id]
+ node_bound = node_info["union_bound"]
+ x, y, width, height = node_bound
+ center_x = x + width / 2
+ center_y = y + height / 2
+ return (
+ center_x / self.viewport_size["width"],
+ center_y / self.viewport_size["height"],
+ )
+
+
+class ImageObservationProcessor(ObservationProcessor):
+ def __init__(
+ self,
+ observation_type: str,
+ viewport_size: Optional[ViewportSize] = None,
+ ):
+ self.observation_type = observation_type
+ self.observation_tag = "image"
+ self.viewport_size = viewport_size
+ self.meta_data = create_empty_metadata()
+
+ def get_page_bboxes(self, page: Page) -> list[list[float]]:
+ """JavaScript code to return bounding boxes and other metadata from HTML elements."""
+ js_script = """
+ (() => {
+ const interactableSelectors = [
+ 'a[href]:not(:has(img))', 'a[href] img', 'button', 'input:not([type="hidden"])', 'textarea', 'select',
+ '[tabindex]:not([tabindex="-1"])', '[contenteditable="true"]', '[role="button"]', '[role="link"]',
+ '[role="checkbox"]', '[role="menuitem"]', '[role="tab"]', '[draggable="true"]',
+ '.btn', 'a[href="/notifications"]', 'a[href="/submit"]', '.fa.fa-star.is-rating-item', 'input[type="checkbox"]'
+
+ ];
+
+ const textSelectors = ['p', 'span', 'div:not(:has(*))', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'li', 'article'];
+ const modifiedTextSelectors = textSelectors.map(selector =>
+ `:not(${interactableSelectors.join(', ')}):not(style) > ${selector}`
+ );
+
+ const combinedSelectors = [...interactableSelectors, ...modifiedTextSelectors];
+ const elements = document.querySelectorAll(combinedSelectors.join(', '));
+
+ const pixelRatio = window.devicePixelRatio;
+ let csvContent = "ID,Element,Top,Right,Bottom,Left,Width,Height,Alt,Class,Id,TextContent,Interactable\\n";
+ let counter = 1;
+
+ elements.forEach(element => {
+ const rect = element.getBoundingClientRect();
+ if (rect.width === 0 || rect.height === 0) return;
+ let altText = element.getAttribute('alt') || '';
+ altText = altText.replace(/"/g, ''); // Escape double quotes in alt text
+ const classList = element.className || '';
+ const id = element.id || '';
+ let textContent = element.textContent || '';
+ textContent = textContent.replace(/"/g, ''); // Escape double quotes in textContent
+
+ // Determine if the element is interactable
+ const isInteractable = interactableSelectors.some(selector => element.matches(selector));
+
+ const dataString = [
+ counter, element.tagName, (rect.top + window.scrollY) * pixelRatio,
+ (rect.right + window.scrollX) * pixelRatio, (rect.bottom + window.scrollY) * pixelRatio,
+ (rect.left + window.scrollX) * pixelRatio, rect.width * pixelRatio, rect.height * pixelRatio,
+ altText, classList, id, textContent, isInteractable
+ ].map(value => `"${value}"`).join(",");
+
+ csvContent += dataString + "\\n";
+ counter++;
+ });
+
+ return csvContent;
+ })();
+ """
+ # Save the bbox as a CSV
+ csv_content = page.evaluate(js_script)
+ return csv_content
+
+ def draw_bounding_boxes(
+ self,
+ data_string,
+ screenshot_img,
+ viewport_size=None,
+ add_ids=True,
+ bbox_color=None,
+ min_width=8,
+ min_height=8,
+ bbox_padding=0,
+ bbox_border=2,
+ plot_ids=None,
+ ):
+ """
+ min_width and min_height: Minimum dimensions of the bounding box to be plotted.
+ """
+ # Read CSV data
+ df = pd.read_csv(StringIO(data_string), delimiter=",", quotechar='"')
+ df["Area"] = df["Width"] * df["Height"]
+ # Remove bounding boxes that are clipped.
+ b_x, b_y = (
+ self.browser_config["win_left_bound"],
+ self.browser_config["win_upper_bound"],
+ )
+ if viewport_size is not None:
+ df = df[
+ (df["Bottom"] - b_y >= 0)
+ & (df["Top"] - b_y <= viewport_size["height"])
+ & (df["Right"] - b_x >= 0)
+ & (df["Left"] - b_x <= viewport_size["width"])
+ ]
+ viewport_area = viewport_size["width"] * viewport_size["height"]
+ # Filter out bounding boxes that too large (more than 80% of the viewport)
+ df = df[df["Area"] <= 0.8 * viewport_area]
+
+ # Open the screenshot image
+ img = screenshot_img.copy()
+ draw = ImageDraw.Draw(img)
+
+ # Load a TTF font with a larger size
+ font_path = "media/SourceCodePro-SemiBold.ttf"
+ font_size, padding = 16, 2
+ font = ImageFont.truetype(font_path, font_size)
+
+ # Create a color cycle using one of the categorical color palettes in matplotlib
+ color_cycle = plt.rcParams["axes.prop_cycle"].by_key()["color"]
+ bbox_id2visid = {}
+ bbox_id2desc = {}
+ index = 0
+ id2center = {}
+ existing_text_rectangles = []
+ text_to_draw = []
+ # Provide [id] textContent inputs to the model as text.
+ text_content_elements = []
+ text_content_text = set() # Store text of interactable elements
+
+ # Iterate through each row in the CSV and draw bounding boxes
+ for _, row in df.iterrows():
+ if not row["Interactable"]:
+ content = ""
+ # Add image alt-text to the text representation.
+ if row["Element"] == "IMG" and pd.notna(row["Alt"]):
+ content += row["Alt"]
+ # Add HTML textContent (if any) to the text representation.
+ if pd.notna(row["TextContent"]):
+ content += (
+ row["TextContent"].strip().replace("\n", "").replace("\t", "")
+ )[
+ :200
+ ] # Limit to 200 characters to avoid having too much text
+
+ # Check if the text is a CSS selector
+ if content and not (content.startswith(".") and "{" in content):
+ # Add elements which are not interactable as StaticText
+ if content not in text_content_text:
+ text_content_elements.append(f"[] [StaticText] [{content}]")
+ text_content_text.add(content)
+ continue
+
+ if (plot_ids is not None) and (row["ID"] not in plot_ids):
+ continue
+
+ unique_id = str(index + 1)
+ bbox_id2visid[row["ID"]] = (
+ unique_id # map the bounding box ID to the unique character ID
+ )
+ top, right, bottom, left, width, height = (
+ row["Top"],
+ row["Right"],
+ row["Bottom"],
+ row["Left"],
+ row["Width"],
+ row["Height"],
+ )
+ left, right, top, bottom = left - b_x, right - b_x, top - b_y, bottom - b_y
+ id2center[unique_id] = (
+ (left + right) / 2,
+ (bottom + top) / 2,
+ width,
+ height,
+ )
+
+ if width >= min_width and height >= min_height:
+ # Get the next color in the cycle
+ color = bbox_color or color_cycle[index % len(color_cycle)]
+ draw.rectangle(
+ [
+ left - bbox_padding,
+ top - bbox_padding,
+ right + bbox_padding,
+ bottom + bbox_padding,
+ ],
+ outline=color,
+ width=bbox_border,
+ )
+ bbox_id2desc[row["ID"]] = color
+
+ # Draw the text on top of the rectangle
+ if add_ids:
+ # Calculate list of possible text positions
+ text_positions = [
+ (left - font_size, top - font_size), # Top-left corner
+ (
+ left,
+ top - font_size,
+ ), # A little to the right of the top-left corner
+ (right, top - font_size), # Top-right corner
+ (
+ right - font_size - 2 * padding,
+ top - font_size,
+ ), # A little to the left of the top-right corner
+ (left - font_size, bottom), # Bottom-left corner
+ (
+ left,
+ bottom,
+ ), # A little to the right of the bottom-left corner
+ (
+ right - font_size - 2 * padding,
+ bottom,
+ ), # A little to the left of the bottom-right corner
+ (
+ left,
+ bottom,
+ ), # A little to the right of the bottom-left corner
+ (
+ right - font_size - 2 * padding,
+ bottom,
+ ), # A little to the left of the bottom-right corner
+ ]
+ text_width = draw.textlength(unique_id, font=font)
+ text_height = font_size # Assume the text is one line
+
+ if viewport_size is not None:
+ for text_position in text_positions:
+ new_text_rectangle = [
+ text_position[0] - padding,
+ text_position[1] - padding,
+ text_position[0] + text_width + padding,
+ text_position[1] + text_height + padding,
+ ]
+
+ # Check if the new text rectangle is within the viewport
+ if (
+ new_text_rectangle[0] >= 0
+ and new_text_rectangle[1] >= 0
+ and new_text_rectangle[2] <= viewport_size["width"]
+ and new_text_rectangle[3] <= viewport_size["height"]
+ ):
+ # If the rectangle is within the viewport, check for overlaps
+ overlaps = False
+ for existing_rectangle in existing_text_rectangles:
+ if self.rectangles_overlap(
+ new_text_rectangle,
+ existing_rectangle,
+ padding * 2,
+ ):
+ overlaps = True
+ break
+
+ if not overlaps:
+ break
+ else:
+ # If the rectangle is outside the viewport, try the next position
+ continue
+ else:
+ # If none of the corners work, move the text rectangle by a fixed amount
+ text_position = (
+ text_positions[0][0] + padding,
+ text_positions[0][1],
+ )
+ new_text_rectangle = [
+ text_position[0] - padding,
+ text_position[1] - padding,
+ text_position[0] + text_width + padding,
+ text_position[1] + text_height + padding,
+ ]
+
+ existing_text_rectangles.append(new_text_rectangle)
+ text_to_draw.append(
+ (new_text_rectangle, text_position, unique_id, color)
+ )
+
+ content = ""
+ if row["Element"] == "IMG" and pd.notna(row["Alt"]):
+ content += row["Alt"]
+ if pd.notna(row["TextContent"]):
+ content += (
+ row["TextContent"]
+ .strip()
+ .replace("\n", "")
+ .replace("\t", "")
+ )[
+ :200
+ ] # Limit to 200 characters
+ text_content_elements.append(
+ f"[{unique_id}] [{row['Element']}] [{content}]"
+ )
+ if content in text_content_text:
+ # Remove text_content_elements with content
+ text_content_elements = [
+ element
+ for element in text_content_elements
+ if element.strip() != content
+ ]
+ text_content_text.add(content)
+
+ index += 1
+
+ for text_rectangle, text_position, unique_id, color in text_to_draw:
+ # Draw a background rectangle for the text
+ draw.rectangle(text_rectangle, fill=color)
+ draw.text(text_position, unique_id, font=font, fill="white")
+
+ content_str = "\n".join(text_content_elements)
+ return img, id2center, content_str
+
+ def rectangles_overlap(self, rect1, rect2, padding):
+ """
+ Check if two rectangles overlap.
+ Each rectangle is represented as a list [x1, y1, x2, y2].
+ """
+ return not (
+ rect1[2] < rect2[0] + padding
+ or rect1[0] > rect2[2] - padding
+ or rect1[1] > rect2[3] - padding
+ or rect1[3] < rect2[1] + padding
+ )
+
+ def process(self, page: Page) -> npt.NDArray[np.uint8]:
+ try:
+ browser_info = self.fetch_browser_info(page)
+ except Exception:
+ page.wait_for_load_state("load", timeout=500)
+ browser_info = self.fetch_browser_info(page)
+
+ self.browser_config = browser_info["config"]
+
+ if self.observation_type == "image_som":
+ # Produce the SoM image, with bounding boxes
+ try:
+ screenshot_bytes = page.screenshot()
+ som_bboxes = self.get_page_bboxes(page)
+ screenshot_img = Image.open(BytesIO(screenshot_bytes))
+ bbox_img, id2center, content_str = self.draw_bounding_boxes(
+ som_bboxes,
+ screenshot_img,
+ viewport_size=self.viewport_size,
+ )
+ self.som_id_info = id2center
+ self.meta_data["obs_nodes_info"] = id2center
+ screenshot_som = np.array(bbox_img)
+ return screenshot_som, content_str
+ except:
+ page.wait_for_event("load")
+ screenshot_bytes = page.screenshot()
+ som_bboxes = self.get_page_bboxes(page)
+ screenshot_img = Image.open(BytesIO(screenshot_bytes))
+ bbox_img, id2center, content_str = self.draw_bounding_boxes(
+ som_bboxes,
+ screenshot_img,
+ viewport_size=self.viewport_size,
+ )
+ self.som_id_info = id2center
+ self.meta_data["obs_nodes_info"] = id2center
+ screenshot_som = np.array(bbox_img)
+ return screenshot_som, content_str
+ else:
+ try:
+ screenshot = png_bytes_to_numpy(page.screenshot())
+ except:
+ page.wait_for_event("load")
+ screenshot = png_bytes_to_numpy(page.screenshot())
+ return screenshot, ""
+
+ def fetch_browser_info(self, page: Page) -> BrowserInfo:
+ client = page.context.new_cdp_session(page)
+ # extract domtree
+ tree = client.send(
+ "DOMSnapshot.captureSnapshot",
+ {
+ "computedStyles": [],
+ "includeDOMRects": True,
+ "includePaintOrder": True,
+ },
+ )
+ client.detach()
+ # calibrate the bounds, in some cases, the bounds are scaled somehow
+ bounds = tree["documents"][0]["layout"]["bounds"]
+ b = bounds[0]
+ n = b[2] / self.viewport_size["width"]
+ bounds = [[x / n for x in bound] for bound in bounds]
+ tree["documents"][0]["layout"]["bounds"] = bounds
+ # add union bound placeholder
+ tree["documents"][0]["layout"]["unionBounds"] = [None for _ in bounds]
+
+ # extract browser info
+ win_upper_bound = page.evaluate("window.pageYOffset")
+ win_left_bound = page.evaluate("window.pageXOffset")
+ win_width = page.evaluate("window.screen.width")
+ win_height = page.evaluate("window.screen.height")
+ win_right_bound = win_left_bound + win_width
+ win_lower_bound = win_upper_bound + win_height
+ device_pixel_ratio = page.evaluate("window.devicePixelRatio")
+ assert device_pixel_ratio == 1.0, "devicePixelRatio is not 1.0"
+
+ config: BrowserConfig = {
+ "win_upper_bound": win_upper_bound,
+ "win_left_bound": win_left_bound,
+ "win_width": win_width,
+ "win_height": win_height,
+ "win_right_bound": win_right_bound,
+ "win_lower_bound": win_lower_bound,
+ "device_pixel_ratio": device_pixel_ratio,
+ }
+
+ # assert len(tree['documents']) == 1, "More than one document in the DOM tree"
+ info: BrowserInfo = {"DOMTree": tree, "config": config}
+
+ return info
+
+ def get_element_center(self, element_id: str) -> tuple[float, float]:
+ if not self.observation_type == "image_som":
+ raise ValueError(
+ "get_element_center() is only supported for 'image_som' observation type."
+ )
+
+ browser_config = self.browser_config
+ center_x, center_y, width, height = self.som_id_info[element_id]
+ return (
+ center_x / self.viewport_size["width"],
+ center_y / self.viewport_size["height"],
+ )
+
+
+class TextObervationProcessorWebRL(TextObervationProcessor):
+ def __init__(
+ self,
+ observation_type: str,
+ current_viewport_only: bool,
+ viewport_size: ViewportSize,
+ captioning_fn=None,
+ ):
+ super().__init__(
+ observation_type,
+ current_viewport_only,
+ viewport_size,
+ captioning_fn,
+ )
+
+ def process(self, page: Page) -> str:
+ # get the tab info
+ page_info = get_parsed_html(page)
+ html = page_info["html"]
+ from bs4 import BeautifulSoup
+ soup = BeautifulSoup(html, 'html.parser')
+ obs_nodes_info = {}
+ for tag in soup.find_all(True):
+ if tag.has_attr('id') and tag.has_attr('data-bbox'):
+ backend_id = tag['id']
+ union_bound = tag['data-bbox']
+ union_bound = [float(num) for num in union_bound.split(',')]
+ obs_nodes_info[str(backend_id)] = {
+ "backend_id": backend_id,
+ "union_bound": union_bound,
+ "text": str(tag)
+ }
+ self.obs_nodes_info = obs_nodes_info
+ self.meta_data["obs_nodes_info"] = obs_nodes_info
+ return html
+
+ def get_element_center(self, element_id: str, page: Page=None) -> tuple[float, float]:
+
+ if page is not None:
+ element = page.query_selector(f"[data-label-id='{element_id}']")
+ bbox = element.bounding_box()
+ relative_bbox = (bbox['x'], bbox['y'], bbox['x'] + bbox['width'], bbox['y'] + bbox['height'])
+ center_x = (relative_bbox[0] + relative_bbox[2]) / 2
+ center_y = (relative_bbox[1] + relative_bbox[3]) / 2
+ else:
+ node_info = self.obs_nodes_info[element_id]
+ node_bound = node_info["union_bound"]
+ x, y, width, height = node_bound
+ center_x = x + width / 2
+ center_y = y + height / 2
+
+ return (
+ center_x / self.viewport_size["width"],
+ center_y / self.viewport_size["height"],
+ )
+
+class ObservationHandler:
+ """Main entry point to access all observation processor"""
+
+ def __init__(
+ self,
+ main_observation_type: str,
+ text_observation_type: str,
+ image_observation_type: str,
+ current_viewport_only: bool,
+ viewport_size: ViewportSize,
+ captioning_fn=None,
+ ) -> None:
+ self.main_observation_type = main_observation_type
+ if text_observation_type == "webrl":
+ self.text_processor = TextObervationProcessorWebRL(
+ text_observation_type,
+ current_viewport_only,
+ viewport_size,
+ captioning_fn,
+ )
+ else:
+ self.text_processor = TextObervationProcessor(
+ text_observation_type,
+ current_viewport_only,
+ viewport_size,
+ captioning_fn,
+ )
+ self.image_processor = ImageObservationProcessor(
+ image_observation_type, viewport_size
+ )
+ self.viewport_size = viewport_size
+
+ def get_observation_space(self) -> spaces.Dict:
+ text_space = spaces.Text(
+ min_length=0,
+ max_length=UTTERANCE_MAX_LENGTH,
+ charset=ASCII_CHARSET + FREQ_UNICODE_CHARSET,
+ )
+
+ image_space = spaces.Box(
+ # Each position stores the RGB values. Note the swapped axes (height first).
+ np.zeros(
+ (self.viewport_size["height"], self.viewport_size["width"], 3),
+ dtype=np.uint8,
+ ),
+ np.ones(
+ (self.viewport_size["height"], self.viewport_size["width"], 3),
+ dtype=np.uint8,
+ )
+ * 255.0,
+ dtype=np.uint8,
+ )
+
+ return spaces.Dict({"text": text_space, "image": image_space})
+
+ def get_observation(self, page: Page) -> dict[str, Observation]:
+ text_obs = self.text_processor.process(page)
+ image_obs, content_str = self.image_processor.process(page)
+ if content_str != "":
+ text_obs = content_str
+ return {"text": text_obs, "image": image_obs}
+
+ def get_observation_metadata(self) -> dict[str, ObservationMetadata]:
+ return {
+ "text": self.text_processor.meta_data,
+ "image": self.image_processor.meta_data,
+ }
+
+ @property
+ def action_processor(self) -> ObservationProcessor:
+ """Return the main processor that is associated with the action space"""
+ if self.main_observation_type == "text":
+ return self.text_processor
+ elif self.main_observation_type == "image":
+ return self.image_processor
+ else:
+ raise ValueError("Invalid main observation type")
diff --git a/VAB-WebArena-Lite/new/prompt_constructor.py b/VAB-WebArena-Lite/new/prompt_constructor.py
index 9cd5a2a..fc59b06 100644
--- a/VAB-WebArena-Lite/new/prompt_constructor.py
+++ b/VAB-WebArena-Lite/new/prompt_constructor.py
@@ -499,3 +499,59 @@ class MultimodalCoTPromptConstructor(CoTPromptConstructor):
raise NotImplementedError(
f"Provider {self.lm_config.provider} not implemented"
)
+
+class WebRLPromptConstructor(PromptConstructor):
+ """The agent will direct predict the action"""
+
+ def __init__(
+ self,
+ instruction_path: str | Path,
+ lm_config: lm_config.LMConfig,
+ tokenizer: Tokenizer,
+ ):
+ super().__init__(instruction_path, lm_config, tokenizer)
+
+ def construct(
+ self,
+ trajectory: Trajectory,
+ intent: str,
+ meta_data: dict[str, Any] = {},
+ ) -> APIInput:
+ """Construct prompt given the trajectory"""
+ state_info: StateInfo = trajectory[-1] # type: ignore[assignment]
+
+ obs = state_info["observation"][self.obs_modality]
+ max_obs_length = self.lm_config.gen_config["max_obs_length"]
+ if max_obs_length:
+ if self.lm_config.provider == "google":
+ print("NOTE: This is a Gemini model, so we use characters instead of tokens for max_obs_length.")
+ obs = obs[:max_obs_length]
+ else:
+ try:
+ obs = self.tokenizer.decode(self.tokenizer.encode(obs)[:max_obs_length]) # type: ignore[arg-type]
+ except:
+ print("NOTE: There is no available tokenizer, so we use characters instead of tokens for max_obs_length.")
+ obs = obs[:max_obs_length]
+
+ turn_num = len(meta_data["action_history"])
+ if turn_num == 1:
+ previous_action_str = []
+ else:
+ previous_action_str = meta_data["action_history"][1:]
+
+ index = turn_num - 1
+ history = ""
+ for i in range(index - 1, -1, -1):
+ if i == 0:
+ history = f"Round {i}\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n{intent}\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n{previous_action_str[i]}\n\n" + history
+ else:
+ history = f"Round {i}\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n** Simplified html **\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n{previous_action_str[i]}\n\n" + history
+ if len(history) + len(obs) > (16384 - 512):
+ obs = obs[:(16384 - 512)-len(history)]
+ current_turn = f"Round {index}\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n{obs}\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
+ prompt = f"Task Instruction: {intent}\n\n{history}{current_turn}"
+
+ return prompt
+
+ def extract_action(self, response: str) -> str:
+ return response
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/run.py b/VAB-WebArena-Lite/new/run.py
index e406e2c..6f16161 100644
--- a/VAB-WebArena-Lite/new/run.py
+++ b/VAB-WebArena-Lite/new/run.py
@@ -13,11 +13,13 @@ import tempfile
import time
from pathlib import Path
from typing import List
+import cv2
+import shutil
import openai
import requests
import torch
-from PIL import Image
+from PIL import Image, ImageDraw, ImageFont
from agent import (
PromptAgent,
@@ -62,6 +64,34 @@ formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
console_handler.setFormatter(formatter)
file_handler.setFormatter(formatter)
+def text_wrap(text, font, max_width):
+ lines = []
+ paragraphs = text.split('\n') # 按照 \n 分割文本为段落
+ for paragraph in paragraphs:
+ words = paragraph.split(' ')
+ line = ''
+ for word in words:
+ # 临时行
+ test_line = f"{line} {word}".strip()
+ # 获取临时行的宽度
+ test_line_bbox = font.getbbox(test_line)
+ test_line_width = test_line_bbox[2] - test_line_bbox[0]
+ if test_line_width <= max_width:
+ # 如果临时行的宽度不超过图片宽度,继续添加单词
+ line = test_line
+ else:
+ # 如果超过了最大宽度,保存当前行,开始新的一行
+ lines.append(line)
+ line = word
+ # 添加每段的最后一行
+ if line:
+ lines.append(line)
+ # 每个段落后添加一个空行,以保留段落的换行
+ lines.append('')
+ # 移除最后一个空行(不需要额外的空行)
+ if lines[-1] == '':
+ lines.pop()
+ return lines
def config() -> argparse.Namespace:
parser = argparse.ArgumentParser(
@@ -88,6 +118,7 @@ def config() -> argparse.Namespace:
"html",
"image",
"image_som",
+ "webrl",
],
default="accessibility_tree",
help="Observation type",
@@ -176,6 +207,9 @@ def config() -> argparse.Namespace:
# logging related
parser.add_argument("--result_dir", type=str, default="")
+
+ # if use self-deployed model
+ parser.add_argument("--planner_ip", type=str, default=None)
args = parser.parse_args()
# check the whether the action space is compatible with the observation space
@@ -196,7 +230,7 @@ def config() -> argparse.Namespace:
def early_stop(
- trajectory: Trajectory, max_steps: int, thresholds: dict[str, int]
+ trajectory: Trajectory, max_steps: int, thresholds: dict[str, int], actions=None
) -> tuple[bool, str]:
"""Check whether need to stop early"""
@@ -228,28 +262,38 @@ def early_stop(
if len(action_seq) == 0:
return False, ""
- last_action: Action = action_seq[-1]
+ if actions is None:
+ last_action: Action = action_seq[-1]
+ if last_action["action_type"] != ActionTypes.TYPE:
+ if len(last_k_actions) >= k:
+ if all(
+ [
+ is_equivalent(action, last_action)
+ for action in last_k_actions
+ ]
+ ):
+ return True, f"Same action for {k} times"
+ else:
+ # check the action sequence
+ if (
+ sum([is_equivalent(action, last_action) for action in action_seq])
+ >= k
+ ):
+ return True, f"Same typing action for {k} times"
+ return False, ""
- if last_action["action_type"] != ActionTypes.TYPE:
+ else:
+ last_k_actions = actions[-k:]
+ last_action = actions[-1]
if len(last_k_actions) >= k:
if all(
[
- is_equivalent(action, last_action)
+ action == last_action
for action in last_k_actions
]
):
return True, f"Same action for {k} times"
-
- else:
- # check the action sequence
- if (
- sum([is_equivalent(action, last_action) for action in action_seq])
- >= k
- ):
- return True, f"Same typing action for {k} times"
-
- return False, ""
-
+ return False, ""
def update_action_history(path: str, task_id: int, actions: List[str], score: float=-0.1):
obj = {
@@ -387,18 +431,22 @@ def test(
obs, info = env.reset(options={"config_file": config_file})
state_info: StateInfo = {"observation": obs, "info": info}
trajectory.append(state_info)
-
meta_data = {"action_history": ["None"]}
out_path = os.path.join(args.result_dir, "actions", f"{task_id}.json")
actions = []
+ os.makedirs(os.path.join(args.result_dir, 'screehshots'), exist_ok=True)
+ if os.path.exists(os.path.join(args.result_dir, 'screehshots', f"{task_id}")):
+ shutil.rmtree(os.path.join(args.result_dir, 'screehshots', f"{task_id}"))
+ os.makedirs(os.path.join(args.result_dir, 'screehshots', f"{task_id}"))
+
while True:
update_action_history(out_path, task_id, actions=actions)
-
+ # If no actions variable is passed, the behavior of early_stop is the same as the original one.
early_stop_flag, stop_info = early_stop(
- trajectory, max_steps, early_stop_thresholds
+ trajectory, max_steps, early_stop_thresholds, actions
)
-
+
if early_stop_flag:
action = create_stop_action(f"Early stop: {stop_info}")
else:
@@ -407,14 +455,14 @@ def test(
trajectory,
intent,
images=images,
- meta_data=meta_data,
+ meta_data=meta_data
)
except ValueError as e:
# get the error message
action = create_stop_action(f"ERROR: {str(e)}")
-
+
trajectory.append(action)
-
+
action_str = get_action_description(
action,
state_info["info"]["observation_metadata"],
@@ -426,13 +474,50 @@ def test(
render_helper.render(
action, state_info, meta_data, args.render_screenshot
)
+
+ current_screenshot = os.path.join(args.result_dir, 'screehshots', f"{task_id}", f"{len(actions)}.png")
+ _ = env.page.viewport_size
+ env.page.screenshot(path="/dev/null")
+ env.page.screenshot(path=current_screenshot)
+ element_id = action["element_id"]
+ if element_id != "":
+ element = env.page.query_selector(f"[data-label-id='{element_id}']")
+ if element:
+ bbox = element.bounding_box()
+ bbox = [int(bbox['x']), int(bbox['y']), int(bbox['width']),int(bbox['height'])]
+ image = cv2.imread(current_screenshot)
+ cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[0] + bbox[2], bbox[1] + bbox[3]), (0, 255, 0), 2)
+ cv2.circle(image, (int(bbox[0] + bbox[2] / 2), int(bbox[1] + bbox[3] / 2)), radius=0, color=(0, 255, 0), thickness=2)
+ cv2.imwrite(current_screenshot, image)
+ font_path = "/workspace/qzh/Arial.ttf" # 使用FreeSans字体
+ font_size = 30 # 你可以调整这个值来增大或缩小字体
+ image = Image.open(current_screenshot)
+ font = ImageFont.truetype(font_path, font_size)
+ draw = ImageDraw.Draw(image)
+ image_width = image.width
+ wrapped_text = text_wrap(action_str, font, image_width)
+ line_height = font.getbbox('hg')[3] - font.getbbox('hg')[1]
+ text_height = line_height * len(wrapped_text)
+ new_image_height = image.height + text_height + 20 # 20 is extra white space
+ new_image = Image.new('RGB', (image.width, new_image_height), color=(255, 255, 255)) # white background
+ draw_new = ImageDraw.Draw(new_image)
+ y_text = 10 # Initial position of text from top
+ for line in wrapped_text:
+ text_bbox = draw_new.textbbox((0, 0), line, font=font)
+ text_width = text_bbox[2] - text_bbox[0]
+ text_position = ((image_width - text_width) // 2, y_text) # Center text horizontally
+ draw_new.text(text_position, line, font=font, fill=(0, 0, 0)) # black text
+ y_text += line_height # move to next line
+ new_image.paste(image, (0, text_height + 20))
+ new_image.save(current_screenshot)
+
meta_data["action_history"].append(action_str)
actions.append(action_str)
- print(action_str)
+ print('Action String: ', action_str)
if action["action_type"] == ActionTypes.STOP:
break
-
+
obs, _, terminated, _, info = env.step(action)
state_info = {"observation": obs, "info": info}
trajectory.append(state_info)
@@ -540,7 +625,7 @@ if __name__ == "__main__":
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = config()
- args.sleep_after_execution = 2.5
+ args.sleep_after_execution = 3.0
prepare(args)
test_config_base_dir = args.test_config_base_dir
diff --git a/VAB-WebArena-Lite/new/score.py b/VAB-WebArena-Lite/new/score.py
new file mode 100644
index 0000000..95235e5
--- /dev/null
+++ b/VAB-WebArena-Lite/new/score.py
@@ -0,0 +1,86 @@
+import os, json, sys, copy
+
+USE_TASKS = [i for i in range(165)]
+
+def get_result(res_dict, src="all"):
+ if len(res_dict) == 0:
+ return ''
+
+ success_id = [k for k, v in res_dict.items() if v >= 1.0]
+ score = len(success_id)
+ finish_count = len(res_dict)
+ pacc, acc = score / finish_count * 100, score / TASKS * 100
+
+ print(sorted(success_id))
+
+ meta = """
+--------
+src file: {}
+successed: {:3} / {:4} (812)
+partial accuracy: {:7}
+overall accuracy: {:7}
+--------
+""".format(src, int(score), finish_count, round(pacc, 2), round(acc, 2))
+
+ print(meta)
+
+def export_result(res_dict, src=".", note=["1.0", "0.0"], show_all=False):
+ out_string = ""
+ for id in USE_TASKS:
+ # with open(f"Pipeline/config_files/{id}.json", "r") as f:
+ # jd = json.load(f)
+
+ # if "map" in jd["sites"]:
+ # continue
+ if id in res_dict:
+ if res_dict[id] >= 1.0:
+ out_string += note[0]
+ else:
+ out_string += note[1]
+ elif show_all:
+ out_string += note[1]
+ out_string += "\n"
+
+ with open(os.path.join(src, 'export.txt'), 'w') as f:
+ f.write(out_string)
+
+TASKS = 165
+
+files = sys.argv[1]
+file_list = files.split(',')
+
+all_result = {}
+
+for src in file_list:
+ path = os.path.join(src, 'actions')
+
+ result = {}
+ finished = os.listdir(path)
+
+ for file in finished:
+ if not file.endswith('.json'):
+ continue
+ with open(os.path.join(path, file), 'r') as f:
+ data = json.load(f)
+
+ if not isinstance(data, dict):
+ continue
+
+ task_id = data.get('task_id', 1000)
+ # if task_id >= TASKS:
+ # continue
+
+ task_score = data.get('score', 0)
+ if task_score < 0:
+ continue
+
+ result[task_id] = task_score
+ if task_id not in all_result or task_score > all_result[task_id]:
+ all_result[task_id] = task_score
+
+ get_result(result, src)
+ export_result(result, src=src)
+
+if len(file_list) > 1:
+ get_result(all_result)
+export_result(all_result, show_all=True)
diff --git a/VAB-WebArena-Lite/new/test_webarena_lite.raw.json b/VAB-WebArena-Lite/new/test_webarena_lite.raw.json
index c77225a..adb1822 100644
--- a/VAB-WebArena-Lite/new/test_webarena_lite.raw.json
+++ b/VAB-WebArena-Lite/new/test_webarena_lite.raw.json
@@ -187,7 +187,7 @@
"string_match"
],
"reference_answers": {
- "must_include": [
+ "fuzzy_match": [
"0"
]
},
@@ -774,8 +774,8 @@
"string_match"
],
"reference_answers": {
- "fuzzy_match": [
- "914km"
+ "must_include": [
+ "914"
]
},
"reference_url": "",
@@ -1139,7 +1139,7 @@
"string_match"
],
"reference_answers": {
- "must_include": [
+ "fuzzy_match": [
"0"
]
},
@@ -1679,7 +1679,7 @@
"string_match"
],
"reference_answers": {
- "exact_match": "N/A"
+ "fuzzy_match": ["N/A"]
},
"reference_url": "",
"program_html": [],
@@ -2605,7 +2605,7 @@
"string_match"
],
"reference_answers": {
- "exact_match": "yjlou"
+ "must_include": "yjlou"
},
"reference_url": "",
"program_html": [],
@@ -3712,7 +3712,7 @@
"string_match"
],
"reference_answers": {
- "exact_match": "N/A"
+ "fuzzy_match": "N/A"
},
"reference_url": "",
"program_html": [],
diff --git a/VAB-WebArena-Lite/new/tokenizers.py b/VAB-WebArena-Lite/new/tokenizers.py
index 2234040..b6e8e74 100644
--- a/VAB-WebArena-Lite/new/tokenizers.py
+++ b/VAB-WebArena-Lite/new/tokenizers.py
@@ -1,15 +1,18 @@
from typing import Any
import tiktoken
-from transformers import LlamaTokenizer # type: ignore
+from transformers import LlamaTokenizer, AutoTokenizer # type: ignore
class Tokenizer(object):
def __init__(self, provider: str, model_name: str) -> None:
if provider == "openai":
- self.tokenizer = tiktoken.encoding_for_model(model_name)
+ try:
+ self.tokenizer = tiktoken.encoding_for_model(model_name)
+ except: # The provider is in openai format but the model is a finetuned model
+ self.tokenizer = None
elif provider == "huggingface":
- self.tokenizer = LlamaTokenizer.from_pretrained(model_name)
+ self.tokenizer = LlamaTokenizer.from_pretrained(model_name, trust_remote_code=True)
# turn off adding special tokens automatically
self.tokenizer.add_special_tokens = False # type: ignore[attr-defined]
self.tokenizer.add_bos_token = False # type: ignore[attr-defined]
diff --git a/VAB-WebArena-Lite/new/utils.py b/VAB-WebArena-Lite/new/utils.py
index 4558585..2a8b87c 100644
--- a/VAB-WebArena-Lite/new/utils.py
+++ b/VAB-WebArena-Lite/new/utils.py
@@ -21,6 +21,8 @@ APIInput = str | list[Any] | dict[str, Any]
def call_llm(
lm_config: lm_config.LMConfig,
prompt: APIInput,
+ api_key = None,
+ base_url = None
) -> str:
response: str
if lm_config.provider == "openai":
@@ -39,11 +41,13 @@ def call_llm(
assert isinstance(prompt, str)
response = generate_from_openai_completion(
prompt=prompt,
- engine=lm_config.model,
+ model=lm_config.model,
temperature=lm_config.gen_config["temperature"],
max_tokens=lm_config.gen_config["max_tokens"],
top_p=lm_config.gen_config["top_p"],
stop_token=lm_config.gen_config["stop_token"],
+ api_key=api_key,
+ base_url=base_url
)
else:
raise ValueError(
diff --git a/VAB-WebArena-Lite/new/wa_parallel_run_webrl.sh b/VAB-WebArena-Lite/new/wa_parallel_run_webrl.sh
new file mode 100644
index 0000000..f8c4784
--- /dev/null
+++ b/VAB-WebArena-Lite/new/wa_parallel_run_webrl.sh
@@ -0,0 +1,97 @@
+#!/bin/bash
+DATASET='webarena' # TODO: select from ['webarena', 'visualwebarena']
+result_dir='' # TODO: set your result_dir
+provider='openai' # TODO: select from ['openai', 'finetune', ...]
+model='' # TODO: assign model name, which is used for action generation
+planner_ip='' # TODO: ip address of the model you are deploying (only if you are deploying your own model using e.g. vllm)
+instruction_path='agent/prompts/jsons/p_webrl.json' # e.g., agent/prompts/jsons/p_cot_id_actree_2s.json
+test_config_base_dir='config_files/wa/test_webarena_lite' # e.g., config_files/wa/test_webarena_lite
+temperature=0.0
+
+SERVER='' # TODO: your server address
+MAP_SERVER='' # TODO: the server address for MAP tasks
+OPENAI_API_KEY='' # TODO: if you test OpenAI APIs
+OPENAI_ORGANIZATION=''
+CONDA_ENV_NAME='' # TODO: the name of your conda environment for testing WebArena
+
+ENV_VARIABLES="export DATASET=${DATASET}; export SHOPPING='http://${SERVER}:7770';export SHOPPING_ADMIN='http://${SERVER}:7780/admin';export REDDIT='http://${SERVER}:9999';export GITLAB='http://${SERVER}:8023';export MAP='http://${MAP_SERVER}:3000';export WIKIPEDIA='http://${SERVER}:8888/wikipedia_en_all_maxi_2022-05/A/User:The_other_Kiwix_guy/Landing';export HOMEPAGE='http://${SERVER}:4399';export OPENAI_API_KEY=${OPENAI_API_KEY};export OPENAI_ORGANIZATION=${OPENAI_ORGANIZATION}"
+
+# get the number of tmux panes
+num_panes=$(tmux list-panes | wc -l)
+
+# calculate how many panes need to be created
+let "panes_to_create = 7 - num_panes"
+
+# array of tmux commands to create each pane
+tmux_commands=(
+ 'tmux split-window -h'
+ 'tmux split-window -v'
+ 'tmux select-pane -t 0; tmux split-window -v'
+ 'tmux split-window -v'
+ 'tmux select-pane -t 3; tmux split-window -v'
+ 'tmux select-pane -t 5; tmux split-window -v'
+)
+
+# create panes up to 7
+for ((i=0; i<$panes_to_create; i++)); do
+ eval ${tmux_commands[$i]}
+done
+
+#!/bin/bash
+
+# Function to run a job
+run_job() {
+ tmux select-pane -t $1
+ COMMAND="python run.py \
+ --instruction_path ${instruction_path} \
+ --test_start_idx $2 \
+ --test_end_idx $3 \
+ --result_dir ${result_dir} \
+ --test_config_base_dir ${test_config_base_dir} \
+ --provider ${provider} \
+ --model ${model} \
+ --mode completion \
+ --planner_ip ${planner_ip} \
+ --stop_token \"<|eot_id|>\" \
+ --temperature ${temperature} \
+ --max_obs_length 0 \
+ --max_tokens 2048 \
+ --viewport_width 1280 \
+ --viewport_height 720 \
+ --parsing_failure_th 5 \
+ --repeating_action_failure_th 5 \
+ --action_set_tag webrl_id --observation_type webrl"
+ tmux send-keys "tmux set mouse on; conda activate ${CONDA_ENV_NAME}; ${ENV_VARIABLES}; until ${COMMAND}; do echo 'crashed' >&2; sleep 1; done" C-m
+ sleep 3
+}
+
+TOLERANCE=2
+run_batch() {
+ args=("$@") # save all arguments in an array
+ num_jobs=${#args[@]} # get number of arguments
+
+ for ((i=1; i<$num_jobs; i++)); do
+ run_job $i ${args[i-1]} ${args[i]}
+ done
+
+ # Wait for all jobs to finish
+ while tmux list-panes -F "#{pane_pid} #{pane_current_command}" | grep -q python; do
+ sleep 100 # wait for 10 seconds before checking again
+ done
+
+ # Run checker
+ while ! python scripts/check_error_runs.py ${result_dir} --delete_errors --tolerance ${TOLERANCE}; do
+ echo "Check failed, rerunning jobs..."
+ for ((i=1; i<$num_jobs; i++)); do
+ run_job $i ${args[i-1]} ${args[i]}
+ done
+
+ # Wait for all jobs to finish
+ while tmux list-panes -F "#{pane_pid} #{pane_current_command}" | grep -q python; do
+ sleep 100 # wait for 10 seconds before checking again
+ done
+ done
+
+}
+run_batch 0 28 56 84 112 140 165
+
diff --git a/VAB-WebArena-Lite/replace.sh b/VAB-WebArena-Lite/replace.sh
index add244c..7a674f0 100644
--- a/VAB-WebArena-Lite/replace.sh
+++ b/VAB-WebArena-Lite/replace.sh
@@ -14,6 +14,14 @@ cp -f new/generate_test_data.py visualwebarena/scripts/generate_test_data.py
cp -f new/run.py visualwebarena/run.py
cp -f new/agent.py visualwebarena/agent/agent.py
cp -f new/prompt_constructor.py visualwebarena/agent/prompts/prompt_constructor.py
+cp -f new/p_webrl.json visualwebarena/agent/prompts/jsons/p_webrl.json
+
+# browser_env
+cp -f new/actions.py visualwebarena/browser_env/actions.py
+cp -f new/envs.py visualwebarena/browser_env/envs.py
+cp -f new/helper_functions_browser.py visualwebarena/browser_env/helper_functions.py
+cp -f new/processors.py visualwebarena/browser_env/processors.py
+cp -rf new/html_tools visualwebarena/browser_env/
# llms
cp -f new/utils.py visualwebarena/llms/utils.py
@@ -22,15 +30,19 @@ cp -f new/lm_config.py visualwebarena/llms/lm_config.py
cp -f new/tokenizers.py visualwebarena/llms/tokenizers.py
cp -f new/api_utils.py visualwebarena/llms/providers/api_utils.py
cp -f new/openai_utils.py visualwebarena/llms/providers/openai_utils.py
+cp -f new/utils.py visualwebarena/llms/utils.py
# eval
cp -f new/evaluators.py visualwebarena/evaluation_harness/evaluators.py
-cp -f new/helper_functions.py visualwebarena/evaluation_harness/helper_functions.py
+cp -f new/helper_functions_eval.py visualwebarena/evaluation_harness/helper_functions.py
# misc
cp -f README.md visualwebarena/README.md
cp -f new/wa_parallel_run.sh visualwebarena/wa_parallel_run.sh
+cp -f new/score.py visualwebarena/score.py
+cp -f new/wa_parallel_run_webrl.sh visualwebarena/wa_parallel_run_webrl.sh
+
# 3. remove temporary files
mv visualwebarena/* .
rm -rf new
From a1a6cbd2099c3596a05a53182144168b1e06f924 Mon Sep 17 00:00:00 2001
From: QZH-777 <1961710177@qq.com>
Date: Thu, 14 Nov 2024 20:04:38 +0800
Subject: [PATCH 2/4] add webrl chat mode
---
VAB-WebArena-Lite/new/p_webrl_chat.json | 13 +++
VAB-WebArena-Lite/new/prompt_constructor.py | 60 ++++++++++++
VAB-WebArena-Lite/new/score.py | 25 +++++
.../new/wa_parallel_run_webrl_chat.sh | 97 +++++++++++++++++++
VAB-WebArena-Lite/replace.sh | 2 +
5 files changed, 197 insertions(+)
create mode 100644 VAB-WebArena-Lite/new/p_webrl_chat.json
create mode 100644 VAB-WebArena-Lite/new/wa_parallel_run_webrl_chat.sh
diff --git a/VAB-WebArena-Lite/new/p_webrl_chat.json b/VAB-WebArena-Lite/new/p_webrl_chat.json
new file mode 100644
index 0000000..ce1595b
--- /dev/null
+++ b/VAB-WebArena-Lite/new/p_webrl_chat.json
@@ -0,0 +1,13 @@
+{
+ "intro": "# Setup\nYou are a professional web browsing agent assistant that can fulfill user's high-level instructions. Given Simplified html of the browsed webpage at each step, you plan operations in python-style pseudo code using provided functions, or customize functions (if necessary) and then provide their implementations. \n# More details about the code\nYour code should be readable, simple, and only **ONE-LINE-OF-CODE** at a time, avoid using loop statement and only use if-else control if necessary. Predefined functions are as follow:\n\n```\ndef do(action, argument, element):\n\t\"\"\"A single browsing operation on the webpage.\n\tArgs:\n\t\t:param action: one of the actions from [\"Click\", \"Right Click\", \"Type\", \"Search\", \"Hover\", \"Scroll Up\", \"Scroll Down\", \"Press Enter\", \"Switch Tab\", \"Select Dropdown Option\", \"Wait\"].\n\t\t:param argument: optional. Only for \"Type\", \"Search\", \"Switch Page\", and \"Select Dropdown Option\", indicating the content to type in, page number(start from 0) to switch, or key to press.\n\t\t \"Search\" action is equivalent to \"Type\" action plus \"Enter\" key press.\n\t\t:param element: optional. Only for \"Click\", \"Right Click\", \"Type\", \"Search\", \"Select Dropdown Option\", and \"Hover\". Should be specific element id in the html.\n\tReturns:\n\t\tNone. The webpage will be updated after executing the action.\n\t\"\"\"\n\ndef exit(message):\n\t\"\"\"Ending the browsing process if the assistant think it has fulfilled the goal.\n\tArgs:\n\t\t:param message: optional. If user's instruction is a question, return assistant's answer in the message based on the browsing content.\n\tReturns:\n\t\tNone.\n\t\"\"\"\n\ndef go_backward():\n\t\"\"\"Go back to the previous page.\n\t\"\"\"\n\ndef go_forward():\n \"\"\"Go forward to the next page.\n \"\"\"\n```\n\nHere are some examples:\n- # Element: the 'REPORTS' section on the left sidebar\ndo(action=\"Click\", element=\"7\")\n- # Element: the 'Period' dropdown, middle center\ndo(action=\"Select Dropdown Option\", argument=\"Month\", element=\"20\")\n- # Element: the 'From' date picker input field, middle center\ndo(action=\"Type\", argument=\"01/01/2023\", element=\"22\")\n- do(action=\"Scroll Down\")\n- exit(message=\"The top-3 best-selling products in January 2023 are: 1\")\n- # Element: The search bar\ndo(action=\"Search\", argument=\"international airport near Carnegie Mellon University within a driving distance of 50 km\", element=\"13\")\n- # Note: Pittsburgh International Airport, Southern Beltway, Findlay Township, Allegheny County, 15231, United States\n# Element: The field labeled 'Pittsburgh International Airport' in the top left corner\ndo(action=\"Type\", argument=\"Cleveland Hopkins International Airport\", element=\"14\")\n\nREMEMBER: \n- only **ONE-LINE-OF-CODE** at a time\n- Don't generate an operation element that you do not see in the screenshot.\n- Use \"# Element\" to describe the element you choose in the html.\n- Use '# Note\" to record information useful to answer the instruction if needed.\n- If you find yourself fallen into some sort of loop, try to use another method or change your action.\n- If you think a page is still loading or still playing animation and you want to wait a while, use \"Wait\" action.\n- You are acting in a real world, try your best not to reject user's demand. Solve all the problem you encounter.\n- If you think you didn't get expected webpage, you should try using more precise and locative description of the element.\n- You must make sure the target element of `find_element*` exists on current screenshot, if not, you should navigate to the target place first.\n- You must identify potential errors or mistakes made by `find_element*` function and correct them. If the webpage is not as expected, you should try to re-do or un-do the operation.\n- You should **NEVER** try to use the browser's address bar at the top of the page to navigate.\n- Your answer shouldn't be in a code snippet format. Just write the function name and its arguments.\n- For quote, exit, go_backward, go_forward request, you should strictly obey the format of quote, exit, go_backward, go_forward functions, answers like do(\"Quote\", xxx, None) or do(\"quote\", xxx, None)are not allowed.\n- If you use do function to perform \"Click\", \"Right Click\", \"Type\", \"Search\", \"Select Dropdown Option\", and \"Hover\", the param element must not be None.\n",
+ "examples": [],
+ "template": "",
+ "meta_data": {
+ "observation": "webrl",
+ "action_type": "webrl_id",
+ "keywords": [],
+ "prompt_constructor": "WebRLChatPromptConstructor",
+ "answer_phrase": "",
+ "action_splitter": ""
+ }
+}
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/prompt_constructor.py b/VAB-WebArena-Lite/new/prompt_constructor.py
index fc59b06..3a21bdd 100644
--- a/VAB-WebArena-Lite/new/prompt_constructor.py
+++ b/VAB-WebArena-Lite/new/prompt_constructor.py
@@ -553,5 +553,65 @@ class WebRLPromptConstructor(PromptConstructor):
return prompt
+ def extract_action(self, response: str) -> str:
+ return response
+
+class WebRLChatPromptConstructor(PromptConstructor):
+ """The agent will direct predict the action"""
+
+ def __init__(
+ self,
+ instruction_path: str | Path,
+ lm_config: lm_config.LMConfig,
+ tokenizer: Tokenizer,
+ ):
+ super().__init__(instruction_path, lm_config, tokenizer)
+
+ def construct(
+ self,
+ trajectory: Trajectory,
+ intent: str,
+ meta_data: dict[str, Any] = {},
+ ) -> APIInput:
+ """Construct prompt given the trajectory"""
+ state_info: StateInfo = trajectory[-1] # type: ignore[assignment]
+
+ obs = state_info["observation"][self.obs_modality]
+ max_obs_length = self.lm_config.gen_config["max_obs_length"]
+ if max_obs_length:
+ if self.lm_config.provider == "google":
+ print("NOTE: This is a Gemini model, so we use characters instead of tokens for max_obs_length.")
+ obs = obs[:max_obs_length]
+ else:
+ try:
+ obs = self.tokenizer.decode(self.tokenizer.encode(obs)[:max_obs_length]) # type: ignore[arg-type]
+ except:
+ print("NOTE: There is no available tokenizer, so we use characters instead of tokens for max_obs_length.")
+ obs = obs[:max_obs_length]
+
+ turn_num = len(meta_data["action_history"])
+ if turn_num == 1:
+ previous_action_str = []
+ else:
+ previous_action_str = meta_data["action_history"][1:]
+
+ index = turn_num - 1
+ conversations = []
+ for i in range(index - 1, -1, -1):
+ if i == 0:
+ content_user = f"Task Instruction: {intent}\n\nRound {i}\n{intent}"
+ content_assistant = f"{previous_action_str[i]}"
+ else:
+ content_user = f"Round {i}\n** Simplified html **"
+ content_assistant = f"{previous_action_str[i]}"
+ conversation = [{'role': 'user', 'content': content_user}, {'role': 'assistant', 'content': content_assistant}]
+ conversations = conversation + conversations
+
+ system_turn = [{'role': 'system', 'content': self.instruction['intro']}]
+ current_turn = [{'role': 'user', 'content': f'Round {index}\n\n{obs}'}]
+ conversations = system_turn + conversations + current_turn
+
+ return conversations
+
def extract_action(self, response: str) -> str:
return response
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/score.py b/VAB-WebArena-Lite/new/score.py
index 95235e5..915c3e7 100644
--- a/VAB-WebArena-Lite/new/score.py
+++ b/VAB-WebArena-Lite/new/score.py
@@ -84,3 +84,28 @@ for src in file_list:
if len(file_list) > 1:
get_result(all_result)
export_result(all_result, show_all=True)
+
+with open('./config_files/wa/test_webarena_lite.raw.json') as fp:
+ configs = json.load(fp)
+sub_results = {}
+sub_ids = {}
+for item in configs:
+ web = tuple(item['sites'])
+ task_id = int(item['task_id'])
+ old_task_id = int(item['old_task_id'])
+ if web not in sub_results:
+ sub_results[web] = []
+ if web not in sub_ids:
+ sub_ids[web] = []
+ if task_id in all_result:
+ sub_results[web].append(all_result[task_id])
+ if all_result[task_id] == 1:
+ sub_ids[web].append(old_task_id)
+ else:
+ sub_results[web].append(0)
+for web in sub_results:
+ print(web, round(sum(sub_results[web]) / len(sub_results[web]) * 100, 1))
+
+print('\n\n')
+for web in sub_ids:
+ print(web, sorted(sub_ids[web]), len(sub_ids[web]))
\ No newline at end of file
diff --git a/VAB-WebArena-Lite/new/wa_parallel_run_webrl_chat.sh b/VAB-WebArena-Lite/new/wa_parallel_run_webrl_chat.sh
new file mode 100644
index 0000000..f018513
--- /dev/null
+++ b/VAB-WebArena-Lite/new/wa_parallel_run_webrl_chat.sh
@@ -0,0 +1,97 @@
+#!/bin/bash
+DATASET='webarena' # TODO: select from ['webarena', 'visualwebarena']
+result_dir='' # TODO: set your result_dir
+provider='openai' # TODO: select from ['openai', 'finetune', ...]
+model='' # TODO: assign model name, which is used for action generation
+planner_ip='' # TODO: ip address of the model you are deploying (only if you are deploying your own model using e.g. vllm)
+instruction_path='agent/prompts/jsons/p_webrl_chat.json' # e.g., agent/prompts/jsons/p_cot_id_actree_2s.json
+test_config_base_dir='config_files/wa/test_webarena_lite' # e.g., config_files/wa/test_webarena_lite
+temperature=0.0
+
+SERVER='' # TODO: your server address
+MAP_SERVER='' # TODO: the server address for MAP tasks
+OPENAI_API_KEY='' # TODO: if you test OpenAI APIs
+OPENAI_ORGANIZATION=''
+CONDA_ENV_NAME='' # TODO: the name of your conda environment for testing WebArena
+
+ENV_VARIABLES="export DATASET=${DATASET}; export SHOPPING='http://${SERVER}:7770';export SHOPPING_ADMIN='http://${SERVER}:7780/admin';export REDDIT='http://${SERVER}:9999';export GITLAB='http://${SERVER}:8023';export MAP='http://${MAP_SERVER}:3000';export WIKIPEDIA='http://${SERVER}:8888/wikipedia_en_all_maxi_2022-05/A/User:The_other_Kiwix_guy/Landing';export HOMEPAGE='http://${SERVER}:4399';export OPENAI_API_KEY=${OPENAI_API_KEY};export OPENAI_ORGANIZATION=${OPENAI_ORGANIZATION}"
+
+# get the number of tmux panes
+num_panes=$(tmux list-panes | wc -l)
+
+# calculate how many panes need to be created
+let "panes_to_create = 7 - num_panes"
+
+# array of tmux commands to create each pane
+tmux_commands=(
+ 'tmux split-window -h'
+ 'tmux split-window -v'
+ 'tmux select-pane -t 0; tmux split-window -v'
+ 'tmux split-window -v'
+ 'tmux select-pane -t 3; tmux split-window -v'
+ 'tmux select-pane -t 5; tmux split-window -v'
+)
+
+# create panes up to 7
+for ((i=0; i<$panes_to_create; i++)); do
+ eval ${tmux_commands[$i]}
+done
+
+#!/bin/bash
+
+# Function to run a job
+run_job() {
+ tmux select-pane -t $1
+ COMMAND="python run.py \
+ --instruction_path ${instruction_path} \
+ --test_start_idx $2 \
+ --test_end_idx $3 \
+ --result_dir ${result_dir} \
+ --test_config_base_dir ${test_config_base_dir} \
+ --provider ${provider} \
+ --model ${model} \
+ --mode chat \
+ --planner_ip ${planner_ip} \
+ --stop_token \"<|eot_id|>\" \
+ --temperature ${temperature} \
+ --max_obs_length 0 \
+ --max_tokens 2048 \
+ --viewport_width 1280 \
+ --viewport_height 720 \
+ --parsing_failure_th 5 \
+ --repeating_action_failure_th 5 \
+ --action_set_tag webrl_id --observation_type webrl"
+ tmux send-keys "tmux set mouse on; conda activate ${CONDA_ENV_NAME}; ${ENV_VARIABLES}; until ${COMMAND}; do echo 'crashed' >&2; sleep 1; done" C-m
+ sleep 3
+}
+
+TOLERANCE=2
+run_batch() {
+ args=("$@") # save all arguments in an array
+ num_jobs=${#args[@]} # get number of arguments
+
+ for ((i=1; i<$num_jobs; i++)); do
+ run_job $i ${args[i-1]} ${args[i]}
+ done
+
+ # Wait for all jobs to finish
+ while tmux list-panes -F "#{pane_pid} #{pane_current_command}" | grep -q python; do
+ sleep 100 # wait for 10 seconds before checking again
+ done
+
+ # Run checker
+ while ! python scripts/check_error_runs.py ${result_dir} --delete_errors --tolerance ${TOLERANCE}; do
+ echo "Check failed, rerunning jobs..."
+ for ((i=1; i<$num_jobs; i++)); do
+ run_job $i ${args[i-1]} ${args[i]}
+ done
+
+ # Wait for all jobs to finish
+ while tmux list-panes -F "#{pane_pid} #{pane_current_command}" | grep -q python; do
+ sleep 100 # wait for 10 seconds before checking again
+ done
+ done
+
+}
+run_batch 0 28 56 84 112 140 165
+
diff --git a/VAB-WebArena-Lite/replace.sh b/VAB-WebArena-Lite/replace.sh
index 7a674f0..c0b7fa7 100644
--- a/VAB-WebArena-Lite/replace.sh
+++ b/VAB-WebArena-Lite/replace.sh
@@ -15,6 +15,7 @@ cp -f new/run.py visualwebarena/run.py
cp -f new/agent.py visualwebarena/agent/agent.py
cp -f new/prompt_constructor.py visualwebarena/agent/prompts/prompt_constructor.py
cp -f new/p_webrl.json visualwebarena/agent/prompts/jsons/p_webrl.json
+cp -f new/p_webrl_chat.json visualwebarena/agent/prompts/jsons/p_webrl_chat.json
# browser_env
cp -f new/actions.py visualwebarena/browser_env/actions.py
@@ -42,6 +43,7 @@ cp -f new/wa_parallel_run.sh visualwebarena/wa_parallel_run.sh
cp -f new/score.py visualwebarena/score.py
cp -f new/wa_parallel_run_webrl.sh visualwebarena/wa_parallel_run_webrl.sh
+cp -f new/wa_parallel_run_webrl_chat.sh visualwebarena/wa_parallel_run_webrl_chat.sh
# 3. remove temporary files
mv visualwebarena/* .
From 8675710e50144ec4590f0a8b4b9f5d3fd3b51d54 Mon Sep 17 00:00:00 2001
From: Qi Zehan <67227342+QZH-777@users.noreply.github.com>
Date: Thu, 14 Nov 2024 20:42:04 +0800
Subject: [PATCH 3/4] Update README.md
---
VAB-WebArena-Lite/README.md | 69 +++++++++++++++++++++++++++++++++++++
1 file changed, 69 insertions(+)
diff --git a/VAB-WebArena-Lite/README.md b/VAB-WebArena-Lite/README.md
index 49140c4..32b4e1b 100644
--- a/VAB-WebArena-Lite/README.md
+++ b/VAB-WebArena-Lite/README.md
@@ -162,6 +162,75 @@ After all parallel groupes finish, run `score.py` to get the pass rate
python score.py
```
+## 🚀 Evaluating in WebRL Setting (Text Modal)
+
+[WebRL](https://github.com/THUDM/WebRL) is one of the top-performing models on WebArena-Lite. It uses a plain text modality as input. Additionally, we provide evaluation scripts that support this plain text modality.
+
+### Evaluation of Finetuned Models
+
+To run the finetuned model in WebRL setting, you can run evaluation with the following flags:
+
+```bash
+python run.py \
+ --instruction_path agent/prompts/jsons/p_webrl.json \
+ --test_start_idx 0 \
+ --test_end_idx 1 \
+ --result_dir \
+ --test_config_base_dir config_files/wa/test_webarena_lite \
+ --provider openai \
+ --mode completion \
+ --model \
+ --planner_ip \
+ --stop_token "<|eot_id|>" \
+ --max_obs_length 0 \
+ --max_tokens 2048 \
+ --viewport_width 1280 \
+ --viewport_height 720 \
+ --action_set_tag webrl_id --observation_type webrl
+```
+
+You need to first use tools like vllm to deploy the finetuned model locally. Once deployed, the model can be accessed through the OpenAI API call method.
+
+Ensure that the `--model` and `--planner_ip` fields are completed with the appropriate model name and the IP address of the deployed model instance.
+
+We also provide the parallel script.
+
+```bash
+# Remember to first launch a tmux session
+tmux
+bash wa_parallel_run_webrl.sh
+```
+
+### Evaluation of Proprietary Models
+
+To run the proprietary model in WebRL setting, you can run evaluation with the following flags:
+
+```bash
+python run.py \
+ --instruction_path agent/prompts/jsons/p_webrl_chat.json \
+ --test_start_idx 0 \
+ --test_end_idx 1 \
+ --result_dir \
+ --test_config_base_dir config_files/wa/test_webarena_lite \
+ --provider openai \
+ --model GPT-4o \
+ --mode chat \
+ --planner_ip '' \
+ --max_obs_length 0 \
+ --max_tokens 2048 \
+ --viewport_width 1280 \
+ --viewport_height 720 \
+ --action_set_tag webrl_id --observation_type webrl
+```
+
+You can switch the evaluation model by modifying `--model`. We also provide the parallel script.
+
+```bash
+# Remember to first launch a tmux session
+tmux
+bash wa_parallel_run_webrl_chat.sh
+```
+
### 🚨 Important: Refresh all websites before re-run another round of testing!
Since tasks in WebArena may involve changing status and database of websites (e.g., posting comments on Reddit), if websites are not all refreshed before another round of evaluation, the results would be problematic.
From 4982d4063e18980adce8896611412053fdf9d149 Mon Sep 17 00:00:00 2001
From: QZH-777 <1961710177@qq.com>
Date: Thu, 14 Nov 2024 20:46:16 +0800
Subject: [PATCH 4/4] add webrl chat mode
---
VAB-WebArena-Lite/new/run.py | 21 ---------------------
1 file changed, 21 deletions(-)
diff --git a/VAB-WebArena-Lite/new/run.py b/VAB-WebArena-Lite/new/run.py
index 6f16161..c2d707e 100644
--- a/VAB-WebArena-Lite/new/run.py
+++ b/VAB-WebArena-Lite/new/run.py
@@ -489,27 +489,6 @@ def test(
cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[0] + bbox[2], bbox[1] + bbox[3]), (0, 255, 0), 2)
cv2.circle(image, (int(bbox[0] + bbox[2] / 2), int(bbox[1] + bbox[3] / 2)), radius=0, color=(0, 255, 0), thickness=2)
cv2.imwrite(current_screenshot, image)
- font_path = "/workspace/qzh/Arial.ttf" # 使用FreeSans字体
- font_size = 30 # 你可以调整这个值来增大或缩小字体
- image = Image.open(current_screenshot)
- font = ImageFont.truetype(font_path, font_size)
- draw = ImageDraw.Draw(image)
- image_width = image.width
- wrapped_text = text_wrap(action_str, font, image_width)
- line_height = font.getbbox('hg')[3] - font.getbbox('hg')[1]
- text_height = line_height * len(wrapped_text)
- new_image_height = image.height + text_height + 20 # 20 is extra white space
- new_image = Image.new('RGB', (image.width, new_image_height), color=(255, 255, 255)) # white background
- draw_new = ImageDraw.Draw(new_image)
- y_text = 10 # Initial position of text from top
- for line in wrapped_text:
- text_bbox = draw_new.textbbox((0, 0), line, font=font)
- text_width = text_bbox[2] - text_bbox[0]
- text_position = ((image_width - text_width) // 2, y_text) # Center text horizontally
- draw_new.text(text_position, line, font=font, fill=(0, 0, 0)) # black text
- y_text += line_height # move to next line
- new_image.paste(image, (0, text_height + 20))
- new_image.save(current_screenshot)
meta_data["action_history"].append(action_str)
actions.append(action_str)