import json import lxml import re from collections import defaultdict from typing import Any, TypedDict, Union import numpy as np import numpy.typing as npt from gymnasium import spaces from playwright.sync_api import CDPSession, Page, ViewportSize from browser_env.constants import ( ASCII_CHARSET, FREQ_UNICODE_CHARSET, IGNORED_ACTREE_PROPERTIES, UTTERANCE_MAX_LENGTH, ) from .utils import ( AccessibilityTree, AccessibilityTreeNode, BrowserConfig, BrowserInfo, DOMNode, DOMTree, Observation, png_bytes_to_numpy, ) from .html_tools import HtmlParser, basic_attrs, print_html_object IN_VIEWPORT_RATIO_THRESHOLD = 0.8 class TreeNode: def __init__(self, node_id, role, name, depth, **kwargs): self.visible = True self.node_id = node_id self.role = role self.name = name self.depth = depth self.properties = None if "properties" in kwargs.keys(): self.properties = kwargs["properties"] self.children = [] self.parent = None def add_child(self, child): child.parent = self self.children.append(child) def copy(self): from copy import deepcopy new_self = deepcopy(self) new_self.children = [] new_self.parent = None return new_self def get_visible_node_number(self): visible_ids = [] def dfs(current_node): if current_node.visible: visible_ids.append(current_node.node_id) for child in current_node.children: dfs(child) dfs(self) return len(visible_ids) def delete_tree(self): for child in self.children: child.delete_tree() self.children.clear() self.parent = None def has_properties(self): return getattr(self, "properties", {}) def visible_children(self): return [c for c in self.children if c.visible] def visible_siblings(self): if not self.parent: return [] return [n for n in self.parent.children if n.visible and n.node_id != self.node_id] def siblings(self): if not self.parent: return [] return [n for n in self.parent.children if n.node_id != self.node_id] def search_node_by_id(self, target_id): if self.node_id == target_id or (self.name and f"[{target_id}]" in self.name): return self for child in self.children: result = child.search_node_by_id(target_id) if result: return result return None def all_children_invisible(self): if not self.children: return True for child in self.children: if child.visible: return False return True def has_the_same_properties_as(self, another_node): node_a_has_properties = getattr(self, "properties", "") node_b_has_properties = getattr(another_node, "properties", "") if not node_a_has_properties and not node_b_has_properties: return True elif (node_a_has_properties and not node_b_has_properties) or (not node_a_has_properties and node_b_has_properties): return False else: return self.properties == another_node.properties def is_identical_to(self, another_node): if another_node.children: return False return self.role == another_node.role and self.name == another_node.name and self.has_the_same_properties_as(another_node=another_node) def last_sibling(self, visible_required=False): if not self.parent: return None last_sibling_idx = self.parent.children.index(self) - 1 if last_sibling_idx < 0: return None if not visible_required: return self.parent.children[last_sibling_idx] for sibling in self.parent.children[:self.parent.children.index(self):-1]: if sibling.visible: return sibling return None def next_sibling(self, visible_required=False): if not self.parent: return None next_sibling_idx = self.parent.children.index(self) + 1 if next_sibling_idx >= len(self.parent.children): return None if not visible_required: return self.parent.children[next_sibling_idx] for sibling in self.parent.children[next_sibling_idx:]: if sibling.visible: return sibling return None def has_identical_siblings(self): if not (self.parent and self.all_children_invisible()): return False if any(sibling.role == self.role and sibling.name == self.name for sibling in self.parent.children if (sibling.node_id != self.node_id and sibling.all_children_invisible())): return True return False def has_identical_surrounding_siblings(self): if self.last_sibling(visible_required=False): if self.is_identical_to(self.last_sibling(visible_required=False)): return True if self.last_sibling(visible_required=True): if self.is_identical_to(self.last_sibling(visible_required=True)): return True if self.next_sibling(visible_required=False): if self.is_identical_to(self.next_sibling(visible_required=False)): return True if self.next_sibling(visible_required=True): if self.is_identical_to(self.next_sibling(visible_required=True)): return True return False def is_differentiable(self, strict=False): if self.parent and self.parent.role == "row": return True if not strict and self.has_identical_siblings(): return False if self.has_identical_surrounding_siblings(): return False return True class ObservationProcessor: def process(self, page: Page, client: CDPSession) -> Observation: raise NotImplementedError class ObservationMetadata(TypedDict): obs_nodes_info: dict[str, Any] def create_empty_metadata() -> ObservationMetadata: return { "obs_nodes_info": {}, } class TextObervationProcessor(ObservationProcessor): def __init__( self, observation_type: str, current_viewport_only: bool, viewport_size: ViewportSize, ): self.observation_type = observation_type self.current_viewport_only = current_viewport_only self.viewport_size = viewport_size self.observation_tag = "text" self.meta_data = ( create_empty_metadata() ) # use the store meta data of this observation type def fetch_browser_info( self, page: Page, client: CDPSession, ) -> BrowserInfo: # extract domtree tree = client.send( "DOMSnapshot.captureSnapshot", { "computedStyles": [], "includeDOMRects": True, "includePaintOrder": True, }, ) # calibrate the bounds, in some cases, the bounds are scaled somehow bounds = tree["documents"][0]["layout"]["bounds"] b = bounds[0] n = b[2] / self.viewport_size["width"] bounds = [[x / n for x in bound] for bound in bounds] tree["documents"][0]["layout"]["bounds"] = bounds # extract browser info win_top_bound = page.evaluate("window.pageYOffset") win_left_bound = page.evaluate("window.pageXOffset") win_width = page.evaluate("window.screen.width") win_height = page.evaluate("window.screen.height") win_right_bound = win_left_bound + win_width win_lower_bound = win_top_bound + win_height device_pixel_ratio = page.evaluate("window.devicePixelRatio") assert device_pixel_ratio == 1.0, "devicePixelRatio is not 1.0" config: BrowserConfig = { "win_top_bound": win_top_bound, "win_left_bound": win_left_bound, "win_width": win_width, "win_height": win_height, "win_right_bound": win_right_bound, "win_lower_bound": win_lower_bound, "device_pixel_ratio": device_pixel_ratio, } # assert len(tree['documents']) == 1, "More than one document in the DOM tree" info: BrowserInfo = {"DOMTree": tree, "config": config} # with open('output/browser_info.json', 'w') as f: # f.write(json.dumps(tree, ensure_ascii=False)) return info @staticmethod def get_bounding_client_rect( client: CDPSession, backend_node_id: str ) -> dict[str, Any]: try: remote_object = client.send( "DOM.resolveNode", {"backendNodeId": int(backend_node_id)} ) remote_object_id = remote_object["object"]["objectId"] response = client.send( "Runtime.callFunctionOn", { "objectId": remote_object_id, "functionDeclaration": """ function() { if (this.nodeType == 3) { var range = document.createRange(); range.selectNode(this); var rect = range.getBoundingClientRect().toJSON(); range.detach(); return rect; } else { return this.getBoundingClientRect().toJSON(); } } """, "returnByValue": True, }, ) return response except Exception as e: return {"result": {"subtype": "error"}} @staticmethod def get_element_in_viewport_ratio( elem_left_bound: float, elem_top_bound: float, width: float, height: float, config: BrowserConfig, ) -> float: elem_right_bound = elem_left_bound + width elem_lower_bound = elem_top_bound + height win_left_bound = 0 win_right_bound = config["win_width"] win_top_bound = 0 win_lower_bound = config["win_height"] # Compute the overlap in x and y axes overlap_width = max( 0, min(elem_right_bound, win_right_bound) - max(elem_left_bound, win_left_bound), ) overlap_height = max( 0, min(elem_lower_bound, win_lower_bound) - max(elem_top_bound, win_top_bound), ) # Compute the overlap area ratio = overlap_width * overlap_height / width * height return ratio def element_is_visible(self, page, element_id): def _get_element_in_viewport_ratio( elem_left_bound: float, elem_top_bound: float, width: float, height: float, config: BrowserConfig, ) -> float: def calculate_overlap(start1, end1, start2, end2): # Calculate overlap overlap_start = max(start1, start2) overlap_end = min(end1, end2) # Check if there's overlap if overlap_start < overlap_end: overlap = overlap_end - overlap_start else: overlap = 0 return overlap elem_right_bound = elem_left_bound + width elem_lower_bound = elem_top_bound + height win_left_bound = 0 win_right_bound = config["win_width"] win_top_bound = 0 win_lower_bound = config["win_height"] overlap_width = calculate_overlap(elem_left_bound, elem_right_bound, win_left_bound, win_right_bound) overlap_height = calculate_overlap(elem_top_bound, elem_lower_bound, win_top_bound, win_lower_bound) try: ratio = (overlap_width * overlap_height) / (width * height) return ratio except: return 1 #TODO try: browser_info = self.fetch_browser_info(page, page.client) except Exception: page.wait_for_load_state("load", timeout=500) browser_info = self.fetch_browser_info(page, page.client) response = self.get_bounding_client_rect( page.client, self.obs_nodes_info[element_id]["backend_id"] ) x = response["result"]["value"]["x"] y = response["result"]["value"]["y"] width = response["result"]["value"]["width"] height = response["result"]["value"]["height"] in_viewport_ratio = _get_element_in_viewport_ratio( elem_left_bound=float(x), elem_top_bound=float(y), width=float(width), height=float(height), config=browser_info["config"], ) if in_viewport_ratio < IN_VIEWPORT_RATIO_THRESHOLD: return False return True def fetch_page_html( self, info: BrowserInfo, page: Page, client: CDPSession, current_viewport_only: bool, ) -> DOMTree: # adopted from [natbot](https://github.com/nat/natbot) tree = info["DOMTree"] config = info["config"] strings = tree["strings"] document = tree["documents"][0] nodes = document["nodes"] layout = document["layout"] import time stt = time.time() # make a dom tree that is easier to navigate dom_tree: DOMTree = [] graph = defaultdict(list) print(nodes.keys()) for node_idx in range(len(nodes["nodeName"])): cur_node: DOMNode = { "nodeId": "", "nodeType": "", "nodeName": "", "nodeValue": "", "attributes": "", "backendNodeId": "", "parentId": "", "childIds": [], "cursor": 0, "union_bound": None, } node_type_idx = nodes["nodeType"][node_idx] node_type = "generic" if node_type_idx >= 0 and node_type_idx < len(strings): node_type = strings[node_type_idx] node_name = strings[nodes["nodeName"][node_idx]] node_value_idx = nodes["nodeValue"][node_idx] node_value = "" if node_value_idx >= 0 and node_value_idx < len(strings): node_value = " ".join(strings[node_value_idx].split()) node_attributes = [ strings[i] for i in nodes["attributes"][node_idx] ] node_attributes_str = "" for i in range(0, len(node_attributes), 2): a = node_attributes[i] b = node_attributes[i + 1] # b = " ".join(b.split()) import re b = re.sub(r"{\s*opacity:\s*.*;*\s*}", " ", b) b = [b_item for b_item in b.split() if b_item.count('vimium') == 0] b = " ".join(b) node_attributes_str += f'{a}="{b}" ' node_attributes_str = node_attributes_str.strip() cur_node["nodeId"] = str(node_idx) cur_node["nodeType"] = node_type cur_node["nodeName"] = node_name cur_node["nodeValue"] = node_value cur_node["attributes"] = node_attributes_str cur_node["backendNodeId"] = str(nodes["backendNodeId"][node_idx]) cur_node["parentId"] = str(nodes["parentIndex"][node_idx]) if cur_node["parentId"] != "-1": graph[cur_node["parentId"]].append(str(cur_node["nodeId"])) # get the bound if cur_node["parentId"] == "-1": cur_node["union_bound"] = [0.0, 0.0, 10.0, 10.0] else: # method 1 # response = self.get_bounding_client_rect( # client, cur_node["backendNodeId"] # ) # if response.get("result", {}).get("subtype", "") == "error": # cur_node["union_bound"] = None # else: # x = response["result"]["value"]["x"] # y = response["result"]["value"]["y"] # width = response["result"]["value"]["width"] # height = response["result"]["value"]["height"] # cur_node["union_bound"] = [x, y, width, height] # method 2 bound = [0.0, 0.0, 0.0, 0.0] if node_idx in layout["nodeIndex"]: bound = layout["bounds"][layout["nodeIndex"].index(node_idx)] bound[0] -= config["win_left_bound"] bound[1] -= config["win_top_bound"] cur_node["union_bound"] = bound dom_tree.append(cur_node) print('[build]', time.time() - stt) stt = time.time() # add parent children index to the node for parent_id, child_ids in graph.items(): dom_tree[int(parent_id)]["childIds"] = child_ids print('[graph]', time.time() - stt) # with open('output/dom_tree.json', 'w') as f: # f.write(json.dumps(dom_tree, ensure_ascii=False)) stt = time.time() # remove the nodes that are not in the current viewport if current_viewport_only: def remove_node_in_graph(node: DOMNode) -> None: # update the node information in the accessibility tree node_id = node["nodeId"] parent_id = node["parentId"] child_ids = node["childIds"] # update the children of the parent node assert dom_tree[int(parent_id)]["parentId"] != "[REMOVED]" # remove the nodeid from parent index = dom_tree[int(parent_id)]["childIds"].index(node_id) dom_tree[int(parent_id)]["childIds"].pop(index) # Insert children_nodeids in the same location for child_id in child_ids: dom_tree[int(parent_id)]["childIds"].insert( index, child_id ) index += 1 # update children node's parent for child_id in child_ids: dom_tree[int(child_id)]["parentId"] = parent_id # mark as removed dom_tree[int(node_id)]["parentId"] = "[REMOVED]" config = info["config"] for cursor, node in enumerate(dom_tree): if not node["union_bound"]: remove_node_in_graph(node) continue [x, y, width, height] = node["union_bound"] # invisible node if width == 0.0 or height == 0.0: parent_id = node["parentId"] if node["nodeName"] not in ['OPTION'] or dom_tree[int(parent_id)]["nodeName"] not in ["SELECT"]: remove_node_in_graph(node) continue in_viewport_ratio = self.get_element_in_viewport_ratio( elem_left_bound=float(x), elem_top_bound=float(y), width=float(width), height=float(height), config=config, ) if in_viewport_ratio < IN_VIEWPORT_RATIO_THRESHOLD: remove_node_in_graph(node) dom_tree = [ node for node in dom_tree if node.get("parentId", "-1") != "[REMOVED]" ] print('[filter]', time.time() - stt) return dom_tree @staticmethod def parse_my_html(dom_tree: DOMTree) -> tuple[str, str, dict[str, Any], Any]: """Parse the html tree into a string text""" obs_nodes_info = {} nodeid_to_cursor = { node["nodeId"]: idx for idx, node in enumerate(dom_tree) } def dfs(node_cursor: int, depth: int) -> tuple[str, list[str]]: tree_str, labeled_elems = '', [] node = dom_tree[node_cursor] valid_node = True pure_text = False try: if node['nodeName'] == '#text': node['nodeName'] = 'text' node_str = f"<{node['nodeName']}" if node["attributes"]: node_str += f" {node['attributes']}" node_str += f" backend-id=\"bid-{node['backendNodeId']}\"> {node['nodeValue']}" # if node['nodeName'] == '#text': # pure_text = True # node_str = node['nodeValue'] valid_node = bool(node["attributes"] or node["nodeValue"] or pure_text) if valid_node: node_html = lxml.html.fromstring(node_str) label = node_html.attrib.get('data-testid', '') if len(label) > 0: labeled_elems.append(node["backendNodeId"]) obs_nodes_info[str(node_cursor)] = { "backend_id": node["backendNodeId"], "union_bound": node["union_bound"], "text": node['nodeValue'], "label": label, } tree_str += f"{node_str}" except Exception as e: valid_node = False for child_ids in node["childIds"]: child_cursor = nodeid_to_cursor[child_ids] child_depth = depth + 1 if valid_node else depth child_str, elems = dfs(child_cursor, child_depth) tree_str += child_str labeled_elems.extend(elems) if valid_node and not pure_text: tree_str += f"" return tree_str, labeled_elems html, labeled_elems = dfs(0, 0) # with open('output/raw.html', 'w') as f: # f.write(html) print(labeled_elems) args = { 'use_position': False, 'id_attr': 'backend-id', 'label_generator': 'order', 'label_attr': 'data-testid', 'attr_list': basic_attrs, 'prompt': 'refine', } hp = HtmlParser(html, args) packet = hp.parse_tree() page_html = packet['html'] print(print_html_object(page_html)) it, pt = packet.get('init_time', 0), packet.get('parse_time', 0) print(f'[Time] {it:.3f} {pt:.3f}') return html, page_html, obs_nodes_info, hp @staticmethod def parse_html(dom_tree: DOMTree) -> tuple[str, dict[str, Any]]: """Parse the html tree into a string text""" obs_nodes_info = {} nodeid_to_cursor = { node["nodeId"]: idx for idx, node in enumerate(dom_tree) } def dfs(node_cursor: int, depth: int) -> str: tree_str = "" node = dom_tree[node_cursor] indent = "\t" * depth valid_node = True try: node_str = f"[{node_cursor}] <{node['nodeName']}" if node["attributes"]: node_str += f" {node['attributes']}" node_str += f"> {node['nodeValue']}" valid_node = bool(node["attributes"] or node["nodeValue"]) if valid_node: obs_nodes_info[str(node_cursor)] = { "backend_id": node["backendNodeId"], "union_bound": node["union_bound"], "text": node_str, } tree_str += f"{indent}{node_str}\n" except Exception as e: valid_node = False for child_ids in node["childIds"]: child_cursor = nodeid_to_cursor[child_ids] child_depth = depth + 1 if valid_node else depth child_str = dfs(child_cursor, child_depth) tree_str += child_str return tree_str html = dfs(0, 0) return html, obs_nodes_info def fetch_page_accessibility_tree( self, info: BrowserInfo, client: CDPSession, current_viewport_only: bool, ) -> AccessibilityTree: accessibility_tree: AccessibilityTree = client.send( "Accessibility.getFullAXTree", {} )["nodes"] # a few nodes are repeated in the accessibility tree seen_ids = set() _accessibility_tree = [] for node in accessibility_tree: if node["nodeId"] not in seen_ids: _accessibility_tree.append(node) seen_ids.add(node["nodeId"]) accessibility_tree = _accessibility_tree nodeid_to_cursor = {} for cursor, node in enumerate(accessibility_tree): nodeid_to_cursor[node["nodeId"]] = cursor # usually because the node is not visible etc if "backendDOMNodeId" not in node: node["union_bound"] = None continue backend_node_id = str(node["backendDOMNodeId"]) if node["role"]["value"] == "RootWebArea": # always inside the viewport node["union_bound"] = [0.0, 0.0, 10.0, 10.0] else: response = self.get_bounding_client_rect( client, backend_node_id ) if response.get("result", {}).get("subtype", "") == "error": node["union_bound"] = None else: x = response["result"]["value"]["x"] y = response["result"]["value"]["y"] width = response["result"]["value"]["width"] height = response["result"]["value"]["height"] node["union_bound"] = [x, y, width, height] # filter nodes that are not in the current viewport if current_viewport_only: def remove_node_in_graph(node: AccessibilityTreeNode) -> None: # update the node information in the accessibility tree nodeid = node["nodeId"] node_cursor = nodeid_to_cursor[nodeid] parent_nodeid = node["parentId"] children_nodeids = node["childIds"] parent_cursor = nodeid_to_cursor[parent_nodeid] # update the children of the parent node assert ( accessibility_tree[parent_cursor].get("parentId", "Root") is not None ) # remove the nodeid from parent's childIds index = accessibility_tree[parent_cursor]["childIds"].index( nodeid ) accessibility_tree[parent_cursor]["childIds"].pop(index) # Insert children_nodeids in the same location for child_nodeid in children_nodeids: accessibility_tree[parent_cursor]["childIds"].insert( index, child_nodeid ) index += 1 # update children node's parent for child_nodeid in children_nodeids: child_cursor = nodeid_to_cursor[child_nodeid] accessibility_tree[child_cursor][ "parentId" ] = parent_nodeid # mark as removed accessibility_tree[node_cursor]["parentId"] = "[REMOVED]" config = info["config"] for node in accessibility_tree: if not node["union_bound"]: remove_node_in_graph(node) continue [x, y, width, height] = node["union_bound"] # invisible node if width == 0 or height == 0: remove_node_in_graph(node) continue in_viewport_ratio = self.get_element_in_viewport_ratio( elem_left_bound=float(x), elem_top_bound=float(y), width=float(width), height=float(height), config=config, ) if in_viewport_ratio < IN_VIEWPORT_RATIO_THRESHOLD: remove_node_in_graph(node) accessibility_tree = [ node for node in accessibility_tree if node.get("parentId", "Root") != "[REMOVED]" ] return accessibility_tree @staticmethod def parse_accessibility_tree( accessibility_tree: AccessibilityTree, ) -> tuple[str, dict[str, Any], TreeNode]: """Parse the accessibility tree into a string text""" node_id_to_idx = {} for idx, node in enumerate(accessibility_tree): node_id_to_idx[node["nodeId"]] = idx obs_nodes_info = {} def dfs(idx: int, obs_node_id: str, depth: int, active_node_dict: dict) -> str: tree_str = "" node = accessibility_tree[idx] indent = "\t" * depth valid_node = True try: role = node["role"]["value"] name = node["name"]["value"] node_str = f"[{obs_node_id}] {role} {repr(name)}" properties = [] structured_properties = {} for property in node.get("properties", []): try: if property["name"] in IGNORED_ACTREE_PROPERTIES: continue properties.append( f'{property["name"]}: {property["value"]["value"]}' ) structured_properties[property["name"]] = property["value"]["value"] except KeyError: pass if properties: node_str += " " + " ".join(properties) # check valid if not node_str.strip(): valid_node = False # empty generic node if not name.strip(): if not properties: if role in [ "generic", "img", "list", "strong", "paragraph", "banner", "navigation", "Section", "LabelText", "Legend", "listitem", ]: valid_node = False elif role in ["listitem"]: valid_node = False if valid_node: tree_str += f"{indent}{node_str}" obs_nodes_info[obs_node_id] = { "backend_id": node["backendDOMNodeId"], "union_bound": node["union_bound"], "text": node_str, } except Exception as e: valid_node = False structured_node = TreeNode(node_id=int(obs_node_id), role=node["role"]["value"], name=node["name"]["value"], depth=depth, properties=structured_properties) if valid_node else None active_node_dict[depth] = structured_node if valid_node else active_node_dict.get(depth, None) for _, child_node_id in enumerate(node["childIds"]): if child_node_id not in node_id_to_idx: continue # mark this to save some tokens child_depth = depth + 1 if valid_node else depth child_str, child_node = dfs( node_id_to_idx[child_node_id], child_node_id, child_depth, active_node_dict=active_node_dict ) if child_str.strip(): if tree_str.strip(): tree_str += "\n" tree_str += child_str if child_depth > 0 and child_node: active_node_dict[child_depth - 1].add_child(child_node) return tree_str, structured_node tree_str, structured_node = dfs(0, accessibility_tree[0]["nodeId"], 0, active_node_dict={}) return tree_str, obs_nodes_info, structured_node @staticmethod def clean_accesibility_tree(tree_str: str) -> str: """further clean accesibility tree""" clean_lines: list[str] = [] for line in tree_str.split("\n"): if "statictext" in line.lower(): prev_lines = clean_lines[-3:] pattern = r"\[\d+\] StaticText '([^']+)'" match = re.search(pattern, line) if match: static_text = match.group(1) if all( static_text not in prev_line for prev_line in prev_lines ): clean_lines.append(line) else: clean_lines.append(line) return "\n".join(clean_lines) def process(self, page: Page, client: CDPSession, context: str) -> str: # get the tab info open_tabs = page.context.pages # try: # tab_titles = [tab.title() for tab in open_tabs] # current_tab_idx = open_tabs.index(page) # for idx in range(len(open_tabs)): # if idx == current_tab_idx: # tab_titles[ # idx # ] = f"Tab {idx} (current): {open_tabs[idx].title()}" # else: # tab_titles[idx] = f"Tab {idx}: {open_tabs[idx].title()}" # tab_title_str = " | ".join(tab_titles) # except Exception: # tab_title_str = " | ".join( # ["Tab {idx}" for idx in range(len(open_tabs))] # ) try: tab_titles = [tab.title() for tab in open_tabs] current_tab_idx = open_tabs.index(page) for idx in range(len(open_tabs)): if idx == current_tab_idx: tab_titles[ idx ] = f"{idx+1}. {open_tabs[idx].title()} <-- current tab" else: tab_titles[idx] = f"{idx+1}. {open_tabs[idx].title()}" tab_title_str = "\n".join(tab_titles) except Exception: tab_title_str = "\n".join( [f"{idx+1}. Default" for idx in range(len(open_tabs))] ) try: browser_info = self.fetch_browser_info(page, client) except Exception: page.wait_for_load_state("load", timeout=500) browser_info = self.fetch_browser_info(page, client) if self.observation_type == "html": import time stt = time.time() dom_tree = self.fetch_page_html( browser_info, page, client, current_viewport_only=self.current_viewport_only, ) print('[fetch]', time.time() - stt) stt = time.time() raw_html, content, obs_nodes_info, hp = self.parse_my_html(dom_tree) print('[parse]', time.time() - stt) window_height = page.evaluate("window.innerHeight") page_height = page.evaluate('document.documentElement.scrollHeight') / window_height position = page.evaluate("window.scrollY") / window_height self.obs_nodes_info = obs_nodes_info self.meta_data["obs_nodes_info"] = obs_nodes_info self.meta_data["position_info"] = { "page_height": page_height, "position": position, } self.meta_data["dom_info"] = { "raw_html": raw_html, "dom_tree": dom_tree, } self.meta_data["html_parser"] = hp self.meta_data["tab_title"] = tab_title_str elif self.observation_type == "accessibility_tree": accessibility_tree = self.fetch_page_accessibility_tree( browser_info, client, current_viewport_only=self.current_viewport_only, ) content, obs_nodes_info, node_root = self.parse_accessibility_tree( accessibility_tree ) content = self.clean_accesibility_tree(content) self.obs_nodes_info = obs_nodes_info page_dialog_message = getattr(page, "dialog_message", "") if page_dialog_message: import copy node_root.properties["page_dialog_message"] = copy.deepcopy(page_dialog_message) + " Retry." page.dialog_message = None self.node_root = node_root self.meta_data["obs_nodes_info"] = obs_nodes_info else: raise ValueError( f"Invalid observatrion type: {self.observation_type}" ) self.browser_config = browser_info["config"] # content = f"{tab_title_str}\n\n{content}" return (content, node_root) def get_node_info_by_element_id(self, AXTreeId): return self.node_root.search_node_by_id(AXTreeId) def get_element_center(self, element_id: str, page) -> tuple[float, float]: node = self.obs_nodes_info[element_id] backend_node_id = str(node["backend_id"]) response = self.get_bounding_client_rect( page.client, backend_node_id ) x = response["result"]["value"]["x"] y = response["result"]["value"]["y"] width = response["result"]["value"]["width"] height = response["result"]["value"]["height"] center_x = x + width / 2 center_y = y + height / 2 return ( center_x / self.viewport_size["width"], center_y / self.viewport_size["height"], ) class ImageObservationProcessor(ObservationProcessor): def __init__(self, observation_type: str, current_viewport_only: bool): self.observation_type = observation_type self.current_viewport_only = current_viewport_only self.observation_tag = "image" self.meta_data = create_empty_metadata() def process(self, page: Page, client: CDPSession, context: str) -> npt.NDArray[np.uint8]: try: screenshot = png_bytes_to_numpy(page.screenshot(full_page=(not self.current_viewport_only))) screenshot = screenshot[:2*screenshot.shape[1], :, :] except: page.wait_for_event("load") screenshot = png_bytes_to_numpy(page.screenshot(full_page=(not self.current_viewport_only))) return screenshot class ObservationHandler: """Main entry point to access all observation processor""" def __init__( self, main_observation_type: str, text_observation_type: str, image_observation_type: str, current_viewport_only: bool, viewport_size: ViewportSize, ) -> None: self.main_observation_type = main_observation_type self.text_processor = TextObervationProcessor( text_observation_type, current_viewport_only, viewport_size ) self.image_processor = ImageObservationProcessor( image_observation_type, current_viewport_only ) self.viewport_size = viewport_size def get_observation_space(self) -> spaces.Dict: text_space = spaces.Text( min_length=0, max_length=UTTERANCE_MAX_LENGTH, charset=ASCII_CHARSET + FREQ_UNICODE_CHARSET, ) image_space = spaces.Box( # Each position stores the RGB values. Note the swapped axes (height first). np.zeros( (self.viewport_size["height"], self.viewport_size["width"], 3), dtype=np.uint8, ), np.ones( (self.viewport_size["height"], self.viewport_size["width"], 3), dtype=np.uint8, ) * 255.0, dtype=np.uint8, ) return spaces.Dict({"text": text_space, "image": image_space}) def get_observation( self, page: Page, client: CDPSession, context: str = '', ) -> dict[str, Observation]: text_obs = self.text_processor.process(page, client, context) image_obs = self.image_processor.process(page, client, context) return {"text": text_obs, "image": image_obs} def get_observation_metadata(self) -> dict[str, ObservationMetadata]: return { "text": self.text_processor.meta_data, "image": self.image_processor.meta_data, } @property def action_processor(self) -> ObservationProcessor: """Return the main processor that is associated with the action space""" if self.main_observation_type == "text": return self.text_processor elif self.main_observation_type == "image": return self.image_processor else: raise ValueError("Invalid main observation type")