1352 lines
52 KiB
Python
1352 lines
52 KiB
Python
import json
|
|
import pkgutil
|
|
import re
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from io import BytesIO, StringIO
|
|
from typing import Any, Optional, TypedDict, Union
|
|
from urllib.parse import urljoin, urlparse
|
|
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import numpy.typing as npt
|
|
import pandas as pd
|
|
import playwright
|
|
import requests
|
|
from gymnasium import spaces
|
|
from PIL import Image, ImageDraw, ImageFont
|
|
from playwright.sync_api import CDPSession, Page, ViewportSize
|
|
from .html_tools.fetch import get_parsed_html
|
|
|
|
from browser_env.constants import (
|
|
ASCII_CHARSET,
|
|
FREQ_UNICODE_CHARSET,
|
|
IGNORED_ACTREE_PROPERTIES,
|
|
INJECTED_ATTR_NAME,
|
|
UTTERANCE_MAX_LENGTH,
|
|
BID_ATTR,
|
|
DATA_REGEXP,
|
|
IN_VIEWPORT_RATIO_THRESHOLD,
|
|
)
|
|
|
|
from .utils import (
|
|
AccessibilityTree,
|
|
AccessibilityTreeNode,
|
|
BrowserConfig,
|
|
BrowserInfo,
|
|
DOMNode,
|
|
DOMTree,
|
|
Observation,
|
|
png_bytes_to_numpy,
|
|
)
|
|
|
|
|
|
def remove_unicode(input_string):
|
|
# Define a regex pattern to match Unicode characters
|
|
unicode_pattern = re.compile(r"[^\x00-\x7F]+")
|
|
|
|
# Use the pattern to replace Unicode characters with an empty string
|
|
cleaned_string = unicode_pattern.sub("", input_string)
|
|
|
|
return cleaned_string
|
|
|
|
|
|
class ObservationProcessor:
|
|
def process(self, page: Page) -> Observation:
|
|
raise NotImplementedError
|
|
|
|
|
|
class ObservationMetadata(TypedDict):
|
|
obs_nodes_info: dict[str, Any]
|
|
|
|
|
|
def create_empty_metadata() -> ObservationMetadata:
|
|
return {
|
|
"obs_nodes_info": {},
|
|
}
|
|
|
|
|
|
def extract_data_items_from_aria(string: str) -> tuple[list[str], str]:
|
|
"""
|
|
Utility function to extract temporary data stored in the "aria-roledescription" attribute of a node
|
|
"""
|
|
|
|
match = DATA_REGEXP.fullmatch(string)
|
|
if not match:
|
|
return [], string
|
|
|
|
groups = match.groups()
|
|
data_items = groups[:-1]
|
|
original_aria = groups[-1]
|
|
return data_items, original_aria
|
|
|
|
|
|
class TextObervationProcessor(ObservationProcessor):
|
|
def __init__(
|
|
self,
|
|
observation_type: str,
|
|
current_viewport_only: bool,
|
|
viewport_size: ViewportSize,
|
|
captioning_fn=None,
|
|
):
|
|
self.observation_type = observation_type
|
|
self.current_viewport_only = current_viewport_only
|
|
self.viewport_size = viewport_size
|
|
self.observation_tag = "text"
|
|
self.meta_data = (
|
|
create_empty_metadata()
|
|
) # use the store meta data of this observation type
|
|
|
|
if self.observation_type in [
|
|
"accessibility_tree_with_captioner",
|
|
"image_som",
|
|
]:
|
|
self.captioning_fn = captioning_fn
|
|
# Cache captions.
|
|
self.url2caption = {}
|
|
|
|
def fetch_browser_info(
|
|
self,
|
|
page: Page,
|
|
) -> BrowserInfo:
|
|
# extract domtree
|
|
client = page.context.new_cdp_session(page)
|
|
tree = client.send(
|
|
"DOMSnapshot.captureSnapshot",
|
|
{
|
|
"computedStyles": [],
|
|
"includeDOMRects": True,
|
|
"includePaintOrder": True,
|
|
},
|
|
)
|
|
client.detach()
|
|
|
|
# calibrate the bounds, in some cases, the bounds are scaled somehow
|
|
bounds = tree["documents"][0]["layout"]["bounds"]
|
|
b = bounds[0]
|
|
n = b[2] / self.viewport_size["width"]
|
|
bounds = [[x / n for x in bound] for bound in bounds]
|
|
tree["documents"][0]["layout"]["bounds"] = bounds
|
|
# add union bound placeholder
|
|
tree["documents"][0]["layout"]["unionBounds"] = [None for _ in bounds]
|
|
|
|
# extract browser info
|
|
win_upper_bound = page.evaluate("window.pageYOffset")
|
|
win_left_bound = page.evaluate("window.pageXOffset")
|
|
win_width = page.evaluate("window.screen.width")
|
|
win_height = page.evaluate("window.screen.height")
|
|
win_right_bound = win_left_bound + win_width
|
|
win_lower_bound = win_upper_bound + win_height
|
|
device_pixel_ratio = page.evaluate("window.devicePixelRatio")
|
|
assert device_pixel_ratio == 1.0, "devicePixelRatio is not 1.0"
|
|
|
|
config: BrowserConfig = {
|
|
"win_upper_bound": win_upper_bound,
|
|
"win_left_bound": win_left_bound,
|
|
"win_width": win_width,
|
|
"win_height": win_height,
|
|
"win_right_bound": win_right_bound,
|
|
"win_lower_bound": win_lower_bound,
|
|
"device_pixel_ratio": device_pixel_ratio,
|
|
}
|
|
|
|
# assert len(tree['documents']) == 1, "More than one document in the DOM tree"
|
|
info: BrowserInfo = {"DOMTree": tree, "config": config}
|
|
|
|
return info
|
|
|
|
@staticmethod
|
|
def get_bounding_client_rect(
|
|
client: CDPSession, backend_node_id: str
|
|
) -> dict[str, Any]:
|
|
try:
|
|
remote_object = client.send(
|
|
"DOM.resolveNode", {"backendNodeId": int(backend_node_id)}
|
|
)
|
|
remote_object_id = remote_object["object"]["objectId"]
|
|
response = client.send(
|
|
"Runtime.callFunctionOn",
|
|
{
|
|
"objectId": remote_object_id,
|
|
"functionDeclaration": """
|
|
function() {
|
|
if (this.nodeType == 3) {
|
|
var range = document.createRange();
|
|
range.selectNode(this);
|
|
var rect = range.getBoundingClientRect().toJSON();
|
|
range.detach();
|
|
return rect;
|
|
} else {
|
|
return this.getBoundingClientRect().toJSON();
|
|
}
|
|
}
|
|
""",
|
|
"returnByValue": True,
|
|
},
|
|
)
|
|
return response
|
|
except Exception as e:
|
|
return {"result": {"subtype": "error"}}
|
|
|
|
@staticmethod
|
|
def get_element_in_viewport_ratio(
|
|
elem_left_bound: float,
|
|
elem_top_bound: float,
|
|
width: float,
|
|
height: float,
|
|
config: BrowserConfig,
|
|
) -> float:
|
|
elem_right_bound = elem_left_bound + width
|
|
elem_lower_bound = elem_top_bound + height
|
|
|
|
win_left_bound = 0
|
|
win_right_bound = config["win_width"]
|
|
win_top_bound = 0
|
|
win_lower_bound = config["win_height"]
|
|
|
|
# Compute the overlap in x and y axes
|
|
overlap_width = max(
|
|
0,
|
|
min(elem_right_bound, win_right_bound)
|
|
- max(elem_left_bound, win_left_bound),
|
|
)
|
|
overlap_height = max(
|
|
0,
|
|
min(elem_lower_bound, win_lower_bound) - max(elem_top_bound, win_top_bound),
|
|
)
|
|
|
|
# Compute the overlap area
|
|
ratio = overlap_width * overlap_height / width * height
|
|
return ratio
|
|
|
|
def fetch_page_html(
|
|
self,
|
|
info: BrowserInfo,
|
|
page: Page,
|
|
current_viewport_only: bool,
|
|
) -> DOMTree:
|
|
# adopted from [natbot](https://github.com/nat/natbot)
|
|
tree = info["DOMTree"]
|
|
strings = tree["strings"]
|
|
document = tree["documents"][0]
|
|
nodes = document["nodes"]
|
|
|
|
# make a dom tree that is easier to navigate
|
|
dom_tree: DOMTree = []
|
|
graph = defaultdict(list)
|
|
client = page.context.new_cdp_session(page)
|
|
for node_idx in range(len(nodes["nodeName"])):
|
|
cur_node: DOMNode = {
|
|
"nodeId": "",
|
|
"nodeType": "",
|
|
"nodeName": "",
|
|
"nodeValue": "",
|
|
"attributes": "",
|
|
"backendNodeId": "",
|
|
"parentId": "",
|
|
"childIds": [],
|
|
"cursor": 0,
|
|
"union_bound": None,
|
|
}
|
|
|
|
node_type_idx = nodes["nodeType"][node_idx]
|
|
node_type = "generic"
|
|
if node_type_idx >= 0 and node_type_idx < len(strings):
|
|
node_type = strings[node_type_idx]
|
|
|
|
node_name = strings[nodes["nodeName"][node_idx]]
|
|
|
|
node_value_idx = nodes["nodeValue"][node_idx]
|
|
node_value = ""
|
|
if node_value_idx >= 0 and node_value_idx < len(strings):
|
|
node_value = " ".join(strings[node_value_idx].split())
|
|
|
|
node_attributes = [strings[i] for i in nodes["attributes"][node_idx]]
|
|
node_attributes_str = ""
|
|
for i in range(0, len(node_attributes), 2):
|
|
a = node_attributes[i]
|
|
b = node_attributes[i + 1]
|
|
b = " ".join(b.split())
|
|
node_attributes_str += f'{a}="{b}" '
|
|
node_attributes_str = node_attributes_str.strip()
|
|
|
|
cur_node["nodeId"] = str(node_idx)
|
|
cur_node["nodeType"] = node_type
|
|
cur_node["nodeName"] = node_name
|
|
cur_node["nodeValue"] = node_value
|
|
cur_node["attributes"] = node_attributes_str
|
|
cur_node["backendNodeId"] = str(nodes["backendNodeId"][node_idx])
|
|
cur_node["parentId"] = str(nodes["parentIndex"][node_idx])
|
|
|
|
if cur_node["parentId"] != "-1":
|
|
graph[cur_node["parentId"]].append(str(cur_node["nodeId"]))
|
|
|
|
# get the bound
|
|
if cur_node["parentId"] == "-1":
|
|
cur_node["union_bound"] = [0.0, 0.0, 10.0, 10.0]
|
|
else:
|
|
response = self.get_bounding_client_rect(
|
|
client, cur_node["backendNodeId"]
|
|
)
|
|
if response.get("result", {}).get("subtype", "") == "error":
|
|
cur_node["union_bound"] = None
|
|
else:
|
|
x = response["result"]["value"]["x"]
|
|
y = response["result"]["value"]["y"]
|
|
width = response["result"]["value"]["width"]
|
|
height = response["result"]["value"]["height"]
|
|
cur_node["union_bound"] = [x, y, width, height]
|
|
|
|
dom_tree.append(cur_node)
|
|
|
|
client.detach()
|
|
# add parent children index to the node
|
|
for parent_id, child_ids in graph.items():
|
|
dom_tree[int(parent_id)]["childIds"] = child_ids
|
|
|
|
# remove the nodes that are not in the current viewport
|
|
if current_viewport_only:
|
|
|
|
def remove_node_in_graph(node: DOMNode) -> None:
|
|
# update the node information in the accessibility tree
|
|
node_id = node["nodeId"]
|
|
parent_id = node["parentId"]
|
|
child_ids = node["childIds"]
|
|
|
|
# update the children of the parent node
|
|
assert dom_tree[int(parent_id)]["parentId"] != "[REMOVED]"
|
|
# remove the nodeid from parent
|
|
index = dom_tree[int(parent_id)]["childIds"].index(node_id)
|
|
dom_tree[int(parent_id)]["childIds"].pop(index)
|
|
|
|
# Insert children_nodeids in the same location
|
|
for child_id in child_ids:
|
|
dom_tree[int(parent_id)]["childIds"].insert(index, child_id)
|
|
index += 1
|
|
|
|
# update children node's parent
|
|
for child_id in child_ids:
|
|
dom_tree[int(child_id)]["parentId"] = parent_id
|
|
# mark as removed
|
|
dom_tree[int(node_id)]["parentId"] = "[REMOVED]"
|
|
|
|
config = info["config"]
|
|
for cursor, node in enumerate(dom_tree):
|
|
if not node["union_bound"]:
|
|
remove_node_in_graph(node)
|
|
continue
|
|
|
|
[x, y, width, height] = node["union_bound"]
|
|
|
|
# invisible node
|
|
if width == 0.0 or height == 0.0:
|
|
remove_node_in_graph(node)
|
|
continue
|
|
|
|
in_viewport_ratio = self.get_element_in_viewport_ratio(
|
|
elem_left_bound=float(x),
|
|
elem_top_bound=float(y),
|
|
width=float(width),
|
|
height=float(height),
|
|
config=config,
|
|
)
|
|
|
|
if in_viewport_ratio < IN_VIEWPORT_RATIO_THRESHOLD:
|
|
remove_node_in_graph(node)
|
|
|
|
dom_tree = [
|
|
node for node in dom_tree if node.get("parentId", "-1") != "[REMOVED]"
|
|
]
|
|
|
|
return dom_tree
|
|
|
|
@staticmethod
|
|
def parse_html(dom_tree: DOMTree) -> tuple[str, dict[str, Any]]:
|
|
"""Parse the html tree into a string text"""
|
|
|
|
obs_nodes_info = {}
|
|
nodeid_to_cursor = {node["nodeId"]: idx for idx, node in enumerate(dom_tree)}
|
|
|
|
def dfs(node_cursor: int, depth: int) -> str:
|
|
tree_str = ""
|
|
node = dom_tree[node_cursor]
|
|
indent = "\t" * depth
|
|
valid_node = True
|
|
try:
|
|
node_str = f"[{node_cursor}] <{node['nodeName']}"
|
|
if node["attributes"]:
|
|
node_str += f" {node['attributes']}"
|
|
node_str += f"> {node['nodeValue']}"
|
|
valid_node = bool(node["attributes"] or node["nodeValue"])
|
|
|
|
if valid_node:
|
|
obs_nodes_info[str(node_cursor)] = {
|
|
"backend_id": node["backendNodeId"],
|
|
"union_bound": node["union_bound"],
|
|
"text": node_str,
|
|
}
|
|
tree_str += f"{indent}{node_str}\n"
|
|
|
|
except Exception as e:
|
|
valid_node = False
|
|
|
|
for child_ids in node["childIds"]:
|
|
child_cursor = nodeid_to_cursor[child_ids]
|
|
child_depth = depth + 1 if valid_node else depth
|
|
child_str = dfs(child_cursor, child_depth)
|
|
tree_str += child_str
|
|
|
|
return tree_str
|
|
|
|
html = dfs(0, 0)
|
|
return html, obs_nodes_info
|
|
|
|
def fetch_page_accessibility_tree(
|
|
self,
|
|
page: Page,
|
|
info: BrowserInfo,
|
|
current_viewport_only: bool,
|
|
) -> AccessibilityTree:
|
|
client = page.context.new_cdp_session(page)
|
|
accessibility_tree: AccessibilityTree = client.send(
|
|
"Accessibility.getFullAXTree", {}
|
|
)["nodes"]
|
|
|
|
# a few nodes are repeated in the accessibility tree
|
|
seen_ids = set()
|
|
_accessibility_tree = []
|
|
for node in accessibility_tree:
|
|
if node["nodeId"] not in seen_ids:
|
|
_accessibility_tree.append(node)
|
|
seen_ids.add(node["nodeId"])
|
|
accessibility_tree = _accessibility_tree
|
|
|
|
nodeid_to_cursor = {}
|
|
for cursor, node in enumerate(accessibility_tree):
|
|
nodeid_to_cursor[node["nodeId"]] = cursor
|
|
# usually because the node is not visible etc
|
|
if "backendDOMNodeId" not in node:
|
|
node["union_bound"] = None
|
|
continue
|
|
backend_node_id = str(node["backendDOMNodeId"])
|
|
if node["role"]["value"] == "RootWebArea":
|
|
# always inside the viewport
|
|
node["union_bound"] = [0.0, 0.0, 10.0, 10.0]
|
|
else:
|
|
response = self.get_bounding_client_rect(
|
|
client,
|
|
backend_node_id
|
|
)
|
|
if response.get("result", {}).get("subtype", "") == "error":
|
|
node["union_bound"] = None
|
|
else:
|
|
x = response["result"]["value"]["x"]
|
|
y = response["result"]["value"]["y"]
|
|
width = response["result"]["value"]["width"]
|
|
height = response["result"]["value"]["height"]
|
|
node["union_bound"] = [x, y, width, height]
|
|
|
|
client.detach()
|
|
# filter nodes that are not in the current viewport
|
|
if current_viewport_only:
|
|
|
|
def remove_node_in_graph(node: AccessibilityTreeNode) -> None:
|
|
# update the node information in the accessibility tree
|
|
nodeid = node["nodeId"]
|
|
node_cursor = nodeid_to_cursor[nodeid]
|
|
parent_nodeid = node["parentId"]
|
|
children_nodeids = node["childIds"]
|
|
parent_cursor = nodeid_to_cursor[parent_nodeid]
|
|
# update the children of the parent node
|
|
assert (
|
|
accessibility_tree[parent_cursor].get("parentId", "Root")
|
|
is not None
|
|
)
|
|
# remove the nodeid from parent's childIds
|
|
index = accessibility_tree[parent_cursor]["childIds"].index(nodeid)
|
|
accessibility_tree[parent_cursor]["childIds"].pop(index)
|
|
# Insert children_nodeids in the same location
|
|
for child_nodeid in children_nodeids:
|
|
accessibility_tree[parent_cursor]["childIds"].insert(
|
|
index, child_nodeid
|
|
)
|
|
index += 1
|
|
# update children node's parent
|
|
for child_nodeid in children_nodeids:
|
|
child_cursor = nodeid_to_cursor[child_nodeid]
|
|
accessibility_tree[child_cursor]["parentId"] = parent_nodeid
|
|
# mark as removed
|
|
accessibility_tree[node_cursor]["parentId"] = "[REMOVED]"
|
|
|
|
config = info["config"]
|
|
for node in accessibility_tree:
|
|
if not node["union_bound"]:
|
|
remove_node_in_graph(node)
|
|
continue
|
|
|
|
[x, y, width, height] = node["union_bound"]
|
|
|
|
# invisible node
|
|
if width == 0 or height == 0:
|
|
remove_node_in_graph(node)
|
|
continue
|
|
|
|
in_viewport_ratio = self.get_element_in_viewport_ratio(
|
|
elem_left_bound=float(x),
|
|
elem_top_bound=float(y),
|
|
width=float(width),
|
|
height=float(height),
|
|
config=config,
|
|
)
|
|
|
|
if in_viewport_ratio < IN_VIEWPORT_RATIO_THRESHOLD:
|
|
remove_node_in_graph(node)
|
|
|
|
accessibility_tree = [
|
|
node
|
|
for node in accessibility_tree
|
|
if node.get("parentId", "Root") != "[REMOVED]"
|
|
]
|
|
|
|
return accessibility_tree
|
|
|
|
@staticmethod
|
|
def parse_accessibility_tree(
|
|
accessibility_tree: AccessibilityTree,
|
|
) -> tuple[str, dict[str, Any]]:
|
|
"""Parse the accessibility tree into a string text"""
|
|
node_id_to_idx = {}
|
|
for idx, node in enumerate(accessibility_tree):
|
|
node_id_to_idx[node["nodeId"]] = idx
|
|
|
|
obs_nodes_info = {}
|
|
|
|
def dfs(idx: int, obs_node_id: str, depth: int) -> str:
|
|
tree_str = ""
|
|
node = accessibility_tree[idx]
|
|
indent = "\t" * depth
|
|
valid_node = True
|
|
try:
|
|
role = node["role"]["value"]
|
|
name = node["name"]["value"]
|
|
node_str = f"[{obs_node_id}] {role} {repr(name)}"
|
|
properties = []
|
|
for property in node.get("properties", []):
|
|
try:
|
|
if property["name"] in IGNORED_ACTREE_PROPERTIES:
|
|
continue
|
|
properties.append(
|
|
f'{property["name"]}: {property["value"]["value"]}'
|
|
)
|
|
except KeyError:
|
|
pass
|
|
|
|
if properties:
|
|
node_str += " " + " ".join(properties)
|
|
|
|
# check valid
|
|
if not node_str.strip():
|
|
valid_node = False
|
|
|
|
# empty generic node
|
|
if not name.strip():
|
|
if not properties:
|
|
if role in [
|
|
"generic",
|
|
"img",
|
|
"list",
|
|
"strong",
|
|
"paragraph",
|
|
"banner",
|
|
"navigation",
|
|
"Section",
|
|
"LabelText",
|
|
"Legend",
|
|
"listitem",
|
|
]:
|
|
valid_node = False
|
|
elif role in ["listitem"]:
|
|
valid_node = False
|
|
|
|
if valid_node:
|
|
tree_str += f"{indent}{node_str}"
|
|
obs_nodes_info[obs_node_id] = {
|
|
"backend_id": node["backendDOMNodeId"],
|
|
"union_bound": node["union_bound"],
|
|
"text": node_str,
|
|
}
|
|
|
|
except Exception as e:
|
|
valid_node = False
|
|
|
|
for _, child_node_id in enumerate(node["childIds"]):
|
|
if child_node_id not in node_id_to_idx:
|
|
continue
|
|
# mark this to save some tokens
|
|
child_depth = depth + 1 if valid_node else depth
|
|
child_str = dfs(
|
|
node_id_to_idx[child_node_id], child_node_id, child_depth
|
|
)
|
|
if child_str.strip():
|
|
if tree_str.strip():
|
|
tree_str += "\n"
|
|
tree_str += child_str
|
|
|
|
return tree_str
|
|
|
|
tree_str = dfs(0, accessibility_tree[0]["nodeId"], 0)
|
|
return tree_str, obs_nodes_info
|
|
|
|
@staticmethod
|
|
def clean_accesibility_tree(tree_str: str) -> str:
|
|
"""further clean accesibility tree"""
|
|
clean_lines: list[str] = []
|
|
for line in tree_str.split("\n"):
|
|
# remove statictext if the content already appears in the previous line
|
|
if "statictext" in line.lower():
|
|
prev_lines = clean_lines[-3:]
|
|
pattern = r"\[\d+\] StaticText (.+)"
|
|
|
|
match = re.search(pattern, line, re.DOTALL)
|
|
if match:
|
|
static_text = match.group(1)[1:-1] # remove the quotes
|
|
if static_text and all(
|
|
static_text not in prev_line for prev_line in prev_lines
|
|
):
|
|
clean_lines.append(line)
|
|
else:
|
|
clean_lines.append(line)
|
|
|
|
return "\n".join(clean_lines)
|
|
|
|
def fetch_image_related(self, page: Page, browser_info: BrowserInfo) -> str:
|
|
# Check if the current page is an image url
|
|
if page.url.endswith((".jpg", ".jpeg", ".png")):
|
|
print("NOTE: We are on an image page!!!")
|
|
# Load image from current url and run captioning on it.
|
|
if page.url not in self.url2caption and self.captioning_fn is not None:
|
|
try:
|
|
image = Image.open(requests.get(page.url, stream=True).raw)
|
|
caption = self.captioning_fn([image])[0].strip()
|
|
self.url2caption[page.url] = remove_unicode(caption)
|
|
except Exception as e:
|
|
print("L579 WARNING: ", e)
|
|
content = self.url2caption.get(page.url, "Image")
|
|
|
|
else:
|
|
if self.captioning_fn is not None:
|
|
images = page.query_selector_all("img")
|
|
image_urls = []
|
|
for image in images:
|
|
try:
|
|
image_url = image.get_attribute("src")
|
|
if not image_url.startswith(("http://", "https://", "www.")):
|
|
image_url = urljoin(page.url, image_url)
|
|
if image_url not in self.url2caption:
|
|
image_urls.append(image_url)
|
|
except Exception as e:
|
|
print("L604 WARNING: ", e)
|
|
|
|
# Run image captioning on image_url pixels. This is for models which use captioning as a baseline.
|
|
if len(image_urls) > 0:
|
|
image_pixels = []
|
|
valid_urls = []
|
|
for url in image_urls:
|
|
if "data:image/svg" in url:
|
|
continue
|
|
else:
|
|
try:
|
|
image = Image.open(requests.get(url, stream=True).raw)
|
|
image_pixels.append(image)
|
|
valid_urls.append(url)
|
|
except Exception as e:
|
|
print("L616 WARNING: ", e)
|
|
|
|
# Caption images.
|
|
if image_pixels:
|
|
# Run in batches of 4.
|
|
bs = 4
|
|
captions = []
|
|
for i in range(0, len(image_pixels), bs):
|
|
try:
|
|
captions.extend(
|
|
self.captioning_fn(image_pixels[i : i + bs])
|
|
)
|
|
except Exception as e:
|
|
print("L628 WARNING: ", e)
|
|
captions.extend([""] * len(image_pixels[i : i + bs]))
|
|
assert len(valid_urls) == len(
|
|
captions
|
|
), f"len(images)={len(valid_urls)}, len(captions)={len(captions)}"
|
|
for image_url, caption in zip(valid_urls, captions):
|
|
self.url2caption[image_url] = remove_unicode(
|
|
caption.strip()
|
|
)
|
|
|
|
image_idx = 0
|
|
for image in images:
|
|
try:
|
|
original_alt = image.get_attribute("alt") or ""
|
|
image_url = image.get_attribute("src")
|
|
if not image_url.startswith(("http://", "https://", "www.")):
|
|
image_url = urljoin(page.url, image_url)
|
|
|
|
updated_alt = original_alt
|
|
|
|
if image_url in self.url2caption:
|
|
if self.url2caption[image_url] not in updated_alt:
|
|
updated_alt = f"{updated_alt}, description: {self.url2caption[image_url]}"
|
|
elif "data:image/svg" not in image_url:
|
|
print(f"WARNING: {image_url} not in self.url2caption")
|
|
|
|
if "url:" not in updated_alt:
|
|
updated_alt = f"{updated_alt}, url: {image_url}"
|
|
|
|
safe_updated_alt = json.dumps(updated_alt)
|
|
image.evaluate(f"node => node.alt = {safe_updated_alt}")
|
|
except Exception as e:
|
|
print("L653 WARNING:", e)
|
|
|
|
if self.observation_type == "accessibility_tree_with_captioner":
|
|
frame_ax_trees = self.fetch_page_accessibility_tree(
|
|
page,
|
|
browser_info,
|
|
current_viewport_only=self.current_viewport_only
|
|
)
|
|
content, obs_nodes_info = self.parse_accessibility_tree(frame_ax_trees)
|
|
content = self.clean_accesibility_tree(content)
|
|
self.obs_nodes_info = obs_nodes_info
|
|
self.meta_data["obs_nodes_info"] = obs_nodes_info
|
|
else:
|
|
content = "" # Not used for SoM
|
|
|
|
return content
|
|
|
|
def process(self, page: Page) -> str:
|
|
# get the tab info
|
|
open_tabs = page.context.pages
|
|
try:
|
|
tab_titles = [tab.title() for tab in open_tabs]
|
|
current_tab_idx = open_tabs.index(page)
|
|
for idx in range(len(open_tabs)):
|
|
if idx == current_tab_idx:
|
|
tab_titles[idx] = f"Tab {idx} (current): {open_tabs[idx].title()}"
|
|
else:
|
|
tab_titles[idx] = f"Tab {idx}: {open_tabs[idx].title()}"
|
|
tab_title_str = " | ".join(tab_titles)
|
|
except Exception:
|
|
tab_title_str = " | ".join([f"Tab {idx}" for idx in range(len(open_tabs))])
|
|
|
|
try:
|
|
browser_info = self.fetch_browser_info(page)
|
|
except Exception:
|
|
page.wait_for_load_state("load", timeout=500)
|
|
browser_info = self.fetch_browser_info(page)
|
|
|
|
if self.observation_type == "html":
|
|
dom_tree = self.fetch_page_html(
|
|
browser_info,
|
|
page,
|
|
self.current_viewport_only,
|
|
)
|
|
content, obs_nodes_info = self.parse_html(dom_tree)
|
|
self.obs_nodes_info = obs_nodes_info
|
|
self.meta_data["obs_nodes_info"] = obs_nodes_info
|
|
|
|
elif self.observation_type == "accessibility_tree":
|
|
accessibility_tree = self.fetch_page_accessibility_tree(
|
|
page,
|
|
browser_info,
|
|
self.current_viewport_only,
|
|
)
|
|
content, obs_nodes_info = self.parse_accessibility_tree(accessibility_tree)
|
|
content = self.clean_accesibility_tree(content)
|
|
self.obs_nodes_info = obs_nodes_info
|
|
self.meta_data["obs_nodes_info"] = obs_nodes_info
|
|
|
|
elif self.observation_type in [
|
|
"accessibility_tree_with_captioner",
|
|
"image_som",
|
|
]:
|
|
content = self.fetch_image_related(
|
|
page,
|
|
browser_info,
|
|
)
|
|
|
|
elif self.observation_type == "":
|
|
content = ""
|
|
|
|
else:
|
|
raise ValueError(f"Invalid observation type: {self.observation_type}")
|
|
|
|
self.browser_config = browser_info["config"]
|
|
content = f"{tab_title_str}\n\n{content}"
|
|
|
|
return content
|
|
|
|
def get_element_center(self, element_id: str) -> tuple[float, float]:
|
|
node_info = self.obs_nodes_info[element_id]
|
|
node_bound = node_info["union_bound"]
|
|
x, y, width, height = node_bound
|
|
center_x = x + width / 2
|
|
center_y = y + height / 2
|
|
return (
|
|
center_x / self.viewport_size["width"],
|
|
center_y / self.viewport_size["height"],
|
|
)
|
|
|
|
|
|
class ImageObservationProcessor(ObservationProcessor):
|
|
def __init__(
|
|
self,
|
|
observation_type: str,
|
|
viewport_size: Optional[ViewportSize] = None,
|
|
):
|
|
self.observation_type = observation_type
|
|
self.observation_tag = "image"
|
|
self.viewport_size = viewport_size
|
|
self.meta_data = create_empty_metadata()
|
|
|
|
def get_page_bboxes(self, page: Page) -> list[list[float]]:
|
|
"""JavaScript code to return bounding boxes and other metadata from HTML elements."""
|
|
js_script = """
|
|
(() => {
|
|
const interactableSelectors = [
|
|
'a[href]:not(:has(img))', 'a[href] img', 'button', 'input:not([type="hidden"])', 'textarea', 'select',
|
|
'[tabindex]:not([tabindex="-1"])', '[contenteditable="true"]', '[role="button"]', '[role="link"]',
|
|
'[role="checkbox"]', '[role="menuitem"]', '[role="tab"]', '[draggable="true"]',
|
|
'.btn', 'a[href="/notifications"]', 'a[href="/submit"]', '.fa.fa-star.is-rating-item', 'input[type="checkbox"]'
|
|
|
|
];
|
|
|
|
const textSelectors = ['p', 'span', 'div:not(:has(*))', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'li', 'article'];
|
|
const modifiedTextSelectors = textSelectors.map(selector =>
|
|
`:not(${interactableSelectors.join(', ')}):not(style) > ${selector}`
|
|
);
|
|
|
|
const combinedSelectors = [...interactableSelectors, ...modifiedTextSelectors];
|
|
const elements = document.querySelectorAll(combinedSelectors.join(', '));
|
|
|
|
const pixelRatio = window.devicePixelRatio;
|
|
let csvContent = "ID,Element,Top,Right,Bottom,Left,Width,Height,Alt,Class,Id,TextContent,Interactable\\n";
|
|
let counter = 1;
|
|
|
|
elements.forEach(element => {
|
|
const rect = element.getBoundingClientRect();
|
|
if (rect.width === 0 || rect.height === 0) return;
|
|
let altText = element.getAttribute('alt') || '';
|
|
altText = altText.replace(/"/g, ''); // Escape double quotes in alt text
|
|
const classList = element.className || '';
|
|
const id = element.id || '';
|
|
let textContent = element.textContent || '';
|
|
textContent = textContent.replace(/"/g, ''); // Escape double quotes in textContent
|
|
|
|
// Determine if the element is interactable
|
|
const isInteractable = interactableSelectors.some(selector => element.matches(selector));
|
|
|
|
const dataString = [
|
|
counter, element.tagName, (rect.top + window.scrollY) * pixelRatio,
|
|
(rect.right + window.scrollX) * pixelRatio, (rect.bottom + window.scrollY) * pixelRatio,
|
|
(rect.left + window.scrollX) * pixelRatio, rect.width * pixelRatio, rect.height * pixelRatio,
|
|
altText, classList, id, textContent, isInteractable
|
|
].map(value => `"${value}"`).join(",");
|
|
|
|
csvContent += dataString + "\\n";
|
|
counter++;
|
|
});
|
|
|
|
return csvContent;
|
|
})();
|
|
"""
|
|
# Save the bbox as a CSV
|
|
csv_content = page.evaluate(js_script)
|
|
return csv_content
|
|
|
|
def draw_bounding_boxes(
|
|
self,
|
|
data_string,
|
|
screenshot_img,
|
|
viewport_size=None,
|
|
add_ids=True,
|
|
bbox_color=None,
|
|
min_width=8,
|
|
min_height=8,
|
|
bbox_padding=0,
|
|
bbox_border=2,
|
|
plot_ids=None,
|
|
):
|
|
"""
|
|
min_width and min_height: Minimum dimensions of the bounding box to be plotted.
|
|
"""
|
|
# Read CSV data
|
|
df = pd.read_csv(StringIO(data_string), delimiter=",", quotechar='"')
|
|
df["Area"] = df["Width"] * df["Height"]
|
|
# Remove bounding boxes that are clipped.
|
|
b_x, b_y = (
|
|
self.browser_config["win_left_bound"],
|
|
self.browser_config["win_upper_bound"],
|
|
)
|
|
if viewport_size is not None:
|
|
df = df[
|
|
(df["Bottom"] - b_y >= 0)
|
|
& (df["Top"] - b_y <= viewport_size["height"])
|
|
& (df["Right"] - b_x >= 0)
|
|
& (df["Left"] - b_x <= viewport_size["width"])
|
|
]
|
|
viewport_area = viewport_size["width"] * viewport_size["height"]
|
|
# Filter out bounding boxes that too large (more than 80% of the viewport)
|
|
df = df[df["Area"] <= 0.8 * viewport_area]
|
|
|
|
# Open the screenshot image
|
|
img = screenshot_img.copy()
|
|
draw = ImageDraw.Draw(img)
|
|
|
|
# Load a TTF font with a larger size
|
|
font_path = "media/SourceCodePro-SemiBold.ttf"
|
|
font_size, padding = 16, 2
|
|
font = ImageFont.truetype(font_path, font_size)
|
|
|
|
# Create a color cycle using one of the categorical color palettes in matplotlib
|
|
color_cycle = plt.rcParams["axes.prop_cycle"].by_key()["color"]
|
|
bbox_id2visid = {}
|
|
bbox_id2desc = {}
|
|
index = 0
|
|
id2center = {}
|
|
existing_text_rectangles = []
|
|
text_to_draw = []
|
|
# Provide [id] textContent inputs to the model as text.
|
|
text_content_elements = []
|
|
text_content_text = set() # Store text of interactable elements
|
|
|
|
# Iterate through each row in the CSV and draw bounding boxes
|
|
for _, row in df.iterrows():
|
|
if not row["Interactable"]:
|
|
content = ""
|
|
# Add image alt-text to the text representation.
|
|
if row["Element"] == "IMG" and pd.notna(row["Alt"]):
|
|
content += row["Alt"]
|
|
# Add HTML textContent (if any) to the text representation.
|
|
if pd.notna(row["TextContent"]):
|
|
content += (
|
|
row["TextContent"].strip().replace("\n", "").replace("\t", "")
|
|
)[
|
|
:200
|
|
] # Limit to 200 characters to avoid having too much text
|
|
|
|
# Check if the text is a CSS selector
|
|
if content and not (content.startswith(".") and "{" in content):
|
|
# Add elements which are not interactable as StaticText
|
|
if content not in text_content_text:
|
|
text_content_elements.append(f"[] [StaticText] [{content}]")
|
|
text_content_text.add(content)
|
|
continue
|
|
|
|
if (plot_ids is not None) and (row["ID"] not in plot_ids):
|
|
continue
|
|
|
|
unique_id = str(index + 1)
|
|
bbox_id2visid[row["ID"]] = (
|
|
unique_id # map the bounding box ID to the unique character ID
|
|
)
|
|
top, right, bottom, left, width, height = (
|
|
row["Top"],
|
|
row["Right"],
|
|
row["Bottom"],
|
|
row["Left"],
|
|
row["Width"],
|
|
row["Height"],
|
|
)
|
|
left, right, top, bottom = left - b_x, right - b_x, top - b_y, bottom - b_y
|
|
id2center[unique_id] = (
|
|
(left + right) / 2,
|
|
(bottom + top) / 2,
|
|
width,
|
|
height,
|
|
)
|
|
|
|
if width >= min_width and height >= min_height:
|
|
# Get the next color in the cycle
|
|
color = bbox_color or color_cycle[index % len(color_cycle)]
|
|
draw.rectangle(
|
|
[
|
|
left - bbox_padding,
|
|
top - bbox_padding,
|
|
right + bbox_padding,
|
|
bottom + bbox_padding,
|
|
],
|
|
outline=color,
|
|
width=bbox_border,
|
|
)
|
|
bbox_id2desc[row["ID"]] = color
|
|
|
|
# Draw the text on top of the rectangle
|
|
if add_ids:
|
|
# Calculate list of possible text positions
|
|
text_positions = [
|
|
(left - font_size, top - font_size), # Top-left corner
|
|
(
|
|
left,
|
|
top - font_size,
|
|
), # A little to the right of the top-left corner
|
|
(right, top - font_size), # Top-right corner
|
|
(
|
|
right - font_size - 2 * padding,
|
|
top - font_size,
|
|
), # A little to the left of the top-right corner
|
|
(left - font_size, bottom), # Bottom-left corner
|
|
(
|
|
left,
|
|
bottom,
|
|
), # A little to the right of the bottom-left corner
|
|
(
|
|
right - font_size - 2 * padding,
|
|
bottom,
|
|
), # A little to the left of the bottom-right corner
|
|
(
|
|
left,
|
|
bottom,
|
|
), # A little to the right of the bottom-left corner
|
|
(
|
|
right - font_size - 2 * padding,
|
|
bottom,
|
|
), # A little to the left of the bottom-right corner
|
|
]
|
|
text_width = draw.textlength(unique_id, font=font)
|
|
text_height = font_size # Assume the text is one line
|
|
|
|
if viewport_size is not None:
|
|
for text_position in text_positions:
|
|
new_text_rectangle = [
|
|
text_position[0] - padding,
|
|
text_position[1] - padding,
|
|
text_position[0] + text_width + padding,
|
|
text_position[1] + text_height + padding,
|
|
]
|
|
|
|
# Check if the new text rectangle is within the viewport
|
|
if (
|
|
new_text_rectangle[0] >= 0
|
|
and new_text_rectangle[1] >= 0
|
|
and new_text_rectangle[2] <= viewport_size["width"]
|
|
and new_text_rectangle[3] <= viewport_size["height"]
|
|
):
|
|
# If the rectangle is within the viewport, check for overlaps
|
|
overlaps = False
|
|
for existing_rectangle in existing_text_rectangles:
|
|
if self.rectangles_overlap(
|
|
new_text_rectangle,
|
|
existing_rectangle,
|
|
padding * 2,
|
|
):
|
|
overlaps = True
|
|
break
|
|
|
|
if not overlaps:
|
|
break
|
|
else:
|
|
# If the rectangle is outside the viewport, try the next position
|
|
continue
|
|
else:
|
|
# If none of the corners work, move the text rectangle by a fixed amount
|
|
text_position = (
|
|
text_positions[0][0] + padding,
|
|
text_positions[0][1],
|
|
)
|
|
new_text_rectangle = [
|
|
text_position[0] - padding,
|
|
text_position[1] - padding,
|
|
text_position[0] + text_width + padding,
|
|
text_position[1] + text_height + padding,
|
|
]
|
|
|
|
existing_text_rectangles.append(new_text_rectangle)
|
|
text_to_draw.append(
|
|
(new_text_rectangle, text_position, unique_id, color)
|
|
)
|
|
|
|
content = ""
|
|
if row["Element"] == "IMG" and pd.notna(row["Alt"]):
|
|
content += row["Alt"]
|
|
if pd.notna(row["TextContent"]):
|
|
content += (
|
|
row["TextContent"]
|
|
.strip()
|
|
.replace("\n", "")
|
|
.replace("\t", "")
|
|
)[
|
|
:200
|
|
] # Limit to 200 characters
|
|
text_content_elements.append(
|
|
f"[{unique_id}] [{row['Element']}] [{content}]"
|
|
)
|
|
if content in text_content_text:
|
|
# Remove text_content_elements with content
|
|
text_content_elements = [
|
|
element
|
|
for element in text_content_elements
|
|
if element.strip() != content
|
|
]
|
|
text_content_text.add(content)
|
|
|
|
index += 1
|
|
|
|
for text_rectangle, text_position, unique_id, color in text_to_draw:
|
|
# Draw a background rectangle for the text
|
|
draw.rectangle(text_rectangle, fill=color)
|
|
draw.text(text_position, unique_id, font=font, fill="white")
|
|
|
|
content_str = "\n".join(text_content_elements)
|
|
return img, id2center, content_str
|
|
|
|
def rectangles_overlap(self, rect1, rect2, padding):
|
|
"""
|
|
Check if two rectangles overlap.
|
|
Each rectangle is represented as a list [x1, y1, x2, y2].
|
|
"""
|
|
return not (
|
|
rect1[2] < rect2[0] + padding
|
|
or rect1[0] > rect2[2] - padding
|
|
or rect1[1] > rect2[3] - padding
|
|
or rect1[3] < rect2[1] + padding
|
|
)
|
|
|
|
def process(self, page: Page) -> npt.NDArray[np.uint8]:
|
|
try:
|
|
browser_info = self.fetch_browser_info(page)
|
|
except Exception:
|
|
page.wait_for_load_state("load", timeout=30000) # 500->30000, modified by yuyr
|
|
browser_info = self.fetch_browser_info(page)
|
|
|
|
self.browser_config = browser_info["config"]
|
|
|
|
if self.observation_type == "image_som":
|
|
# Produce the SoM image, with bounding boxes
|
|
try:
|
|
screenshot_bytes = page.screenshot()
|
|
som_bboxes = self.get_page_bboxes(page)
|
|
screenshot_img = Image.open(BytesIO(screenshot_bytes))
|
|
bbox_img, id2center, content_str = self.draw_bounding_boxes(
|
|
som_bboxes,
|
|
screenshot_img,
|
|
viewport_size=self.viewport_size,
|
|
)
|
|
self.som_id_info = id2center
|
|
self.meta_data["obs_nodes_info"] = id2center
|
|
screenshot_som = np.array(bbox_img)
|
|
return screenshot_som, content_str
|
|
except:
|
|
page.wait_for_event("load")
|
|
screenshot_bytes = page.screenshot()
|
|
som_bboxes = self.get_page_bboxes(page)
|
|
screenshot_img = Image.open(BytesIO(screenshot_bytes))
|
|
bbox_img, id2center, content_str = self.draw_bounding_boxes(
|
|
som_bboxes,
|
|
screenshot_img,
|
|
viewport_size=self.viewport_size,
|
|
)
|
|
self.som_id_info = id2center
|
|
self.meta_data["obs_nodes_info"] = id2center
|
|
screenshot_som = np.array(bbox_img)
|
|
return screenshot_som, content_str
|
|
else:
|
|
try:
|
|
screenshot = png_bytes_to_numpy(page.screenshot())
|
|
except:
|
|
page.wait_for_event("load")
|
|
screenshot = png_bytes_to_numpy(page.screenshot())
|
|
return screenshot, ""
|
|
|
|
def fetch_browser_info(self, page: Page) -> BrowserInfo:
|
|
client = page.context.new_cdp_session(page)
|
|
# extract domtree
|
|
tree = client.send(
|
|
"DOMSnapshot.captureSnapshot",
|
|
{
|
|
"computedStyles": [],
|
|
"includeDOMRects": True,
|
|
"includePaintOrder": True,
|
|
},
|
|
)
|
|
client.detach()
|
|
# calibrate the bounds, in some cases, the bounds are scaled somehow
|
|
bounds = tree["documents"][0]["layout"]["bounds"]
|
|
b = bounds[0]
|
|
n = b[2] / self.viewport_size["width"]
|
|
bounds = [[x / n for x in bound] for bound in bounds]
|
|
tree["documents"][0]["layout"]["bounds"] = bounds
|
|
# add union bound placeholder
|
|
tree["documents"][0]["layout"]["unionBounds"] = [None for _ in bounds]
|
|
|
|
# extract browser info
|
|
win_upper_bound = page.evaluate("window.pageYOffset")
|
|
win_left_bound = page.evaluate("window.pageXOffset")
|
|
win_width = page.evaluate("window.screen.width")
|
|
win_height = page.evaluate("window.screen.height")
|
|
win_right_bound = win_left_bound + win_width
|
|
win_lower_bound = win_upper_bound + win_height
|
|
device_pixel_ratio = page.evaluate("window.devicePixelRatio")
|
|
assert device_pixel_ratio == 1.0, "devicePixelRatio is not 1.0"
|
|
|
|
config: BrowserConfig = {
|
|
"win_upper_bound": win_upper_bound,
|
|
"win_left_bound": win_left_bound,
|
|
"win_width": win_width,
|
|
"win_height": win_height,
|
|
"win_right_bound": win_right_bound,
|
|
"win_lower_bound": win_lower_bound,
|
|
"device_pixel_ratio": device_pixel_ratio,
|
|
}
|
|
|
|
# assert len(tree['documents']) == 1, "More than one document in the DOM tree"
|
|
info: BrowserInfo = {"DOMTree": tree, "config": config}
|
|
|
|
return info
|
|
|
|
def get_element_center(self, element_id: str) -> tuple[float, float]:
|
|
if not self.observation_type == "image_som":
|
|
raise ValueError(
|
|
"get_element_center() is only supported for 'image_som' observation type."
|
|
)
|
|
|
|
browser_config = self.browser_config
|
|
center_x, center_y, width, height = self.som_id_info[element_id]
|
|
return (
|
|
center_x / self.viewport_size["width"],
|
|
center_y / self.viewport_size["height"],
|
|
)
|
|
|
|
|
|
class TextObervationProcessorWebRL(TextObervationProcessor):
|
|
def __init__(
|
|
self,
|
|
observation_type: str,
|
|
current_viewport_only: bool,
|
|
viewport_size: ViewportSize,
|
|
captioning_fn=None,
|
|
):
|
|
super().__init__(
|
|
observation_type,
|
|
current_viewport_only,
|
|
viewport_size,
|
|
captioning_fn,
|
|
)
|
|
|
|
def process(self, page: Page) -> str:
|
|
# get the tab info
|
|
page_info = get_parsed_html(page)
|
|
html = page_info["html"]
|
|
from bs4 import BeautifulSoup
|
|
soup = BeautifulSoup(html, 'html.parser')
|
|
obs_nodes_info = {}
|
|
for tag in soup.find_all(True):
|
|
if tag.has_attr('id') and tag.has_attr('data-bbox'):
|
|
backend_id = tag['id']
|
|
union_bound = tag['data-bbox']
|
|
union_bound = [float(num) for num in union_bound.split(',')]
|
|
obs_nodes_info[str(backend_id)] = {
|
|
"backend_id": backend_id,
|
|
"union_bound": union_bound,
|
|
"text": str(tag)
|
|
}
|
|
self.obs_nodes_info = obs_nodes_info
|
|
self.meta_data["obs_nodes_info"] = obs_nodes_info
|
|
return html
|
|
|
|
def get_element_center(self, element_id: str, page: Page=None) -> tuple[float, float]:
|
|
|
|
if page is not None:
|
|
element = page.query_selector(f"[data-label-id='{element_id}']")
|
|
bbox = element.bounding_box()
|
|
relative_bbox = (bbox['x'], bbox['y'], bbox['x'] + bbox['width'], bbox['y'] + bbox['height'])
|
|
center_x = (relative_bbox[0] + relative_bbox[2]) / 2
|
|
center_y = (relative_bbox[1] + relative_bbox[3]) / 2
|
|
else:
|
|
node_info = self.obs_nodes_info[element_id]
|
|
node_bound = node_info["union_bound"]
|
|
x, y, width, height = node_bound
|
|
center_x = x + width / 2
|
|
center_y = y + height / 2
|
|
|
|
return (
|
|
center_x / self.viewport_size["width"],
|
|
center_y / self.viewport_size["height"],
|
|
)
|
|
|
|
class ObservationHandler:
|
|
"""Main entry point to access all observation processor"""
|
|
|
|
def __init__(
|
|
self,
|
|
main_observation_type: str,
|
|
text_observation_type: str,
|
|
image_observation_type: str,
|
|
current_viewport_only: bool,
|
|
viewport_size: ViewportSize,
|
|
captioning_fn=None,
|
|
) -> None:
|
|
self.main_observation_type = main_observation_type
|
|
if text_observation_type == "webrl":
|
|
self.text_processor = TextObervationProcessorWebRL(
|
|
text_observation_type,
|
|
current_viewport_only,
|
|
viewport_size,
|
|
captioning_fn,
|
|
)
|
|
else:
|
|
self.text_processor = TextObervationProcessor(
|
|
text_observation_type,
|
|
current_viewport_only,
|
|
viewport_size,
|
|
captioning_fn,
|
|
)
|
|
self.image_processor = ImageObservationProcessor(
|
|
image_observation_type, viewport_size
|
|
)
|
|
self.viewport_size = viewport_size
|
|
|
|
def get_observation_space(self) -> spaces.Dict:
|
|
text_space = spaces.Text(
|
|
min_length=0,
|
|
max_length=UTTERANCE_MAX_LENGTH,
|
|
charset=ASCII_CHARSET + FREQ_UNICODE_CHARSET,
|
|
)
|
|
|
|
image_space = spaces.Box(
|
|
# Each position stores the RGB values. Note the swapped axes (height first).
|
|
np.zeros(
|
|
(self.viewport_size["height"], self.viewport_size["width"], 3),
|
|
dtype=np.uint8,
|
|
),
|
|
np.ones(
|
|
(self.viewport_size["height"], self.viewport_size["width"], 3),
|
|
dtype=np.uint8,
|
|
)
|
|
* 255.0,
|
|
dtype=np.uint8,
|
|
)
|
|
|
|
return spaces.Dict({"text": text_space, "image": image_space})
|
|
|
|
def get_observation(self, page: Page) -> dict[str, Observation]:
|
|
text_obs = self.text_processor.process(page)
|
|
image_obs, content_str = self.image_processor.process(page)
|
|
if content_str != "":
|
|
text_obs = content_str
|
|
return {"text": text_obs, "image": image_obs}
|
|
|
|
def get_observation_metadata(self) -> dict[str, ObservationMetadata]:
|
|
return {
|
|
"text": self.text_processor.meta_data,
|
|
"image": self.image_processor.meta_data,
|
|
}
|
|
|
|
@property
|
|
def action_processor(self) -> ObservationProcessor:
|
|
"""Return the main processor that is associated with the action space"""
|
|
if self.main_observation_type == "text":
|
|
return self.text_processor
|
|
elif self.main_observation_type == "image":
|
|
return self.image_processor
|
|
else:
|
|
raise ValueError("Invalid main observation type")
|