AgentOccam/browser_env/processors.py
2025-01-22 11:32:35 -08:00

1126 lines
41 KiB
Python

import json
import lxml
import re
from collections import defaultdict
from typing import Any, TypedDict, Union
import numpy as np
import numpy.typing as npt
from gymnasium import spaces
from playwright.sync_api import CDPSession, Page, ViewportSize
from browser_env.constants import (
ASCII_CHARSET,
FREQ_UNICODE_CHARSET,
IGNORED_ACTREE_PROPERTIES,
UTTERANCE_MAX_LENGTH,
)
from .utils import (
AccessibilityTree,
AccessibilityTreeNode,
BrowserConfig,
BrowserInfo,
DOMNode,
DOMTree,
Observation,
png_bytes_to_numpy,
)
from .html_tools import HtmlParser, basic_attrs, print_html_object
IN_VIEWPORT_RATIO_THRESHOLD = 0.8
class TreeNode:
def __init__(self, node_id, role, name, depth, **kwargs):
self.visible = True
self.node_id = node_id
self.role = role
self.name = name
self.depth = depth
self.properties = None
if "properties" in kwargs.keys():
self.properties = kwargs["properties"]
self.children = []
self.parent = None
def add_child(self, child):
child.parent = self
self.children.append(child)
def copy(self):
from copy import deepcopy
new_self = deepcopy(self)
new_self.children = []
new_self.parent = None
return new_self
def get_visible_node_number(self):
visible_ids = []
def dfs(current_node):
if current_node.visible:
visible_ids.append(current_node.node_id)
for child in current_node.children:
dfs(child)
dfs(self)
return len(visible_ids)
def delete_tree(self):
for child in self.children:
child.delete_tree()
self.children.clear()
self.parent = None
def has_properties(self):
return getattr(self, "properties", {})
def visible_children(self):
return [c for c in self.children if c.visible]
def visible_siblings(self):
if not self.parent:
return []
return [n for n in self.parent.children if n.visible and n.node_id != self.node_id]
def siblings(self):
if not self.parent:
return []
return [n for n in self.parent.children if n.node_id != self.node_id]
def search_node_by_id(self, target_id):
if self.node_id == target_id or (self.name and f"[{target_id}]" in self.name):
return self
for child in self.children:
result = child.search_node_by_id(target_id)
if result:
return result
return None
def all_children_invisible(self):
if not self.children:
return True
for child in self.children:
if child.visible:
return False
return True
def has_the_same_properties_as(self, another_node):
node_a_has_properties = getattr(self, "properties", "")
node_b_has_properties = getattr(another_node, "properties", "")
if not node_a_has_properties and not node_b_has_properties:
return True
elif (node_a_has_properties and not node_b_has_properties) or (not node_a_has_properties and node_b_has_properties):
return False
else:
return self.properties == another_node.properties
def is_identical_to(self, another_node):
if another_node.children:
return False
return self.role == another_node.role and self.name == another_node.name and self.has_the_same_properties_as(another_node=another_node)
def last_sibling(self, visible_required=False):
if not self.parent:
return None
last_sibling_idx = self.parent.children.index(self) - 1
if last_sibling_idx < 0:
return None
if not visible_required:
return self.parent.children[last_sibling_idx]
for sibling in self.parent.children[:self.parent.children.index(self):-1]:
if sibling.visible:
return sibling
return None
def next_sibling(self, visible_required=False):
if not self.parent:
return None
next_sibling_idx = self.parent.children.index(self) + 1
if next_sibling_idx >= len(self.parent.children):
return None
if not visible_required:
return self.parent.children[next_sibling_idx]
for sibling in self.parent.children[next_sibling_idx:]:
if sibling.visible:
return sibling
return None
def has_identical_siblings(self):
if not (self.parent and self.all_children_invisible()):
return False
if any(sibling.role == self.role and sibling.name == self.name for sibling in self.parent.children if (sibling.node_id != self.node_id and sibling.all_children_invisible())):
return True
return False
def has_identical_surrounding_siblings(self):
if self.last_sibling(visible_required=False):
if self.is_identical_to(self.last_sibling(visible_required=False)):
return True
if self.last_sibling(visible_required=True):
if self.is_identical_to(self.last_sibling(visible_required=True)):
return True
if self.next_sibling(visible_required=False):
if self.is_identical_to(self.next_sibling(visible_required=False)):
return True
if self.next_sibling(visible_required=True):
if self.is_identical_to(self.next_sibling(visible_required=True)):
return True
return False
def is_differentiable(self, strict=False):
if self.parent and self.parent.role == "row":
return True
if not strict and self.has_identical_siblings():
return False
if self.has_identical_surrounding_siblings():
return False
return True
class ObservationProcessor:
def process(self, page: Page, client: CDPSession) -> Observation:
raise NotImplementedError
class ObservationMetadata(TypedDict):
obs_nodes_info: dict[str, Any]
def create_empty_metadata() -> ObservationMetadata:
return {
"obs_nodes_info": {},
}
class TextObervationProcessor(ObservationProcessor):
def __init__(
self,
observation_type: str,
current_viewport_only: bool,
viewport_size: ViewportSize,
):
self.observation_type = observation_type
self.current_viewport_only = current_viewport_only
self.viewport_size = viewport_size
self.observation_tag = "text"
self.meta_data = (
create_empty_metadata()
) # use the store meta data of this observation type
def fetch_browser_info(
self,
page: Page,
client: CDPSession,
) -> BrowserInfo:
# extract domtree
tree = client.send(
"DOMSnapshot.captureSnapshot",
{
"computedStyles": [],
"includeDOMRects": True,
"includePaintOrder": True,
},
)
# calibrate the bounds, in some cases, the bounds are scaled somehow
bounds = tree["documents"][0]["layout"]["bounds"]
b = bounds[0]
n = b[2] / self.viewport_size["width"]
bounds = [[x / n for x in bound] for bound in bounds]
tree["documents"][0]["layout"]["bounds"] = bounds
# extract browser info
win_top_bound = page.evaluate("window.pageYOffset")
win_left_bound = page.evaluate("window.pageXOffset")
win_width = page.evaluate("window.screen.width")
win_height = page.evaluate("window.screen.height")
win_right_bound = win_left_bound + win_width
win_lower_bound = win_top_bound + win_height
device_pixel_ratio = page.evaluate("window.devicePixelRatio")
assert device_pixel_ratio == 1.0, "devicePixelRatio is not 1.0"
config: BrowserConfig = {
"win_top_bound": win_top_bound,
"win_left_bound": win_left_bound,
"win_width": win_width,
"win_height": win_height,
"win_right_bound": win_right_bound,
"win_lower_bound": win_lower_bound,
"device_pixel_ratio": device_pixel_ratio,
}
# assert len(tree['documents']) == 1, "More than one document in the DOM tree"
info: BrowserInfo = {"DOMTree": tree, "config": config}
# with open('output/browser_info.json', 'w') as f:
# f.write(json.dumps(tree, ensure_ascii=False))
return info
@staticmethod
def get_bounding_client_rect(
client: CDPSession, backend_node_id: str
) -> dict[str, Any]:
try:
remote_object = client.send(
"DOM.resolveNode", {"backendNodeId": int(backend_node_id)}
)
remote_object_id = remote_object["object"]["objectId"]
response = client.send(
"Runtime.callFunctionOn",
{
"objectId": remote_object_id,
"functionDeclaration": """
function() {
if (this.nodeType == 3) {
var range = document.createRange();
range.selectNode(this);
var rect = range.getBoundingClientRect().toJSON();
range.detach();
return rect;
} else {
return this.getBoundingClientRect().toJSON();
}
}
""",
"returnByValue": True,
},
)
return response
except Exception as e:
return {"result": {"subtype": "error"}}
@staticmethod
def get_element_in_viewport_ratio(
elem_left_bound: float,
elem_top_bound: float,
width: float,
height: float,
config: BrowserConfig,
) -> float:
elem_right_bound = elem_left_bound + width
elem_lower_bound = elem_top_bound + height
win_left_bound = 0
win_right_bound = config["win_width"]
win_top_bound = 0
win_lower_bound = config["win_height"]
# Compute the overlap in x and y axes
overlap_width = max(
0,
min(elem_right_bound, win_right_bound)
- max(elem_left_bound, win_left_bound),
)
overlap_height = max(
0,
min(elem_lower_bound, win_lower_bound)
- max(elem_top_bound, win_top_bound),
)
# Compute the overlap area
ratio = overlap_width * overlap_height / width * height
return ratio
def element_is_visible(self, page, element_id):
def _get_element_in_viewport_ratio(
elem_left_bound: float,
elem_top_bound: float,
width: float,
height: float,
config: BrowserConfig,
) -> float:
def calculate_overlap(start1, end1, start2, end2):
# Calculate overlap
overlap_start = max(start1, start2)
overlap_end = min(end1, end2)
# Check if there's overlap
if overlap_start < overlap_end:
overlap = overlap_end - overlap_start
else:
overlap = 0
return overlap
elem_right_bound = elem_left_bound + width
elem_lower_bound = elem_top_bound + height
win_left_bound = 0
win_right_bound = config["win_width"]
win_top_bound = 0
win_lower_bound = config["win_height"]
overlap_width = calculate_overlap(elem_left_bound, elem_right_bound, win_left_bound, win_right_bound)
overlap_height = calculate_overlap(elem_top_bound, elem_lower_bound, win_top_bound, win_lower_bound)
try:
ratio = (overlap_width * overlap_height) / (width * height)
return ratio
except:
return 1 #TODO
try:
browser_info = self.fetch_browser_info(page, page.client)
except Exception:
page.wait_for_load_state("load", timeout=500)
browser_info = self.fetch_browser_info(page, page.client)
response = self.get_bounding_client_rect(
page.client, self.obs_nodes_info[element_id]["backend_id"]
)
x = response["result"]["value"]["x"]
y = response["result"]["value"]["y"]
width = response["result"]["value"]["width"]
height = response["result"]["value"]["height"]
in_viewport_ratio = _get_element_in_viewport_ratio(
elem_left_bound=float(x),
elem_top_bound=float(y),
width=float(width),
height=float(height),
config=browser_info["config"],
)
if in_viewport_ratio < IN_VIEWPORT_RATIO_THRESHOLD:
return False
return True
def fetch_page_html(
self,
info: BrowserInfo,
page: Page,
client: CDPSession,
current_viewport_only: bool,
) -> DOMTree:
# adopted from [natbot](https://github.com/nat/natbot)
tree = info["DOMTree"]
config = info["config"]
strings = tree["strings"]
document = tree["documents"][0]
nodes = document["nodes"]
layout = document["layout"]
import time
stt = time.time()
# make a dom tree that is easier to navigate
dom_tree: DOMTree = []
graph = defaultdict(list)
print(nodes.keys())
for node_idx in range(len(nodes["nodeName"])):
cur_node: DOMNode = {
"nodeId": "",
"nodeType": "",
"nodeName": "",
"nodeValue": "",
"attributes": "",
"backendNodeId": "",
"parentId": "",
"childIds": [],
"cursor": 0,
"union_bound": None,
}
node_type_idx = nodes["nodeType"][node_idx]
node_type = "generic"
if node_type_idx >= 0 and node_type_idx < len(strings):
node_type = strings[node_type_idx]
node_name = strings[nodes["nodeName"][node_idx]]
node_value_idx = nodes["nodeValue"][node_idx]
node_value = ""
if node_value_idx >= 0 and node_value_idx < len(strings):
node_value = " ".join(strings[node_value_idx].split())
node_attributes = [
strings[i] for i in nodes["attributes"][node_idx]
]
node_attributes_str = ""
for i in range(0, len(node_attributes), 2):
a = node_attributes[i]
b = node_attributes[i + 1]
# b = " ".join(b.split())
import re
b = re.sub(r"{\s*opacity:\s*.*;*\s*}", " ", b)
b = [b_item for b_item in b.split() if b_item.count('vimium') == 0]
b = " ".join(b)
node_attributes_str += f'{a}="{b}" '
node_attributes_str = node_attributes_str.strip()
cur_node["nodeId"] = str(node_idx)
cur_node["nodeType"] = node_type
cur_node["nodeName"] = node_name
cur_node["nodeValue"] = node_value
cur_node["attributes"] = node_attributes_str
cur_node["backendNodeId"] = str(nodes["backendNodeId"][node_idx])
cur_node["parentId"] = str(nodes["parentIndex"][node_idx])
if cur_node["parentId"] != "-1":
graph[cur_node["parentId"]].append(str(cur_node["nodeId"]))
# get the bound
if cur_node["parentId"] == "-1":
cur_node["union_bound"] = [0.0, 0.0, 10.0, 10.0]
else:
# method 1
# response = self.get_bounding_client_rect(
# client, cur_node["backendNodeId"]
# )
# if response.get("result", {}).get("subtype", "") == "error":
# cur_node["union_bound"] = None
# else:
# x = response["result"]["value"]["x"]
# y = response["result"]["value"]["y"]
# width = response["result"]["value"]["width"]
# height = response["result"]["value"]["height"]
# cur_node["union_bound"] = [x, y, width, height]
# method 2
bound = [0.0, 0.0, 0.0, 0.0]
if node_idx in layout["nodeIndex"]:
bound = layout["bounds"][layout["nodeIndex"].index(node_idx)]
bound[0] -= config["win_left_bound"]
bound[1] -= config["win_top_bound"]
cur_node["union_bound"] = bound
dom_tree.append(cur_node)
print('[build]', time.time() - stt)
stt = time.time()
# add parent children index to the node
for parent_id, child_ids in graph.items():
dom_tree[int(parent_id)]["childIds"] = child_ids
print('[graph]', time.time() - stt)
# with open('output/dom_tree.json', 'w') as f:
# f.write(json.dumps(dom_tree, ensure_ascii=False))
stt = time.time()
# remove the nodes that are not in the current viewport
if current_viewport_only:
def remove_node_in_graph(node: DOMNode) -> None:
# update the node information in the accessibility tree
node_id = node["nodeId"]
parent_id = node["parentId"]
child_ids = node["childIds"]
# update the children of the parent node
assert dom_tree[int(parent_id)]["parentId"] != "[REMOVED]"
# remove the nodeid from parent
index = dom_tree[int(parent_id)]["childIds"].index(node_id)
dom_tree[int(parent_id)]["childIds"].pop(index)
# Insert children_nodeids in the same location
for child_id in child_ids:
dom_tree[int(parent_id)]["childIds"].insert(
index, child_id
)
index += 1
# update children node's parent
for child_id in child_ids:
dom_tree[int(child_id)]["parentId"] = parent_id
# mark as removed
dom_tree[int(node_id)]["parentId"] = "[REMOVED]"
config = info["config"]
for cursor, node in enumerate(dom_tree):
if not node["union_bound"]:
remove_node_in_graph(node)
continue
[x, y, width, height] = node["union_bound"]
# invisible node
if width == 0.0 or height == 0.0:
parent_id = node["parentId"]
if node["nodeName"] not in ['OPTION'] or dom_tree[int(parent_id)]["nodeName"] not in ["SELECT"]:
remove_node_in_graph(node)
continue
in_viewport_ratio = self.get_element_in_viewport_ratio(
elem_left_bound=float(x),
elem_top_bound=float(y),
width=float(width),
height=float(height),
config=config,
)
if in_viewport_ratio < IN_VIEWPORT_RATIO_THRESHOLD:
remove_node_in_graph(node)
dom_tree = [
node
for node in dom_tree
if node.get("parentId", "-1") != "[REMOVED]"
]
print('[filter]', time.time() - stt)
return dom_tree
@staticmethod
def parse_my_html(dom_tree: DOMTree) -> tuple[str, str, dict[str, Any], Any]:
"""Parse the html tree into a string text"""
obs_nodes_info = {}
nodeid_to_cursor = {
node["nodeId"]: idx for idx, node in enumerate(dom_tree)
}
def dfs(node_cursor: int, depth: int) -> tuple[str, list[str]]:
tree_str, labeled_elems = '', []
node = dom_tree[node_cursor]
valid_node = True
pure_text = False
try:
if node['nodeName'] == '#text':
node['nodeName'] = 'text'
node_str = f"<{node['nodeName']}"
if node["attributes"]:
node_str += f" {node['attributes']}"
node_str += f" backend-id=\"bid-{node['backendNodeId']}\"> {node['nodeValue']}"
# if node['nodeName'] == '#text':
# pure_text = True
# node_str = node['nodeValue']
valid_node = bool(node["attributes"] or node["nodeValue"] or pure_text)
if valid_node:
node_html = lxml.html.fromstring(node_str)
label = node_html.attrib.get('data-testid', '')
if len(label) > 0:
labeled_elems.append(node["backendNodeId"])
obs_nodes_info[str(node_cursor)] = {
"backend_id": node["backendNodeId"],
"union_bound": node["union_bound"],
"text": node['nodeValue'],
"label": label,
}
tree_str += f"{node_str}"
except Exception as e:
valid_node = False
for child_ids in node["childIds"]:
child_cursor = nodeid_to_cursor[child_ids]
child_depth = depth + 1 if valid_node else depth
child_str, elems = dfs(child_cursor, child_depth)
tree_str += child_str
labeled_elems.extend(elems)
if valid_node and not pure_text:
tree_str += f"</{node['nodeName']}>"
return tree_str, labeled_elems
html, labeled_elems = dfs(0, 0)
# with open('output/raw.html', 'w') as f:
# f.write(html)
print(labeled_elems)
args = {
'use_position': False,
'id_attr': 'backend-id',
'label_generator': 'order',
'label_attr': 'data-testid',
'attr_list': basic_attrs,
'prompt': 'refine',
}
hp = HtmlParser(html, args)
packet = hp.parse_tree()
page_html = packet['html']
print(print_html_object(page_html))
it, pt = packet.get('init_time', 0), packet.get('parse_time', 0)
print(f'[Time] {it:.3f} {pt:.3f}')
return html, page_html, obs_nodes_info, hp
@staticmethod
def parse_html(dom_tree: DOMTree) -> tuple[str, dict[str, Any]]:
"""Parse the html tree into a string text"""
obs_nodes_info = {}
nodeid_to_cursor = {
node["nodeId"]: idx for idx, node in enumerate(dom_tree)
}
def dfs(node_cursor: int, depth: int) -> str:
tree_str = ""
node = dom_tree[node_cursor]
indent = "\t" * depth
valid_node = True
try:
node_str = f"[{node_cursor}] <{node['nodeName']}"
if node["attributes"]:
node_str += f" {node['attributes']}"
node_str += f"> {node['nodeValue']}"
valid_node = bool(node["attributes"] or node["nodeValue"])
if valid_node:
obs_nodes_info[str(node_cursor)] = {
"backend_id": node["backendNodeId"],
"union_bound": node["union_bound"],
"text": node_str,
}
tree_str += f"{indent}{node_str}\n"
except Exception as e:
valid_node = False
for child_ids in node["childIds"]:
child_cursor = nodeid_to_cursor[child_ids]
child_depth = depth + 1 if valid_node else depth
child_str = dfs(child_cursor, child_depth)
tree_str += child_str
return tree_str
html = dfs(0, 0)
return html, obs_nodes_info
def fetch_page_accessibility_tree(
self,
info: BrowserInfo,
client: CDPSession,
current_viewport_only: bool,
) -> AccessibilityTree:
accessibility_tree: AccessibilityTree = client.send(
"Accessibility.getFullAXTree", {}
)["nodes"]
# a few nodes are repeated in the accessibility tree
seen_ids = set()
_accessibility_tree = []
for node in accessibility_tree:
if node["nodeId"] not in seen_ids:
_accessibility_tree.append(node)
seen_ids.add(node["nodeId"])
accessibility_tree = _accessibility_tree
nodeid_to_cursor = {}
for cursor, node in enumerate(accessibility_tree):
nodeid_to_cursor[node["nodeId"]] = cursor
# usually because the node is not visible etc
if "backendDOMNodeId" not in node:
node["union_bound"] = None
continue
backend_node_id = str(node["backendDOMNodeId"])
if node["role"]["value"] == "RootWebArea":
# always inside the viewport
node["union_bound"] = [0.0, 0.0, 10.0, 10.0]
else:
response = self.get_bounding_client_rect(
client, backend_node_id
)
if response.get("result", {}).get("subtype", "") == "error":
node["union_bound"] = None
else:
x = response["result"]["value"]["x"]
y = response["result"]["value"]["y"]
width = response["result"]["value"]["width"]
height = response["result"]["value"]["height"]
node["union_bound"] = [x, y, width, height]
# filter nodes that are not in the current viewport
if current_viewport_only:
def remove_node_in_graph(node: AccessibilityTreeNode) -> None:
# update the node information in the accessibility tree
nodeid = node["nodeId"]
node_cursor = nodeid_to_cursor[nodeid]
parent_nodeid = node["parentId"]
children_nodeids = node["childIds"]
parent_cursor = nodeid_to_cursor[parent_nodeid]
# update the children of the parent node
assert (
accessibility_tree[parent_cursor].get("parentId", "Root")
is not None
)
# remove the nodeid from parent's childIds
index = accessibility_tree[parent_cursor]["childIds"].index(
nodeid
)
accessibility_tree[parent_cursor]["childIds"].pop(index)
# Insert children_nodeids in the same location
for child_nodeid in children_nodeids:
accessibility_tree[parent_cursor]["childIds"].insert(
index, child_nodeid
)
index += 1
# update children node's parent
for child_nodeid in children_nodeids:
child_cursor = nodeid_to_cursor[child_nodeid]
accessibility_tree[child_cursor][
"parentId"
] = parent_nodeid
# mark as removed
accessibility_tree[node_cursor]["parentId"] = "[REMOVED]"
config = info["config"]
for node in accessibility_tree:
if not node["union_bound"]:
remove_node_in_graph(node)
continue
[x, y, width, height] = node["union_bound"]
# invisible node
if width == 0 or height == 0:
remove_node_in_graph(node)
continue
in_viewport_ratio = self.get_element_in_viewport_ratio(
elem_left_bound=float(x),
elem_top_bound=float(y),
width=float(width),
height=float(height),
config=config,
)
if in_viewport_ratio < IN_VIEWPORT_RATIO_THRESHOLD:
remove_node_in_graph(node)
accessibility_tree = [
node
for node in accessibility_tree
if node.get("parentId", "Root") != "[REMOVED]"
]
return accessibility_tree
@staticmethod
def parse_accessibility_tree(
accessibility_tree: AccessibilityTree,
) -> tuple[str, dict[str, Any], TreeNode]:
"""Parse the accessibility tree into a string text"""
node_id_to_idx = {}
for idx, node in enumerate(accessibility_tree):
node_id_to_idx[node["nodeId"]] = idx
obs_nodes_info = {}
def dfs(idx: int, obs_node_id: str, depth: int, active_node_dict: dict) -> str:
tree_str = ""
node = accessibility_tree[idx]
indent = "\t" * depth
valid_node = True
try:
role = node["role"]["value"]
name = node["name"]["value"]
node_str = f"[{obs_node_id}] {role} {repr(name)}"
properties = []
structured_properties = {}
for property in node.get("properties", []):
try:
if property["name"] in IGNORED_ACTREE_PROPERTIES:
continue
properties.append(
f'{property["name"]}: {property["value"]["value"]}'
)
structured_properties[property["name"]] = property["value"]["value"]
except KeyError:
pass
if properties:
node_str += " " + " ".join(properties)
# check valid
if not node_str.strip():
valid_node = False
# empty generic node
if not name.strip():
if not properties:
if role in [
"generic",
"img",
"list",
"strong",
"paragraph",
"banner",
"navigation",
"Section",
"LabelText",
"Legend",
"listitem",
]:
valid_node = False
elif role in ["listitem"]:
valid_node = False
if valid_node:
tree_str += f"{indent}{node_str}"
obs_nodes_info[obs_node_id] = {
"backend_id": node["backendDOMNodeId"],
"union_bound": node["union_bound"],
"text": node_str,
}
except Exception as e:
valid_node = False
structured_node = TreeNode(node_id=int(obs_node_id), role=node["role"]["value"], name=node["name"]["value"], depth=depth, properties=structured_properties) if valid_node else None
active_node_dict[depth] = structured_node if valid_node else active_node_dict.get(depth, None)
for _, child_node_id in enumerate(node["childIds"]):
if child_node_id not in node_id_to_idx:
continue
# mark this to save some tokens
child_depth = depth + 1 if valid_node else depth
child_str, child_node = dfs(
node_id_to_idx[child_node_id], child_node_id, child_depth, active_node_dict=active_node_dict
)
if child_str.strip():
if tree_str.strip():
tree_str += "\n"
tree_str += child_str
if child_depth > 0 and child_node:
active_node_dict[child_depth - 1].add_child(child_node)
return tree_str, structured_node
tree_str, structured_node = dfs(0, accessibility_tree[0]["nodeId"], 0, active_node_dict={})
return tree_str, obs_nodes_info, structured_node
@staticmethod
def clean_accesibility_tree(tree_str: str) -> str:
"""further clean accesibility tree"""
clean_lines: list[str] = []
for line in tree_str.split("\n"):
if "statictext" in line.lower():
prev_lines = clean_lines[-3:]
pattern = r"\[\d+\] StaticText '([^']+)'"
match = re.search(pattern, line)
if match:
static_text = match.group(1)
if all(
static_text not in prev_line
for prev_line in prev_lines
):
clean_lines.append(line)
else:
clean_lines.append(line)
return "\n".join(clean_lines)
def process(self, page: Page, client: CDPSession, context: str) -> str:
# get the tab info
open_tabs = page.context.pages
# try:
# tab_titles = [tab.title() for tab in open_tabs]
# current_tab_idx = open_tabs.index(page)
# for idx in range(len(open_tabs)):
# if idx == current_tab_idx:
# tab_titles[
# idx
# ] = f"Tab {idx} (current): {open_tabs[idx].title()}"
# else:
# tab_titles[idx] = f"Tab {idx}: {open_tabs[idx].title()}"
# tab_title_str = " | ".join(tab_titles)
# except Exception:
# tab_title_str = " | ".join(
# ["Tab {idx}" for idx in range(len(open_tabs))]
# )
try:
tab_titles = [tab.title() for tab in open_tabs]
current_tab_idx = open_tabs.index(page)
for idx in range(len(open_tabs)):
if idx == current_tab_idx:
tab_titles[
idx
] = f"{idx+1}. {open_tabs[idx].title()} <-- current tab"
else:
tab_titles[idx] = f"{idx+1}. {open_tabs[idx].title()}"
tab_title_str = "\n".join(tab_titles)
except Exception:
tab_title_str = "\n".join(
[f"{idx+1}. Default" for idx in range(len(open_tabs))]
)
try:
browser_info = self.fetch_browser_info(page, client)
except Exception:
page.wait_for_load_state("load", timeout=500)
browser_info = self.fetch_browser_info(page, client)
if self.observation_type == "html":
import time
stt = time.time()
dom_tree = self.fetch_page_html(
browser_info,
page,
client,
current_viewport_only=self.current_viewport_only,
)
print('[fetch]', time.time() - stt)
stt = time.time()
raw_html, content, obs_nodes_info, hp = self.parse_my_html(dom_tree)
print('[parse]', time.time() - stt)
window_height = page.evaluate("window.innerHeight")
page_height = page.evaluate('document.documentElement.scrollHeight') / window_height
position = page.evaluate("window.scrollY") / window_height
self.obs_nodes_info = obs_nodes_info
self.meta_data["obs_nodes_info"] = obs_nodes_info
self.meta_data["position_info"] = {
"page_height": page_height,
"position": position,
}
self.meta_data["dom_info"] = {
"raw_html": raw_html,
"dom_tree": dom_tree,
}
self.meta_data["html_parser"] = hp
self.meta_data["tab_title"] = tab_title_str
elif self.observation_type == "accessibility_tree":
accessibility_tree = self.fetch_page_accessibility_tree(
browser_info,
client,
current_viewport_only=self.current_viewport_only,
)
content, obs_nodes_info, node_root = self.parse_accessibility_tree(
accessibility_tree
)
content = self.clean_accesibility_tree(content)
self.obs_nodes_info = obs_nodes_info
page_dialog_message = getattr(page, "dialog_message", "")
if page_dialog_message:
import copy
node_root.properties["page_dialog_message"] = copy.deepcopy(page_dialog_message) + " Retry."
page.dialog_message = None
self.node_root = node_root
self.meta_data["obs_nodes_info"] = obs_nodes_info
else:
raise ValueError(
f"Invalid observatrion type: {self.observation_type}"
)
self.browser_config = browser_info["config"]
# content = f"{tab_title_str}\n\n{content}"
return (content, node_root)
def get_node_info_by_element_id(self, AXTreeId):
return self.node_root.search_node_by_id(AXTreeId)
def get_element_center(self, element_id: str, page) -> tuple[float, float]:
node = self.obs_nodes_info[element_id]
backend_node_id = str(node["backend_id"])
response = self.get_bounding_client_rect(
page.client, backend_node_id
)
x = response["result"]["value"]["x"]
y = response["result"]["value"]["y"]
width = response["result"]["value"]["width"]
height = response["result"]["value"]["height"]
center_x = x + width / 2
center_y = y + height / 2
return (
center_x / self.viewport_size["width"],
center_y / self.viewport_size["height"],
)
class ImageObservationProcessor(ObservationProcessor):
def __init__(self, observation_type: str, current_viewport_only: bool):
self.observation_type = observation_type
self.current_viewport_only = current_viewport_only
self.observation_tag = "image"
self.meta_data = create_empty_metadata()
def process(self, page: Page, client: CDPSession, context: str) -> npt.NDArray[np.uint8]:
try:
screenshot = png_bytes_to_numpy(page.screenshot(full_page=(not self.current_viewport_only)))
screenshot = screenshot[:2*screenshot.shape[1], :, :]
except:
page.wait_for_event("load")
screenshot = png_bytes_to_numpy(page.screenshot(full_page=(not self.current_viewport_only)))
return screenshot
class ObservationHandler:
"""Main entry point to access all observation processor"""
def __init__(
self,
main_observation_type: str,
text_observation_type: str,
image_observation_type: str,
current_viewport_only: bool,
viewport_size: ViewportSize,
) -> None:
self.main_observation_type = main_observation_type
self.text_processor = TextObervationProcessor(
text_observation_type, current_viewport_only, viewport_size
)
self.image_processor = ImageObservationProcessor(
image_observation_type, current_viewport_only
)
self.viewport_size = viewport_size
def get_observation_space(self) -> spaces.Dict:
text_space = spaces.Text(
min_length=0,
max_length=UTTERANCE_MAX_LENGTH,
charset=ASCII_CHARSET + FREQ_UNICODE_CHARSET,
)
image_space = spaces.Box(
# Each position stores the RGB values. Note the swapped axes (height first).
np.zeros(
(self.viewport_size["height"], self.viewport_size["width"], 3),
dtype=np.uint8,
),
np.ones(
(self.viewport_size["height"], self.viewport_size["width"], 3),
dtype=np.uint8,
)
* 255.0,
dtype=np.uint8,
)
return spaces.Dict({"text": text_space, "image": image_space})
def get_observation(
self, page: Page, client: CDPSession, context: str = '',
) -> dict[str, Observation]:
text_obs = self.text_processor.process(page, client, context)
image_obs = self.image_processor.process(page, client, context)
return {"text": text_obs, "image": image_obs}
def get_observation_metadata(self) -> dict[str, ObservationMetadata]:
return {
"text": self.text_processor.meta_data,
"image": self.image_processor.meta_data,
}
@property
def action_processor(self) -> ObservationProcessor:
"""Return the main processor that is associated with the action space"""
if self.main_observation_type == "text":
return self.text_processor
elif self.main_observation_type == "image":
return self.image_processor
else:
raise ValueError("Invalid main observation type")