From 521d7e999af0b88034411bfd534a8a82db67cd2f Mon Sep 17 00:00:00 2001 From: QZH-777 <1961710177@qq.com> Date: Thu, 14 Nov 2024 15:51:41 +0800 Subject: [PATCH] add webrl mode --- VAB-WebArena-Lite/new/actions.py | 2006 +++++++++++++++++ VAB-WebArena-Lite/new/agent.py | 15 +- VAB-WebArena-Lite/new/envs.py | 319 +++ VAB-WebArena-Lite/new/evaluators.py | 4 + .../new/helper_functions_browser.py | 239 ++ ..._functions.py => helper_functions_eval.py} | 0 VAB-WebArena-Lite/new/html_tools/__init__.py | 7 + .../new/html_tools/configs/__init__.py | 3 + .../new/html_tools/configs/config.py | 56 + .../new/html_tools/configs/html_prompt.py | 22 + VAB-WebArena-Lite/new/html_tools/fetch.py | 108 + .../new/html_tools/html_parser.py | 447 ++++ .../new/html_tools/identifier.py | 64 + VAB-WebArena-Lite/new/html_tools/prompt.py | 97 + .../new/html_tools/scripts/__init__.py | 43 + .../html_tools/scripts/clickable_checker.js | 148 ++ .../new/html_tools/scripts/element_info.js | 34 + .../new/html_tools/scripts/label.js | 74 + .../new/html_tools/scripts/label_marker.js | 65 + .../new/html_tools/scripts/prepare.js | 83 + VAB-WebArena-Lite/new/html_tools/utils.py | 101 + VAB-WebArena-Lite/new/openai_utils.py | 18 +- VAB-WebArena-Lite/new/p_webrl.json | 13 + VAB-WebArena-Lite/new/processors.py | 1351 +++++++++++ VAB-WebArena-Lite/new/prompt_constructor.py | 56 + VAB-WebArena-Lite/new/run.py | 137 +- VAB-WebArena-Lite/new/score.py | 86 + .../new/test_webarena_lite.raw.json | 14 +- VAB-WebArena-Lite/new/tokenizers.py | 9 +- VAB-WebArena-Lite/new/utils.py | 6 +- .../new/wa_parallel_run_webrl.sh | 97 + VAB-WebArena-Lite/replace.sh | 14 +- 32 files changed, 5688 insertions(+), 48 deletions(-) create mode 100644 VAB-WebArena-Lite/new/actions.py create mode 100644 VAB-WebArena-Lite/new/envs.py create mode 100644 VAB-WebArena-Lite/new/helper_functions_browser.py rename VAB-WebArena-Lite/new/{helper_functions.py => helper_functions_eval.py} (100%) create mode 100755 VAB-WebArena-Lite/new/html_tools/__init__.py create mode 100755 VAB-WebArena-Lite/new/html_tools/configs/__init__.py create mode 100755 VAB-WebArena-Lite/new/html_tools/configs/config.py create mode 100755 VAB-WebArena-Lite/new/html_tools/configs/html_prompt.py create mode 100755 VAB-WebArena-Lite/new/html_tools/fetch.py create mode 100755 VAB-WebArena-Lite/new/html_tools/html_parser.py create mode 100755 VAB-WebArena-Lite/new/html_tools/identifier.py create mode 100755 VAB-WebArena-Lite/new/html_tools/prompt.py create mode 100755 VAB-WebArena-Lite/new/html_tools/scripts/__init__.py create mode 100755 VAB-WebArena-Lite/new/html_tools/scripts/clickable_checker.js create mode 100755 VAB-WebArena-Lite/new/html_tools/scripts/element_info.js create mode 100755 VAB-WebArena-Lite/new/html_tools/scripts/label.js create mode 100755 VAB-WebArena-Lite/new/html_tools/scripts/label_marker.js create mode 100755 VAB-WebArena-Lite/new/html_tools/scripts/prepare.js create mode 100755 VAB-WebArena-Lite/new/html_tools/utils.py create mode 100644 VAB-WebArena-Lite/new/p_webrl.json create mode 100644 VAB-WebArena-Lite/new/processors.py create mode 100644 VAB-WebArena-Lite/new/score.py create mode 100644 VAB-WebArena-Lite/new/wa_parallel_run_webrl.sh diff --git a/VAB-WebArena-Lite/new/actions.py b/VAB-WebArena-Lite/new/actions.py new file mode 100644 index 0000000..eedb327 --- /dev/null +++ b/VAB-WebArena-Lite/new/actions.py @@ -0,0 +1,2006 @@ +""" +Browser Env action space. +Inspited by Farama-Foundation/miniwob-plusplus +""" +import ast +import random +import re +import string +from enum import IntEnum +from itertools import chain +from typing import Any, TypedDict, Union, cast +import time + +import numpy as np +import numpy.typing as npt +from beartype import beartype +from beartype.door import is_bearable +from gymnasium import spaces +from playwright._impl._api_structures import ViewportSize +from playwright.async_api import BrowserContext as ABrowserContext +from playwright.async_api import Locator as ALocator +from playwright.async_api import Page as APage +from playwright.sync_api import BrowserContext, Locator, Page + +from browser_env.constants import ( + ASCII_CHARSET, + FREQ_UNICODE_CHARSET, + MAX_ANSWER_LENGTH, + MAX_ELEMENT_ID, + MAX_ELEMENT_INDEX_IN_VIEWPORT, + MAX_PAGE_NUMBER, + MAX_VANILLA_STR_LENGTH, + PLAYWRIGHT_ACTIONS, + PLAYWRIGHT_LOCATORS, + ROLES, + SPECIAL_KEY_MAPPINGS, + SPECIAL_KEYS, + SPECIAL_LOCATORS, + TEXT_MAX_LENGTH, + TYPING_MAX_LENGTH, + URL_MAX_LENGTH, + RolesType, +) +from browser_env.processors import ObservationProcessor + + +class ParsedPlaywrightCode(TypedDict): + function_name: str + arguments: list[str] + keywords: dict[str, Any] + + +from browser_env.processors import ( + ObservationProcessor, + TextObervationProcessor, +) + + +@beartype +def is_in_viewport( + element: Locator, viewport: ViewportSize, threshold: float = 0.3 +) -> bool: + """Given a playwright locator, check if it is in the viewport""" + box = element.bounding_box() + assert box is not None + boxx0 = box["x"] + boxx1 = box["x"] + box["width"] + boxy0 = box["y"] + boxy1 = box["y"] + box["height"] + viewportx0, viewporty0 = 0, 0 + viewportx1, viewporty1 = viewport["width"], viewport["height"] + inter = max(0, min(boxx1, viewportx1) - max(boxx0, viewportx0)) * max( + 0, min(boxy1, viewporty1) - max(boxy0, viewporty0) + ) + ratio = inter / (box["width"] * box["height"]) + return ratio > threshold + + +@beartype +async def async_is_in_viewport( + element: ALocator, viewport: ViewportSize, threshold: float = 0.3 +) -> bool: + box = await element.bounding_box() + assert box is not None + boxx0 = box["x"] + boxx1 = box["x"] + box["width"] + boxy0 = box["y"] + boxy1 = box["y"] + box["height"] + viewportx0, viewporty0 = 0, 0 + viewportx1, viewporty1 = viewport["width"], viewport["height"] + inter = max(0, min(boxx1, viewportx1) - max(boxx0, viewportx0)) * max( + 0, min(boxy1, viewporty1) - max(boxy0, viewporty0) + ) + ratio = inter / (box["width"] * box["height"]) + return ratio > threshold + + +class Action(TypedDict): + action_type: int + coords: npt.NDArray[np.float32] + element_role: int + element_name: str + text: list[int] + page_number: int + url: str + nth: int + element_id: str + direction: str + key_comb: str + pw_code: str + answer: str + raw_prediction: str # raw prediction from the model + + +@beartype +def action2str( + action: Action, action_set_tag: str, semantic_element: str = "" +) -> str: + """Return the string representation of an action + + sementic_element: the semantic information of the element + such as a line in an accessibility tree + """ + if action_set_tag in [ + "id_accessibility_tree", + "id_accessibility_tree_with_captioner", + ]: + element_id = action["element_id"] + match action["action_type"]: + case ActionTypes.CLICK: + # [ID=X] xxxxx + action_str = f"click [{element_id}] where [{element_id}] is {semantic_element}" + case ActionTypes.TYPE: + text = "".join([_id2key[i] for i in action["text"]]) + action_str = f"type [{element_id}] [{text}] where [{element_id}] is {semantic_element}" + case ActionTypes.HOVER: + action_str = f"hover [{element_id}] where [{element_id}] is {semantic_element}" + case ActionTypes.SCROLL: + action_str = f"scroll [{action['direction']}]" + case ActionTypes.KEY_PRESS: + action_str = f"press [{action['key_comb']}]" + case ActionTypes.GOTO_URL: + action_str = f"goto [{action['url']}]" + case ActionTypes.NEW_TAB: + action_str = "new_tab" + case ActionTypes.PAGE_CLOSE: + action_str = "close_tab" + case ActionTypes.GO_BACK: + action_str = "go_back" + case ActionTypes.GO_FORWARD: + action_str = "go_forward" + case ActionTypes.PAGE_FOCUS: + action_str = f"page_focus [{action['page_number']}]" + case ActionTypes.CLEAR: + action_str = f"clear [{element_id}] where [{element_id}] is {semantic_element}" + case ActionTypes.STOP: + action_str = f"stop [{action['answer']}]" + case ActionTypes.NONE: + action_str = "none" + case _: + raise ValueError( + f"Unknown action type {action['action_type']}" + ) + elif action_set_tag == "som": + element_id = action["element_id"] + match action["action_type"]: + case ActionTypes.CLICK: + # [ID=X] xxxxx + action_str = f"click [{element_id}] where [{element_id}]" + case ActionTypes.CLEAR: + action_str = f"clear [{element_id}] where [{element_id}] is {semantic_element}" + case ActionTypes.TYPE: + text = "".join([_id2key[i] for i in action["text"]]) + action_str = ( + f"type [{element_id}] [{text}] where [{element_id}]" + ) + case ActionTypes.HOVER: + action_str = f"hover [{element_id}] where [{element_id}]" + case ActionTypes.SCROLL: + action_str = f"scroll [{action['direction']}]" + case ActionTypes.KEY_PRESS: + action_str = f"press [{action['key_comb']}]" + case ActionTypes.GOTO_URL: + action_str = f"goto [{action['url']}]" + case ActionTypes.NEW_TAB: + action_str = "new_tab" + case ActionTypes.PAGE_CLOSE: + action_str = "close_tab" + case ActionTypes.GO_BACK: + action_str = "go_back" + case ActionTypes.GO_FORWARD: + action_str = "go_forward" + case ActionTypes.PAGE_FOCUS: + action_str = f"page_focus [{action['page_number']}]" + case ActionTypes.STOP: + action_str = f"stop [{action['answer']}]" + case ActionTypes.NONE: + action_str = "none" + case _: + raise ValueError( + f"Unknown action type {action['action_type']}" + ) + else: + raise NotImplementedError(f"Unknown action set tag {action_set_tag}") + + return action_str + + +def action2create_function(action: Action) -> str: + match (action["action_type"]): + case ActionTypes.NONE: + return "create_none_action()" + # mouse wheel and keyboard action + case ActionTypes.SCROLL: + direction = "up" if "up" in action["direction"] else "down" + return f"create_scroll_action({repr(direction)})" + case ActionTypes.KEY_PRESS: + return f"create_key_press_action({repr(action['key_comb'])})" + # inter-page actions + case ActionTypes.PAGE_FOCUS: + return f"create_page_focus_action({action['page_number']})" + case ActionTypes.NEW_TAB: + return "create_new_tab_action()" + case ActionTypes.GO_BACK: + return "create_go_back_action()" + case ActionTypes.GO_FORWARD: + return "create_go_forward_action()" + case ActionTypes.GOTO_URL: + return f"create_goto_url_action({repr(action['url'])})" + case ActionTypes.PAGE_CLOSE: + return "create_page_close_action()" + + # low-level keyboard and mouse actions + case ActionTypes.MOUSE_CLICK: + return f"create_mouse_click_action({action['coords'][0]}, {action['coords'][1]})" + case ActionTypes.MOUSE_HOVER: + return f"create_mouse_hover_action({action['coords'][0]}, {action['coords'][1]})" + case ActionTypes.KEYBOARD_TYPE: + return f"create_keyboard_type_action({list(map(lambda x: _id2key[x], action['text']))})" + + # mid-level keyboard and mouse actions + case ActionTypes.CLICK: + args = [] + args.append(f"element_id={repr(action['element_id'])}") + args.append( + f"element_role={repr(_id2role[action['element_role']])}" + ) + args.append(f"element_name={repr(action['element_name'])}") + args.append(f"pw_code={repr(action['pw_code'])}") + args_str = ", ".join(args) + return f"create_click_action({args_str})" + case ActionTypes.CLEAR: + args = [] + args.append(f"element_id={repr(action['element_id'])}") + args.append( + f"element_role={repr(_id2role[action['element_role']])}" + ) + args.append(f"element_name={repr(action['element_name'])}") + args.append(f"pw_code={repr(action['pw_code'])}") + args_str = ", ".join(args) + return f"create_clear_action({args_str})" + case ActionTypes.HOVER: + args = [] + args.append(f"element_id={repr(action['element_id'])}") + args.append( + f"element_role={repr(_id2role[action['element_role']])}" + ) + args.append(f"element_name={repr(action['element_name'])}") + args.append(f"pw_code={repr(action['pw_code'])}") + args_str = ", ".join(args) + return f"create_hover_action({args_str})" + case ActionTypes.TYPE: + args = [] + text = "".join(map(lambda x: _id2key[x], action["text"])) + args.append(f"text={repr(text)}") + args.append(f"element_id={repr(action['element_id'])}") + args.append( + f"element_role={repr(_id2role[action['element_role']])}" + ) + args.append(f"element_name={repr(action['element_name'])}") + args.append(f"pw_code={repr(action['pw_code'])}") + args_str = ", ".join(args) + return f"create_type_action({args_str})" + + # high-level actions, only support locators from playwright + case ActionTypes.CHECK: + return f"create_check_action(pw_code={repr(action['pw_code'])})" + case ActionTypes.SELECT_OPTION: + return f"create_select_option_action(pw_code={repr(action['pw_code'])})" + case ActionTypes.STOP: + return f'create_stop_action({repr(action["answer"])})' + + raise ValueError(f"Invalid action type: {action['action_type']}") + + +class ActionTypes(IntEnum): + """Valid action types for browser env.""" + + NONE = 0 + # mouse wheel and keyboard, universal across all action spaces + SCROLL = 1 + KEY_PRESS = 2 + + # low level mouse and keyboard actions + MOUSE_CLICK = 3 + KEYBOARD_TYPE = 4 + MOUSE_HOVER = 5 + + # mid level mouse and keyboard actions + CLICK = 6 + TYPE = 7 + HOVER = 8 + + # page level actions, universal across all action spaces + PAGE_FOCUS = 9 + NEW_TAB = 10 + GO_BACK = 11 + GO_FORWARD = 12 + GOTO_URL = 13 + PAGE_CLOSE = 14 + + # high-leval actions that playwright support + CHECK = 15 + SELECT_OPTION = 16 + + STOP = 17 + CLEAR = 18 + + # webrl actions + SEARCH = 19 + SELECT_DROPDOWN_OPTION = 20 + + def __str__(self) -> str: + return f"ACTION_TYPES.{self.name}" + + +@beartype +def is_equivalent(a: Action, b: Action) -> bool: + """Return True if two actions are equal.""" + if a["action_type"] != b["action_type"]: + return False + match (a["action_type"]): + case ActionTypes.NONE: + return True + case ActionTypes.SCROLL: + da = "up" if "up" in a["direction"] else "down" + db = "up" if "up" in b["direction"] else "down" + return da == db + case ActionTypes.KEY_PRESS: + return a["key_comb"] == b["key_comb"] + case ActionTypes.MOUSE_CLICK | ActionTypes.MOUSE_HOVER: + return np.allclose(a["coords"], b["coords"]) + case ActionTypes.KEYBOARD_TYPE: + return a["text"] == b["text"] + case ActionTypes.CLICK | ActionTypes.HOVER | ActionTypes.TYPE: # TODO: can be further optimized + if a["element_id"] and b["element_id"]: + return a["element_id"] == b["element_id"] + elif a["element_role"] and b["element_role"]: + return ( + a["element_role"] == b["element_role"] + and a["element_name"] == b["element_name"] + ) + elif a["pw_code"] and b["pw_code"]: + return a["pw_code"] == b["pw_code"] + else: + return False + case ActionTypes.PAGE_FOCUS: + return a["page_number"] == b["page_number"] + case ActionTypes.NEW_TAB: + return True + case ActionTypes.GO_BACK: + return True + case ActionTypes.GO_FORWARD: + return True + case ActionTypes.GOTO_URL: + return a["url"] == b["url"] + case ActionTypes.PAGE_CLOSE: + return True + case ActionTypes.CHECK | ActionTypes.SELECT_OPTION: + return a["pw_code"] == b["pw_code"] + case ActionTypes.STOP: + return a["answer"] == b["answer"] + case _: + raise ValueError(f"Unknown action type: {a['action_type']}") + + +_key2id: dict[str, int] = { + key: i + for i, key in enumerate( + chain(SPECIAL_KEYS, ASCII_CHARSET, FREQ_UNICODE_CHARSET, ["\n"]) + ) +} +_id2key: list[str] = sorted(_key2id, key=_key2id.get) # type: ignore[arg-type] +_role2id: dict[RolesType, int] = { + cast(RolesType, role): i + for i, role in enumerate(chain(ROLES, SPECIAL_LOCATORS)) +} +_id2role: list[RolesType] = sorted(_role2id, key=_role2id.get) # type: ignore[arg-type] + + +@beartype +def _keys2ids(keys: list[int | str] | str) -> list[int]: + return list( + map( + lambda key: _key2id.get(str(key), _key2id.get(key, " ")) + if is_bearable(key, str) + else int(key), + keys, + ) + ) + + +def get_action_space() -> spaces.Dict: + """Return the space of serialized actions.""" + space = spaces.Dict( + { + "action_type": spaces.Discrete(len(ActionTypes)), + # coords (left, top) is used for COORD_CLICK + "coords": spaces.Box( + np.array([0.0, 0.0], dtype=np.float32), + np.array([1.0, 1.0], dtype=np.float32), + ), + # element role is used for FOCUS_AND_CLICK and FOCUS_AND_TYPE + "element_role": spaces.Discrete( + len(ROLES) + len(SPECIAL_LOCATORS) + ), + # element name is used with element role + "element_name": spaces.Text(TEXT_MAX_LENGTH), + "element_id": spaces.Text(TEXT_MAX_LENGTH), + # text is only used for TYPE and FOCUS_AND_TYPE + "text": spaces.MultiDiscrete( + [ + len(ASCII_CHARSET) + + len(SPECIAL_KEYS) + + len(FREQ_UNICODE_CHARSET) + ] + * TYPING_MAX_LENGTH + ), + "page_number": spaces.Discrete(MAX_PAGE_NUMBER), + "url": spaces.Text(URL_MAX_LENGTH), + "nth": spaces.Discrete(MAX_ELEMENT_INDEX_IN_VIEWPORT), + "key_comb": spaces.Text(MAX_VANILLA_STR_LENGTH), + "direction": spaces.Text(MAX_VANILLA_STR_LENGTH), + "pw_code": spaces.Text(MAX_VANILLA_STR_LENGTH), + "answer": spaces.Text(MAX_ANSWER_LENGTH), + } + ) + return space + + +def create_random_action() -> Action: + """Return a random action.""" + return { + "action_type": np.random.randint(len(ActionTypes)), + "coords": np.random.rand(2).astype(np.float32), + "element_role": np.random.randint(len(ROLES) + len(SPECIAL_LOCATORS)), + "element_name": "".join( + random.choices(ASCII_CHARSET, k=np.random.randint(TEXT_MAX_LENGTH)) + ), + "text": list( + random.choices( + list(range(len(ASCII_CHARSET))), + k=np.random.randint(TYPING_MAX_LENGTH), + ) + ), + "page_number": np.random.randint(MAX_PAGE_NUMBER), + "url": "".join( + random.choices(ASCII_CHARSET, k=np.random.randint(URL_MAX_LENGTH)) + ), + "nth": np.random.randint(MAX_ELEMENT_INDEX_IN_VIEWPORT), + "element_id": str(np.random.randint(MAX_ELEMENT_ID)), + "key_comb": "+".join( + random.choices(SPECIAL_KEYS, k=np.random.randint(3)) + ), + "direction": random.choice(["up", "down"]), + "pw_code": "".join( + random.choices( + string.ascii_uppercase + string.digits, + k=np.random.randint(MAX_VANILLA_STR_LENGTH), + ) + ), + "answer": str(np.random.randint(MAX_ANSWER_LENGTH)), + "raw_prediction": str(np.random.randint(MAX_ANSWER_LENGTH)), + } + + +@beartype +def create_none_action() -> Action: + """Return a valid action object that does nothing.""" + return { + "action_type": ActionTypes.NONE, + "coords": np.zeros(2, dtype=np.float32), + "element_role": 0, + "element_name": "", + "text": [], + "page_number": 0, + "url": "", + "nth": 0, + "pw_code": "", # str that requires further processing + "element_id": "", + "key_comb": "", + "direction": "", + "answer": "", + "raw_prediction": "", + } + + +@beartype +def create_stop_action(answer: str) -> Action: + action = create_none_action() + action.update({"action_type": ActionTypes.STOP, "answer": answer}) + return action + + +@beartype +def create_scroll_action(direction: str) -> Action: + """Return the playwright action""" + assert direction in ["up", "down"] + action = create_none_action() + action.update( + { + "action_type": ActionTypes.SCROLL, + "direction": direction, + } + ) + return action + + +@beartype +def create_mouse_hover_action( + left: float | None = None, top: float | None = None +) -> Action: + """Return a valid action object with type COORD_CLICK.""" + action = create_none_action() + action.update( + { + "action_type": ActionTypes.MOUSE_HOVER, + "coords": np.array([left, top], dtype=np.float32), + } + ) + return action + + +@beartype +def create_key_press_action(key_comb: str) -> Action: + """Return the key press action""" + + def map_keys(key_comb: str) -> str: + keys = key_comb.split("+") + mapped_keys = [] + for key in keys: + mapped_key = SPECIAL_KEY_MAPPINGS.get(key.lower(), key) + mapped_keys.append(mapped_key) + return "+".join(mapped_keys) + + action = create_none_action() + mapped_key_comb = map_keys(key_comb) + action.update( + { + "action_type": ActionTypes.KEY_PRESS, + "key_comb": mapped_key_comb, + } + ) + return action + + +@beartype +def create_page_focus_action(page_number: int) -> Action: + """Return a valid action object with type PAGE_FOCUS.""" + action = create_none_action() + action.update( + { + "action_type": ActionTypes.PAGE_FOCUS, + "page_number": page_number, + } + ) + return action + + +@beartype +def create_new_tab_action() -> Action: + """Return a valid action object with type NEW_TAB.""" + action = create_none_action() + action.update( + { + "action_type": ActionTypes.NEW_TAB, + } + ) + return action + + +@beartype +def create_go_back_action() -> Action: + """Return a valid action object with type GO_BACK.""" + action = create_none_action() + action.update( + { + "action_type": ActionTypes.GO_BACK, + } + ) + return action + + +@beartype +def create_go_forward_action() -> Action: + """Return a valid action object with type GO_FORWARD.""" + action = create_none_action() + action.update( + { + "action_type": ActionTypes.GO_FORWARD, + } + ) + return action + + +@beartype +def create_goto_url_action(url: str) -> Action: + """Return a valid action object with type GOTO_URL.""" + action = create_none_action() + action.update( + { + "action_type": ActionTypes.GOTO_URL, + "url": url, + } + ) + return action + + +@beartype +def create_page_close_action() -> Action: + """Return a valid action object with type PAGE_CLOSE.""" + action = create_none_action() + action.update( + { + "action_type": ActionTypes.PAGE_CLOSE, + } + ) + return action + + +@beartype +def create_mouse_click_action( + left: float | None = None, top: float | None = None +) -> Action: + """Return a valid action object with type COORD_CLICK.""" + action = create_none_action() + if left and top: + action.update( + { + "action_type": ActionTypes.MOUSE_CLICK, + "coords": np.array([left, top], dtype=np.float32), + } + ) + elif (not left) and (not top): + action.update( + { + "action_type": ActionTypes.CLICK, + } + ) + else: + raise ValueError("left and top must be both None or both not None") + return action + + +@beartype +def create_clear_action( + element_id: str = "", + element_role: RolesType = "link", + element_name: str = "", + pw_code: str = "", + nth: int = 0, +) -> Action: + action = create_none_action() + action.update( + { + "action_type": ActionTypes.CLEAR, + "element_id": element_id, + "element_role": _role2id[element_role], + "element_name": element_name, + "nth": nth, + "pw_code": pw_code, + } + ) + return action + + +@beartype +def create_keyboard_type_action(keys: list[int | str] | str) -> Action: + """Return a valid action object with type TYPE.""" + action = create_none_action() + action.update( + { + "action_type": ActionTypes.KEYBOARD_TYPE, + "text": _keys2ids(keys), + } + ) + return action + + +@beartype +def create_click_action( + element_id: str = "", + element_role: RolesType = "link", + element_name: str = "", + pw_code: str = "", + nth: int = 0, +) -> Action: + action = create_none_action() + action.update( + { + "action_type": ActionTypes.CLICK, + "element_id": element_id, + "element_role": _role2id[element_role], + "element_name": element_name, + "nth": nth, + "pw_code": pw_code, + } + ) + return action + + +@beartype +def create_hover_action( + element_id: str = "", + element_role: RolesType = "link", + element_name: str = "", + pw_code: str = "", + nth: int = 0, +) -> Action: + action = create_none_action() + action.update( + { + "action_type": ActionTypes.HOVER, + "element_id": element_id, + "element_role": _role2id[element_role], + "element_name": element_name, + "nth": nth, + "pw_code": pw_code, + } + ) + return action + + +@beartype +def create_type_action( + text: str, + element_id: str = "", + element_role: RolesType = "link", + element_name: str = "", + pw_code: str = "", + nth: int = 0, +) -> Action: + action = create_none_action() + action.update( + { + "action_type": ActionTypes.TYPE, + "element_id": element_id, + "element_role": _role2id[element_role], + "element_name": element_name, + "nth": nth, + "text": _keys2ids(text), + "pw_code": pw_code, + } + ) + return action + + +@beartype +def create_type_action_webrl( + text: str, + element_id: str = "", + element_role: RolesType = "link", + element_name: str = "", + pw_code: str = "", + nth: int = 0, +) -> Action: + action = create_none_action() + action.update( + { + "action_type": ActionTypes.TYPE, + "element_id": element_id, + "element_role": _role2id[element_role], + "element_name": element_name, + "nth": nth, + "text": text, + "pw_code": pw_code, + } + ) + return action + + +@beartype +def create_search_action( + text: str, + element_id: str = "", + element_role: RolesType = "link", + element_name: str = "", + pw_code: str = "", + nth: int = 0, +) -> Action: + action = create_none_action() + action.update( + { + "action_type": ActionTypes.SEARCH, + "element_id": element_id, + "element_role": _role2id[element_role], + "element_name": element_name, + "nth": nth, + "text": text, + "pw_code": pw_code, + } + ) + return action + + +@beartype +def create_select_dropdown_option_action( + argument: str, + element_id: str = "", + element_role: RolesType = "link", + element_name: str = "", + pw_code: str = "", + nth: int = 0, +) -> Action: + """Return a valid action object with type SELECT_DROPDOWN_OPTION.""" + action = create_none_action() + action.update( + { + "action_type": ActionTypes.SELECT_DROPDOWN_OPTION, + "element_id": element_id, + "element_role": _role2id[element_role], + "element_name": element_name, + "nth": nth, + "argument": argument, + "pw_code": pw_code, + } + ) + return action + + +@beartype +def create_check_action(pw_code: str) -> Action: + action = create_none_action() + action.update( + { + "action_type": ActionTypes.CHECK, + "pw_code": pw_code, + } + ) + return action + + +@beartype +def create_select_option_action( + pw_code: str, +) -> Action: + action = create_none_action() + action.update( + { + "action_type": ActionTypes.SELECT_OPTION, + "pw_code": pw_code, + } + ) + return action + + +@beartype +def create_focus_action( + element_role: RolesType, element_name: str = "", nth: int = 0 +) -> Action: + """Return a valid action object with type CLICK. + + Keep compatible with the old version.""" + action = create_none_action() + action.update( + { + "action_type": ActionTypes.CLICK, + "element_role": _role2id[element_role], + "element_name": element_name, + "nth": nth, + } + ) + return action + + +@beartype +def create_focus_and_click_action( + element_role: RolesType, element_name: str = "", nth: int = 0 +) -> Action: + """Return a valid action object with type CLICK. + + Keep compatible with the old version.""" + + action = create_none_action() + action.update( + { + "action_type": ActionTypes.CLICK, + "element_role": _role2id[element_role], + "element_name": element_name, + "nth": nth, + } + ) + return action + + +@beartype +def create_focus_and_type_action( + keys: list[int | str] | str, + element_role: RolesType, + element_name: str = "", + nth: int = 0, +) -> Action: + """Return a valid action object with type TYPE. + + Keep compatible with the old version.""" + action = create_none_action() + action.update( + { + "action_type": ActionTypes.TYPE, + "element_role": _role2id[element_role], + "element_name": element_name, + "text": _keys2ids(keys), + "nth": nth, + } + ) + return action + + +@beartype +def execute_scroll(direction: str, page: Page) -> None: + # perform the action + # code from natbot + if direction == "up": + page.evaluate( + "(document.scrollingElement || document.body).scrollTop = (document.scrollingElement || document.body).scrollTop - window.innerHeight;" + ) + elif direction == "down": + page.evaluate( + "(document.scrollingElement || document.body).scrollTop = (document.scrollingElement || document.body).scrollTop + window.innerHeight;" + ) + +@beartype +def execute_scroll_webrl(direction: str, page: Page) -> None: + # perform the action which move 2/3 of the height of the page at a time + if direction == "up": + page.mouse.wheel(0, -page.viewport_size['height'] * 2.0 / 3) + elif direction == "down": + page.mouse.wheel(0, page.viewport_size['height'] * 2.0 / 3) + +@beartype +async def aexecute_scroll(direction: str, page: APage) -> None: + # perform the action + # code from natbot + if direction == "up": + await page.evaluate( + "(document.scrollingElement || document.body).scrollTop = (document.scrollingElement || document.body).scrollTop - window.innerHeight;" + ) + elif direction == "down": + await page.evaluate( + "(document.scrollingElement || document.body).scrollTop = (document.scrollingElement || document.body).scrollTop + window.innerHeight;" + ) + + +@beartype +def execute_key_press(key: str, page: Page) -> None: + """Press a key.""" + if "Meta" in key and "Mac" not in page.evaluate("navigator.platform"): + key = key.replace("Meta", "Control") + page.keyboard.press(key) + + +@beartype +async def aexecute_key_press(key: str, page: APage) -> None: + """Press a key.""" + if "Meta" in key and "Mac" not in await page.evaluate( + "navigator.platform" + ): + key = key.replace("Meta", "Control") + await page.keyboard.press(key) + + +@beartype +def execute_mouse_hover(left: float, top: float, page: Page) -> None: + """Click at coordinates (left, top).""" + viewport_size = page.viewport_size + assert viewport_size + page.mouse.move( + left * viewport_size["width"], top * viewport_size["height"] + ) + + +@beartype +async def aexecute_mouse_hover(left: float, top: float, page: APage) -> None: + """Click at coordinates (left, top).""" + viewport_size = page.viewport_size + assert viewport_size + await page.mouse.move( + left * viewport_size["width"], top * viewport_size["height"] + ) + + +def execute_mouse_click(left: float, top: float, page: Page) -> None: + """Click at coordinates (left, top).""" + viewport_size = page.viewport_size + assert viewport_size + page.mouse.click( + left * viewport_size["width"], top * viewport_size["height"] + ) + + +@beartype +async def aexecute_mouse_click(left: float, top: float, page: APage) -> None: + """Click at coordinates (left, top).""" + viewport_size = page.viewport_size + assert viewport_size + await page.mouse.click( + left * viewport_size["width"], top * viewport_size["height"] + ) + + +@beartype +def execute_keyboard_type(text: str, page: Page) -> None: + """Fill the focused element with text.""" + page.keyboard.type(text) + + +@beartype +async def aexecute_keyboard_type(text: str, page: APage) -> None: + """Fill the focused element with text.""" + await page.keyboard.type(text) + + +@beartype +def execute_click_current(page: Page) -> None: + """Click at the current mouse position.""" + locators = page.locator("*:focus") + if not locators.count(): + for frame in page.frames[1:]: + locators = frame.locator("*:focus") + if locators.count(): + break + locators.click() + + +@beartype +async def aexecute_click_current(page: APage) -> None: + """Click at the current mouse position.""" + locators = page.locator("*:focus") + locator_count = await locators.count() + if not locator_count: + for frame in page.frames[1:]: + locators = frame.locator("*:focus") + locator_count = await locators.count() + if locator_count: + break + await locators.click() + await page.wait_for_load_state("load") + + +@beartype +def execute_type(keys: list[int], page: Page) -> None: + """Send keystrokes to the focused element.""" + text = "".join([_id2key[key] for key in keys]) + page.keyboard.type(text) + + +@beartype +async def aexecute_type(keys: list[int], page: APage) -> None: + """Send keystrokes to the focused element.""" + text = "".join([_id2key[key] for key in keys]) + await page.keyboard.type(text) + + +@beartype +def execute_focus( + element_role: int, element_name: str, nth: int, page: Page +) -> None: + """Click the specified DOM element.""" + element_role_str = _id2role[element_role] + if page.viewport_size is None: + raise ValueError("Viewport size is not set for the current page") + element_location_list: list[tuple[Locator, float, float]] = [] + for frame in page.frames: + match element_role_str: + case "alt_text": + locators = frame.get_by_alt_text(element_name) + case "label": + locators = frame.get_by_label(element_name) + case "placeholder": + locators = frame.get_by_placeholder(element_name) + case _: + locators = frame.get_by_role( + role=element_role_str, name=element_name + ) + for locator_idx in range(locators.count()): + locator = locators.nth(locator_idx) + if is_in_viewport(locator, page.viewport_size): + bounding_box = locator.bounding_box() + assert bounding_box + element_location_list.append( + (locator, bounding_box["x"], bounding_box["y"]) + ) + if len(element_location_list) <= nth: + raise ValueError( + f"There are only {len(element_location_list)} elements found in viewport, but {nth + 1} is requested" + ) + element_location_list.sort(key=lambda x: (x[2], x[1])) # row major order + element_location_list[nth][0].focus() + + +@beartype +async def aexecute_focus( + element_role: int, element_name: str, nth: int, page: APage +) -> None: + """Click the specified DOM element.""" + element_role_str = _id2role[element_role] + if page.viewport_size is None: + raise ValueError("Viewport size is not set for the current page") + element_location_list: list[tuple[ALocator, float, float]] = [] + for frame in page.frames: + match element_role_str: + case "alt_text": + locators = frame.get_by_alt_text(element_name) + case "label": + locators = frame.get_by_label(element_name) + case "placeholder": + locators = frame.get_by_placeholder(element_name) + case _: + locators = frame.get_by_role( + role=element_role_str, name=element_name + ) + for locator_idx in range(await locators.count()): + locator = locators.nth(locator_idx) + if await async_is_in_viewport(locator, page.viewport_size): + bounding_box = await locator.bounding_box() + assert bounding_box + element_location_list.append( + (locator, bounding_box["x"], bounding_box["y"]) + ) + if len(element_location_list) <= nth: + raise ValueError( + f"There are only {len(element_location_list)} elements found in viewport, but {nth + 1} is requested" + ) + element_location_list.sort(key=lambda x: (x[2], x[1])) # row major order + await element_location_list[nth][0].focus() + + +@beartype +def locate(locator_calls: list[ParsedPlaywrightCode], page: Page) -> Locator: + locator = page + for call in locator_calls: + function_name = call["function_name"] + arguments = call["arguments"] + keywords = call["keywords"] + locator = getattr(locator, function_name)(*arguments, **keywords) + return locator # type: ignore[return-value] + + +@beartype +async def alocate( + locator_calls: list[ParsedPlaywrightCode], page: APage +) -> ALocator: + locator = page + for call in locator_calls: + function_name = call["function_name"] + arguments = call["arguments"] + keywords = call["keywords"] + locator = await getattr(locator, function_name)(*arguments, **keywords) + return locator # type: ignore[return-value] + + +@beartype +def execute_playwright_click( + locator_code: list[ParsedPlaywrightCode], + page: Page, + pw_action_args: list[str] = [], + pw_action_kwargs: dict[str, Any] = {}, +) -> None: + locator = locate(locator_code, page) + + # perform the action + locator.click(*pw_action_args, **pw_action_kwargs) + + +@beartype +async def aexecute_playwright_click( + locator_code: list[ParsedPlaywrightCode], + page: APage, + pw_action_args: list[str] = [], + pw_action_kwargs: dict[str, Any] = {}, +) -> None: + locator = await alocate(locator_code, page) + + # perform the action + await locator.click(*pw_action_args, **pw_action_kwargs) + + +@beartype +def execute_playwright_hover( + locator_code: list[ParsedPlaywrightCode], page: Page +) -> None: + locator = locate(locator_code, page) + + # perform the action + locator.hover() + + +@beartype +async def aexecute_playwright_hover( + locator_code: list[ParsedPlaywrightCode], page: APage +) -> None: + locator = await alocate(locator_code, page) + + # perform the action + await locator.hover() + + +@beartype +def execute_playwright_type( + text: str, + locator_code: list[ParsedPlaywrightCode], + page: Page, + pw_action_args: list[str] = [], + pw_action_kwargs: dict[str, Any] = {}, +) -> None: + locator = locate(locator_code, page) + # perform the action + pw_action_args = [text] + pw_action_args # text is the first argument + locator.type(*pw_action_args, **pw_action_kwargs) + + +@beartype +async def aexecute_playwright_type( + text: str, + locator_code: list[ParsedPlaywrightCode], + page: APage, + pw_action_args: list[str] = [], + pw_action_kwargs: dict[str, Any] = {}, +) -> None: + locator = await alocate(locator_code, page) + # perform the action + pw_action_args = [text] + pw_action_args # text is the first argument + await locator.type(*pw_action_args, **pw_action_kwargs) + + +@beartype +def execute_playwright_select_option( + locator_code: list[ParsedPlaywrightCode], + page: Page, + pw_action_args: list[str] = [], + pw_action_kwargs: dict[str, Any] = {}, +) -> None: + locator = locate(locator_code, page) + # perform the action + locator.select_option(*pw_action_args, **pw_action_kwargs) + + +@beartype +async def aexecute_playwright_select_option( + locator_code: list[ParsedPlaywrightCode], + page: APage, + pw_action_args: list[str] = [], + pw_action_kwargs: dict[str, Any] = {}, +) -> None: + locator = await alocate(locator_code, page) + # perform the action + await locator.select_option(*pw_action_args, **pw_action_kwargs) + + +@beartype +def execute_playwright_check( + locator_code: list[ParsedPlaywrightCode], page: Page +) -> None: + locator = locate(locator_code, page) + # perform the action + locator.check() + + +@beartype +async def aexecute_playwright_check( + locator_code: list[ParsedPlaywrightCode], page: APage +) -> None: + locator = await alocate(locator_code, page) + # perform the action + await locator.check() + + +@beartype +def execute_action_webrl( + action: Action, + page: Page, + browser_ctx: BrowserContext, + obseration_processor: ObservationProcessor, + sleep_after_execution: float = 0.0, +) -> Page: + """Execute the action on the ChromeDriver.""" + action_type = action["action_type"] + num_tabs_before = len(browser_ctx.pages) + match action_type: + case ActionTypes.NONE: + pass + case ActionTypes.SCROLL: + direction = "up" if "up" in action["direction"] else "down" + execute_scroll_webrl(direction, page) + case ActionTypes.KEY_PRESS: + keys = action["key_comb"] + execute_key_press(keys, page) + case ActionTypes.MOUSE_CLICK: + execute_mouse_click(action["coords"][0], action["coords"][1], page) + case ActionTypes.CLICK: + # check each kind of locator in order + # TODO[shuyanzh]: order is temp now + element_id = action["element_id"] + element_center = obseration_processor.get_element_center(element_id, page) # type: ignore[attr-defined] + execute_mouse_click(element_center[0], element_center[1], page) + case ActionTypes.HOVER: + element_id = action["element_id"] + element_center = obseration_processor.get_element_center(element_id) # type: ignore[attr-defined] + execute_mouse_hover(element_center[0], element_center[1], page) + case ActionTypes.TYPE: + element_id = action["element_id"] + element_center = obseration_processor.get_element_center(element_id) # type: ignore[attr-defined] + execute_mouse_click(element_center[0], element_center[1], page) + execute_key_press("Meta+A", page) + execute_key_press('Backspace', page) + # execute_mouse_click(element_center[0], element_center[1], page) + text = _keys2ids(action["text"]) + execute_type(text, page) + case ActionTypes.SEARCH: + element_id = action["element_id"] + element_center = obseration_processor.get_element_center(element_id) # type: ignore[attr-defined] + execute_mouse_click(element_center[0], element_center[1], page) + execute_key_press("Meta+A", page) + execute_key_press('Backspace', page) + # execute_mouse_click(element_center[0], element_center[1], page) + text = _keys2ids(action["text"]) + execute_type(text, page) + time.sleep(2) + execute_key_press("Enter", page) + case ActionTypes.GO_BACK: + page.go_back() + case ActionTypes.GO_FORWARD: + page.go_forward() + case ActionTypes.GOTO_URL: + page.goto(action["url"]) + case ActionTypes.SELECT_DROPDOWN_OPTION: + # Click + element_id = action["element_id"] + argument = action["argument"] + element_center = obseration_processor.get_element_center(element_id, page) # type: ignore[attr-defined] + execute_mouse_click(element_center[0], element_center[1], page) + # get element + device_pixel_ratio = page.evaluate("window.devicePixelRatio") + center_x, center_y = element_center[0] * page.viewport_size["width"], element_center[1] * page.viewport_size["height"] + last_turn_element = page.evaluate_handle(f"""() => document.elementFromPoint({center_x / device_pixel_ratio}, {center_y / device_pixel_ratio})""") + # get select element options + select_element_options = [{"value": option.get_attribute('value'), "text": option.text_content().strip(' \n')} for option in + last_turn_element.query_selector_all("option")] + selector_option_dict = dict((o["text"].lower(), o["value"]) for o in select_element_options) + value = None + for key in selector_option_dict.keys(): + if argument.lower() in key.lower(): + value = selector_option_dict[key] + break + if value is not None: + last_turn_element.select_option(value=value) + case _: + raise ValueError(f"Unknown action type: {action_type}") + + page.wait_for_timeout(int(sleep_after_execution * 1000)) + num_tabs_now = len(browser_ctx.pages) + # if a new tab is opened by clicking, switch to the new tab + if num_tabs_now > num_tabs_before: + page = browser_ctx.pages[-1] + page.bring_to_front() + + return page + +@beartype +def execute_action( + action: Action, + page: Page, + browser_ctx: BrowserContext, + obseration_processor: ObservationProcessor, + sleep_after_execution: float = 0.0, +) -> Page: + """Execute the action on the ChromeDriver.""" + action_type = action["action_type"] + num_tabs_before = len(browser_ctx.pages) + match action_type: + case ActionTypes.NONE: + pass + + case ActionTypes.SCROLL: + direction = "up" if "up" in action["direction"] else "down" + execute_scroll(direction, page) + case ActionTypes.KEY_PRESS: + keys = action["key_comb"] + execute_key_press(keys, page) + + case ActionTypes.MOUSE_CLICK: + execute_mouse_click(action["coords"][0], action["coords"][1], page) + case ActionTypes.CLEAR: + element_id = action["element_id"] + element_center = obseration_processor.get_element_center(element_id) # type: ignore[attr-defined] + execute_mouse_click(element_center[0], element_center[1], page) + execute_key_press("Meta+A", page) + execute_key_press('Backspace', page) + case ActionTypes.MOUSE_HOVER: + execute_mouse_hover(action["coords"][0], action["coords"][1], page) + case ActionTypes.KEYBOARD_TYPE: + execute_type(action["text"], page) + case ActionTypes.CLICK: + # check each kind of locator in order + # TODO[shuyanzh]: order is temp now + if action["element_id"]: + element_id = action["element_id"] + element_center = obseration_processor.get_element_center(element_id) # type: ignore[attr-defined] + execute_mouse_click(element_center[0], element_center[1], page) + elif action["element_role"] and action["element_name"]: + element_role = int(action["element_role"]) + element_name = action["element_name"] + nth = action["nth"] + execute_focus(element_role, element_name, nth, page) + execute_click_current(page) + elif action["pw_code"]: + parsed_code = parse_playwright_code(action["pw_code"]) + locator_code = parsed_code[:-1] + # [shuyanzh], don't support action args and kwargs now + execute_playwright_click(locator_code=locator_code, page=page) + else: + raise ValueError("No proper locator found for click action") + case ActionTypes.HOVER: + if action["element_id"]: + element_id = action["element_id"] + element_center = obseration_processor.get_element_center(element_id) # type: ignore[attr-defined] + execute_mouse_hover(element_center[0], element_center[1], page) + elif action["element_role"] and action["element_name"]: + element_role = int(action["element_role"]) + element_name = action["element_name"] + nth = action["nth"] + execute_focus(element_role, element_name, nth, page) + elif action["pw_code"]: + parsed_code = parse_playwright_code(action["pw_code"]) + locator_code = parsed_code[:-1] + # [shuyanzh], don't support action args and kwargs now + execute_playwright_hover(locator_code=locator_code, page=page) + else: + raise NotImplementedError( + "No proper locator found for hover action" + ) + case ActionTypes.TYPE: + if action["element_id"]: + element_id = action["element_id"] + element_center = obseration_processor.get_element_center(element_id) # type: ignore[attr-defined] + execute_mouse_click(element_center[0], element_center[1], page) + execute_type(action["text"], page) + elif action["element_role"] and action["element_name"]: + element_role = int(action["element_role"]) + element_name = action["element_name"] + nth = action["nth"] + execute_focus(element_role, element_name, nth, page) + execute_type(action["text"], page) + elif action["pw_code"]: + parsed_code = parse_playwright_code(action["pw_code"]) + locator_code = parsed_code[:-1] + text = parsed_code[-1]["arguments"][0] + # [shuyanzh], don't support action args and kwargs now + execute_playwright_type( + text=text, locator_code=locator_code, page=page + ) + else: + raise NotImplementedError( + "No proper locator found for type action" + ) + + case ActionTypes.PAGE_FOCUS: + page = browser_ctx.pages[action["page_number"]] + page.bring_to_front() + case ActionTypes.NEW_TAB: + page = browser_ctx.new_page() + case ActionTypes.GO_BACK: + page.go_back() + case ActionTypes.GO_FORWARD: + page.go_forward() + case ActionTypes.GOTO_URL: + page.goto(action["url"]) + case ActionTypes.PAGE_CLOSE: + page.close() + if len(browser_ctx.pages) > 0: + page = browser_ctx.pages[-1] + else: + page = browser_ctx.new_page() + + case ActionTypes.SELECT_OPTION: + if action["pw_code"]: + parsed_code = parse_playwright_code(action["pw_code"]) + locator_code = parsed_code[:-1] + execute_playwright_select_option(locator_code, page) + else: + raise NotImplementedError( + "No proper locator found for select option action" + ) + case ActionTypes.CHECK: + if action["pw_code"]: + parsed_code = parse_playwright_code(action["pw_code"]) + locator_code = parsed_code[:-1] + execute_playwright_check(locator_code, page) + else: + raise NotImplementedError( + "No proper locator found for select option action" + ) + + case _: + raise ValueError(f"Unknown action type: {action_type}") + + page.wait_for_timeout(int(sleep_after_execution * 1000)) + num_tabs_now = len(browser_ctx.pages) + # if a new tab is opened by clicking, switch to the new tab + if num_tabs_now > num_tabs_before: + page = browser_ctx.pages[-1] + page.bring_to_front() + + return page + + + +@beartype +async def aexecute_action( + action: Action, page: APage, browser_ctx: ABrowserContext +) -> APage: + """Execute the async action on the ChromeDriver.""" + action_type = action["action_type"] + match action_type: + case ActionTypes.NONE: + pass + case ActionTypes.SCROLL: + direction = "up" if "up" in action["direction"] else "down" + await aexecute_scroll(direction, page) + case ActionTypes.KEY_PRESS: + keys = action["key_comb"] + await aexecute_key_press(keys, page) + + case ActionTypes.MOUSE_CLICK: + await aexecute_mouse_click( + action["coords"][0], action["coords"][1], page + ) + case ActionTypes.CLEAR: + element_id = action["element_id"] + element_center = obseration_processor.get_element_center(element_id) # type: ignore[attr-defined] + await execute_mouse_click(element_center[0], element_center[1], page) + await execute_key_press("Meta+A", page) + await execute_key_press('Backspace', page) + case ActionTypes.MOUSE_HOVER: + await aexecute_mouse_hover( + action["coords"][0], action["coords"][1], page + ) + case ActionTypes.KEYBOARD_TYPE: + await aexecute_type(action["text"], page) + + case ActionTypes.CLICK: + # check each kind of locator in order + # TODO[shuyanzh]: order is temp now + if action["element_id"]: + raise NotImplementedError + elif action["element_role"] and action["element_name"]: + element_role = int(action["element_role"]) + element_name = action["element_name"] + nth = action["nth"] + await aexecute_focus(element_role, element_name, nth, page) + await aexecute_click_current(page) + elif action["pw_code"]: + parsed_code = parse_playwright_code(action["pw_code"]) + locator_code = parsed_code[:-1] + # [shuyanzh], don't support action args and kwargs now + await aexecute_playwright_click( + locator_code=locator_code, page=page + ) + else: + raise ValueError("No proper locator found for click action") + case ActionTypes.HOVER: + if action["element_id"]: + raise NotImplementedError + elif action["element_role"] and action["element_name"]: + element_role = int(action["element_role"]) + element_name = action["element_name"] + nth = action["nth"] + await aexecute_focus(element_role, element_name, nth, page) + elif action["pw_code"]: + parsed_code = parse_playwright_code(action["pw_code"]) + locator_code = parsed_code[:-1] + # [shuyanzh], don't support action args and kwargs now + await aexecute_playwright_hover( + locator_code=locator_code, page=page + ) + else: + raise NotImplementedError( + "No proper locator found for hover action" + ) + case ActionTypes.TYPE: + if action["element_id"]: + raise NotImplementedError + elif action["element_role"] and action["element_name"]: + element_role = int(action["element_role"]) + element_name = action["element_name"] + nth = action["nth"] + await aexecute_focus(element_role, element_name, nth, page) + await aexecute_type(action["text"], page) + elif action["pw_code"]: + parsed_code = parse_playwright_code(action["pw_code"]) + locator_code = parsed_code[:-1] + text = parsed_code[-1]["arguments"][0] + # [shuyanzh], don't support action args and kwargs now + await aexecute_playwright_type( + text=text, locator_code=locator_code, page=page + ) + else: + raise NotImplementedError( + "No proper locator found for type action" + ) + + case ActionTypes.PAGE_FOCUS: + page = browser_ctx.pages[action["page_number"]] + await page.bring_to_front() + case ActionTypes.NEW_TAB: + page = await browser_ctx.new_page() + case ActionTypes.GO_BACK: + await page.go_back() + case ActionTypes.GO_FORWARD: + await page.go_forward() + case ActionTypes.GOTO_URL: + await page.goto(action["url"]) + case ActionTypes.PAGE_CLOSE: + await page.close() + if len(browser_ctx.pages) > 0: + page = browser_ctx.pages[-1] + else: + page = await browser_ctx.new_page() + + case ActionTypes.SELECT_OPTION: + if action["pw_code"]: + parsed_code = parse_playwright_code(action["pw_code"]) + locator_code = parsed_code[:-1] + await aexecute_playwright_select_option(locator_code, page) + else: + raise NotImplementedError( + "No proper locator found for select option action" + ) + case ActionTypes.CHECK: + if action["pw_code"]: + parsed_code = parse_playwright_code(action["pw_code"]) + locator_code = parsed_code[:-1] + await aexecute_playwright_check(locator_code, page) + else: + raise NotImplementedError( + "No proper locator found for select option action" + ) + + case _: + raise ValueError(f"Unknown action type: {action_type}") + + return page + + +@beartype +def parse_playwright_code(code: str) -> list[ParsedPlaywrightCode]: + # extract function calls + if not code.startswith("page."): + raise ValueError( + f'Playwright action must start with "page.", but got {code}' + ) + + regex = r"\.(?![^\(\)]*\))" + chain = re.split(regex, code)[1:] + + parsed_chain = [] + + for item in chain: + tree = ast.parse(item) + funcs = [] + for node in ast.walk(tree): + if isinstance(node, ast.Call): + function_name = node.func.id # type: ignore[attr-defined] + arguments = [ + ast.literal_eval(arg) if isinstance(arg, ast.Str) else arg + for arg in node.args + ] + keywords = { + str(kw.arg): ast.literal_eval(kw.value) + for kw in node.keywords + } + funcs.append( + ParsedPlaywrightCode( + { + "function_name": function_name, + "arguments": arguments, + "keywords": keywords, + } + ) + ) + + if len(funcs) != 1: + raise ValueError(f"Fail to parse {item} in {code}") + + if ( + funcs[0]["function_name"] + not in PLAYWRIGHT_LOCATORS + PLAYWRIGHT_ACTIONS + ): + raise ValueError( + f"Invalid playwright code {item}, ", + f"the function needs to be one of {PLAYWRIGHT_LOCATORS + PLAYWRIGHT_ACTIONS}", + ) + + parsed_chain.append(funcs[0]) + + last_action = parsed_chain[-1] + if last_action["function_name"] not in PLAYWRIGHT_ACTIONS: + raise ValueError( + f"Invalid playwright action {last_action},", + f"the action needs to be one of {PLAYWRIGHT_ACTIONS}", + ) + + return parsed_chain + + +@beartype +class ActionParsingError(Exception): + def __init__(self, message: str) -> None: + self.message = message + super().__init__(self.message) + + +@beartype +def create_playwright_action(playwright_code: str) -> Action: + """Main function to return individual playwright action""" + # get the last action + regex = r"\.(?![^\(\)]*\))" + action = re.split(regex, playwright_code)[-1].split("(")[0] + match action: + case "press": + p = r'press\((?:"|\')(.+?)(?:"|\')\)' + match = re.search(p, playwright_code) + if not match: + raise ActionParsingError( + f"Invalid press action, required to be page.press(KEY_COMB_STR)" + ) + key_comb = match.group(1) + return create_key_press_action(key_comb=key_comb) + case "scroll": + direction = "up" if "up" in playwright_code else "down" + return create_scroll_action(direction=direction) + case "click": + return create_click_action(pw_code=playwright_code) + case "clear": + return create_clear_action(pw_code=playwright_code) + case "hover": + return create_hover_action(pw_code=playwright_code) + case "type" | "fill": + p = r'type|fill\((?:"|\')(.+?)(?:"|\')\)' + match = re.search(p, playwright_code) + if not match: + raise ActionParsingError( + f"Invalid type/fill action, required to be page.type(TEXT)" + ) + text = match.group(1) + return create_type_action(text=text, pw_code=playwright_code) + case "select_option": + return create_select_option_action(pw_code=playwright_code) + case "check": + return create_check_action(pw_code=playwright_code) + case "goto": + p = r'goto\((?:"|\')(.+?)(?:"|\')\)' + match = re.search(p, playwright_code) + if not match: + raise ActionParsingError( + f"Invalid goto action, required to be page.goto(URL_STR)" + ) + url = match.group(1) + return create_goto_url_action(url) + case "page_focus": + # get the page number + p = r"page_focus\((\d+)\)" + match = re.search(p, playwright_code) + if not match: + raise ActionParsingError("page focus requires a page number") + page_num = int(match.group(1)) + return create_page_focus_action(page_num) + case "new_tab": + return create_new_tab_action() + case "go_back": + return create_go_back_action() + case "go_forward": + return create_go_forward_action() + case "page_close": + return create_page_close_action() + case "stop": # page.stop(answer) + p = r'stop\(?"(.+)?"\)' + match = re.search(p, playwright_code) + if not match: + answer = "" + else: + answer = match.group(1) + return create_stop_action(answer) + + raise ActionParsingError(f"Unknown playwright action {action}") + + +@beartype +def create_id_based_action(action_str: str) -> Action: + """Main function to return individual id based action""" + action_str = action_str.strip() + if "[" in action_str: + action = action_str.split("[")[0].strip() + else: + actions = action_str.split() + if actions: + action = actions[0].strip() + else: + raise ActionParsingError(f"No action specified: {action_str}") + match action: + case "click": + match = re.search(r"click ?\[(\d+)\]", action_str) + if not match: + raise ActionParsingError(f"Invalid click action {action_str}") + element_id = match.group(1) + return create_click_action(element_id=element_id) + case "clear": + match = re.search(r"clear ?\[(\d+)\]", action_str) + if not match: + raise ActionParsingError(f"Invalid clear action {action_str}") + element_id = match.group(1) + return create_clear_action(element_id=element_id) + case "hover": + match = re.search(r"hover ?\[(\d+)\]", action_str) + if not match: + raise ActionParsingError(f"Invalid hover action {action_str}") + element_id = match.group(1) + return create_hover_action(element_id=element_id) + case "type": + # add default enter flag + if not (action_str.endswith("[0]") or action_str.endswith("[1]")): + action_str += " [1]" + + match = re.search( + r"type ?\[(\d+)\] ?\[(.+)\] ?\[(\d+)\]", action_str + ) + if not match: + raise ActionParsingError(f"Invalid type action {action_str}") + element_id, text, enter_flag = ( + match.group(1), + match.group(2), + match.group(3), + ) + if enter_flag == "1": + text += "\n" + return create_type_action(text=text, element_id=element_id) + case "press": + match = re.search(r"press ?\[(.+)\]", action_str) + if not match: + raise ActionParsingError(f"Invalid press action {action_str}") + key_comb = match.group(1) + return create_key_press_action(key_comb=key_comb) + case "scroll": + # up or down + match = re.search(r"scroll ?\[?(up|down)\]?", action_str) + if not match: + raise ActionParsingError(f"Invalid scroll action {action_str}") + direction = match.group(1) + return create_scroll_action(direction=direction) + case "goto": + match = re.search(r"goto ?\[(.+)\]", action_str) + if not match: + raise ActionParsingError(f"Invalid goto action {action_str}") + url = match.group(1) + return create_goto_url_action(url=url) + case "new_tab": + return create_new_tab_action() + case "go_back": + return create_go_back_action() + case "go_forward": + return create_go_forward_action() + case "tab_focus": + match = re.search(r"tab_focus ?\[(\d+)\]", action_str) + if not match: + raise ActionParsingError( + f"Invalid tab_focus action {action_str}" + ) + page_number = int(match.group(1)) + return create_page_focus_action(page_number) + case "close_tab": + return create_page_close_action() + case "stop": # stop answer + match = re.search(r"stop ?\[(.+)\]", action_str) + if not match: # some tasks don't require an answer + answer = "" + else: + answer = match.group(1) + return create_stop_action(answer) + + raise ActionParsingError(f"Invalid action {action_str}") + + +@beartype +def create_webrl_id_based_action(action_str: str) -> Action: + """Main function to return individual webrl id based action""" + import ast + def remove_comments(code): + # 按行分割代码 + for key in ['exit(','do(','go_backward(']: + if key in code: + return key + code.split(key)[-1] + lines = code.split('\n') + for i, line in enumerate(lines): + if line.strip().startswith('#'): + # 跳过注释行 + continue + else: + # 返回非注释行及其后面的部分 + return '\n'.join(lines[i:]) + return '' + + def parse_function_call(expression): + expression = remove_comments(expression) + # 将字符串解析为 AST + expression = expression.strip() + tree = ast.parse(expression, mode='eval') + # 提取函数名称 + func_call = tree.body + if not isinstance(func_call, ast.Call): + return { + "operation": expression, + } + func_name = func_call.func.id + result = { + "operation": func_name, + } + # 提取参数 + args = func_call.args + kwargs = func_call.keywords + for kw in kwargs: + if func_name == "do" and kw.arg == "action": + result["action"] = ast.literal_eval(kw.value) + # elif func_name == "do" and kw.arg == "argument": + # result["argument"] = ast.literal_eval(kw.value) + else: + if "kwargs" not in result: + result["kwargs"] = {} + if kw.arg == "element": + try: + # 解析元素的内部函数 + inner_func = kw.value + if isinstance(inner_func, ast.Call) and inner_func.func.id == 'find_element_by_instruction': + for inner_kw in inner_func.keywords: + if inner_kw.arg == "instruction": + result["kwargs"]["instruction"] = ast.literal_eval(inner_kw.value) + else: + result["kwargs"][kw.arg] = ast.literal_eval(inner_func) + except Exception: + result["kwargs"][kw.arg] = ast.literal_eval(kw.value) + else: + result["kwargs"][kw.arg] = ast.literal_eval(kw.value) + return result + + action_str = action_str.strip() + try: + action = parse_function_call(action_str) + except Exception as e: + raise ActionParsingError(f"No action specified: {action_str}") + operation = action["operation"] + match operation: + case "do": + action_type = action["action"].lower() + match action_type: + case "press enter": + return create_key_press_action(key_comb='enter') + case "scroll up": + return create_scroll_action(direction='up') + case "scroll down": + return create_scroll_action(direction='down') + case "click": + element_id = action["kwargs"]["element"] + return create_click_action(element_id=element_id) + case "type": + element_id = action["kwargs"]["element"] + text = action["kwargs"]["argument"] + return create_type_action_webrl(text=text, element_id=element_id) + case "hover": + element_id = action["kwargs"]["element"] + return create_hover_action(element_id=element_id) + case "select dropdown option": + element_id = action["kwargs"]["element"] + argument = action["kwargs"]["argument"] + return create_select_dropdown_option_action(argument=argument, element_id=element_id) + case "go forward": + return create_go_forward_action() + case "go backward": + return create_go_back_action() + case "search": + element_id = action["kwargs"]["element"] + text = action["kwargs"]["argument"] + return create_search_action(text=text, element_id=element_id) + case "exit": # stop answer + answer = action['kwargs']['message'] + return create_stop_action(answer) + + raise ActionParsingError(f"Invalid action {action_str}") \ No newline at end of file diff --git a/VAB-WebArena-Lite/new/agent.py b/VAB-WebArena-Lite/new/agent.py index 461ba7e..96caf17 100644 --- a/VAB-WebArena-Lite/new/agent.py +++ b/VAB-WebArena-Lite/new/agent.py @@ -14,6 +14,7 @@ from browser_env.actions import ( create_id_based_action, create_none_action, create_playwright_action, + create_webrl_id_based_action ) from browser_env.utils import Observation, StateInfo from llms import ( @@ -108,13 +109,15 @@ class PromptAgent(Agent): lm_config: lm_config.LMConfig, prompt_constructor: PromptConstructor, captioning_fn = None, + planner_ip = None ) -> None: super().__init__() self.lm_config = lm_config self.prompt_constructor = prompt_constructor self.action_set_tag = action_set_tag self.captioning_fn = captioning_fn - + self.planner_ip = planner_ip + # Check if the model is multimodal. if ("gemini" in lm_config.model or "gpt-4" in lm_config.model and "vision" in lm_config.model or lm_config.provider in ["api", "finetune"]) and type(prompt_constructor) == MultimodalCoTPromptConstructor: self.multimodal_inputs = True @@ -165,7 +168,10 @@ class PromptAgent(Agent): lm_config = self.lm_config n = 0 while True: - response = call_llm(lm_config, prompt) + if self.planner_ip is not None and self.planner_ip != "": + response = call_llm(lm_config, prompt, 'EMPTY', self.planner_ip) + else: + response = call_llm(lm_config, prompt) force_prefix = self.prompt_constructor.instruction[ "meta_data" ].get("force_prefix", "") @@ -183,6 +189,8 @@ class PromptAgent(Agent): action = create_playwright_action(parsed_response) elif self.action_set_tag == "som": action = create_id_based_action(parsed_response) + elif self.action_set_tag == 'webrl_id': + action = create_webrl_id_based_action(parsed_response) else: raise ValueError( f"Unknown action type {self.action_set_tag}" @@ -218,7 +226,8 @@ def construct_agent(args: argparse.Namespace, captioning_fn=None) -> Agent: action_set_tag=args.action_set_tag, lm_config=llm_config, prompt_constructor=prompt_constructor, - captioning_fn=captioning_fn + captioning_fn=captioning_fn, + planner_ip=args.planner_ip ) else: raise NotImplementedError( diff --git a/VAB-WebArena-Lite/new/envs.py b/VAB-WebArena-Lite/new/envs.py new file mode 100644 index 0000000..2dfe7dc --- /dev/null +++ b/VAB-WebArena-Lite/new/envs.py @@ -0,0 +1,319 @@ +import json +import os +import re +import subprocess +import time +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Union + +import numpy as np +import numpy.typing as npt +import requests +from beartype import beartype +from gymnasium import Env +from gymnasium.spaces import Box, Text +from playwright.sync_api import ( + CDPSession, + Page, + Playwright, + ViewportSize, + expect, + sync_playwright, +) + +DATASET = os.environ["DATASET"] +if DATASET == "visualwebarena": + from browser_env.env_config import ( + CLASSIFIEDS, + CLASSIFIEDS_RESET_TOKEN, + ) + +from .actions import Action, execute_action, get_action_space, execute_action_webrl +from .processors import ObservationHandler, ObservationMetadata +from .utils import ( + AccessibilityTree, + DetachedPage, + Observation, + png_bytes_to_numpy, +) + + +@dataclass +class PlaywrightScript: + function: str # goto, get_by_role + destination: str # https://www.google.com/, combobox + name: str | None = None # Search, Avatar 2009 + operation: str | None = None # click, fill, press + value: str | None = None # avatar movie, Enter + + +def parse_action(action: str) -> PlaywrightScript: + splitted = action.strip().split(" ") + assert len(splitted) >= 2 + match splitted[:2]: + case ["goto", url]: + assert len(splitted) == 2 + return PlaywrightScript("goto", url) + case ["get_by_role", destination]: + assert len(splitted) >= 4 + match splitted[2:]: + case [name, operation]: + return PlaywrightScript( + "get_by_role", destination, name, operation + ) + case [name, operation, value]: + return PlaywrightScript( + "get_by_role", destination, name, operation, value + ) + case _: + raise ValueError("Invalid action") + case _: + raise ValueError(f"Invalid action {action}") + + +class ScriptBrowserEnv(Env[dict[str, Observation], Action]): + """ + The goal of this environment is to produce a prototype of a browser environment. + In the end, we want to support a fully configurable browser environment with wide + range of action spaces and observation spaces, both structured and unstructured. + But in this prototype, we just support action space specified by Playwright script, + and observation space is the html content of the page. + """ + + @beartype + def __init__( + self, + max_page_length: int = 8192, + headless: bool = True, + slow_mo: int = 0, + observation_type: str = "html", + current_viewport_only: bool = False, + viewport_size: ViewportSize = {"width": 1280, "height": 720}, + save_trace_enabled: bool = False, + sleep_after_execution: float = 0.0, + captioning_fn=None, + ): + # TODO: make Space[Action] = ActionSpace + self.action_space = get_action_space() # type: ignore[assignment] + self.headless = headless + self.slow_mo = slow_mo + self.current_viewport_only = current_viewport_only + self.reset_finished = False + self.viewport_size = viewport_size + self.save_trace_enabled = save_trace_enabled + self.sleep_after_execution = sleep_after_execution + + match observation_type: + case "html" | "accessibility_tree" | "accessibility_tree_with_captioner" | "webrl": + self.text_observation_type = observation_type + self.image_observation_type = "" + self.main_observation_type = "text" + case "image": + self.image_observation_type = observation_type + self.text_observation_type = "" # type: ignore[assignment] + self.main_observation_type = "image" + case "image_som": + self.image_observation_type = observation_type + self.text_observation_type = observation_type # type: ignore[assignment] + self.main_observation_type = "image" + case _: + raise ValueError( + f"Unsupported observation type: {observation_type}" + ) + + self.observation_handler = ObservationHandler( + self.main_observation_type, + self.text_observation_type, + self.image_observation_type, + self.current_viewport_only, + self.viewport_size, + captioning_fn, + ) + + self.observation_space = ( + self.observation_handler.get_observation_space() + ) + + @beartype + def setup(self, config_file: Path | None = None) -> None: + self.context_manager = sync_playwright() + self.playwright = self.context_manager.__enter__() + self.browser = self.playwright.chromium.launch( + headless=self.headless, slow_mo=self.slow_mo + ) + + if config_file: + with open(config_file, "r") as f: + instance_config = json.load(f) + else: + instance_config = {} + + # Reset site if needed. Currently only supported for Classifieds. + # TODO(jykoh): Add reset functionality for Shopping/Reddit. + if instance_config.get("require_reset", False): + if "classifieds" in instance_config["sites"]: + # Send POST request to __CLASSIFIEDS__/index.php?page=reset with token=CLASSIFIEDS_TOKEN + response = requests.post( + f"{CLASSIFIEDS}/index.php?page=reset", + data={"token": CLASSIFIEDS_RESET_TOKEN}, + ) + + # Check if the request was successful + if response.status_code == 200: + print("Reset Classifieds site.") + else: + print( + "Failed to reset Classifieds site:", + response.status_code, + ) + else: + print( + "WARNING: Reset is not supported for this site. Please manually reset the site." + ) + + storage_state = instance_config.get("storage_state", None) + start_url = instance_config.get("start_url", None) + geolocation = instance_config.get("geolocation", None) + + # Use custom viewport size if specified in the config, otherwise use the default. + viewport_size = self.viewport_size.copy() + viewport_size.update(instance_config.get("viewport_size", {})) + self.observation_handler.viewport_size = viewport_size + + self.context = self.browser.new_context( + viewport=viewport_size, + storage_state=storage_state, + geolocation=geolocation, + device_scale_factor=1, + ) + if self.save_trace_enabled: + self.context.tracing.start(screenshots=True, snapshots=True) + + if start_url: + start_urls = start_url.split(" |AND| ") + for url in start_urls: + page = self.context.new_page() + if self.text_observation_type in [ + "accessibility_tree", + "accessibility_tree_with_captioner", + ]: + client = page.context.new_cdp_session(page) + client.send("Accessibility.enable") + client.detach() + page.goto(url) + # set the first page as the current page + self.page = self.context.pages[0] + self.page.bring_to_front() + else: + self.page = self.context.new_page() + if self.text_observation_type in [ + "accessibility_tree", + "accessibility_tree_with_captioner", + ]: + client = self.page.context.new_cdp_session(self.page) + client.send("Accessibility.enable") + client.detach() + + def _get_obs(self) -> dict[str, Observation]: + obs = self.observation_handler.get_observation(self.page) + return obs + + def _get_obs_metadata(self) -> dict[str, ObservationMetadata]: + metadata = self.observation_handler.get_observation_metadata() + return metadata + + @beartype + def reset( + self, + *, + seed: int | None = None, + options: dict[str, str] | None = None, + ) -> tuple[dict[str, Observation], dict[str, Any]]: + """ + Reset the environment. + :param options: options for the environment. The current supported options are: + - "storage_state": the storage state of the browser. It is a file path to a json file. + """ + super().reset(seed=seed, options=options) + if self.reset_finished: + self.context_manager.__exit__() + + if options is not None and "config_file" in options: + config_file = Path(options["config_file"]) + if config_file.exists(): + self.setup(config_file=config_file) + else: + raise ValueError(f"Config file {config_file} does not exist.") + else: + self.setup() + self.reset_finished = True + timeout_in_ms = 120000 + self.page.set_default_timeout(timeout_in_ms) + self.page.set_default_navigation_timeout(timeout_in_ms) + self.page.wait_for_timeout(int(self.sleep_after_execution * 1000)) + + observation = self._get_obs() + observation_metadata = self._get_obs_metadata() + info = { + "page": DetachedPage(self.page.url, ""), + "fail_error": "", + "observation_metadata": observation_metadata, + } + + return (observation, info) + + def save_trace(self, trace_path: str | Path) -> None: + if self.save_trace_enabled: + self.context.tracing.stop(path=trace_path) + + def close(self) -> None: + if self.reset_finished: + self.context_manager.__exit__() + + def step( + self, action: Action + ) -> tuple[dict[str, Observation], float, bool, bool, dict[str, Any]]: + if not self.reset_finished: + raise RuntimeError("Call reset first before calling step.") + + success = False + fail_error = "" + try: + if self.text_observation_type == 'webrl': + self.page = execute_action_webrl( + action, + self.page, + self.context, + self.observation_handler.action_processor, + self.sleep_after_execution, + ) + else: + self.page = execute_action( + action, + self.page, + self.context, + self.observation_handler.action_processor, + self.sleep_after_execution, + ) + success = True + except Exception as e: + fail_error = str(e) + + observation = self._get_obs() + observation_metadata = self._get_obs_metadata() + + info = { + "page": DetachedPage(self.page.url, self.page.content()), + "fail_error": fail_error, + "observation_metadata": observation_metadata, + } + msg = ( + observation, + float(success), # reward + False, # terminated + False, # truncated + info, + ) + return msg diff --git a/VAB-WebArena-Lite/new/evaluators.py b/VAB-WebArena-Lite/new/evaluators.py index 85b5363..aac511b 100644 --- a/VAB-WebArena-Lite/new/evaluators.py +++ b/VAB-WebArena-Lite/new/evaluators.py @@ -172,6 +172,10 @@ class StringEvaluator(Evaluator): # prevent false positive (e.g, 0) if len(word_tokenize(clean_ref)) == 1: tok_pred = word_tokenize(clean_pred) + for token in tok_pred: + if '/' in token: + sub_tokens = token.split('/') + tok_pred.extend(sub_tokens) return float(clean_ref in tok_pred) else: return float(clean_ref in clean_pred) diff --git a/VAB-WebArena-Lite/new/helper_functions_browser.py b/VAB-WebArena-Lite/new/helper_functions_browser.py new file mode 100644 index 0000000..cb902ba --- /dev/null +++ b/VAB-WebArena-Lite/new/helper_functions_browser.py @@ -0,0 +1,239 @@ +import base64 +import io +import json +import re +from pathlib import Path +from typing import Any + +from PIL import Image + +from agent.prompts import * +from browser_env import ( + Action, + ActionTypes, + ObservationMetadata, + StateInfo, + action2str, +) + +HTML_TEMPLATE = """ + + + + + + + {body} + + +""" + + +def get_render_action( + action: Action, + observation_metadata: dict[str, ObservationMetadata], + action_set_tag: str, +) -> str: + """Parse the predicted actions for rendering purpose. More comprehensive information""" + match action_set_tag: + case "id_accessibility_tree": + text_meta_data = observation_metadata["text"] + if action["element_id"] in text_meta_data["obs_nodes_info"]: + node_content = text_meta_data["obs_nodes_info"][ + action["element_id"] + ]["text"] + else: + node_content = "No match found" + + action_str = f"
{action['raw_prediction']}
" + action_str += f"
{repr(action)}
" + action_str += f"
{action2str(action, action_set_tag, node_content)}
" + + case "som": + text_meta_data = observation_metadata["text"] + if action["element_id"] in text_meta_data["obs_nodes_info"]: + node_content = text_meta_data["obs_nodes_info"][ + action["element_id"] + ] + else: + node_content = "No match found" + action_str = f"
{action['raw_prediction']}
" + action_str += f"
{repr(action)}
" + action_str += f"
{action2str(action, action_set_tag, node_content)}
" + + case "playwright": + action_str = action["pw_code"] + case "webrl_id": + action_str = action["raw_prediction"] + case _: + raise ValueError( + f"Unknown action type {action['action_type'], action_set_tag}" + ) + return action_str + + +def get_action_description( + action: Action, + observation_metadata: dict[str, ObservationMetadata], + action_set_tag: str, + prompt_constructor: PromptConstructor | None, +) -> str: + """Generate the text version of the predicted actions to store in action history for prompt use. + May contain hint information to recover from the failures""" + + match action_set_tag: + case "id_accessibility_tree": + text_meta_data = observation_metadata["text"] + if action["action_type"] in [ + ActionTypes.CLICK, + ActionTypes.HOVER, + ActionTypes.TYPE, + ]: + action_name = str(action["action_type"]).split(".")[1].lower() + if action["element_id"] in text_meta_data["obs_nodes_info"]: + node_content = text_meta_data["obs_nodes_info"][ + action["element_id"] + ]["text"] + node_content = " ".join(node_content.split()[1:]) + action_str = action2str( + action, action_set_tag, node_content + ) + else: + action_str = f"Attempt to perfom \"{action_name}\" on element \"[{action['element_id']}]\" but no matching element found. Please check the observation more carefully." + else: + if ( + action["action_type"] == ActionTypes.NONE + and prompt_constructor is not None + ): + action_splitter = prompt_constructor.instruction[ + "meta_data" + ]["action_splitter"] + action_str = f'The previous prediction you issued was "{action["raw_prediction"]}". However, the format was incorrect. Ensure that the action is wrapped inside a pair of {action_splitter} and enclose arguments within [] as follows: {action_splitter}action [arg] ...{action_splitter}.' + else: + action_str = action2str(action, action_set_tag, "") + + case "som": + text_meta_data = observation_metadata["image"] + if action["action_type"] in [ + ActionTypes.CLICK, + ActionTypes.HOVER, + ActionTypes.TYPE, + ]: + action_name = str(action["action_type"]).split(".")[1].lower() + if action["element_id"] in text_meta_data["obs_nodes_info"]: + action_str = action2str(action, action_set_tag, "") + else: + print( + 'action["element_id"], text_meta_data["obs_nodes_info"]', + action["element_id"], + text_meta_data["obs_nodes_info"], + ) + action_str = f"Attempt to perfom \"{action_name}\" on element \"[{action['element_id']}]\" but no matching element found. Please check the observation more carefully." + else: + if ( + action["action_type"] == ActionTypes.NONE + and prompt_constructor is not None + ): + action_splitter = prompt_constructor.instruction[ + "meta_data" + ]["action_splitter"] + action_str = f'The previous prediction you issued was "{action["raw_prediction"]}". However, the format was incorrect. Ensure that the action is wrapped inside a pair of {action_splitter} and enclose arguments within [] as follows: {action_splitter}action [arg] ...{action_splitter}.' + else: + action_str = action2str(action, action_set_tag, "") + + case "playwright": + action_str = action["pw_code"] + + case "webrl_id": + action_str = action["raw_prediction"] + + case _: + raise ValueError(f"Unknown action type {action['action_type']}") + + return action_str + + +class RenderHelper(object): + """Helper class to render text and image observations and meta data in the trajectory""" + + def __init__( + self, config_file: str, result_dir: str, action_set_tag: str + ) -> None: + with open(config_file, "r") as f: + _config = json.load(f) + _config_str = "" + for k, v in _config.items(): + _config_str += f"{k}: {v}\n" + _config_str = f"
{_config_str}
\n" + task_id = _config["task_id"] + + self.action_set_tag = action_set_tag + + self.render_file = open( + Path(result_dir) / f"render_{task_id}.html", "a+", encoding="utf-8" + ) + self.render_file.truncate(0) + # write init template + self.render_file.write(HTML_TEMPLATE.format(body=f"{_config_str}")) + self.render_file.read() + self.render_file.flush() + + def render( + self, + action: Action, + state_info: StateInfo, + meta_data: dict[str, Any], + render_screenshot: bool = False, + ) -> None: + """Render the trajectory""" + # text observation + observation = state_info["observation"] + text_obs = observation["text"] + info = state_info["info"] + new_content = f"

New Page

\n" + new_content += f"

URL: {state_info['info']['page'].url}

\n" + new_content += f"
{text_obs}
\n" + + if render_screenshot: + # image observation + img_obs = observation["image"] + image = Image.fromarray(img_obs) + byte_io = io.BytesIO() + image.save(byte_io, format="PNG") + byte_io.seek(0) + image_bytes = base64.b64encode(byte_io.read()) + image_str = image_bytes.decode("utf-8") + new_content += f"\n" + + # meta data + new_content += f"
{meta_data['action_history'][-1]}
\n" + + # action + action_str = get_render_action( + action, + info["observation_metadata"], + action_set_tag=self.action_set_tag, + ) + # with yellow background + action_str = f"
{action_str}
" + new_content += f"{action_str}\n" + + # add new content + self.render_file.seek(0) + html = self.render_file.read() + html_body = re.findall(r"(.*?)", html, re.DOTALL)[0] + html_body += new_content + + html = HTML_TEMPLATE.format(body=html_body) + self.render_file.seek(0) + self.render_file.truncate() + self.render_file.write(html) + self.render_file.flush() + + def close(self) -> None: + self.render_file.close() diff --git a/VAB-WebArena-Lite/new/helper_functions.py b/VAB-WebArena-Lite/new/helper_functions_eval.py similarity index 100% rename from VAB-WebArena-Lite/new/helper_functions.py rename to VAB-WebArena-Lite/new/helper_functions_eval.py diff --git a/VAB-WebArena-Lite/new/html_tools/__init__.py b/VAB-WebArena-Lite/new/html_tools/__init__.py new file mode 100755 index 0000000..6aac581 --- /dev/null +++ b/VAB-WebArena-Lite/new/html_tools/__init__.py @@ -0,0 +1,7 @@ +from .identifier import IdentifierTool +from .prompt import HtmlPrompt +from .html_parser import HtmlParser + +from .utils import print_html_object +from .configs import basic_attrs, mind2web_keep_attrs +from .fetch import get_parsed_html \ No newline at end of file diff --git a/VAB-WebArena-Lite/new/html_tools/configs/__init__.py b/VAB-WebArena-Lite/new/html_tools/configs/__init__.py new file mode 100755 index 0000000..b900a67 --- /dev/null +++ b/VAB-WebArena-Lite/new/html_tools/configs/__init__.py @@ -0,0 +1,3 @@ +from .html_prompt import prompts +from .config import basic_attrs, mind2web_keep_attrs, miniwob_attrs +from .config import config_meta \ No newline at end of file diff --git a/VAB-WebArena-Lite/new/html_tools/configs/config.py b/VAB-WebArena-Lite/new/html_tools/configs/config.py new file mode 100755 index 0000000..d7da96c --- /dev/null +++ b/VAB-WebArena-Lite/new/html_tools/configs/config.py @@ -0,0 +1,56 @@ +basic_attrs = [ + 'title', + 'value', + 'type', + 'placeholder', + 'selected', + 'data-value', + 'data-text', + 'data-testid', + 'data-label', + 'data-bbox', + 'data-status' +] + +mind2web_keep_attrs = [ + 'alt', + 'aria_description', + 'aria_label', + 'aria_role', + 'input_checked', + 'input_value', + 'label', + 'name', + 'option_selected', + 'placeholder', + 'role', + 'text_value', + 'title', + 'type', + 'value', +] + +miniwob_attrs = [ + 'id', + 'type', + 'value', +] + +config_meta = """ +======= Configs ======= +Columns: + - id: {id_attr} + - label: {label_attr} +Position: {use_position} + - window: {window_size} + - rect_dict: {rect} +Keep: + - parents: {parent_chain} + - attrs: {keep_attrs} + - elems: {keep_elem} + - obs_elem: {obs_elem} +Generator: + - prompt: {prompt_name} + - label: {identifier_name} +======================== +""" \ No newline at end of file diff --git a/VAB-WebArena-Lite/new/html_tools/configs/html_prompt.py b/VAB-WebArena-Lite/new/html_tools/configs/html_prompt.py new file mode 100755 index 0000000..904d021 --- /dev/null +++ b/VAB-WebArena-Lite/new/html_tools/configs/html_prompt.py @@ -0,0 +1,22 @@ +refine_prompt = { + 'dom': '<{tag}{label}|{attr}{content}{subtree} >', + 'label': '[{label}]', + 'attr': '{attr}', + 'attr_splitter': '; ', + 'subtree_splitter': ' ', +} + +xml_prompt = { + 'dom': '<{tag}{label}{attr}>{content}{subtree} ', + 'label': ' id="{label}"', + 'attr': '{key}="{attr}"', + 'attr_splitter': ' ', + 'subtree_splitter': ' ', +} + +prompts = { + 'refine': refine_prompt, + 'xml': xml_prompt, + 'new_data': refine_prompt, +} + \ No newline at end of file diff --git a/VAB-WebArena-Lite/new/html_tools/fetch.py b/VAB-WebArena-Lite/new/html_tools/fetch.py new file mode 100755 index 0000000..25d5637 --- /dev/null +++ b/VAB-WebArena-Lite/new/html_tools/fetch.py @@ -0,0 +1,108 @@ +import os +import json +import base64 +from .html_parser import HtmlParser +from .configs import basic_attrs +from .scripts import * + +def get_window(page): + x = page.evaluate("window.scrollX") + y = page.evaluate("window.scrollY") + w = page.evaluate("window.innerWidth") + h = page.evaluate("window.innerHeight") + return (x, y, w, h) + +def modify_page(page): + page.wait_for_timeout(500) + + try: + page.evaluate(remove_id_script) + except: + pass + + packet = { + "raw_html": page.evaluate("document.documentElement.outerHTML"), + "window": get_window(page) + } + + page.evaluate(prepare_script) + page.wait_for_timeout(100) + + img_bytes = page.screenshot(path="debug_info/screenshot_raw.png") + raw_image = base64.b64encode(img_bytes).decode() + + page.evaluate(clickable_checker_script) + page.wait_for_timeout(50) + + # get all clickable elements + start_id = 0 + items, start_id = page.evaluate(label_script, { + "selector": ".possible-clickable-element", + "startIndex": start_id + }) + page.wait_for_timeout(50) + + # mark our own labels and get the images + items = page.evaluate(label_marker_script, items) + page.wait_for_timeout(100) + img_bytes = page.screenshot(path="debug_info/marked.png") + marked_image = base64.b64encode(img_bytes).decode() + + # remove markers on the page + page.evaluate(remove_label_mark_script) + + packet.update({ + "raw_image": raw_image, + "marked_image": marked_image, + "modified_html": page.evaluate("document.documentElement.outerHTML") + }) + + # element_info, include "all_elements" and "clickable_elements" + element_info = page.evaluate(element_info_script) + page.wait_for_timeout(100) + packet.update(element_info) + return packet + +def save_debug_info(packet): + with open("debug_info/raw.html", "w") as f: + f.write(packet["modified_html"]) + with open("debug_info/parsed.html", "w") as f: + f.write(packet["html"]) + with open("debug_info/all_element.json", "w") as f: + f.write(json.dumps(packet["all_elements"])) + +def get_parsed_html(page): + if not os.path.exists("debug_info"): + os.makedirs("debug_info") + + print("parsing html...") + + packet = modify_page(page) + raw_html = packet["modified_html"] + + args = { + "use_position": True, + "rect_dict": {}, + "window_size": packet["window"], + "id-attr": "data-backend-node-id", + "label_attr": "data-label-id", + "label_generator": "order", + "regenerate_label": False, + "attr_list": basic_attrs, + "prompt": "xml", + "dataset": "pipeline" + } + + hp = HtmlParser(raw_html, args) + res = hp.parse_tree() + page_html = res.get("html", "") + + packet["html"] = page_html + + # for debug + save_debug_info(packet) + + print("parsing finished.") + + return packet + diff --git a/VAB-WebArena-Lite/new/html_tools/html_parser.py b/VAB-WebArena-Lite/new/html_tools/html_parser.py new file mode 100755 index 0000000..080342f --- /dev/null +++ b/VAB-WebArena-Lite/new/html_tools/html_parser.py @@ -0,0 +1,447 @@ +from lxml import html +import time, copy, random +import json, re, os + +from .identifier import IdentifierTool +from .prompt import HtmlPrompt +from .configs import config_meta +from .utils import get_xpath_top_down, rect2tuple + +class HtmlParser(): + def __init__(self, ctx: str, args: dict[str]={}) -> None: + stt = time.time() + self.dom_tree = self.ctx2tree(ctx) + # tool related + self.bids2label = {} + self.bids2xpath = {} + self.used_labels = {} + + # parse args + self.parse_args(args) + self.init_time = time.time() - stt + + def parse_args(self, args: dict[str]={}) -> None: + def attr_check(attr, type_model='str'): + if attr is None: + return False + attr_type = type(attr) + if attr_type != type(type_model): + return False + if attr_type == type('str') and len(attr) == 0: + return False + return True + + args = {} if args is None else args + + # [Position] use_pos: False -> use full page, otherwise use window_size + dataset = args.get('dataset', '') + use_position = args.get('use_position', False) + window_size = args.get('window_size', None) + rect = args.get('rect_dict', None) + if use_position: + if not attr_check(window_size, ()): + raise ValueError('window_size must be set when use_position is True') + if not attr_check(rect, {}): + raise ValueError('rect_dict must be set when use_position is True') + + if not attr_check(rect, {}): + rect = {} + + # [Label] for vimium is temp_clickable_label, otherwise keep all of it + label_attr = args.get('label_attr', '') + get_new_label = args.get('regenerate_label', False) + label_method = args.get('label_generator', None) + regen_label = not attr_check(label_method) + + # [id] for mind2web is backend_node_id, for normal website use our method + id_attr = args.get('id_attr', '') + regen_id = not attr_check(id_attr) + + if regen_id: + id_attr = 'temp_id' + + # [attributes] + keep_attrs = args.get('attr_list', []) + if not attr_check(keep_attrs, []): + keep_attrs = [] + + # [Tags] for clickable elem, keep: must keep, obs: keep if follow specific rule + parent_chain = args.get('parent_chain', False) + keep_elem = args.get('keep_elem', []) + obs_elem = args.get('obs_elem', []) + + # sanity check + self.set_args(use_position, window_size, rect, label_attr, id_attr, keep_attrs, keep_elem, obs_elem, parent_chain, get_new_label, dataset) + + # [Prompt] + prompt = args.get('prompt', None) + self.prompt = HtmlPrompt(prompt) + + # traverse and get special data + if regen_id or regen_label: + self.mark_id() + + if get_new_label: + self.used_labels = {} + + self.identifier = IdentifierTool(label_method, self.used_labels) + + def set_args(self, use_position: bool=False, window_size: tuple=(), rect_dict: dict[str]={}, label_attr: str='', + id_attr: str='', keep_attrs: list[str]=[], keep_elem: list[str]=[], obs_elem: list[str]=[], + parent_chain: bool=False, get_new_label: bool=False, dataset: str='') -> None: + + self.use_position = use_position + self.window_size = window_size + self.rect = rect_dict + self.label_attr = label_attr + self.id_attr = id_attr + self.keep_attrs = keep_attrs + self.keep = keep_elem + self.obs = obs_elem + self.parent_chain = parent_chain + self.get_new_label = get_new_label + self.dataset = dataset + + def get_config(self): + config = { + 'id_attr': self.id_attr, + 'keep_attrs': self.keep_attrs[:5], + 'label_attr': self.label_attr, + 'use_position': self.use_position, + 'window_size': self.window_size, + 'rect': dict(list(self.rect.items())[:3]), + 'keep_elem': self.keep[:5], + 'obs_elem': self.obs[:5], + 'parent_chain': self.parent_chain, + 'prompt_name': self.prompt.name, + 'identifier_name': self.identifier.name + } + + return config, config_meta.format(**config) + + def update_rect_dict(self, rect_dict: dict[str]={}) -> None: + self.rect = rect_dict + + @staticmethod + def ctx2tree(ctx: str) -> html.HtmlElement: + # remove useless tags, eg. style and script + ctx = re.sub('', '', ctx) + ctx = re.sub('[\W\w]*?', '', ctx) + ctx = re.sub('[\W\w]*?', '', ctx) + ctx = '' if ctx is None else re.sub(r'\s+', ' ', ctx).strip() + dom_tree = html.fromstring(ctx.encode('utf-8')) + match = re.search(' html.HtmlElement: + node = tree.xpath('//*')[0] + while True: + parent = node.getparent() + if parent is None: + break + node = parent + return node + + def get_node_by_bid(self, tree: html.HtmlElement, bid: str) -> html.HtmlElement: + nodes = tree.xpath(f'//*[@{self.id_attr}="{bid}"]') + if len(nodes) == 0: + return None + return nodes[0] + + def id_label_converter(self, label: str) -> str: + return self.bids2label.get(label, '') + + def id_xpath_converter(self, label: str) -> str: + return self.bids2xpath.get(label, '') + + def mark_id(self) -> None: + root = self.get_root(self.dom_tree) + _, i2xpath, used_labels = get_xpath_top_down(root, self.id_attr, self.label_attr) + self.used_labels = used_labels + self.bids2xpath = i2xpath + + def parse(self, root: html.HtmlElement, keep: list[str], obs: list[str], parent_chain: bool=False, get_new_label: bool=False) -> dict[str]: + def get_text(str: str) -> str: + return '' if str is None else str.strip()[:500] + + def check_attr(attr: str, node: html.HtmlElement) -> bool: + tag = node.tag + if ( + ( attr == 'role' and node.attrib.get(attr, '') in ['presentation', 'none', 'link'] ) + or ( attr == 'type' and node.attrib.get(attr, '') == 'hidden' ) + or ( attr == 'aria-hidden' and node.attrib.get(attr, '') == 'true') + # or ( attr == 'value' and tag in ['option'] ) + ): + return False + return True + + def is_visible(node: html.HtmlElement, bid: str) -> bool: + if self.dataset == 'mind2web': + bound = node.attrib.get('bounding_box_rect', None) + self.rect[bid] = rect2tuple(bound) + + if self.dataset == 'pipeline': + bound = node.attrib.get('data-bbox', None) + self.rect[bid] = rect2tuple(bound) + + if node.attrib.get('aria-hidden', 'false') == 'true': + return False + + if not self.use_position: + return True + + rect = self.rect.get(bid, None) + if rect is None: + return False + + if self.window_size is None: + return True + + # get window size + wx, wy, ww, wh = self.window_size + wx, wy = 0, 0 + x, y, w, h = rect + if x + w < wx or x > wx + ww or y + h < wy or y > wy + wh: + return False + + return True + + def _dfs(node: html.HtmlElement, keep: list[str]=[], obs: list[str]=[], + parent_chain: bool=False, get_new_label: bool=False, par_keep: bool=False) -> (str, dict[str]): + # basic information + bid = node.attrib.get(self.id_attr, '') + tag = node.tag + label = node.attrib.get(self.label_attr, '') + + # element which is keeped equivalent to visible + visible = is_visible(node, bid) + in_keep_list = bid in keep + in_obs_list = (bid in obs or len(label) > 0) and visible + par_keep = par_keep and tag == "option" + keep_element = in_keep_list or in_obs_list or visible or par_keep + + if label: + keep_element = True + + # mark label + bids2label, labeled_elems = {}, [] + have_label = False + if in_keep_list or in_obs_list: + if label is None or len(label) == 0 or get_new_label: + label = self.identifier.generate() + node.attrib[self.label_attr] = label + bids2label[bid] = label + bids2label[label] = bid + have_label = True + + # get text or alt_text of current element + text = get_text(node.text) + + classes = {} + # keep attributes if needed + keep_all_attrs = len(self.keep_attrs) == 0 + keep_attrs = node.attrib.keys() if keep_all_attrs else self.keep_attrs + + # traverse attributes + for attr in keep_attrs: + if attr not in node.attrib or not check_attr(attr, node): + continue + if attr in [self.id_attr, self.label_attr]: + continue + val = get_text(node.attrib[attr]) + if len(val) > 0 or keep_all_attrs: + classes[attr] = val + + have_text = len(text) > 0 or len(classes) - (1 if 'data-bbox' in classes else 0) > 0 + par_keep = keep_element and tag == 'select' + + parts = [] + clickable_count = 0 + children = node.getchildren() + for child in children: + cres, cmsg = _dfs(child, keep, obs, parent_chain, get_new_label, par_keep) + clickable_count += 1 if cmsg.get('have_clickable', False) else 0 + bids2label.update(cmsg.get('bids2label', {})) + labeled_elems.extend(cmsg.get('label_element', [])) + if len(cres) != 0: + parts.append(cres) + + dom = self.prompt.subtree_constructor(parts) + + # remove if all children are text + keep_as_all_text = (dom.count('<') == dom.count(' 0 + if keep_as_all_text: + matches = re.findall(r']+) >', dom) + dom = self.prompt.subtree_constructor(matches) + + keep_element = keep_element and (clickable_count > 1 or have_text or have_label or keep_as_all_text) + keep_as_parent = len(dom) > 0 and parent_chain + if in_keep_list or keep_element or keep_as_parent: + dom = self.prompt.prompt_constructor(tag, label, text, dom, classes) + + if have_label: + labeled_elems.append(bid) + + control_msg = { + 'have_clickable': bool(clickable_count or have_text), + 'bids2label': bids2label, + 'label_element': labeled_elems, + } + + return dom, control_msg + + dom, cmsg = _dfs(root, keep, obs, parent_chain, get_new_label) + return dom, cmsg + + def parse_tree(self) -> dict[str]: + # start from here + stt = time.time() + root = self.get_root(self.dom_tree) + dom, cmsg = self.parse(root, self.keep, self.obs, self.parent_chain, self.get_new_label) + self.bids2label = cmsg.get('bids2label', {}) + self.keep = list(set(self.keep + cmsg.get('label_element', []))) + + obj = { + 'html': dom, + 'parse_time': time.time() - stt + } + + return obj + + # From mind2web, https://github.com/OSU-NLP-Group/Mind2Web/blob/main/src/data_utils/dom_utils.py + def get_keep_elements(self, tree: html.HtmlElement, keep: list[str], max_depth: int, max_children: int, + max_sibling: int, dfs_count: int=1, keep_parent: bool=False) -> list[str]: + def get_anscendants(node: html.HtmlElement, max_depth: int, current_depth: int=0) -> list[str]: + if current_depth > max_depth: + return [] + + anscendants = [] + parent = node.getparent() + if parent is not None: + anscendants.append(parent) + anscendants.extend(get_anscendants(parent, max_depth, current_depth + 1)) + + return anscendants + + def get_descendants(node: html.HtmlElement, max_depth: int, current_depth: int=0) -> list[str]: + if current_depth > max_depth: + return [] + + descendants = [] + for child in node: + descendants.append(child) + descendants.extend(get_descendants(child, max_depth, current_depth + 1)) + + return descendants + + to_keep = set(copy.deepcopy(keep)) + nodes_to_keep = set() + + for _ in range(max(1, dfs_count)): + for bid in to_keep: + candidate_node = self.get_node_by_bid(tree, bid) + if candidate_node is None: + continue + + nodes_to_keep.add(candidate_node.attrib[self.id_attr]) + # get all ancestors or with max depth + nodes_to_keep.update([x.attrib.get(self.id_attr, '') for x in get_anscendants(candidate_node, max_depth)]) + + # get descendants with max depth + nodes_to_keep.update([x.attrib.get(self.id_attr, '') for x in get_descendants(candidate_node, max_depth)][:max_children]) + # get siblings within range + parent = candidate_node.getparent() + if parent is None: + continue + + siblings = [x for x in parent.getchildren() if x.tag != 'text'] + if candidate_node not in siblings: + continue + + idx_in_sibling = siblings.index(candidate_node) + nodes_to_keep.update([x.attrib.get(self.id_attr, '') + for x in siblings[max(0, idx_in_sibling - max_sibling) : idx_in_sibling + max_sibling + 1]]) + + max_children = int(max_children * 0.5) + max_depth = int(max_depth * 0.5) + max_sibling = int(max_sibling * 0.7) + + to_keep = copy.deepcopy(nodes_to_keep) + + if keep_parent: + for bid in keep: + candidate_node = self.get_node_by_bid(tree, bid) + if candidate_node is None: + continue + nodes_to_keep.update([x.attrib.get(self.id_attr, '') for x in candidate_node.xpath("ancestor::*")]) + + return list(nodes_to_keep) + + def prune(self, tree: html.HtmlElement, nodes_to_keep: list[str]) -> html.HtmlElement: + # remove nodes not in nodes_to_keep + for node in tree.xpath('//*')[::-1]: + if node.tag != 'text': + is_keep = node.attrib.get(self.id_attr, '') in nodes_to_keep + is_candidate = node.attrib.get(self.id_attr, '') in self.keep + else: + is_keep = (node.getparent().attrib.get(self.id_attr, '') in nodes_to_keep) + is_candidate = (node.getparent().attrib.get(self.id_attr, '') in self.keep) + + if not is_keep and node.getparent() is not None: + # insert all children into parent + for child in node.getchildren(): + node.addprevious(child) + node.getparent().remove(node) + else: + # if not is_candidate or node.tag == 'text': + # node.attrib.pop(self.id_attr, None) + if ( + len(node.attrib) == 0 + and not any([x.tag == 'text' for x in node.getchildren()]) + and node.getparent() is not None + and node.tag != "text" + and len(node.getchildren()) <= 1 + ): + # insert all children into parent + for child in node.getchildren(): + node.addprevious(child) + node.getparent().remove(node) + + return tree + + def prune_tree(self, dfs_count: int=1, max_depth: int=3, max_children: int=30, + max_sibling: int=3, keep_parent: bool=False) -> None: + # clone the tree + new_tree = copy.deepcopy(self.dom_tree) + nodes_to_keep = self.get_keep_elements(new_tree, self.keep, max_depth, max_children, max_sibling, dfs_count, keep_parent) + new_tree = self.prune(new_tree, nodes_to_keep) + + self.dom_tree = new_tree + + def get_segment(self, bid: str) -> str: + # clone the tree + new_tree = copy.deepcopy(self.dom_tree) + nodes_to_keep = self.get_keep_elements(new_tree, [bid], 0, 2, 1) + new_tree = self.prune(new_tree, nodes_to_keep) + dom, _ = self.parse(new_tree, self.keep, [], False) + return dom + + def get_rect_data(self, bids: list[str]) -> list[dict[str]]: + res = [] + for bid in bids: + label = self.bids2label.get(bid, '') + rect = self.rect.get(bid, None) + res.append({ + 'bid': bid, + 'label': label, + 'rect': rect + }) + return res + \ No newline at end of file diff --git a/VAB-WebArena-Lite/new/html_tools/identifier.py b/VAB-WebArena-Lite/new/html_tools/identifier.py new file mode 100755 index 0000000..793ebed --- /dev/null +++ b/VAB-WebArena-Lite/new/html_tools/identifier.py @@ -0,0 +1,64 @@ +import secrets + +class IdentifierTool: + def __init__(self, method: str='order', existing_labels: dict[str]={}) -> None: + self.methods = { + 'order': self.get_identifier_in_order, + 'random': self.get_random_identifier, + } + + if method is None: + method = 'order' + + self.func = self.methods.get(method, None) + self.name = method + if self.func is None: + raise ValueError(f'Invalid method for identifier: {method}') + + self.reset(existing_labels) + + def reset(self, exists: dict[str]={}) -> None: + self.identifier = -1 + self.exists = {} if exists is None else exists + + def get_identifier_in_order(self) -> str: + def id2str(id: int) -> str: + if id < 26: + return chr(id + 65) + id -= 26 + c0 = id // 676 + c1 = (id // 26) % 26 + c2 = id % 26 + label = f'{chr(c1 + 65)}{chr(c2 + 65)}' + return label if c0 == 0 else f'{chr(c0 + 64)}{label}' + + self.identifier += 1 + label = id2str(self.identifier) + + while label in self.exists: + self.identifier += 1 + label = id2str(self.identifier) + + self.exists[label] = True + return label + + def get_random_identifier(self) -> str: + secret_generator = secrets.SystemRandom() + + def get_random_label(n: int=2) -> str: + tmp = '' + for _ in range(n): + tmp += chr(secret_generator.randint(65, 90)) + return tmp + + wc = 3 if len(self.exists) > 280 else 2 + + label = get_random_label(wc) + while label in self.exists: + label = get_random_label(wc) + + self.exists[label] = True + return label + + def generate(self): + return self.func() \ No newline at end of file diff --git a/VAB-WebArena-Lite/new/html_tools/prompt.py b/VAB-WebArena-Lite/new/html_tools/prompt.py new file mode 100755 index 0000000..38d6b94 --- /dev/null +++ b/VAB-WebArena-Lite/new/html_tools/prompt.py @@ -0,0 +1,97 @@ +from .configs import prompts + +class HtmlPrompt: + def __init__(self, prompt: str='') -> None: + prompt = self.extract(prompt, 'xml') + if prompt not in prompts: + raise Exception('Unknown prompt: ' + prompt) + + constructors = { + 'refine': self.normal_prompt_constructor, + 'xml': self.normal_prompt_constructor, + 'new_data': self.new_data_prompt_constructor, + } + + self.name = prompt + self.prompt = prompts[prompt] + self.constructor = constructors[prompt] + + @staticmethod + def extract(data, default=''): + return data if data is not None else default + + def subtree_constructor(self, subtree: list[str]=[]) -> str: + return self.prompt['subtree_splitter'].join(subtree) + + def normal_prompt_constructor(self, tag: str='', label: str='', content: str='', subtree_str: str='', class_dict: dict[str]={}) -> str: + def add_prefix(data, prefix): + return prefix + data if len(data) > 0 else '' + + tag = self.extract(tag) + label = self.extract(label) + content = self.extract(content) + subtree_str = self.extract(subtree_str, '') + class_dict = self.extract(class_dict, {}) + + label_str = '' + if len(label) > 0: + label_str = self.prompt['label'].format(label=label) + + classes = [] + values = set() + for key, val in class_dict.items(): + if val in values: + continue + values.add(val) + classes.append(self.prompt['attr'].format(key=key, attr=val)) + classes_str = self.prompt['attr_splitter'].join(classes) + + content_splitter = ' ' if len(classes_str) == 0 else self.prompt['attr_splitter'] + classes_str = add_prefix(classes_str, ' ') + content_str = add_prefix(content, content_splitter) + subtree_str = add_prefix(subtree_str, ' ') + + return self.prompt['dom'].format(tag=tag, label=label_str, attr=classes_str, content=content_str, subtree=subtree_str) + + def new_data_prompt_constructor(self, tag: str='', label: str='', content: str='', subtree_str: str='', class_dict: dict[str]={}) -> str: + def add_prefix(data, prefix): + return prefix + data if len(data) > 0 else '' + + tag = self.extract(tag) + label = self.extract(label) + content = self.extract(content) + subtree_str = self.extract(subtree_str, '') + class_dict = self.extract(class_dict, {}) + + label_str = '' + if len(label) > 0: + label_str = self.prompt['label'].format(label=label) + + classes = [] + values = set() + + message = [] + for key, val in class_dict.items(): + if val == '': + message.append(key) + continue + if val in values: + continue + values.add(val) + classes.append(self.prompt['attr'].format(key=key, attr=val)) + + if len(message) > 0: + message_str = ' '.join(message) + classes.append(self.prompt['attr'].format(key='message', attr=message_str)) + + classes_str = self.prompt['attr_splitter'].join(classes) + + content_splitter = ' ' if len(classes_str) == 0 else self.prompt['attr_splitter'] + classes_str = add_prefix(classes_str, ' ') + content_str = add_prefix(content, content_splitter) + subtree_str = add_prefix(subtree_str, ' ') + + return self.prompt['dom'].format(tag=tag, label=label_str, attr=classes_str, content=content_str, subtree=subtree_str) + + def prompt_constructor(self, tag: str='', label: str='', content: str='', subtree_str: str='', class_dict: dict[str]={}) -> str: + return self.constructor(tag, label, content, subtree_str, class_dict) \ No newline at end of file diff --git a/VAB-WebArena-Lite/new/html_tools/scripts/__init__.py b/VAB-WebArena-Lite/new/html_tools/scripts/__init__.py new file mode 100755 index 0000000..0d36c9a --- /dev/null +++ b/VAB-WebArena-Lite/new/html_tools/scripts/__init__.py @@ -0,0 +1,43 @@ +import os +from pathlib import Path +rootdir = Path(__file__).parent + +with open(os.path.join(rootdir,'prepare.js'), 'r') as f: + prepare_script = f.read() + +with open(os.path.join(rootdir, 'clickable_checker.js'), 'r') as f: + clickable_checker_script = f.read() + +with open(os.path.join(rootdir, 'label.js'), 'r') as f: + label_script = f.read() + +with open(os.path.join(rootdir, 'element_info.js'), 'r') as f: + element_info_script = f.read() + +# draw label on page +with open(os.path.join(rootdir, 'label_marker.js'), 'r') as f: + label_marker_script = f.read() + +# remove label draw on page +remove_label_mark_script = """ + () => { + document.querySelectorAll(".our-dom-marker").forEach(item => { + document.body.removeChild(item); + }); + } +""" + +remove_id_script = """ + () => { + Array.from(document.getElementsByClassName('possible-clickable-element')).forEach((element) => { + element.classList.remove('possible-clickable-element'); + element.removeAttribute('data-value'); + element.removeAttribute('data-text'); + element.removeAttribute('data-label'); + element.removeAttribute('data-bbox'); + element.removeAttribute('data-status'); + element.removeAttribute('data-backend-node-id'); + element.removeAttribute('data-label-id'); + }); + } +""" diff --git a/VAB-WebArena-Lite/new/html_tools/scripts/clickable_checker.js b/VAB-WebArena-Lite/new/html_tools/scripts/clickable_checker.js new file mode 100755 index 0000000..42267f0 --- /dev/null +++ b/VAB-WebArena-Lite/new/html_tools/scripts/clickable_checker.js @@ -0,0 +1,148 @@ +() => { + var items = Array.prototype.slice.call( + document.querySelectorAll('*') + ).map(function(element) { + var vw = Math.max(document.documentElement.clientWidth || 0, window.innerWidth || 0); + var vh = Math.max(document.documentElement.clientHeight || 0, window.innerHeight || 0); + + var rects = [...element.getClientRects()].filter(bb => { + var center_x = bb.left + bb.width / 2; + var center_y = bb.top + bb.height / 2; + var elAtCenter = document.elementFromPoint(center_x, center_y); + + if (!elAtCenter) return false; + return elAtCenter === element || element.contains(elAtCenter) + }).map(bb => { + const rect = { + left: Math.max(0, bb.left), + top: Math.max(0, bb.top), + right: Math.min(vw, bb.right), + bottom: Math.min(vh, bb.bottom) + }; + return { + ...rect, + width: rect.right - rect.left, + height: rect.bottom - rect.top + } + }); + // var rects = []; + var area = rects.reduce((acc, rect) => acc + rect.width * rect.height, 0); + + const tagName = element.tagName.toLowerCase?.() || ""; + let isClickable = ((element.onclick != null) || window.getComputedStyle(element).cursor == "pointer"); + + // Insert area elements that provide click functionality to an img. + if (tagName === "img") { + let mapName = element.getAttribute("usemap"); + if (mapName) { + const imgClientRects = element.getClientRects(); + mapName = mapName.replace(/^#/, "").replace('"', '\\"'); + const map = document.querySelector(`map[name=\"${mapName}\"]`); + if (map && (imgClientRects.length > 0)) isClickable = true; + } + } + + if (!isClickable) { + const role = element.getAttribute("role"); + const clickableRoles = [ + "button", + "tab", + "link", + "checkbox", + "menuitem", + "menuitemcheckbox", + "menuitemradio", + "radio", + ]; + if (role != null && clickableRoles.includes(role.toLowerCase())) { + isClickable = true; + } else { + const contentEditable = element.getAttribute("contentEditable"); + if ( + contentEditable != null && + ["", "contenteditable", "true"].includes(contentEditable.toLowerCase()) + ) { + isClickable = true; + } + } + } + + // Check for jsaction event listeners on the element. + if (!isClickable && element.hasAttribute("jsaction")) { + const jsactionRules = element.getAttribute("jsaction").split(";"); + for (let jsactionRule of jsactionRules) { + const ruleSplit = jsactionRule.trim().split(":"); + if ((ruleSplit.length >= 1) && (ruleSplit.length <= 2)) { + const [eventType, namespace, actionName] = ruleSplit.length === 1 + ? ["click", ...ruleSplit[0].trim().split("."), "_"] + : [ruleSplit[0], ...ruleSplit[1].trim().split("."), "_"]; + if (!isClickable) { + isClickable = (eventType === "click") && (namespace !== "none") && (actionName !== "_"); + } + } + } + } + + if (!isClickable) { + const clickableTags = [ + "input", + "textarea", + "select", + "button", + "a", + "iframe", + "video", + "object", + "embed", + "details" + ]; + isClickable = clickableTags.includes(tagName); + } + + if (!isClickable) { + if (tagName === "label") + isClickable = (element.control != null) && !element.control.disabled; + else if (tagName === "img") + isClickable = ["zoom-in", "zoom-out"].includes(element.style.cursor); + } + + // An element with a class name containing the text "button" might be clickable. However, real + // clickables are often wrapped in elements with such class names. So, when we find clickables + // based only on their class name, we mark them as unreliable. + const className = element.getAttribute("class"); + if (!isClickable && className && className.toLowerCase().includes("button")) { + isClickable = true; + } + + // Elements with tabindex are sometimes useful, but usually not. We can treat them as second + // class citizens when it improves UX, so take special note of them. + const tabIndexValue = element.getAttribute("tabindex"); + const tabIndex = tabIndexValue ? parseInt(tabIndexValue) : -1; + if (!isClickable && !(tabIndex < 0) && !isNaN(tabIndex)) { + isClickable = true; + } + + const idValue = element.getAttribute("id"); + const id = idValue ? idValue.toLowerCase() : ""; + if (isClickable && area == 0) { + const textValue = element.textContent.trim().replace(/\s{2,}/g, ' '); + clickable_msg = `${tagName}[id=${id}] ${isClickable} (${area}) ${textValue}` + } + + return { + element: element, + include: isClickable, + area, + rects, + text: element.textContent.trim().replace(/\s{2,}/g, ' ') + }; + }).filter(item => + item.include && (item.area >= 1) + ); + + items = items.filter(x => !items.some(y => x.element.contains(y.element) && !(x == y))) + + items.forEach(item => { + item.element.classList.add('possible-clickable-element'); + }); +} \ No newline at end of file diff --git a/VAB-WebArena-Lite/new/html_tools/scripts/element_info.js b/VAB-WebArena-Lite/new/html_tools/scripts/element_info.js new file mode 100755 index 0000000..c396220 --- /dev/null +++ b/VAB-WebArena-Lite/new/html_tools/scripts/element_info.js @@ -0,0 +1,34 @@ +() => { + function getElementInfo(element) { + return { + "bid": element.getAttribute("data-backend-node-id") || "", + "label": element.getAttribute("data-label-id") || "", + "tag": element.tagName.toLowerCase?.() || "", + "area": JSON.parse("[" + (element.getAttribute("data-bbox") || "") + "]"), + "text": element.innerText?.trim().replace(/\s{2,}/g, " ") || "", + "id": element.getAttribute("id") || "", + "role": element.getAttribute("role") || "", + "aria-label": element.getAttribute("aria-label") || "", + "href": element.getAttribute("href") || "", + }; + } + + var all_items = Array.prototype.slice.call( + document.querySelectorAll("*") + ).map((element) => { + return getElementInfo(element); + }); + + var clickable_items = Array.prototype.slice.call( + document.querySelectorAll("*") + ).filter( + element => element.getAttribute("data-label-id") + ).map((element) => { + return getElementInfo(element); + }); + + return { + all_elements: all_items, + clickable_elements: clickable_items + }; +} \ No newline at end of file diff --git a/VAB-WebArena-Lite/new/html_tools/scripts/label.js b/VAB-WebArena-Lite/new/html_tools/scripts/label.js new file mode 100755 index 0000000..285ca89 --- /dev/null +++ b/VAB-WebArena-Lite/new/html_tools/scripts/label.js @@ -0,0 +1,74 @@ +(packet) => { + function int2str(index) { + var str = ""; + while (index >= 0) { + str = String.fromCharCode(65 + index % 26) + str; + index = Math.floor(index / 26) - 1; + } + return str; + }; + + selector = packet.selector + index = packet.startIndex + var items = Array.prototype.slice.call( + document.querySelectorAll(selector) + ); + + var vw = Math.max(document.documentElement.clientWidth || 0, window.innerWidth || 0); + var vh = Math.max(document.documentElement.clientHeight || 0, window.innerHeight || 0); + + items = items.filter( + x => !items.some(y => x.contains(y) && !(x == y)) + ).map(element => { + var bb = element.getClientRects(); + var rect = { + left: 0, + top: 0, + right: 0, + bottom: 0, + width: 0, + height: 0 + }; + var keep = false; + var text = "", id = -1; + if (bb.length > 0) { + bb = bb[0]; + rect = { + left: Math.max(0, bb.left), + top: Math.max(0, bb.top), + right: Math.min(vw, bb.right), + bottom: Math.min(vh, bb.bottom) + }; + rect = { + ...rect, + width: rect.right - rect.left, + height: rect.bottom - rect.top + }; + if (rect.width > 0 || rect.height > 0) { + keep = true; + if (index >= 0) { + // id = int2str(index++); + id = index++; + element.setAttribute("data-label-id", id); + } + var childNodes = element.childNodes; + + for (var i = 0; i < childNodes.length; i++) { + if (childNodes[i].nodeType == Node.TEXT_NODE) { + text += childNodes[i].textContent; + } + } + } + } + + return { + keep: true, + id, + rects: rect, + tag: element.tagName.toLowerCase?.() || "", + text,//: element.innerText?.trim().replace(/\s{2,}/g, " ") || "" + }; + }).filter(x => x.keep); + + return [items, index]; +} \ No newline at end of file diff --git a/VAB-WebArena-Lite/new/html_tools/scripts/label_marker.js b/VAB-WebArena-Lite/new/html_tools/scripts/label_marker.js new file mode 100755 index 0000000..6c55af5 --- /dev/null +++ b/VAB-WebArena-Lite/new/html_tools/scripts/label_marker.js @@ -0,0 +1,65 @@ +(items) => { + function getRandomColor() { + var letters = '0123456789ABCDEF'; + var color = '#'; + for (var i = 0; i < 6; i++) { + color += letters[Math.floor(Math.random() * 16)]; + } + return color; + } + + items.filter( + item => item.id != "" + ).forEach((item) => { + const bbox = item.rects; + const id_string = `dom-marker-id-${index}`; + + index = item.id; + + outerElement = document.createElement("div"); + outerElement.classList.add("our-dom-marker"); + // var borderColor = getRandomColor(); + var borderColor = "#FFFF00"; + outerElement.style.outline = `2px dashed ${borderColor}`; + outerElement.style.position = "fixed"; + outerElement.style.left = bbox.left - 2 + "px"; + outerElement.style.top = bbox.top - 2 + "px"; + outerElement.style.width = bbox.width + 4 + "px"; + outerElement.style.height = bbox.height + 4 + "px"; + outerElement.style.pointerEvents = "none"; + outerElement.style.boxSizing = "border-box"; + outerElement.style.zIndex = 2147483647; + + innerElement = document.createElement("div"); + innerElement.classList.add("our-dom-marker"); + innerElement.style.outline = `2px dashed #222288`; + innerElement.style.position = "fixed"; + innerElement.style.left = bbox.left + "px"; + innerElement.style.top = bbox.top + "px"; + innerElement.style.width = bbox.width + "px"; + innerElement.style.height = bbox.height + "px"; + innerElement.style.pointerEvents = "none"; + innerElement.style.boxSizing = "border-box"; + innerElement.style.zIndex = 2147483647; + + // Add floating label at the corner + var label = document.createElement("span"); + var topPosition = 25; + if (bbox.top < 25) topPosition = bbox.top; + label.textContent = index; + label.style.position = "absolute"; + label.style.top = `-${topPosition}px`; + label.style.left = "0px"; + label.style.background = borderColor; + label.style.color = "black"; + label.style.padding = "2px 4px"; + label.style.fontSize = "16px"; + label.style.borderRadius = "2px"; + label.style.fontWeight = "bold"; + outerElement.appendChild(label); + + document.body.appendChild(outerElement); + document.body.appendChild(innerElement); + }) + return items; +} \ No newline at end of file diff --git a/VAB-WebArena-Lite/new/html_tools/scripts/prepare.js b/VAB-WebArena-Lite/new/html_tools/scripts/prepare.js new file mode 100755 index 0000000..9e76a62 --- /dev/null +++ b/VAB-WebArena-Lite/new/html_tools/scripts/prepare.js @@ -0,0 +1,83 @@ +() => { + // mark backend node id + var vw = Math.max(document.documentElement.clientWidth || 0, window.innerWidth || 0); + var vh = Math.max(document.documentElement.clientHeight || 0, window.innerHeight || 0); + + var backendId = 0; + Array.prototype.slice.call( + document.querySelectorAll("*") + ).forEach((element) => { + element.setAttribute("data-backend-node-id", backendId); + backendId++; + + var tag = element.tagName.toLowerCase?.() || ""; + var bb = element.getClientRects(); + var rect = { + left: 0, + top: 0, + right: 0, + bottom: 0, + width: 0, + height: 0 + }; + + if (bb.length > 0) { + bb = bb[0]; + // rect = { + // left: Math.round(Math.max(0, bb.left) * 100) / 100, + // top: Math.round(Math.max(0, bb.top) * 100) / 100, + // right: Math.round(Math.min(vw, bb.right) * 100) / 100, + // bottom: Math.round(Math.min(vh, bb.bottom) * 100) / 100 + // }; + rect = { + left: (Math.round(bb.left) * 100) / 100, + top: (Math.round(bb.top) * 100) / 100, + right: (Math.round(bb.right) * 100) / 100, + bottom: (Math.round(bb.bottom) * 100) / 100 + }; + rect = { + ...rect, + width: Math.round((rect.right - rect.left) * 100) / 100, + height: Math.round((rect.bottom - rect.top) * 100) / 100 + }; + + element.setAttribute("data-bbox", `${rect.left},${rect.top},${rect.width},${rect.height}`); + } + + if (element.hasChildNodes()) { + let children = Array.prototype.slice.call(element.childNodes); + var texts = children.filter( + (node) => node.nodeType == Node.TEXT_NODE + ).map( + (node) => node.textContent.trim().replace(/\s{2,}/g, " ") || "" + ).filter( + (text) => text.length > 0 + ) + element.setAttribute("data-text", texts.join(",")); + } + + // fix select issue + if (tag == "select") { + var value = element.value; + var text = element.options[element.selectedIndex]?.text || ""; + element.setAttribute("data-value", value); + element.setAttribute("data-text", text); + element.options[element.selectedIndex]?.setAttribute("data-status", "selected"); + } + + if (tag == "input") { + var input_type = element.getAttribute("type") || ""; + if (input_type == "checkbox") { + var status = element.checked? "checked" : "not-checked"; + element.setAttribute("data-status", status); + } + } + }); + + // fix input and textarea issue + Array.prototype.slice.call( + document.querySelectorAll("input, textarea") + ).forEach(element => { + element.setAttribute("data-value", element.value); + }); +} \ No newline at end of file diff --git a/VAB-WebArena-Lite/new/html_tools/utils.py b/VAB-WebArena-Lite/new/html_tools/utils.py new file mode 100755 index 0000000..9e0be65 --- /dev/null +++ b/VAB-WebArena-Lite/new/html_tools/utils.py @@ -0,0 +1,101 @@ +from lxml import html +def get_xpath_top_down(element: html.HtmlElement, id_column: str='temp_id', label_column: str='temp_clickable_label', path: str='', order: int=0, + in_svg: bool=False, temp_id: int=0) -> tuple[int, dict[str, str], dict[str]]: + used_labels, i2xpath = {}, {} + # path + tag = element.tag.lower() + in_svg = in_svg or (tag == 'svg') + + if not in_svg and 'id' in element.attrib: + node_id = element.attrib['id'] + path = f'//*[@id="{node_id}"]' + else: + suffix = f'[{order}]' if order > 0 else '' + prefix = f'*[name()="{tag}"]' if in_svg else tag + path = path + '/' + prefix + suffix + + # add temp id + element.attrib[id_column] = str(temp_id) + ori_label = element.attrib.get(label_column, '') + if ori_label != '': + used_labels[ori_label] = True + + bid = str(temp_id) + i2xpath[bid] = path + i2xpath[path] = bid + i2xpath[f'xpath/{path}'] = bid + i2xpath[f'xpath=/{path}'] = bid + + temp_id += 1 + + # traverse node + children = element.getchildren() + tag_dict = {} + id_list = [] + for child in children: + ctag = child.tag.lower() + if ctag not in tag_dict: + tag_dict[ctag] = 0 + tag_dict[ctag] += 1 + id_list.append(tag_dict[ctag]) + + for cid, child in zip(id_list, children): + ctag = child.tag.lower() + cod = cid if tag_dict[ctag] > 1 else 0 + temp_id, i2x, ulabels = get_xpath_top_down(child, id_column, label_column, path, cod, in_svg, temp_id) + i2xpath.update(i2x) + used_labels.update(ulabels) + + return temp_id, i2xpath, used_labels + +def print_html_object(obj: str='') -> str: + tab_cnt = 0 + result, content, sep = '', '', '' + last_is_left, last_is_right = False, False + for ch in obj: + if ch == '<': + result += '\n' + if len(content.strip()) > 0: + result += sep + content.strip() + '\n' + result += sep + '<' + + tab_cnt += 1 + sep = ' ' * tab_cnt + + content = '' + last_is_right = False + last_is_left = True + elif ch == '>': + if last_is_left: + result += content + else: + if last_is_right: + result += '\n' + if len(content.strip()) > 0: + result += sep + content.strip() + '\n' + + tab_cnt -= 1 + sep = ' ' * tab_cnt + + if not last_is_left: + result += sep + + result += '>' + content = '' + + last_is_right = True + last_is_left = False + else: + content += ch + + return result + +def rect2tuple(rect: str) -> tuple[int, int, int, int]: + if rect is None or type(rect) != type('str'): + return None + rect = rect.strip() + if rect.count(',') != 3: + return None + rect = rect.split(',') + rect = [float(r) for r in rect] + return tuple(rect) diff --git a/VAB-WebArena-Lite/new/openai_utils.py b/VAB-WebArena-Lite/new/openai_utils.py index 8374c7a..c9754f2 100644 --- a/VAB-WebArena-Lite/new/openai_utils.py +++ b/VAB-WebArena-Lite/new/openai_utils.py @@ -36,7 +36,6 @@ def retry_with_exponential_backoff( # type: ignore # Initialize variables num_retries = 0 delay = initial_delay - # Loop until a successful response or max_retries is hit or an exception is raised while True: try: @@ -142,27 +141,32 @@ async def agenerate_from_openai_completion( @retry_with_exponential_backoff def generate_from_openai_completion( prompt: str, - engine: str, + model: str, temperature: float, max_tokens: int, top_p: float, - context_length: int, stop_token: str | None = None, + api_key: str | None = None, + base_url: str | None = None ) -> str: - if "OPENAI_API_KEY" not in os.environ: + if "OPENAI_API_KEY" not in os.environ and api_key is None: raise ValueError( "OPENAI_API_KEY environment variable must be set when using OpenAI API." ) - + if api_key is not None: + client = OpenAI(api_key=api_key, base_url=base_url) response = client.completions.create( prompt=prompt, - engine=engine, + model=model, temperature=temperature, max_tokens=max_tokens, top_p=top_p, stop=[stop_token], ) - answer: str = response["choices"][0]["text"] + try: + answer: str = response["choices"][0]["text"] + except: + answer: str = response.choices[0].text return answer diff --git a/VAB-WebArena-Lite/new/p_webrl.json b/VAB-WebArena-Lite/new/p_webrl.json new file mode 100644 index 0000000..c172455 --- /dev/null +++ b/VAB-WebArena-Lite/new/p_webrl.json @@ -0,0 +1,13 @@ +{ + "intro": "", + "examples": [], + "template": "", + "meta_data": { + "observation": "webrl", + "action_type": "webrl_id", + "keywords": [], + "prompt_constructor": "WebRLPromptConstructor", + "answer_phrase": "", + "action_splitter": "" + } + } \ No newline at end of file diff --git a/VAB-WebArena-Lite/new/processors.py b/VAB-WebArena-Lite/new/processors.py new file mode 100644 index 0000000..d3f422d --- /dev/null +++ b/VAB-WebArena-Lite/new/processors.py @@ -0,0 +1,1351 @@ +import json +import pkgutil +import re +from collections import defaultdict +from dataclasses import dataclass +from io import BytesIO, StringIO +from typing import Any, Optional, TypedDict, Union +from urllib.parse import urljoin, urlparse + +import matplotlib.pyplot as plt +import numpy as np +import numpy.typing as npt +import pandas as pd +import playwright +import requests +from gymnasium import spaces +from PIL import Image, ImageDraw, ImageFont +from playwright.sync_api import CDPSession, Page, ViewportSize +from .html_tools.fetch import get_parsed_html + +from browser_env.constants import ( + ASCII_CHARSET, + FREQ_UNICODE_CHARSET, + IGNORED_ACTREE_PROPERTIES, + INJECTED_ATTR_NAME, + UTTERANCE_MAX_LENGTH, + BID_ATTR, + DATA_REGEXP, + IN_VIEWPORT_RATIO_THRESHOLD, +) + +from .utils import ( + AccessibilityTree, + AccessibilityTreeNode, + BrowserConfig, + BrowserInfo, + DOMNode, + DOMTree, + Observation, + png_bytes_to_numpy, +) + + +def remove_unicode(input_string): + # Define a regex pattern to match Unicode characters + unicode_pattern = re.compile(r"[^\x00-\x7F]+") + + # Use the pattern to replace Unicode characters with an empty string + cleaned_string = unicode_pattern.sub("", input_string) + + return cleaned_string + + +class ObservationProcessor: + def process(self, page: Page) -> Observation: + raise NotImplementedError + + +class ObservationMetadata(TypedDict): + obs_nodes_info: dict[str, Any] + + +def create_empty_metadata() -> ObservationMetadata: + return { + "obs_nodes_info": {}, + } + + +def extract_data_items_from_aria(string: str) -> tuple[list[str], str]: + """ + Utility function to extract temporary data stored in the "aria-roledescription" attribute of a node + """ + + match = DATA_REGEXP.fullmatch(string) + if not match: + return [], string + + groups = match.groups() + data_items = groups[:-1] + original_aria = groups[-1] + return data_items, original_aria + + +class TextObervationProcessor(ObservationProcessor): + def __init__( + self, + observation_type: str, + current_viewport_only: bool, + viewport_size: ViewportSize, + captioning_fn=None, + ): + self.observation_type = observation_type + self.current_viewport_only = current_viewport_only + self.viewport_size = viewport_size + self.observation_tag = "text" + self.meta_data = ( + create_empty_metadata() + ) # use the store meta data of this observation type + + if self.observation_type in [ + "accessibility_tree_with_captioner", + "image_som", + ]: + self.captioning_fn = captioning_fn + # Cache captions. + self.url2caption = {} + + def fetch_browser_info( + self, + page: Page, + ) -> BrowserInfo: + # extract domtree + client = page.context.new_cdp_session(page) + tree = client.send( + "DOMSnapshot.captureSnapshot", + { + "computedStyles": [], + "includeDOMRects": True, + "includePaintOrder": True, + }, + ) + client.detach() + + # calibrate the bounds, in some cases, the bounds are scaled somehow + bounds = tree["documents"][0]["layout"]["bounds"] + b = bounds[0] + n = b[2] / self.viewport_size["width"] + bounds = [[x / n for x in bound] for bound in bounds] + tree["documents"][0]["layout"]["bounds"] = bounds + # add union bound placeholder + tree["documents"][0]["layout"]["unionBounds"] = [None for _ in bounds] + + # extract browser info + win_upper_bound = page.evaluate("window.pageYOffset") + win_left_bound = page.evaluate("window.pageXOffset") + win_width = page.evaluate("window.screen.width") + win_height = page.evaluate("window.screen.height") + win_right_bound = win_left_bound + win_width + win_lower_bound = win_upper_bound + win_height + device_pixel_ratio = page.evaluate("window.devicePixelRatio") + assert device_pixel_ratio == 1.0, "devicePixelRatio is not 1.0" + + config: BrowserConfig = { + "win_upper_bound": win_upper_bound, + "win_left_bound": win_left_bound, + "win_width": win_width, + "win_height": win_height, + "win_right_bound": win_right_bound, + "win_lower_bound": win_lower_bound, + "device_pixel_ratio": device_pixel_ratio, + } + + # assert len(tree['documents']) == 1, "More than one document in the DOM tree" + info: BrowserInfo = {"DOMTree": tree, "config": config} + + return info + + @staticmethod + def get_bounding_client_rect( + client: CDPSession, backend_node_id: str + ) -> dict[str, Any]: + try: + remote_object = client.send( + "DOM.resolveNode", {"backendNodeId": int(backend_node_id)} + ) + remote_object_id = remote_object["object"]["objectId"] + response = client.send( + "Runtime.callFunctionOn", + { + "objectId": remote_object_id, + "functionDeclaration": """ + function() { + if (this.nodeType == 3) { + var range = document.createRange(); + range.selectNode(this); + var rect = range.getBoundingClientRect().toJSON(); + range.detach(); + return rect; + } else { + return this.getBoundingClientRect().toJSON(); + } + } + """, + "returnByValue": True, + }, + ) + return response + except Exception as e: + return {"result": {"subtype": "error"}} + + @staticmethod + def get_element_in_viewport_ratio( + elem_left_bound: float, + elem_top_bound: float, + width: float, + height: float, + config: BrowserConfig, + ) -> float: + elem_right_bound = elem_left_bound + width + elem_lower_bound = elem_top_bound + height + + win_left_bound = 0 + win_right_bound = config["win_width"] + win_top_bound = 0 + win_lower_bound = config["win_height"] + + # Compute the overlap in x and y axes + overlap_width = max( + 0, + min(elem_right_bound, win_right_bound) + - max(elem_left_bound, win_left_bound), + ) + overlap_height = max( + 0, + min(elem_lower_bound, win_lower_bound) - max(elem_top_bound, win_top_bound), + ) + + # Compute the overlap area + ratio = overlap_width * overlap_height / width * height + return ratio + + def fetch_page_html( + self, + info: BrowserInfo, + page: Page, + current_viewport_only: bool, + ) -> DOMTree: + # adopted from [natbot](https://github.com/nat/natbot) + tree = info["DOMTree"] + strings = tree["strings"] + document = tree["documents"][0] + nodes = document["nodes"] + + # make a dom tree that is easier to navigate + dom_tree: DOMTree = [] + graph = defaultdict(list) + client = page.context.new_cdp_session(page) + for node_idx in range(len(nodes["nodeName"])): + cur_node: DOMNode = { + "nodeId": "", + "nodeType": "", + "nodeName": "", + "nodeValue": "", + "attributes": "", + "backendNodeId": "", + "parentId": "", + "childIds": [], + "cursor": 0, + "union_bound": None, + } + + node_type_idx = nodes["nodeType"][node_idx] + node_type = "generic" + if node_type_idx >= 0 and node_type_idx < len(strings): + node_type = strings[node_type_idx] + + node_name = strings[nodes["nodeName"][node_idx]] + + node_value_idx = nodes["nodeValue"][node_idx] + node_value = "" + if node_value_idx >= 0 and node_value_idx < len(strings): + node_value = " ".join(strings[node_value_idx].split()) + + node_attributes = [strings[i] for i in nodes["attributes"][node_idx]] + node_attributes_str = "" + for i in range(0, len(node_attributes), 2): + a = node_attributes[i] + b = node_attributes[i + 1] + b = " ".join(b.split()) + node_attributes_str += f'{a}="{b}" ' + node_attributes_str = node_attributes_str.strip() + + cur_node["nodeId"] = str(node_idx) + cur_node["nodeType"] = node_type + cur_node["nodeName"] = node_name + cur_node["nodeValue"] = node_value + cur_node["attributes"] = node_attributes_str + cur_node["backendNodeId"] = str(nodes["backendNodeId"][node_idx]) + cur_node["parentId"] = str(nodes["parentIndex"][node_idx]) + + if cur_node["parentId"] != "-1": + graph[cur_node["parentId"]].append(str(cur_node["nodeId"])) + + # get the bound + if cur_node["parentId"] == "-1": + cur_node["union_bound"] = [0.0, 0.0, 10.0, 10.0] + else: + response = self.get_bounding_client_rect( + client, cur_node["backendNodeId"] + ) + if response.get("result", {}).get("subtype", "") == "error": + cur_node["union_bound"] = None + else: + x = response["result"]["value"]["x"] + y = response["result"]["value"]["y"] + width = response["result"]["value"]["width"] + height = response["result"]["value"]["height"] + cur_node["union_bound"] = [x, y, width, height] + + dom_tree.append(cur_node) + + client.detach() + # add parent children index to the node + for parent_id, child_ids in graph.items(): + dom_tree[int(parent_id)]["childIds"] = child_ids + + # remove the nodes that are not in the current viewport + if current_viewport_only: + + def remove_node_in_graph(node: DOMNode) -> None: + # update the node information in the accessibility tree + node_id = node["nodeId"] + parent_id = node["parentId"] + child_ids = node["childIds"] + + # update the children of the parent node + assert dom_tree[int(parent_id)]["parentId"] != "[REMOVED]" + # remove the nodeid from parent + index = dom_tree[int(parent_id)]["childIds"].index(node_id) + dom_tree[int(parent_id)]["childIds"].pop(index) + + # Insert children_nodeids in the same location + for child_id in child_ids: + dom_tree[int(parent_id)]["childIds"].insert(index, child_id) + index += 1 + + # update children node's parent + for child_id in child_ids: + dom_tree[int(child_id)]["parentId"] = parent_id + # mark as removed + dom_tree[int(node_id)]["parentId"] = "[REMOVED]" + + config = info["config"] + for cursor, node in enumerate(dom_tree): + if not node["union_bound"]: + remove_node_in_graph(node) + continue + + [x, y, width, height] = node["union_bound"] + + # invisible node + if width == 0.0 or height == 0.0: + remove_node_in_graph(node) + continue + + in_viewport_ratio = self.get_element_in_viewport_ratio( + elem_left_bound=float(x), + elem_top_bound=float(y), + width=float(width), + height=float(height), + config=config, + ) + + if in_viewport_ratio < IN_VIEWPORT_RATIO_THRESHOLD: + remove_node_in_graph(node) + + dom_tree = [ + node for node in dom_tree if node.get("parentId", "-1") != "[REMOVED]" + ] + + return dom_tree + + @staticmethod + def parse_html(dom_tree: DOMTree) -> tuple[str, dict[str, Any]]: + """Parse the html tree into a string text""" + + obs_nodes_info = {} + nodeid_to_cursor = {node["nodeId"]: idx for idx, node in enumerate(dom_tree)} + + def dfs(node_cursor: int, depth: int) -> str: + tree_str = "" + node = dom_tree[node_cursor] + indent = "\t" * depth + valid_node = True + try: + node_str = f"[{node_cursor}] <{node['nodeName']}" + if node["attributes"]: + node_str += f" {node['attributes']}" + node_str += f"> {node['nodeValue']}" + valid_node = bool(node["attributes"] or node["nodeValue"]) + + if valid_node: + obs_nodes_info[str(node_cursor)] = { + "backend_id": node["backendNodeId"], + "union_bound": node["union_bound"], + "text": node_str, + } + tree_str += f"{indent}{node_str}\n" + + except Exception as e: + valid_node = False + + for child_ids in node["childIds"]: + child_cursor = nodeid_to_cursor[child_ids] + child_depth = depth + 1 if valid_node else depth + child_str = dfs(child_cursor, child_depth) + tree_str += child_str + + return tree_str + + html = dfs(0, 0) + return html, obs_nodes_info + + def fetch_page_accessibility_tree( + self, + page: Page, + info: BrowserInfo, + current_viewport_only: bool, + ) -> AccessibilityTree: + client = page.context.new_cdp_session(page) + accessibility_tree: AccessibilityTree = client.send( + "Accessibility.getFullAXTree", {} + )["nodes"] + + # a few nodes are repeated in the accessibility tree + seen_ids = set() + _accessibility_tree = [] + for node in accessibility_tree: + if node["nodeId"] not in seen_ids: + _accessibility_tree.append(node) + seen_ids.add(node["nodeId"]) + accessibility_tree = _accessibility_tree + + nodeid_to_cursor = {} + for cursor, node in enumerate(accessibility_tree): + nodeid_to_cursor[node["nodeId"]] = cursor + # usually because the node is not visible etc + if "backendDOMNodeId" not in node: + node["union_bound"] = None + continue + backend_node_id = str(node["backendDOMNodeId"]) + if node["role"]["value"] == "RootWebArea": + # always inside the viewport + node["union_bound"] = [0.0, 0.0, 10.0, 10.0] + else: + response = self.get_bounding_client_rect( + client, + backend_node_id + ) + if response.get("result", {}).get("subtype", "") == "error": + node["union_bound"] = None + else: + x = response["result"]["value"]["x"] + y = response["result"]["value"]["y"] + width = response["result"]["value"]["width"] + height = response["result"]["value"]["height"] + node["union_bound"] = [x, y, width, height] + + client.detach() + # filter nodes that are not in the current viewport + if current_viewport_only: + + def remove_node_in_graph(node: AccessibilityTreeNode) -> None: + # update the node information in the accessibility tree + nodeid = node["nodeId"] + node_cursor = nodeid_to_cursor[nodeid] + parent_nodeid = node["parentId"] + children_nodeids = node["childIds"] + parent_cursor = nodeid_to_cursor[parent_nodeid] + # update the children of the parent node + assert ( + accessibility_tree[parent_cursor].get("parentId", "Root") + is not None + ) + # remove the nodeid from parent's childIds + index = accessibility_tree[parent_cursor]["childIds"].index(nodeid) + accessibility_tree[parent_cursor]["childIds"].pop(index) + # Insert children_nodeids in the same location + for child_nodeid in children_nodeids: + accessibility_tree[parent_cursor]["childIds"].insert( + index, child_nodeid + ) + index += 1 + # update children node's parent + for child_nodeid in children_nodeids: + child_cursor = nodeid_to_cursor[child_nodeid] + accessibility_tree[child_cursor]["parentId"] = parent_nodeid + # mark as removed + accessibility_tree[node_cursor]["parentId"] = "[REMOVED]" + + config = info["config"] + for node in accessibility_tree: + if not node["union_bound"]: + remove_node_in_graph(node) + continue + + [x, y, width, height] = node["union_bound"] + + # invisible node + if width == 0 or height == 0: + remove_node_in_graph(node) + continue + + in_viewport_ratio = self.get_element_in_viewport_ratio( + elem_left_bound=float(x), + elem_top_bound=float(y), + width=float(width), + height=float(height), + config=config, + ) + + if in_viewport_ratio < IN_VIEWPORT_RATIO_THRESHOLD: + remove_node_in_graph(node) + + accessibility_tree = [ + node + for node in accessibility_tree + if node.get("parentId", "Root") != "[REMOVED]" + ] + + return accessibility_tree + + @staticmethod + def parse_accessibility_tree( + accessibility_tree: AccessibilityTree, + ) -> tuple[str, dict[str, Any]]: + """Parse the accessibility tree into a string text""" + node_id_to_idx = {} + for idx, node in enumerate(accessibility_tree): + node_id_to_idx[node["nodeId"]] = idx + + obs_nodes_info = {} + + def dfs(idx: int, obs_node_id: str, depth: int) -> str: + tree_str = "" + node = accessibility_tree[idx] + indent = "\t" * depth + valid_node = True + try: + role = node["role"]["value"] + name = node["name"]["value"] + node_str = f"[{obs_node_id}] {role} {repr(name)}" + properties = [] + for property in node.get("properties", []): + try: + if property["name"] in IGNORED_ACTREE_PROPERTIES: + continue + properties.append( + f'{property["name"]}: {property["value"]["value"]}' + ) + except KeyError: + pass + + if properties: + node_str += " " + " ".join(properties) + + # check valid + if not node_str.strip(): + valid_node = False + + # empty generic node + if not name.strip(): + if not properties: + if role in [ + "generic", + "img", + "list", + "strong", + "paragraph", + "banner", + "navigation", + "Section", + "LabelText", + "Legend", + "listitem", + ]: + valid_node = False + elif role in ["listitem"]: + valid_node = False + + if valid_node: + tree_str += f"{indent}{node_str}" + obs_nodes_info[obs_node_id] = { + "backend_id": node["backendDOMNodeId"], + "union_bound": node["union_bound"], + "text": node_str, + } + + except Exception as e: + valid_node = False + + for _, child_node_id in enumerate(node["childIds"]): + if child_node_id not in node_id_to_idx: + continue + # mark this to save some tokens + child_depth = depth + 1 if valid_node else depth + child_str = dfs( + node_id_to_idx[child_node_id], child_node_id, child_depth + ) + if child_str.strip(): + if tree_str.strip(): + tree_str += "\n" + tree_str += child_str + + return tree_str + + tree_str = dfs(0, accessibility_tree[0]["nodeId"], 0) + return tree_str, obs_nodes_info + + @staticmethod + def clean_accesibility_tree(tree_str: str) -> str: + """further clean accesibility tree""" + clean_lines: list[str] = [] + for line in tree_str.split("\n"): + # remove statictext if the content already appears in the previous line + if "statictext" in line.lower(): + prev_lines = clean_lines[-3:] + pattern = r"\[\d+\] StaticText (.+)" + + match = re.search(pattern, line, re.DOTALL) + if match: + static_text = match.group(1)[1:-1] # remove the quotes + if static_text and all( + static_text not in prev_line for prev_line in prev_lines + ): + clean_lines.append(line) + else: + clean_lines.append(line) + + return "\n".join(clean_lines) + + def fetch_image_related(self, page: Page, browser_info: BrowserInfo) -> str: + # Check if the current page is an image url + if page.url.endswith((".jpg", ".jpeg", ".png")): + print("NOTE: We are on an image page!!!") + # Load image from current url and run captioning on it. + if page.url not in self.url2caption and self.captioning_fn is not None: + try: + image = Image.open(requests.get(page.url, stream=True).raw) + caption = self.captioning_fn([image])[0].strip() + self.url2caption[page.url] = remove_unicode(caption) + except Exception as e: + print("L579 WARNING: ", e) + content = self.url2caption.get(page.url, "Image") + + else: + if self.captioning_fn is not None: + images = page.query_selector_all("img") + image_urls = [] + for image in images: + try: + image_url = image.get_attribute("src") + if not image_url.startswith(("http://", "https://", "www.")): + image_url = urljoin(page.url, image_url) + if image_url not in self.url2caption: + image_urls.append(image_url) + except Exception as e: + print("L604 WARNING: ", e) + + # Run image captioning on image_url pixels. This is for models which use captioning as a baseline. + if len(image_urls) > 0: + image_pixels = [] + valid_urls = [] + for url in image_urls: + if "data:image/svg" in url: + continue + else: + try: + image = Image.open(requests.get(url, stream=True).raw) + image_pixels.append(image) + valid_urls.append(url) + except Exception as e: + print("L616 WARNING: ", e) + + # Caption images. + if image_pixels: + # Run in batches of 4. + bs = 4 + captions = [] + for i in range(0, len(image_pixels), bs): + try: + captions.extend( + self.captioning_fn(image_pixels[i : i + bs]) + ) + except Exception as e: + print("L628 WARNING: ", e) + captions.extend([""] * len(image_pixels[i : i + bs])) + assert len(valid_urls) == len( + captions + ), f"len(images)={len(valid_urls)}, len(captions)={len(captions)}" + for image_url, caption in zip(valid_urls, captions): + self.url2caption[image_url] = remove_unicode( + caption.strip() + ) + + image_idx = 0 + for image in images: + try: + original_alt = image.get_attribute("alt") or "" + image_url = image.get_attribute("src") + if not image_url.startswith(("http://", "https://", "www.")): + image_url = urljoin(page.url, image_url) + + updated_alt = original_alt + + if image_url in self.url2caption: + if self.url2caption[image_url] not in updated_alt: + updated_alt = f"{updated_alt}, description: {self.url2caption[image_url]}" + elif "data:image/svg" not in image_url: + print(f"WARNING: {image_url} not in self.url2caption") + + if "url:" not in updated_alt: + updated_alt = f"{updated_alt}, url: {image_url}" + + safe_updated_alt = json.dumps(updated_alt) + image.evaluate(f"node => node.alt = {safe_updated_alt}") + except Exception as e: + print("L653 WARNING:", e) + + if self.observation_type == "accessibility_tree_with_captioner": + frame_ax_trees = self.fetch_page_accessibility_tree( + page, + browser_info, + current_viewport_only=self.current_viewport_only + ) + content, obs_nodes_info = self.parse_accessibility_tree(frame_ax_trees) + content = self.clean_accesibility_tree(content) + self.obs_nodes_info = obs_nodes_info + self.meta_data["obs_nodes_info"] = obs_nodes_info + else: + content = "" # Not used for SoM + + return content + + def process(self, page: Page) -> str: + # get the tab info + open_tabs = page.context.pages + try: + tab_titles = [tab.title() for tab in open_tabs] + current_tab_idx = open_tabs.index(page) + for idx in range(len(open_tabs)): + if idx == current_tab_idx: + tab_titles[idx] = f"Tab {idx} (current): {open_tabs[idx].title()}" + else: + tab_titles[idx] = f"Tab {idx}: {open_tabs[idx].title()}" + tab_title_str = " | ".join(tab_titles) + except Exception: + tab_title_str = " | ".join([f"Tab {idx}" for idx in range(len(open_tabs))]) + + try: + browser_info = self.fetch_browser_info(page) + except Exception: + page.wait_for_load_state("load", timeout=500) + browser_info = self.fetch_browser_info(page) + + if self.observation_type == "html": + dom_tree = self.fetch_page_html( + browser_info, + page, + self.current_viewport_only, + ) + content, obs_nodes_info = self.parse_html(dom_tree) + self.obs_nodes_info = obs_nodes_info + self.meta_data["obs_nodes_info"] = obs_nodes_info + + elif self.observation_type == "accessibility_tree": + accessibility_tree = self.fetch_page_accessibility_tree( + page, + browser_info, + self.current_viewport_only, + ) + content, obs_nodes_info = self.parse_accessibility_tree(accessibility_tree) + content = self.clean_accesibility_tree(content) + self.obs_nodes_info = obs_nodes_info + self.meta_data["obs_nodes_info"] = obs_nodes_info + + elif self.observation_type in [ + "accessibility_tree_with_captioner", + "image_som", + ]: + content = self.fetch_image_related( + page, + browser_info, + ) + + elif self.observation_type == "": + content = "" + + else: + raise ValueError(f"Invalid observation type: {self.observation_type}") + + self.browser_config = browser_info["config"] + content = f"{tab_title_str}\n\n{content}" + + return content + + def get_element_center(self, element_id: str) -> tuple[float, float]: + node_info = self.obs_nodes_info[element_id] + node_bound = node_info["union_bound"] + x, y, width, height = node_bound + center_x = x + width / 2 + center_y = y + height / 2 + return ( + center_x / self.viewport_size["width"], + center_y / self.viewport_size["height"], + ) + + +class ImageObservationProcessor(ObservationProcessor): + def __init__( + self, + observation_type: str, + viewport_size: Optional[ViewportSize] = None, + ): + self.observation_type = observation_type + self.observation_tag = "image" + self.viewport_size = viewport_size + self.meta_data = create_empty_metadata() + + def get_page_bboxes(self, page: Page) -> list[list[float]]: + """JavaScript code to return bounding boxes and other metadata from HTML elements.""" + js_script = """ + (() => { + const interactableSelectors = [ + 'a[href]:not(:has(img))', 'a[href] img', 'button', 'input:not([type="hidden"])', 'textarea', 'select', + '[tabindex]:not([tabindex="-1"])', '[contenteditable="true"]', '[role="button"]', '[role="link"]', + '[role="checkbox"]', '[role="menuitem"]', '[role="tab"]', '[draggable="true"]', + '.btn', 'a[href="/notifications"]', 'a[href="/submit"]', '.fa.fa-star.is-rating-item', 'input[type="checkbox"]' + + ]; + + const textSelectors = ['p', 'span', 'div:not(:has(*))', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'li', 'article']; + const modifiedTextSelectors = textSelectors.map(selector => + `:not(${interactableSelectors.join(', ')}):not(style) > ${selector}` + ); + + const combinedSelectors = [...interactableSelectors, ...modifiedTextSelectors]; + const elements = document.querySelectorAll(combinedSelectors.join(', ')); + + const pixelRatio = window.devicePixelRatio; + let csvContent = "ID,Element,Top,Right,Bottom,Left,Width,Height,Alt,Class,Id,TextContent,Interactable\\n"; + let counter = 1; + + elements.forEach(element => { + const rect = element.getBoundingClientRect(); + if (rect.width === 0 || rect.height === 0) return; + let altText = element.getAttribute('alt') || ''; + altText = altText.replace(/"/g, ''); // Escape double quotes in alt text + const classList = element.className || ''; + const id = element.id || ''; + let textContent = element.textContent || ''; + textContent = textContent.replace(/"/g, ''); // Escape double quotes in textContent + + // Determine if the element is interactable + const isInteractable = interactableSelectors.some(selector => element.matches(selector)); + + const dataString = [ + counter, element.tagName, (rect.top + window.scrollY) * pixelRatio, + (rect.right + window.scrollX) * pixelRatio, (rect.bottom + window.scrollY) * pixelRatio, + (rect.left + window.scrollX) * pixelRatio, rect.width * pixelRatio, rect.height * pixelRatio, + altText, classList, id, textContent, isInteractable + ].map(value => `"${value}"`).join(","); + + csvContent += dataString + "\\n"; + counter++; + }); + + return csvContent; + })(); + """ + # Save the bbox as a CSV + csv_content = page.evaluate(js_script) + return csv_content + + def draw_bounding_boxes( + self, + data_string, + screenshot_img, + viewport_size=None, + add_ids=True, + bbox_color=None, + min_width=8, + min_height=8, + bbox_padding=0, + bbox_border=2, + plot_ids=None, + ): + """ + min_width and min_height: Minimum dimensions of the bounding box to be plotted. + """ + # Read CSV data + df = pd.read_csv(StringIO(data_string), delimiter=",", quotechar='"') + df["Area"] = df["Width"] * df["Height"] + # Remove bounding boxes that are clipped. + b_x, b_y = ( + self.browser_config["win_left_bound"], + self.browser_config["win_upper_bound"], + ) + if viewport_size is not None: + df = df[ + (df["Bottom"] - b_y >= 0) + & (df["Top"] - b_y <= viewport_size["height"]) + & (df["Right"] - b_x >= 0) + & (df["Left"] - b_x <= viewport_size["width"]) + ] + viewport_area = viewport_size["width"] * viewport_size["height"] + # Filter out bounding boxes that too large (more than 80% of the viewport) + df = df[df["Area"] <= 0.8 * viewport_area] + + # Open the screenshot image + img = screenshot_img.copy() + draw = ImageDraw.Draw(img) + + # Load a TTF font with a larger size + font_path = "media/SourceCodePro-SemiBold.ttf" + font_size, padding = 16, 2 + font = ImageFont.truetype(font_path, font_size) + + # Create a color cycle using one of the categorical color palettes in matplotlib + color_cycle = plt.rcParams["axes.prop_cycle"].by_key()["color"] + bbox_id2visid = {} + bbox_id2desc = {} + index = 0 + id2center = {} + existing_text_rectangles = [] + text_to_draw = [] + # Provide [id] textContent inputs to the model as text. + text_content_elements = [] + text_content_text = set() # Store text of interactable elements + + # Iterate through each row in the CSV and draw bounding boxes + for _, row in df.iterrows(): + if not row["Interactable"]: + content = "" + # Add image alt-text to the text representation. + if row["Element"] == "IMG" and pd.notna(row["Alt"]): + content += row["Alt"] + # Add HTML textContent (if any) to the text representation. + if pd.notna(row["TextContent"]): + content += ( + row["TextContent"].strip().replace("\n", "").replace("\t", "") + )[ + :200 + ] # Limit to 200 characters to avoid having too much text + + # Check if the text is a CSS selector + if content and not (content.startswith(".") and "{" in content): + # Add elements which are not interactable as StaticText + if content not in text_content_text: + text_content_elements.append(f"[] [StaticText] [{content}]") + text_content_text.add(content) + continue + + if (plot_ids is not None) and (row["ID"] not in plot_ids): + continue + + unique_id = str(index + 1) + bbox_id2visid[row["ID"]] = ( + unique_id # map the bounding box ID to the unique character ID + ) + top, right, bottom, left, width, height = ( + row["Top"], + row["Right"], + row["Bottom"], + row["Left"], + row["Width"], + row["Height"], + ) + left, right, top, bottom = left - b_x, right - b_x, top - b_y, bottom - b_y + id2center[unique_id] = ( + (left + right) / 2, + (bottom + top) / 2, + width, + height, + ) + + if width >= min_width and height >= min_height: + # Get the next color in the cycle + color = bbox_color or color_cycle[index % len(color_cycle)] + draw.rectangle( + [ + left - bbox_padding, + top - bbox_padding, + right + bbox_padding, + bottom + bbox_padding, + ], + outline=color, + width=bbox_border, + ) + bbox_id2desc[row["ID"]] = color + + # Draw the text on top of the rectangle + if add_ids: + # Calculate list of possible text positions + text_positions = [ + (left - font_size, top - font_size), # Top-left corner + ( + left, + top - font_size, + ), # A little to the right of the top-left corner + (right, top - font_size), # Top-right corner + ( + right - font_size - 2 * padding, + top - font_size, + ), # A little to the left of the top-right corner + (left - font_size, bottom), # Bottom-left corner + ( + left, + bottom, + ), # A little to the right of the bottom-left corner + ( + right - font_size - 2 * padding, + bottom, + ), # A little to the left of the bottom-right corner + ( + left, + bottom, + ), # A little to the right of the bottom-left corner + ( + right - font_size - 2 * padding, + bottom, + ), # A little to the left of the bottom-right corner + ] + text_width = draw.textlength(unique_id, font=font) + text_height = font_size # Assume the text is one line + + if viewport_size is not None: + for text_position in text_positions: + new_text_rectangle = [ + text_position[0] - padding, + text_position[1] - padding, + text_position[0] + text_width + padding, + text_position[1] + text_height + padding, + ] + + # Check if the new text rectangle is within the viewport + if ( + new_text_rectangle[0] >= 0 + and new_text_rectangle[1] >= 0 + and new_text_rectangle[2] <= viewport_size["width"] + and new_text_rectangle[3] <= viewport_size["height"] + ): + # If the rectangle is within the viewport, check for overlaps + overlaps = False + for existing_rectangle in existing_text_rectangles: + if self.rectangles_overlap( + new_text_rectangle, + existing_rectangle, + padding * 2, + ): + overlaps = True + break + + if not overlaps: + break + else: + # If the rectangle is outside the viewport, try the next position + continue + else: + # If none of the corners work, move the text rectangle by a fixed amount + text_position = ( + text_positions[0][0] + padding, + text_positions[0][1], + ) + new_text_rectangle = [ + text_position[0] - padding, + text_position[1] - padding, + text_position[0] + text_width + padding, + text_position[1] + text_height + padding, + ] + + existing_text_rectangles.append(new_text_rectangle) + text_to_draw.append( + (new_text_rectangle, text_position, unique_id, color) + ) + + content = "" + if row["Element"] == "IMG" and pd.notna(row["Alt"]): + content += row["Alt"] + if pd.notna(row["TextContent"]): + content += ( + row["TextContent"] + .strip() + .replace("\n", "") + .replace("\t", "") + )[ + :200 + ] # Limit to 200 characters + text_content_elements.append( + f"[{unique_id}] [{row['Element']}] [{content}]" + ) + if content in text_content_text: + # Remove text_content_elements with content + text_content_elements = [ + element + for element in text_content_elements + if element.strip() != content + ] + text_content_text.add(content) + + index += 1 + + for text_rectangle, text_position, unique_id, color in text_to_draw: + # Draw a background rectangle for the text + draw.rectangle(text_rectangle, fill=color) + draw.text(text_position, unique_id, font=font, fill="white") + + content_str = "\n".join(text_content_elements) + return img, id2center, content_str + + def rectangles_overlap(self, rect1, rect2, padding): + """ + Check if two rectangles overlap. + Each rectangle is represented as a list [x1, y1, x2, y2]. + """ + return not ( + rect1[2] < rect2[0] + padding + or rect1[0] > rect2[2] - padding + or rect1[1] > rect2[3] - padding + or rect1[3] < rect2[1] + padding + ) + + def process(self, page: Page) -> npt.NDArray[np.uint8]: + try: + browser_info = self.fetch_browser_info(page) + except Exception: + page.wait_for_load_state("load", timeout=500) + browser_info = self.fetch_browser_info(page) + + self.browser_config = browser_info["config"] + + if self.observation_type == "image_som": + # Produce the SoM image, with bounding boxes + try: + screenshot_bytes = page.screenshot() + som_bboxes = self.get_page_bboxes(page) + screenshot_img = Image.open(BytesIO(screenshot_bytes)) + bbox_img, id2center, content_str = self.draw_bounding_boxes( + som_bboxes, + screenshot_img, + viewport_size=self.viewport_size, + ) + self.som_id_info = id2center + self.meta_data["obs_nodes_info"] = id2center + screenshot_som = np.array(bbox_img) + return screenshot_som, content_str + except: + page.wait_for_event("load") + screenshot_bytes = page.screenshot() + som_bboxes = self.get_page_bboxes(page) + screenshot_img = Image.open(BytesIO(screenshot_bytes)) + bbox_img, id2center, content_str = self.draw_bounding_boxes( + som_bboxes, + screenshot_img, + viewport_size=self.viewport_size, + ) + self.som_id_info = id2center + self.meta_data["obs_nodes_info"] = id2center + screenshot_som = np.array(bbox_img) + return screenshot_som, content_str + else: + try: + screenshot = png_bytes_to_numpy(page.screenshot()) + except: + page.wait_for_event("load") + screenshot = png_bytes_to_numpy(page.screenshot()) + return screenshot, "" + + def fetch_browser_info(self, page: Page) -> BrowserInfo: + client = page.context.new_cdp_session(page) + # extract domtree + tree = client.send( + "DOMSnapshot.captureSnapshot", + { + "computedStyles": [], + "includeDOMRects": True, + "includePaintOrder": True, + }, + ) + client.detach() + # calibrate the bounds, in some cases, the bounds are scaled somehow + bounds = tree["documents"][0]["layout"]["bounds"] + b = bounds[0] + n = b[2] / self.viewport_size["width"] + bounds = [[x / n for x in bound] for bound in bounds] + tree["documents"][0]["layout"]["bounds"] = bounds + # add union bound placeholder + tree["documents"][0]["layout"]["unionBounds"] = [None for _ in bounds] + + # extract browser info + win_upper_bound = page.evaluate("window.pageYOffset") + win_left_bound = page.evaluate("window.pageXOffset") + win_width = page.evaluate("window.screen.width") + win_height = page.evaluate("window.screen.height") + win_right_bound = win_left_bound + win_width + win_lower_bound = win_upper_bound + win_height + device_pixel_ratio = page.evaluate("window.devicePixelRatio") + assert device_pixel_ratio == 1.0, "devicePixelRatio is not 1.0" + + config: BrowserConfig = { + "win_upper_bound": win_upper_bound, + "win_left_bound": win_left_bound, + "win_width": win_width, + "win_height": win_height, + "win_right_bound": win_right_bound, + "win_lower_bound": win_lower_bound, + "device_pixel_ratio": device_pixel_ratio, + } + + # assert len(tree['documents']) == 1, "More than one document in the DOM tree" + info: BrowserInfo = {"DOMTree": tree, "config": config} + + return info + + def get_element_center(self, element_id: str) -> tuple[float, float]: + if not self.observation_type == "image_som": + raise ValueError( + "get_element_center() is only supported for 'image_som' observation type." + ) + + browser_config = self.browser_config + center_x, center_y, width, height = self.som_id_info[element_id] + return ( + center_x / self.viewport_size["width"], + center_y / self.viewport_size["height"], + ) + + +class TextObervationProcessorWebRL(TextObervationProcessor): + def __init__( + self, + observation_type: str, + current_viewport_only: bool, + viewport_size: ViewportSize, + captioning_fn=None, + ): + super().__init__( + observation_type, + current_viewport_only, + viewport_size, + captioning_fn, + ) + + def process(self, page: Page) -> str: + # get the tab info + page_info = get_parsed_html(page) + html = page_info["html"] + from bs4 import BeautifulSoup + soup = BeautifulSoup(html, 'html.parser') + obs_nodes_info = {} + for tag in soup.find_all(True): + if tag.has_attr('id') and tag.has_attr('data-bbox'): + backend_id = tag['id'] + union_bound = tag['data-bbox'] + union_bound = [float(num) for num in union_bound.split(',')] + obs_nodes_info[str(backend_id)] = { + "backend_id": backend_id, + "union_bound": union_bound, + "text": str(tag) + } + self.obs_nodes_info = obs_nodes_info + self.meta_data["obs_nodes_info"] = obs_nodes_info + return html + + def get_element_center(self, element_id: str, page: Page=None) -> tuple[float, float]: + + if page is not None: + element = page.query_selector(f"[data-label-id='{element_id}']") + bbox = element.bounding_box() + relative_bbox = (bbox['x'], bbox['y'], bbox['x'] + bbox['width'], bbox['y'] + bbox['height']) + center_x = (relative_bbox[0] + relative_bbox[2]) / 2 + center_y = (relative_bbox[1] + relative_bbox[3]) / 2 + else: + node_info = self.obs_nodes_info[element_id] + node_bound = node_info["union_bound"] + x, y, width, height = node_bound + center_x = x + width / 2 + center_y = y + height / 2 + + return ( + center_x / self.viewport_size["width"], + center_y / self.viewport_size["height"], + ) + +class ObservationHandler: + """Main entry point to access all observation processor""" + + def __init__( + self, + main_observation_type: str, + text_observation_type: str, + image_observation_type: str, + current_viewport_only: bool, + viewport_size: ViewportSize, + captioning_fn=None, + ) -> None: + self.main_observation_type = main_observation_type + if text_observation_type == "webrl": + self.text_processor = TextObervationProcessorWebRL( + text_observation_type, + current_viewport_only, + viewport_size, + captioning_fn, + ) + else: + self.text_processor = TextObervationProcessor( + text_observation_type, + current_viewport_only, + viewport_size, + captioning_fn, + ) + self.image_processor = ImageObservationProcessor( + image_observation_type, viewport_size + ) + self.viewport_size = viewport_size + + def get_observation_space(self) -> spaces.Dict: + text_space = spaces.Text( + min_length=0, + max_length=UTTERANCE_MAX_LENGTH, + charset=ASCII_CHARSET + FREQ_UNICODE_CHARSET, + ) + + image_space = spaces.Box( + # Each position stores the RGB values. Note the swapped axes (height first). + np.zeros( + (self.viewport_size["height"], self.viewport_size["width"], 3), + dtype=np.uint8, + ), + np.ones( + (self.viewport_size["height"], self.viewport_size["width"], 3), + dtype=np.uint8, + ) + * 255.0, + dtype=np.uint8, + ) + + return spaces.Dict({"text": text_space, "image": image_space}) + + def get_observation(self, page: Page) -> dict[str, Observation]: + text_obs = self.text_processor.process(page) + image_obs, content_str = self.image_processor.process(page) + if content_str != "": + text_obs = content_str + return {"text": text_obs, "image": image_obs} + + def get_observation_metadata(self) -> dict[str, ObservationMetadata]: + return { + "text": self.text_processor.meta_data, + "image": self.image_processor.meta_data, + } + + @property + def action_processor(self) -> ObservationProcessor: + """Return the main processor that is associated with the action space""" + if self.main_observation_type == "text": + return self.text_processor + elif self.main_observation_type == "image": + return self.image_processor + else: + raise ValueError("Invalid main observation type") diff --git a/VAB-WebArena-Lite/new/prompt_constructor.py b/VAB-WebArena-Lite/new/prompt_constructor.py index 9cd5a2a..fc59b06 100644 --- a/VAB-WebArena-Lite/new/prompt_constructor.py +++ b/VAB-WebArena-Lite/new/prompt_constructor.py @@ -499,3 +499,59 @@ class MultimodalCoTPromptConstructor(CoTPromptConstructor): raise NotImplementedError( f"Provider {self.lm_config.provider} not implemented" ) + +class WebRLPromptConstructor(PromptConstructor): + """The agent will direct predict the action""" + + def __init__( + self, + instruction_path: str | Path, + lm_config: lm_config.LMConfig, + tokenizer: Tokenizer, + ): + super().__init__(instruction_path, lm_config, tokenizer) + + def construct( + self, + trajectory: Trajectory, + intent: str, + meta_data: dict[str, Any] = {}, + ) -> APIInput: + """Construct prompt given the trajectory""" + state_info: StateInfo = trajectory[-1] # type: ignore[assignment] + + obs = state_info["observation"][self.obs_modality] + max_obs_length = self.lm_config.gen_config["max_obs_length"] + if max_obs_length: + if self.lm_config.provider == "google": + print("NOTE: This is a Gemini model, so we use characters instead of tokens for max_obs_length.") + obs = obs[:max_obs_length] + else: + try: + obs = self.tokenizer.decode(self.tokenizer.encode(obs)[:max_obs_length]) # type: ignore[arg-type] + except: + print("NOTE: There is no available tokenizer, so we use characters instead of tokens for max_obs_length.") + obs = obs[:max_obs_length] + + turn_num = len(meta_data["action_history"]) + if turn_num == 1: + previous_action_str = [] + else: + previous_action_str = meta_data["action_history"][1:] + + index = turn_num - 1 + history = "" + for i in range(index - 1, -1, -1): + if i == 0: + history = f"Round {i}\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n{intent}\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n{previous_action_str[i]}\n\n" + history + else: + history = f"Round {i}\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n** Simplified html **\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n{previous_action_str[i]}\n\n" + history + if len(history) + len(obs) > (16384 - 512): + obs = obs[:(16384 - 512)-len(history)] + current_turn = f"Round {index}\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n{obs}\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" + prompt = f"Task Instruction: {intent}\n\n{history}{current_turn}" + + return prompt + + def extract_action(self, response: str) -> str: + return response \ No newline at end of file diff --git a/VAB-WebArena-Lite/new/run.py b/VAB-WebArena-Lite/new/run.py index e406e2c..6f16161 100644 --- a/VAB-WebArena-Lite/new/run.py +++ b/VAB-WebArena-Lite/new/run.py @@ -13,11 +13,13 @@ import tempfile import time from pathlib import Path from typing import List +import cv2 +import shutil import openai import requests import torch -from PIL import Image +from PIL import Image, ImageDraw, ImageFont from agent import ( PromptAgent, @@ -62,6 +64,34 @@ formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") console_handler.setFormatter(formatter) file_handler.setFormatter(formatter) +def text_wrap(text, font, max_width): + lines = [] + paragraphs = text.split('\n') # 按照 \n 分割文本为段落 + for paragraph in paragraphs: + words = paragraph.split(' ') + line = '' + for word in words: + # 临时行 + test_line = f"{line} {word}".strip() + # 获取临时行的宽度 + test_line_bbox = font.getbbox(test_line) + test_line_width = test_line_bbox[2] - test_line_bbox[0] + if test_line_width <= max_width: + # 如果临时行的宽度不超过图片宽度,继续添加单词 + line = test_line + else: + # 如果超过了最大宽度,保存当前行,开始新的一行 + lines.append(line) + line = word + # 添加每段的最后一行 + if line: + lines.append(line) + # 每个段落后添加一个空行,以保留段落的换行 + lines.append('') + # 移除最后一个空行(不需要额外的空行) + if lines[-1] == '': + lines.pop() + return lines def config() -> argparse.Namespace: parser = argparse.ArgumentParser( @@ -88,6 +118,7 @@ def config() -> argparse.Namespace: "html", "image", "image_som", + "webrl", ], default="accessibility_tree", help="Observation type", @@ -176,6 +207,9 @@ def config() -> argparse.Namespace: # logging related parser.add_argument("--result_dir", type=str, default="") + + # if use self-deployed model + parser.add_argument("--planner_ip", type=str, default=None) args = parser.parse_args() # check the whether the action space is compatible with the observation space @@ -196,7 +230,7 @@ def config() -> argparse.Namespace: def early_stop( - trajectory: Trajectory, max_steps: int, thresholds: dict[str, int] + trajectory: Trajectory, max_steps: int, thresholds: dict[str, int], actions=None ) -> tuple[bool, str]: """Check whether need to stop early""" @@ -228,28 +262,38 @@ def early_stop( if len(action_seq) == 0: return False, "" - last_action: Action = action_seq[-1] + if actions is None: + last_action: Action = action_seq[-1] + if last_action["action_type"] != ActionTypes.TYPE: + if len(last_k_actions) >= k: + if all( + [ + is_equivalent(action, last_action) + for action in last_k_actions + ] + ): + return True, f"Same action for {k} times" + else: + # check the action sequence + if ( + sum([is_equivalent(action, last_action) for action in action_seq]) + >= k + ): + return True, f"Same typing action for {k} times" + return False, "" - if last_action["action_type"] != ActionTypes.TYPE: + else: + last_k_actions = actions[-k:] + last_action = actions[-1] if len(last_k_actions) >= k: if all( [ - is_equivalent(action, last_action) + action == last_action for action in last_k_actions ] ): return True, f"Same action for {k} times" - - else: - # check the action sequence - if ( - sum([is_equivalent(action, last_action) for action in action_seq]) - >= k - ): - return True, f"Same typing action for {k} times" - - return False, "" - + return False, "" def update_action_history(path: str, task_id: int, actions: List[str], score: float=-0.1): obj = { @@ -387,18 +431,22 @@ def test( obs, info = env.reset(options={"config_file": config_file}) state_info: StateInfo = {"observation": obs, "info": info} trajectory.append(state_info) - meta_data = {"action_history": ["None"]} out_path = os.path.join(args.result_dir, "actions", f"{task_id}.json") actions = [] + os.makedirs(os.path.join(args.result_dir, 'screehshots'), exist_ok=True) + if os.path.exists(os.path.join(args.result_dir, 'screehshots', f"{task_id}")): + shutil.rmtree(os.path.join(args.result_dir, 'screehshots', f"{task_id}")) + os.makedirs(os.path.join(args.result_dir, 'screehshots', f"{task_id}")) + while True: update_action_history(out_path, task_id, actions=actions) - + # If no actions variable is passed, the behavior of early_stop is the same as the original one. early_stop_flag, stop_info = early_stop( - trajectory, max_steps, early_stop_thresholds + trajectory, max_steps, early_stop_thresholds, actions ) - + if early_stop_flag: action = create_stop_action(f"Early stop: {stop_info}") else: @@ -407,14 +455,14 @@ def test( trajectory, intent, images=images, - meta_data=meta_data, + meta_data=meta_data ) except ValueError as e: # get the error message action = create_stop_action(f"ERROR: {str(e)}") - + trajectory.append(action) - + action_str = get_action_description( action, state_info["info"]["observation_metadata"], @@ -426,13 +474,50 @@ def test( render_helper.render( action, state_info, meta_data, args.render_screenshot ) + + current_screenshot = os.path.join(args.result_dir, 'screehshots', f"{task_id}", f"{len(actions)}.png") + _ = env.page.viewport_size + env.page.screenshot(path="/dev/null") + env.page.screenshot(path=current_screenshot) + element_id = action["element_id"] + if element_id != "": + element = env.page.query_selector(f"[data-label-id='{element_id}']") + if element: + bbox = element.bounding_box() + bbox = [int(bbox['x']), int(bbox['y']), int(bbox['width']),int(bbox['height'])] + image = cv2.imread(current_screenshot) + cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[0] + bbox[2], bbox[1] + bbox[3]), (0, 255, 0), 2) + cv2.circle(image, (int(bbox[0] + bbox[2] / 2), int(bbox[1] + bbox[3] / 2)), radius=0, color=(0, 255, 0), thickness=2) + cv2.imwrite(current_screenshot, image) + font_path = "/workspace/qzh/Arial.ttf" # 使用FreeSans字体 + font_size = 30 # 你可以调整这个值来增大或缩小字体 + image = Image.open(current_screenshot) + font = ImageFont.truetype(font_path, font_size) + draw = ImageDraw.Draw(image) + image_width = image.width + wrapped_text = text_wrap(action_str, font, image_width) + line_height = font.getbbox('hg')[3] - font.getbbox('hg')[1] + text_height = line_height * len(wrapped_text) + new_image_height = image.height + text_height + 20 # 20 is extra white space + new_image = Image.new('RGB', (image.width, new_image_height), color=(255, 255, 255)) # white background + draw_new = ImageDraw.Draw(new_image) + y_text = 10 # Initial position of text from top + for line in wrapped_text: + text_bbox = draw_new.textbbox((0, 0), line, font=font) + text_width = text_bbox[2] - text_bbox[0] + text_position = ((image_width - text_width) // 2, y_text) # Center text horizontally + draw_new.text(text_position, line, font=font, fill=(0, 0, 0)) # black text + y_text += line_height # move to next line + new_image.paste(image, (0, text_height + 20)) + new_image.save(current_screenshot) + meta_data["action_history"].append(action_str) actions.append(action_str) - print(action_str) + print('Action String: ', action_str) if action["action_type"] == ActionTypes.STOP: break - + obs, _, terminated, _, info = env.step(action) state_info = {"observation": obs, "info": info} trajectory.append(state_info) @@ -540,7 +625,7 @@ if __name__ == "__main__": os.environ["TOKENIZERS_PARALLELISM"] = "false" args = config() - args.sleep_after_execution = 2.5 + args.sleep_after_execution = 3.0 prepare(args) test_config_base_dir = args.test_config_base_dir diff --git a/VAB-WebArena-Lite/new/score.py b/VAB-WebArena-Lite/new/score.py new file mode 100644 index 0000000..95235e5 --- /dev/null +++ b/VAB-WebArena-Lite/new/score.py @@ -0,0 +1,86 @@ +import os, json, sys, copy + +USE_TASKS = [i for i in range(165)] + +def get_result(res_dict, src="all"): + if len(res_dict) == 0: + return '' + + success_id = [k for k, v in res_dict.items() if v >= 1.0] + score = len(success_id) + finish_count = len(res_dict) + pacc, acc = score / finish_count * 100, score / TASKS * 100 + + print(sorted(success_id)) + + meta = """ +-------- +src file: {} +successed: {:3} / {:4} (812) +partial accuracy: {:7} +overall accuracy: {:7} +-------- +""".format(src, int(score), finish_count, round(pacc, 2), round(acc, 2)) + + print(meta) + +def export_result(res_dict, src=".", note=["1.0", "0.0"], show_all=False): + out_string = "" + for id in USE_TASKS: + # with open(f"Pipeline/config_files/{id}.json", "r") as f: + # jd = json.load(f) + + # if "map" in jd["sites"]: + # continue + if id in res_dict: + if res_dict[id] >= 1.0: + out_string += note[0] + else: + out_string += note[1] + elif show_all: + out_string += note[1] + out_string += "\n" + + with open(os.path.join(src, 'export.txt'), 'w') as f: + f.write(out_string) + +TASKS = 165 + +files = sys.argv[1] +file_list = files.split(',') + +all_result = {} + +for src in file_list: + path = os.path.join(src, 'actions') + + result = {} + finished = os.listdir(path) + + for file in finished: + if not file.endswith('.json'): + continue + with open(os.path.join(path, file), 'r') as f: + data = json.load(f) + + if not isinstance(data, dict): + continue + + task_id = data.get('task_id', 1000) + # if task_id >= TASKS: + # continue + + task_score = data.get('score', 0) + if task_score < 0: + continue + + result[task_id] = task_score + if task_id not in all_result or task_score > all_result[task_id]: + all_result[task_id] = task_score + + get_result(result, src) + export_result(result, src=src) + +if len(file_list) > 1: + get_result(all_result) +export_result(all_result, show_all=True) diff --git a/VAB-WebArena-Lite/new/test_webarena_lite.raw.json b/VAB-WebArena-Lite/new/test_webarena_lite.raw.json index c77225a..adb1822 100644 --- a/VAB-WebArena-Lite/new/test_webarena_lite.raw.json +++ b/VAB-WebArena-Lite/new/test_webarena_lite.raw.json @@ -187,7 +187,7 @@ "string_match" ], "reference_answers": { - "must_include": [ + "fuzzy_match": [ "0" ] }, @@ -774,8 +774,8 @@ "string_match" ], "reference_answers": { - "fuzzy_match": [ - "914km" + "must_include": [ + "914" ] }, "reference_url": "", @@ -1139,7 +1139,7 @@ "string_match" ], "reference_answers": { - "must_include": [ + "fuzzy_match": [ "0" ] }, @@ -1679,7 +1679,7 @@ "string_match" ], "reference_answers": { - "exact_match": "N/A" + "fuzzy_match": ["N/A"] }, "reference_url": "", "program_html": [], @@ -2605,7 +2605,7 @@ "string_match" ], "reference_answers": { - "exact_match": "yjlou" + "must_include": "yjlou" }, "reference_url": "", "program_html": [], @@ -3712,7 +3712,7 @@ "string_match" ], "reference_answers": { - "exact_match": "N/A" + "fuzzy_match": "N/A" }, "reference_url": "", "program_html": [], diff --git a/VAB-WebArena-Lite/new/tokenizers.py b/VAB-WebArena-Lite/new/tokenizers.py index 2234040..b6e8e74 100644 --- a/VAB-WebArena-Lite/new/tokenizers.py +++ b/VAB-WebArena-Lite/new/tokenizers.py @@ -1,15 +1,18 @@ from typing import Any import tiktoken -from transformers import LlamaTokenizer # type: ignore +from transformers import LlamaTokenizer, AutoTokenizer # type: ignore class Tokenizer(object): def __init__(self, provider: str, model_name: str) -> None: if provider == "openai": - self.tokenizer = tiktoken.encoding_for_model(model_name) + try: + self.tokenizer = tiktoken.encoding_for_model(model_name) + except: # The provider is in openai format but the model is a finetuned model + self.tokenizer = None elif provider == "huggingface": - self.tokenizer = LlamaTokenizer.from_pretrained(model_name) + self.tokenizer = LlamaTokenizer.from_pretrained(model_name, trust_remote_code=True) # turn off adding special tokens automatically self.tokenizer.add_special_tokens = False # type: ignore[attr-defined] self.tokenizer.add_bos_token = False # type: ignore[attr-defined] diff --git a/VAB-WebArena-Lite/new/utils.py b/VAB-WebArena-Lite/new/utils.py index 4558585..2a8b87c 100644 --- a/VAB-WebArena-Lite/new/utils.py +++ b/VAB-WebArena-Lite/new/utils.py @@ -21,6 +21,8 @@ APIInput = str | list[Any] | dict[str, Any] def call_llm( lm_config: lm_config.LMConfig, prompt: APIInput, + api_key = None, + base_url = None ) -> str: response: str if lm_config.provider == "openai": @@ -39,11 +41,13 @@ def call_llm( assert isinstance(prompt, str) response = generate_from_openai_completion( prompt=prompt, - engine=lm_config.model, + model=lm_config.model, temperature=lm_config.gen_config["temperature"], max_tokens=lm_config.gen_config["max_tokens"], top_p=lm_config.gen_config["top_p"], stop_token=lm_config.gen_config["stop_token"], + api_key=api_key, + base_url=base_url ) else: raise ValueError( diff --git a/VAB-WebArena-Lite/new/wa_parallel_run_webrl.sh b/VAB-WebArena-Lite/new/wa_parallel_run_webrl.sh new file mode 100644 index 0000000..f8c4784 --- /dev/null +++ b/VAB-WebArena-Lite/new/wa_parallel_run_webrl.sh @@ -0,0 +1,97 @@ +#!/bin/bash +DATASET='webarena' # TODO: select from ['webarena', 'visualwebarena'] +result_dir='' # TODO: set your result_dir +provider='openai' # TODO: select from ['openai', 'finetune', ...] +model='' # TODO: assign model name, which is used for action generation +planner_ip='' # TODO: ip address of the model you are deploying (only if you are deploying your own model using e.g. vllm) +instruction_path='agent/prompts/jsons/p_webrl.json' # e.g., agent/prompts/jsons/p_cot_id_actree_2s.json +test_config_base_dir='config_files/wa/test_webarena_lite' # e.g., config_files/wa/test_webarena_lite +temperature=0.0 + +SERVER='' # TODO: your server address +MAP_SERVER='' # TODO: the server address for MAP tasks +OPENAI_API_KEY='' # TODO: if you test OpenAI APIs +OPENAI_ORGANIZATION='' +CONDA_ENV_NAME='' # TODO: the name of your conda environment for testing WebArena + +ENV_VARIABLES="export DATASET=${DATASET}; export SHOPPING='http://${SERVER}:7770';export SHOPPING_ADMIN='http://${SERVER}:7780/admin';export REDDIT='http://${SERVER}:9999';export GITLAB='http://${SERVER}:8023';export MAP='http://${MAP_SERVER}:3000';export WIKIPEDIA='http://${SERVER}:8888/wikipedia_en_all_maxi_2022-05/A/User:The_other_Kiwix_guy/Landing';export HOMEPAGE='http://${SERVER}:4399';export OPENAI_API_KEY=${OPENAI_API_KEY};export OPENAI_ORGANIZATION=${OPENAI_ORGANIZATION}" + +# get the number of tmux panes +num_panes=$(tmux list-panes | wc -l) + +# calculate how many panes need to be created +let "panes_to_create = 7 - num_panes" + +# array of tmux commands to create each pane +tmux_commands=( + 'tmux split-window -h' + 'tmux split-window -v' + 'tmux select-pane -t 0; tmux split-window -v' + 'tmux split-window -v' + 'tmux select-pane -t 3; tmux split-window -v' + 'tmux select-pane -t 5; tmux split-window -v' +) + +# create panes up to 7 +for ((i=0; i<$panes_to_create; i++)); do + eval ${tmux_commands[$i]} +done + +#!/bin/bash + +# Function to run a job +run_job() { + tmux select-pane -t $1 + COMMAND="python run.py \ + --instruction_path ${instruction_path} \ + --test_start_idx $2 \ + --test_end_idx $3 \ + --result_dir ${result_dir} \ + --test_config_base_dir ${test_config_base_dir} \ + --provider ${provider} \ + --model ${model} \ + --mode completion \ + --planner_ip ${planner_ip} \ + --stop_token \"<|eot_id|>\" \ + --temperature ${temperature} \ + --max_obs_length 0 \ + --max_tokens 2048 \ + --viewport_width 1280 \ + --viewport_height 720 \ + --parsing_failure_th 5 \ + --repeating_action_failure_th 5 \ + --action_set_tag webrl_id --observation_type webrl" + tmux send-keys "tmux set mouse on; conda activate ${CONDA_ENV_NAME}; ${ENV_VARIABLES}; until ${COMMAND}; do echo 'crashed' >&2; sleep 1; done" C-m + sleep 3 +} + +TOLERANCE=2 +run_batch() { + args=("$@") # save all arguments in an array + num_jobs=${#args[@]} # get number of arguments + + for ((i=1; i<$num_jobs; i++)); do + run_job $i ${args[i-1]} ${args[i]} + done + + # Wait for all jobs to finish + while tmux list-panes -F "#{pane_pid} #{pane_current_command}" | grep -q python; do + sleep 100 # wait for 10 seconds before checking again + done + + # Run checker + while ! python scripts/check_error_runs.py ${result_dir} --delete_errors --tolerance ${TOLERANCE}; do + echo "Check failed, rerunning jobs..." + for ((i=1; i<$num_jobs; i++)); do + run_job $i ${args[i-1]} ${args[i]} + done + + # Wait for all jobs to finish + while tmux list-panes -F "#{pane_pid} #{pane_current_command}" | grep -q python; do + sleep 100 # wait for 10 seconds before checking again + done + done + +} +run_batch 0 28 56 84 112 140 165 + diff --git a/VAB-WebArena-Lite/replace.sh b/VAB-WebArena-Lite/replace.sh index add244c..7a674f0 100644 --- a/VAB-WebArena-Lite/replace.sh +++ b/VAB-WebArena-Lite/replace.sh @@ -14,6 +14,14 @@ cp -f new/generate_test_data.py visualwebarena/scripts/generate_test_data.py cp -f new/run.py visualwebarena/run.py cp -f new/agent.py visualwebarena/agent/agent.py cp -f new/prompt_constructor.py visualwebarena/agent/prompts/prompt_constructor.py +cp -f new/p_webrl.json visualwebarena/agent/prompts/jsons/p_webrl.json + +# browser_env +cp -f new/actions.py visualwebarena/browser_env/actions.py +cp -f new/envs.py visualwebarena/browser_env/envs.py +cp -f new/helper_functions_browser.py visualwebarena/browser_env/helper_functions.py +cp -f new/processors.py visualwebarena/browser_env/processors.py +cp -rf new/html_tools visualwebarena/browser_env/ # llms cp -f new/utils.py visualwebarena/llms/utils.py @@ -22,15 +30,19 @@ cp -f new/lm_config.py visualwebarena/llms/lm_config.py cp -f new/tokenizers.py visualwebarena/llms/tokenizers.py cp -f new/api_utils.py visualwebarena/llms/providers/api_utils.py cp -f new/openai_utils.py visualwebarena/llms/providers/openai_utils.py +cp -f new/utils.py visualwebarena/llms/utils.py # eval cp -f new/evaluators.py visualwebarena/evaluation_harness/evaluators.py -cp -f new/helper_functions.py visualwebarena/evaluation_harness/helper_functions.py +cp -f new/helper_functions_eval.py visualwebarena/evaluation_harness/helper_functions.py # misc cp -f README.md visualwebarena/README.md cp -f new/wa_parallel_run.sh visualwebarena/wa_parallel_run.sh +cp -f new/score.py visualwebarena/score.py +cp -f new/wa_parallel_run_webrl.sh visualwebarena/wa_parallel_run_webrl.sh + # 3. remove temporary files mv visualwebarena/* . rm -rf new