webrl/VAB-WebArena-Lite/new/processors.py
2024-11-14 15:51:41 +08:00

1352 lines
52 KiB
Python

import json
import pkgutil
import re
from collections import defaultdict
from dataclasses import dataclass
from io import BytesIO, StringIO
from typing import Any, Optional, TypedDict, Union
from urllib.parse import urljoin, urlparse
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import pandas as pd
import playwright
import requests
from gymnasium import spaces
from PIL import Image, ImageDraw, ImageFont
from playwright.sync_api import CDPSession, Page, ViewportSize
from .html_tools.fetch import get_parsed_html
from browser_env.constants import (
ASCII_CHARSET,
FREQ_UNICODE_CHARSET,
IGNORED_ACTREE_PROPERTIES,
INJECTED_ATTR_NAME,
UTTERANCE_MAX_LENGTH,
BID_ATTR,
DATA_REGEXP,
IN_VIEWPORT_RATIO_THRESHOLD,
)
from .utils import (
AccessibilityTree,
AccessibilityTreeNode,
BrowserConfig,
BrowserInfo,
DOMNode,
DOMTree,
Observation,
png_bytes_to_numpy,
)
def remove_unicode(input_string):
# Define a regex pattern to match Unicode characters
unicode_pattern = re.compile(r"[^\x00-\x7F]+")
# Use the pattern to replace Unicode characters with an empty string
cleaned_string = unicode_pattern.sub("", input_string)
return cleaned_string
class ObservationProcessor:
def process(self, page: Page) -> Observation:
raise NotImplementedError
class ObservationMetadata(TypedDict):
obs_nodes_info: dict[str, Any]
def create_empty_metadata() -> ObservationMetadata:
return {
"obs_nodes_info": {},
}
def extract_data_items_from_aria(string: str) -> tuple[list[str], str]:
"""
Utility function to extract temporary data stored in the "aria-roledescription" attribute of a node
"""
match = DATA_REGEXP.fullmatch(string)
if not match:
return [], string
groups = match.groups()
data_items = groups[:-1]
original_aria = groups[-1]
return data_items, original_aria
class TextObervationProcessor(ObservationProcessor):
def __init__(
self,
observation_type: str,
current_viewport_only: bool,
viewport_size: ViewportSize,
captioning_fn=None,
):
self.observation_type = observation_type
self.current_viewport_only = current_viewport_only
self.viewport_size = viewport_size
self.observation_tag = "text"
self.meta_data = (
create_empty_metadata()
) # use the store meta data of this observation type
if self.observation_type in [
"accessibility_tree_with_captioner",
"image_som",
]:
self.captioning_fn = captioning_fn
# Cache captions.
self.url2caption = {}
def fetch_browser_info(
self,
page: Page,
) -> BrowserInfo:
# extract domtree
client = page.context.new_cdp_session(page)
tree = client.send(
"DOMSnapshot.captureSnapshot",
{
"computedStyles": [],
"includeDOMRects": True,
"includePaintOrder": True,
},
)
client.detach()
# calibrate the bounds, in some cases, the bounds are scaled somehow
bounds = tree["documents"][0]["layout"]["bounds"]
b = bounds[0]
n = b[2] / self.viewport_size["width"]
bounds = [[x / n for x in bound] for bound in bounds]
tree["documents"][0]["layout"]["bounds"] = bounds
# add union bound placeholder
tree["documents"][0]["layout"]["unionBounds"] = [None for _ in bounds]
# extract browser info
win_upper_bound = page.evaluate("window.pageYOffset")
win_left_bound = page.evaluate("window.pageXOffset")
win_width = page.evaluate("window.screen.width")
win_height = page.evaluate("window.screen.height")
win_right_bound = win_left_bound + win_width
win_lower_bound = win_upper_bound + win_height
device_pixel_ratio = page.evaluate("window.devicePixelRatio")
assert device_pixel_ratio == 1.0, "devicePixelRatio is not 1.0"
config: BrowserConfig = {
"win_upper_bound": win_upper_bound,
"win_left_bound": win_left_bound,
"win_width": win_width,
"win_height": win_height,
"win_right_bound": win_right_bound,
"win_lower_bound": win_lower_bound,
"device_pixel_ratio": device_pixel_ratio,
}
# assert len(tree['documents']) == 1, "More than one document in the DOM tree"
info: BrowserInfo = {"DOMTree": tree, "config": config}
return info
@staticmethod
def get_bounding_client_rect(
client: CDPSession, backend_node_id: str
) -> dict[str, Any]:
try:
remote_object = client.send(
"DOM.resolveNode", {"backendNodeId": int(backend_node_id)}
)
remote_object_id = remote_object["object"]["objectId"]
response = client.send(
"Runtime.callFunctionOn",
{
"objectId": remote_object_id,
"functionDeclaration": """
function() {
if (this.nodeType == 3) {
var range = document.createRange();
range.selectNode(this);
var rect = range.getBoundingClientRect().toJSON();
range.detach();
return rect;
} else {
return this.getBoundingClientRect().toJSON();
}
}
""",
"returnByValue": True,
},
)
return response
except Exception as e:
return {"result": {"subtype": "error"}}
@staticmethod
def get_element_in_viewport_ratio(
elem_left_bound: float,
elem_top_bound: float,
width: float,
height: float,
config: BrowserConfig,
) -> float:
elem_right_bound = elem_left_bound + width
elem_lower_bound = elem_top_bound + height
win_left_bound = 0
win_right_bound = config["win_width"]
win_top_bound = 0
win_lower_bound = config["win_height"]
# Compute the overlap in x and y axes
overlap_width = max(
0,
min(elem_right_bound, win_right_bound)
- max(elem_left_bound, win_left_bound),
)
overlap_height = max(
0,
min(elem_lower_bound, win_lower_bound) - max(elem_top_bound, win_top_bound),
)
# Compute the overlap area
ratio = overlap_width * overlap_height / width * height
return ratio
def fetch_page_html(
self,
info: BrowserInfo,
page: Page,
current_viewport_only: bool,
) -> DOMTree:
# adopted from [natbot](https://github.com/nat/natbot)
tree = info["DOMTree"]
strings = tree["strings"]
document = tree["documents"][0]
nodes = document["nodes"]
# make a dom tree that is easier to navigate
dom_tree: DOMTree = []
graph = defaultdict(list)
client = page.context.new_cdp_session(page)
for node_idx in range(len(nodes["nodeName"])):
cur_node: DOMNode = {
"nodeId": "",
"nodeType": "",
"nodeName": "",
"nodeValue": "",
"attributes": "",
"backendNodeId": "",
"parentId": "",
"childIds": [],
"cursor": 0,
"union_bound": None,
}
node_type_idx = nodes["nodeType"][node_idx]
node_type = "generic"
if node_type_idx >= 0 and node_type_idx < len(strings):
node_type = strings[node_type_idx]
node_name = strings[nodes["nodeName"][node_idx]]
node_value_idx = nodes["nodeValue"][node_idx]
node_value = ""
if node_value_idx >= 0 and node_value_idx < len(strings):
node_value = " ".join(strings[node_value_idx].split())
node_attributes = [strings[i] for i in nodes["attributes"][node_idx]]
node_attributes_str = ""
for i in range(0, len(node_attributes), 2):
a = node_attributes[i]
b = node_attributes[i + 1]
b = " ".join(b.split())
node_attributes_str += f'{a}="{b}" '
node_attributes_str = node_attributes_str.strip()
cur_node["nodeId"] = str(node_idx)
cur_node["nodeType"] = node_type
cur_node["nodeName"] = node_name
cur_node["nodeValue"] = node_value
cur_node["attributes"] = node_attributes_str
cur_node["backendNodeId"] = str(nodes["backendNodeId"][node_idx])
cur_node["parentId"] = str(nodes["parentIndex"][node_idx])
if cur_node["parentId"] != "-1":
graph[cur_node["parentId"]].append(str(cur_node["nodeId"]))
# get the bound
if cur_node["parentId"] == "-1":
cur_node["union_bound"] = [0.0, 0.0, 10.0, 10.0]
else:
response = self.get_bounding_client_rect(
client, cur_node["backendNodeId"]
)
if response.get("result", {}).get("subtype", "") == "error":
cur_node["union_bound"] = None
else:
x = response["result"]["value"]["x"]
y = response["result"]["value"]["y"]
width = response["result"]["value"]["width"]
height = response["result"]["value"]["height"]
cur_node["union_bound"] = [x, y, width, height]
dom_tree.append(cur_node)
client.detach()
# add parent children index to the node
for parent_id, child_ids in graph.items():
dom_tree[int(parent_id)]["childIds"] = child_ids
# remove the nodes that are not in the current viewport
if current_viewport_only:
def remove_node_in_graph(node: DOMNode) -> None:
# update the node information in the accessibility tree
node_id = node["nodeId"]
parent_id = node["parentId"]
child_ids = node["childIds"]
# update the children of the parent node
assert dom_tree[int(parent_id)]["parentId"] != "[REMOVED]"
# remove the nodeid from parent
index = dom_tree[int(parent_id)]["childIds"].index(node_id)
dom_tree[int(parent_id)]["childIds"].pop(index)
# Insert children_nodeids in the same location
for child_id in child_ids:
dom_tree[int(parent_id)]["childIds"].insert(index, child_id)
index += 1
# update children node's parent
for child_id in child_ids:
dom_tree[int(child_id)]["parentId"] = parent_id
# mark as removed
dom_tree[int(node_id)]["parentId"] = "[REMOVED]"
config = info["config"]
for cursor, node in enumerate(dom_tree):
if not node["union_bound"]:
remove_node_in_graph(node)
continue
[x, y, width, height] = node["union_bound"]
# invisible node
if width == 0.0 or height == 0.0:
remove_node_in_graph(node)
continue
in_viewport_ratio = self.get_element_in_viewport_ratio(
elem_left_bound=float(x),
elem_top_bound=float(y),
width=float(width),
height=float(height),
config=config,
)
if in_viewport_ratio < IN_VIEWPORT_RATIO_THRESHOLD:
remove_node_in_graph(node)
dom_tree = [
node for node in dom_tree if node.get("parentId", "-1") != "[REMOVED]"
]
return dom_tree
@staticmethod
def parse_html(dom_tree: DOMTree) -> tuple[str, dict[str, Any]]:
"""Parse the html tree into a string text"""
obs_nodes_info = {}
nodeid_to_cursor = {node["nodeId"]: idx for idx, node in enumerate(dom_tree)}
def dfs(node_cursor: int, depth: int) -> str:
tree_str = ""
node = dom_tree[node_cursor]
indent = "\t" * depth
valid_node = True
try:
node_str = f"[{node_cursor}] <{node['nodeName']}"
if node["attributes"]:
node_str += f" {node['attributes']}"
node_str += f"> {node['nodeValue']}"
valid_node = bool(node["attributes"] or node["nodeValue"])
if valid_node:
obs_nodes_info[str(node_cursor)] = {
"backend_id": node["backendNodeId"],
"union_bound": node["union_bound"],
"text": node_str,
}
tree_str += f"{indent}{node_str}\n"
except Exception as e:
valid_node = False
for child_ids in node["childIds"]:
child_cursor = nodeid_to_cursor[child_ids]
child_depth = depth + 1 if valid_node else depth
child_str = dfs(child_cursor, child_depth)
tree_str += child_str
return tree_str
html = dfs(0, 0)
return html, obs_nodes_info
def fetch_page_accessibility_tree(
self,
page: Page,
info: BrowserInfo,
current_viewport_only: bool,
) -> AccessibilityTree:
client = page.context.new_cdp_session(page)
accessibility_tree: AccessibilityTree = client.send(
"Accessibility.getFullAXTree", {}
)["nodes"]
# a few nodes are repeated in the accessibility tree
seen_ids = set()
_accessibility_tree = []
for node in accessibility_tree:
if node["nodeId"] not in seen_ids:
_accessibility_tree.append(node)
seen_ids.add(node["nodeId"])
accessibility_tree = _accessibility_tree
nodeid_to_cursor = {}
for cursor, node in enumerate(accessibility_tree):
nodeid_to_cursor[node["nodeId"]] = cursor
# usually because the node is not visible etc
if "backendDOMNodeId" not in node:
node["union_bound"] = None
continue
backend_node_id = str(node["backendDOMNodeId"])
if node["role"]["value"] == "RootWebArea":
# always inside the viewport
node["union_bound"] = [0.0, 0.0, 10.0, 10.0]
else:
response = self.get_bounding_client_rect(
client,
backend_node_id
)
if response.get("result", {}).get("subtype", "") == "error":
node["union_bound"] = None
else:
x = response["result"]["value"]["x"]
y = response["result"]["value"]["y"]
width = response["result"]["value"]["width"]
height = response["result"]["value"]["height"]
node["union_bound"] = [x, y, width, height]
client.detach()
# filter nodes that are not in the current viewport
if current_viewport_only:
def remove_node_in_graph(node: AccessibilityTreeNode) -> None:
# update the node information in the accessibility tree
nodeid = node["nodeId"]
node_cursor = nodeid_to_cursor[nodeid]
parent_nodeid = node["parentId"]
children_nodeids = node["childIds"]
parent_cursor = nodeid_to_cursor[parent_nodeid]
# update the children of the parent node
assert (
accessibility_tree[parent_cursor].get("parentId", "Root")
is not None
)
# remove the nodeid from parent's childIds
index = accessibility_tree[parent_cursor]["childIds"].index(nodeid)
accessibility_tree[parent_cursor]["childIds"].pop(index)
# Insert children_nodeids in the same location
for child_nodeid in children_nodeids:
accessibility_tree[parent_cursor]["childIds"].insert(
index, child_nodeid
)
index += 1
# update children node's parent
for child_nodeid in children_nodeids:
child_cursor = nodeid_to_cursor[child_nodeid]
accessibility_tree[child_cursor]["parentId"] = parent_nodeid
# mark as removed
accessibility_tree[node_cursor]["parentId"] = "[REMOVED]"
config = info["config"]
for node in accessibility_tree:
if not node["union_bound"]:
remove_node_in_graph(node)
continue
[x, y, width, height] = node["union_bound"]
# invisible node
if width == 0 or height == 0:
remove_node_in_graph(node)
continue
in_viewport_ratio = self.get_element_in_viewport_ratio(
elem_left_bound=float(x),
elem_top_bound=float(y),
width=float(width),
height=float(height),
config=config,
)
if in_viewport_ratio < IN_VIEWPORT_RATIO_THRESHOLD:
remove_node_in_graph(node)
accessibility_tree = [
node
for node in accessibility_tree
if node.get("parentId", "Root") != "[REMOVED]"
]
return accessibility_tree
@staticmethod
def parse_accessibility_tree(
accessibility_tree: AccessibilityTree,
) -> tuple[str, dict[str, Any]]:
"""Parse the accessibility tree into a string text"""
node_id_to_idx = {}
for idx, node in enumerate(accessibility_tree):
node_id_to_idx[node["nodeId"]] = idx
obs_nodes_info = {}
def dfs(idx: int, obs_node_id: str, depth: int) -> str:
tree_str = ""
node = accessibility_tree[idx]
indent = "\t" * depth
valid_node = True
try:
role = node["role"]["value"]
name = node["name"]["value"]
node_str = f"[{obs_node_id}] {role} {repr(name)}"
properties = []
for property in node.get("properties", []):
try:
if property["name"] in IGNORED_ACTREE_PROPERTIES:
continue
properties.append(
f'{property["name"]}: {property["value"]["value"]}'
)
except KeyError:
pass
if properties:
node_str += " " + " ".join(properties)
# check valid
if not node_str.strip():
valid_node = False
# empty generic node
if not name.strip():
if not properties:
if role in [
"generic",
"img",
"list",
"strong",
"paragraph",
"banner",
"navigation",
"Section",
"LabelText",
"Legend",
"listitem",
]:
valid_node = False
elif role in ["listitem"]:
valid_node = False
if valid_node:
tree_str += f"{indent}{node_str}"
obs_nodes_info[obs_node_id] = {
"backend_id": node["backendDOMNodeId"],
"union_bound": node["union_bound"],
"text": node_str,
}
except Exception as e:
valid_node = False
for _, child_node_id in enumerate(node["childIds"]):
if child_node_id not in node_id_to_idx:
continue
# mark this to save some tokens
child_depth = depth + 1 if valid_node else depth
child_str = dfs(
node_id_to_idx[child_node_id], child_node_id, child_depth
)
if child_str.strip():
if tree_str.strip():
tree_str += "\n"
tree_str += child_str
return tree_str
tree_str = dfs(0, accessibility_tree[0]["nodeId"], 0)
return tree_str, obs_nodes_info
@staticmethod
def clean_accesibility_tree(tree_str: str) -> str:
"""further clean accesibility tree"""
clean_lines: list[str] = []
for line in tree_str.split("\n"):
# remove statictext if the content already appears in the previous line
if "statictext" in line.lower():
prev_lines = clean_lines[-3:]
pattern = r"\[\d+\] StaticText (.+)"
match = re.search(pattern, line, re.DOTALL)
if match:
static_text = match.group(1)[1:-1] # remove the quotes
if static_text and all(
static_text not in prev_line for prev_line in prev_lines
):
clean_lines.append(line)
else:
clean_lines.append(line)
return "\n".join(clean_lines)
def fetch_image_related(self, page: Page, browser_info: BrowserInfo) -> str:
# Check if the current page is an image url
if page.url.endswith((".jpg", ".jpeg", ".png")):
print("NOTE: We are on an image page!!!")
# Load image from current url and run captioning on it.
if page.url not in self.url2caption and self.captioning_fn is not None:
try:
image = Image.open(requests.get(page.url, stream=True).raw)
caption = self.captioning_fn([image])[0].strip()
self.url2caption[page.url] = remove_unicode(caption)
except Exception as e:
print("L579 WARNING: ", e)
content = self.url2caption.get(page.url, "Image")
else:
if self.captioning_fn is not None:
images = page.query_selector_all("img")
image_urls = []
for image in images:
try:
image_url = image.get_attribute("src")
if not image_url.startswith(("http://", "https://", "www.")):
image_url = urljoin(page.url, image_url)
if image_url not in self.url2caption:
image_urls.append(image_url)
except Exception as e:
print("L604 WARNING: ", e)
# Run image captioning on image_url pixels. This is for models which use captioning as a baseline.
if len(image_urls) > 0:
image_pixels = []
valid_urls = []
for url in image_urls:
if "data:image/svg" in url:
continue
else:
try:
image = Image.open(requests.get(url, stream=True).raw)
image_pixels.append(image)
valid_urls.append(url)
except Exception as e:
print("L616 WARNING: ", e)
# Caption images.
if image_pixels:
# Run in batches of 4.
bs = 4
captions = []
for i in range(0, len(image_pixels), bs):
try:
captions.extend(
self.captioning_fn(image_pixels[i : i + bs])
)
except Exception as e:
print("L628 WARNING: ", e)
captions.extend([""] * len(image_pixels[i : i + bs]))
assert len(valid_urls) == len(
captions
), f"len(images)={len(valid_urls)}, len(captions)={len(captions)}"
for image_url, caption in zip(valid_urls, captions):
self.url2caption[image_url] = remove_unicode(
caption.strip()
)
image_idx = 0
for image in images:
try:
original_alt = image.get_attribute("alt") or ""
image_url = image.get_attribute("src")
if not image_url.startswith(("http://", "https://", "www.")):
image_url = urljoin(page.url, image_url)
updated_alt = original_alt
if image_url in self.url2caption:
if self.url2caption[image_url] not in updated_alt:
updated_alt = f"{updated_alt}, description: {self.url2caption[image_url]}"
elif "data:image/svg" not in image_url:
print(f"WARNING: {image_url} not in self.url2caption")
if "url:" not in updated_alt:
updated_alt = f"{updated_alt}, url: {image_url}"
safe_updated_alt = json.dumps(updated_alt)
image.evaluate(f"node => node.alt = {safe_updated_alt}")
except Exception as e:
print("L653 WARNING:", e)
if self.observation_type == "accessibility_tree_with_captioner":
frame_ax_trees = self.fetch_page_accessibility_tree(
page,
browser_info,
current_viewport_only=self.current_viewport_only
)
content, obs_nodes_info = self.parse_accessibility_tree(frame_ax_trees)
content = self.clean_accesibility_tree(content)
self.obs_nodes_info = obs_nodes_info
self.meta_data["obs_nodes_info"] = obs_nodes_info
else:
content = "" # Not used for SoM
return content
def process(self, page: Page) -> str:
# get the tab info
open_tabs = page.context.pages
try:
tab_titles = [tab.title() for tab in open_tabs]
current_tab_idx = open_tabs.index(page)
for idx in range(len(open_tabs)):
if idx == current_tab_idx:
tab_titles[idx] = f"Tab {idx} (current): {open_tabs[idx].title()}"
else:
tab_titles[idx] = f"Tab {idx}: {open_tabs[idx].title()}"
tab_title_str = " | ".join(tab_titles)
except Exception:
tab_title_str = " | ".join([f"Tab {idx}" for idx in range(len(open_tabs))])
try:
browser_info = self.fetch_browser_info(page)
except Exception:
page.wait_for_load_state("load", timeout=500)
browser_info = self.fetch_browser_info(page)
if self.observation_type == "html":
dom_tree = self.fetch_page_html(
browser_info,
page,
self.current_viewport_only,
)
content, obs_nodes_info = self.parse_html(dom_tree)
self.obs_nodes_info = obs_nodes_info
self.meta_data["obs_nodes_info"] = obs_nodes_info
elif self.observation_type == "accessibility_tree":
accessibility_tree = self.fetch_page_accessibility_tree(
page,
browser_info,
self.current_viewport_only,
)
content, obs_nodes_info = self.parse_accessibility_tree(accessibility_tree)
content = self.clean_accesibility_tree(content)
self.obs_nodes_info = obs_nodes_info
self.meta_data["obs_nodes_info"] = obs_nodes_info
elif self.observation_type in [
"accessibility_tree_with_captioner",
"image_som",
]:
content = self.fetch_image_related(
page,
browser_info,
)
elif self.observation_type == "":
content = ""
else:
raise ValueError(f"Invalid observation type: {self.observation_type}")
self.browser_config = browser_info["config"]
content = f"{tab_title_str}\n\n{content}"
return content
def get_element_center(self, element_id: str) -> tuple[float, float]:
node_info = self.obs_nodes_info[element_id]
node_bound = node_info["union_bound"]
x, y, width, height = node_bound
center_x = x + width / 2
center_y = y + height / 2
return (
center_x / self.viewport_size["width"],
center_y / self.viewport_size["height"],
)
class ImageObservationProcessor(ObservationProcessor):
def __init__(
self,
observation_type: str,
viewport_size: Optional[ViewportSize] = None,
):
self.observation_type = observation_type
self.observation_tag = "image"
self.viewport_size = viewport_size
self.meta_data = create_empty_metadata()
def get_page_bboxes(self, page: Page) -> list[list[float]]:
"""JavaScript code to return bounding boxes and other metadata from HTML elements."""
js_script = """
(() => {
const interactableSelectors = [
'a[href]:not(:has(img))', 'a[href] img', 'button', 'input:not([type="hidden"])', 'textarea', 'select',
'[tabindex]:not([tabindex="-1"])', '[contenteditable="true"]', '[role="button"]', '[role="link"]',
'[role="checkbox"]', '[role="menuitem"]', '[role="tab"]', '[draggable="true"]',
'.btn', 'a[href="/notifications"]', 'a[href="/submit"]', '.fa.fa-star.is-rating-item', 'input[type="checkbox"]'
];
const textSelectors = ['p', 'span', 'div:not(:has(*))', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'li', 'article'];
const modifiedTextSelectors = textSelectors.map(selector =>
`:not(${interactableSelectors.join(', ')}):not(style) > ${selector}`
);
const combinedSelectors = [...interactableSelectors, ...modifiedTextSelectors];
const elements = document.querySelectorAll(combinedSelectors.join(', '));
const pixelRatio = window.devicePixelRatio;
let csvContent = "ID,Element,Top,Right,Bottom,Left,Width,Height,Alt,Class,Id,TextContent,Interactable\\n";
let counter = 1;
elements.forEach(element => {
const rect = element.getBoundingClientRect();
if (rect.width === 0 || rect.height === 0) return;
let altText = element.getAttribute('alt') || '';
altText = altText.replace(/"/g, ''); // Escape double quotes in alt text
const classList = element.className || '';
const id = element.id || '';
let textContent = element.textContent || '';
textContent = textContent.replace(/"/g, ''); // Escape double quotes in textContent
// Determine if the element is interactable
const isInteractable = interactableSelectors.some(selector => element.matches(selector));
const dataString = [
counter, element.tagName, (rect.top + window.scrollY) * pixelRatio,
(rect.right + window.scrollX) * pixelRatio, (rect.bottom + window.scrollY) * pixelRatio,
(rect.left + window.scrollX) * pixelRatio, rect.width * pixelRatio, rect.height * pixelRatio,
altText, classList, id, textContent, isInteractable
].map(value => `"${value}"`).join(",");
csvContent += dataString + "\\n";
counter++;
});
return csvContent;
})();
"""
# Save the bbox as a CSV
csv_content = page.evaluate(js_script)
return csv_content
def draw_bounding_boxes(
self,
data_string,
screenshot_img,
viewport_size=None,
add_ids=True,
bbox_color=None,
min_width=8,
min_height=8,
bbox_padding=0,
bbox_border=2,
plot_ids=None,
):
"""
min_width and min_height: Minimum dimensions of the bounding box to be plotted.
"""
# Read CSV data
df = pd.read_csv(StringIO(data_string), delimiter=",", quotechar='"')
df["Area"] = df["Width"] * df["Height"]
# Remove bounding boxes that are clipped.
b_x, b_y = (
self.browser_config["win_left_bound"],
self.browser_config["win_upper_bound"],
)
if viewport_size is not None:
df = df[
(df["Bottom"] - b_y >= 0)
& (df["Top"] - b_y <= viewport_size["height"])
& (df["Right"] - b_x >= 0)
& (df["Left"] - b_x <= viewport_size["width"])
]
viewport_area = viewport_size["width"] * viewport_size["height"]
# Filter out bounding boxes that too large (more than 80% of the viewport)
df = df[df["Area"] <= 0.8 * viewport_area]
# Open the screenshot image
img = screenshot_img.copy()
draw = ImageDraw.Draw(img)
# Load a TTF font with a larger size
font_path = "media/SourceCodePro-SemiBold.ttf"
font_size, padding = 16, 2
font = ImageFont.truetype(font_path, font_size)
# Create a color cycle using one of the categorical color palettes in matplotlib
color_cycle = plt.rcParams["axes.prop_cycle"].by_key()["color"]
bbox_id2visid = {}
bbox_id2desc = {}
index = 0
id2center = {}
existing_text_rectangles = []
text_to_draw = []
# Provide [id] textContent inputs to the model as text.
text_content_elements = []
text_content_text = set() # Store text of interactable elements
# Iterate through each row in the CSV and draw bounding boxes
for _, row in df.iterrows():
if not row["Interactable"]:
content = ""
# Add image alt-text to the text representation.
if row["Element"] == "IMG" and pd.notna(row["Alt"]):
content += row["Alt"]
# Add HTML textContent (if any) to the text representation.
if pd.notna(row["TextContent"]):
content += (
row["TextContent"].strip().replace("\n", "").replace("\t", "")
)[
:200
] # Limit to 200 characters to avoid having too much text
# Check if the text is a CSS selector
if content and not (content.startswith(".") and "{" in content):
# Add elements which are not interactable as StaticText
if content not in text_content_text:
text_content_elements.append(f"[] [StaticText] [{content}]")
text_content_text.add(content)
continue
if (plot_ids is not None) and (row["ID"] not in plot_ids):
continue
unique_id = str(index + 1)
bbox_id2visid[row["ID"]] = (
unique_id # map the bounding box ID to the unique character ID
)
top, right, bottom, left, width, height = (
row["Top"],
row["Right"],
row["Bottom"],
row["Left"],
row["Width"],
row["Height"],
)
left, right, top, bottom = left - b_x, right - b_x, top - b_y, bottom - b_y
id2center[unique_id] = (
(left + right) / 2,
(bottom + top) / 2,
width,
height,
)
if width >= min_width and height >= min_height:
# Get the next color in the cycle
color = bbox_color or color_cycle[index % len(color_cycle)]
draw.rectangle(
[
left - bbox_padding,
top - bbox_padding,
right + bbox_padding,
bottom + bbox_padding,
],
outline=color,
width=bbox_border,
)
bbox_id2desc[row["ID"]] = color
# Draw the text on top of the rectangle
if add_ids:
# Calculate list of possible text positions
text_positions = [
(left - font_size, top - font_size), # Top-left corner
(
left,
top - font_size,
), # A little to the right of the top-left corner
(right, top - font_size), # Top-right corner
(
right - font_size - 2 * padding,
top - font_size,
), # A little to the left of the top-right corner
(left - font_size, bottom), # Bottom-left corner
(
left,
bottom,
), # A little to the right of the bottom-left corner
(
right - font_size - 2 * padding,
bottom,
), # A little to the left of the bottom-right corner
(
left,
bottom,
), # A little to the right of the bottom-left corner
(
right - font_size - 2 * padding,
bottom,
), # A little to the left of the bottom-right corner
]
text_width = draw.textlength(unique_id, font=font)
text_height = font_size # Assume the text is one line
if viewport_size is not None:
for text_position in text_positions:
new_text_rectangle = [
text_position[0] - padding,
text_position[1] - padding,
text_position[0] + text_width + padding,
text_position[1] + text_height + padding,
]
# Check if the new text rectangle is within the viewport
if (
new_text_rectangle[0] >= 0
and new_text_rectangle[1] >= 0
and new_text_rectangle[2] <= viewport_size["width"]
and new_text_rectangle[3] <= viewport_size["height"]
):
# If the rectangle is within the viewport, check for overlaps
overlaps = False
for existing_rectangle in existing_text_rectangles:
if self.rectangles_overlap(
new_text_rectangle,
existing_rectangle,
padding * 2,
):
overlaps = True
break
if not overlaps:
break
else:
# If the rectangle is outside the viewport, try the next position
continue
else:
# If none of the corners work, move the text rectangle by a fixed amount
text_position = (
text_positions[0][0] + padding,
text_positions[0][1],
)
new_text_rectangle = [
text_position[0] - padding,
text_position[1] - padding,
text_position[0] + text_width + padding,
text_position[1] + text_height + padding,
]
existing_text_rectangles.append(new_text_rectangle)
text_to_draw.append(
(new_text_rectangle, text_position, unique_id, color)
)
content = ""
if row["Element"] == "IMG" and pd.notna(row["Alt"]):
content += row["Alt"]
if pd.notna(row["TextContent"]):
content += (
row["TextContent"]
.strip()
.replace("\n", "")
.replace("\t", "")
)[
:200
] # Limit to 200 characters
text_content_elements.append(
f"[{unique_id}] [{row['Element']}] [{content}]"
)
if content in text_content_text:
# Remove text_content_elements with content
text_content_elements = [
element
for element in text_content_elements
if element.strip() != content
]
text_content_text.add(content)
index += 1
for text_rectangle, text_position, unique_id, color in text_to_draw:
# Draw a background rectangle for the text
draw.rectangle(text_rectangle, fill=color)
draw.text(text_position, unique_id, font=font, fill="white")
content_str = "\n".join(text_content_elements)
return img, id2center, content_str
def rectangles_overlap(self, rect1, rect2, padding):
"""
Check if two rectangles overlap.
Each rectangle is represented as a list [x1, y1, x2, y2].
"""
return not (
rect1[2] < rect2[0] + padding
or rect1[0] > rect2[2] - padding
or rect1[1] > rect2[3] - padding
or rect1[3] < rect2[1] + padding
)
def process(self, page: Page) -> npt.NDArray[np.uint8]:
try:
browser_info = self.fetch_browser_info(page)
except Exception:
page.wait_for_load_state("load", timeout=500)
browser_info = self.fetch_browser_info(page)
self.browser_config = browser_info["config"]
if self.observation_type == "image_som":
# Produce the SoM image, with bounding boxes
try:
screenshot_bytes = page.screenshot()
som_bboxes = self.get_page_bboxes(page)
screenshot_img = Image.open(BytesIO(screenshot_bytes))
bbox_img, id2center, content_str = self.draw_bounding_boxes(
som_bboxes,
screenshot_img,
viewport_size=self.viewport_size,
)
self.som_id_info = id2center
self.meta_data["obs_nodes_info"] = id2center
screenshot_som = np.array(bbox_img)
return screenshot_som, content_str
except:
page.wait_for_event("load")
screenshot_bytes = page.screenshot()
som_bboxes = self.get_page_bboxes(page)
screenshot_img = Image.open(BytesIO(screenshot_bytes))
bbox_img, id2center, content_str = self.draw_bounding_boxes(
som_bboxes,
screenshot_img,
viewport_size=self.viewport_size,
)
self.som_id_info = id2center
self.meta_data["obs_nodes_info"] = id2center
screenshot_som = np.array(bbox_img)
return screenshot_som, content_str
else:
try:
screenshot = png_bytes_to_numpy(page.screenshot())
except:
page.wait_for_event("load")
screenshot = png_bytes_to_numpy(page.screenshot())
return screenshot, ""
def fetch_browser_info(self, page: Page) -> BrowserInfo:
client = page.context.new_cdp_session(page)
# extract domtree
tree = client.send(
"DOMSnapshot.captureSnapshot",
{
"computedStyles": [],
"includeDOMRects": True,
"includePaintOrder": True,
},
)
client.detach()
# calibrate the bounds, in some cases, the bounds are scaled somehow
bounds = tree["documents"][0]["layout"]["bounds"]
b = bounds[0]
n = b[2] / self.viewport_size["width"]
bounds = [[x / n for x in bound] for bound in bounds]
tree["documents"][0]["layout"]["bounds"] = bounds
# add union bound placeholder
tree["documents"][0]["layout"]["unionBounds"] = [None for _ in bounds]
# extract browser info
win_upper_bound = page.evaluate("window.pageYOffset")
win_left_bound = page.evaluate("window.pageXOffset")
win_width = page.evaluate("window.screen.width")
win_height = page.evaluate("window.screen.height")
win_right_bound = win_left_bound + win_width
win_lower_bound = win_upper_bound + win_height
device_pixel_ratio = page.evaluate("window.devicePixelRatio")
assert device_pixel_ratio == 1.0, "devicePixelRatio is not 1.0"
config: BrowserConfig = {
"win_upper_bound": win_upper_bound,
"win_left_bound": win_left_bound,
"win_width": win_width,
"win_height": win_height,
"win_right_bound": win_right_bound,
"win_lower_bound": win_lower_bound,
"device_pixel_ratio": device_pixel_ratio,
}
# assert len(tree['documents']) == 1, "More than one document in the DOM tree"
info: BrowserInfo = {"DOMTree": tree, "config": config}
return info
def get_element_center(self, element_id: str) -> tuple[float, float]:
if not self.observation_type == "image_som":
raise ValueError(
"get_element_center() is only supported for 'image_som' observation type."
)
browser_config = self.browser_config
center_x, center_y, width, height = self.som_id_info[element_id]
return (
center_x / self.viewport_size["width"],
center_y / self.viewport_size["height"],
)
class TextObervationProcessorWebRL(TextObervationProcessor):
def __init__(
self,
observation_type: str,
current_viewport_only: bool,
viewport_size: ViewportSize,
captioning_fn=None,
):
super().__init__(
observation_type,
current_viewport_only,
viewport_size,
captioning_fn,
)
def process(self, page: Page) -> str:
# get the tab info
page_info = get_parsed_html(page)
html = page_info["html"]
from bs4 import BeautifulSoup
soup = BeautifulSoup(html, 'html.parser')
obs_nodes_info = {}
for tag in soup.find_all(True):
if tag.has_attr('id') and tag.has_attr('data-bbox'):
backend_id = tag['id']
union_bound = tag['data-bbox']
union_bound = [float(num) for num in union_bound.split(',')]
obs_nodes_info[str(backend_id)] = {
"backend_id": backend_id,
"union_bound": union_bound,
"text": str(tag)
}
self.obs_nodes_info = obs_nodes_info
self.meta_data["obs_nodes_info"] = obs_nodes_info
return html
def get_element_center(self, element_id: str, page: Page=None) -> tuple[float, float]:
if page is not None:
element = page.query_selector(f"[data-label-id='{element_id}']")
bbox = element.bounding_box()
relative_bbox = (bbox['x'], bbox['y'], bbox['x'] + bbox['width'], bbox['y'] + bbox['height'])
center_x = (relative_bbox[0] + relative_bbox[2]) / 2
center_y = (relative_bbox[1] + relative_bbox[3]) / 2
else:
node_info = self.obs_nodes_info[element_id]
node_bound = node_info["union_bound"]
x, y, width, height = node_bound
center_x = x + width / 2
center_y = y + height / 2
return (
center_x / self.viewport_size["width"],
center_y / self.viewport_size["height"],
)
class ObservationHandler:
"""Main entry point to access all observation processor"""
def __init__(
self,
main_observation_type: str,
text_observation_type: str,
image_observation_type: str,
current_viewport_only: bool,
viewport_size: ViewportSize,
captioning_fn=None,
) -> None:
self.main_observation_type = main_observation_type
if text_observation_type == "webrl":
self.text_processor = TextObervationProcessorWebRL(
text_observation_type,
current_viewport_only,
viewport_size,
captioning_fn,
)
else:
self.text_processor = TextObervationProcessor(
text_observation_type,
current_viewport_only,
viewport_size,
captioning_fn,
)
self.image_processor = ImageObservationProcessor(
image_observation_type, viewport_size
)
self.viewport_size = viewport_size
def get_observation_space(self) -> spaces.Dict:
text_space = spaces.Text(
min_length=0,
max_length=UTTERANCE_MAX_LENGTH,
charset=ASCII_CHARSET + FREQ_UNICODE_CHARSET,
)
image_space = spaces.Box(
# Each position stores the RGB values. Note the swapped axes (height first).
np.zeros(
(self.viewport_size["height"], self.viewport_size["width"], 3),
dtype=np.uint8,
),
np.ones(
(self.viewport_size["height"], self.viewport_size["width"], 3),
dtype=np.uint8,
)
* 255.0,
dtype=np.uint8,
)
return spaces.Dict({"text": text_space, "image": image_space})
def get_observation(self, page: Page) -> dict[str, Observation]:
text_obs = self.text_processor.process(page)
image_obs, content_str = self.image_processor.process(page)
if content_str != "":
text_obs = content_str
return {"text": text_obs, "image": image_obs}
def get_observation_metadata(self) -> dict[str, ObservationMetadata]:
return {
"text": self.text_processor.meta_data,
"image": self.image_processor.meta_data,
}
@property
def action_processor(self) -> ObservationProcessor:
"""Return the main processor that is associated with the action space"""
if self.main_observation_type == "text":
return self.text_processor
elif self.main_observation_type == "image":
return self.image_processor
else:
raise ValueError("Invalid main observation type")