AgentOccam/browser_env/html_tools/html_parser.py
2025-01-22 11:32:35 -08:00

426 lines
17 KiB
Python

from lxml import html
import time, copy, random
import json, re, os
from .identifier import IdentifierTool
from .prompt import HtmlPrompt
from .configs import config_meta
from .utils import get_xpath_top_down, rect2tuple
class HtmlParser():
def __init__(self, ctx: str, args: dict[str]={}) -> None:
stt = time.time()
self.dom_tree = self.ctx2tree(ctx)
# tool related
self.bids2label = {}
self.bids2xpath = {}
self.used_labels = {}
# parse args
self.parse_args(args)
self.init_time = time.time() - stt
def parse_args(self, args: dict[str]={}) -> None:
def attr_check(attr, type_model='str'):
if attr is None:
return False
attr_type = type(attr)
if attr_type != type(type_model):
return False
if attr_type == type('str') and len(attr) == 0:
return False
return True
args = {} if args is None else args
# [Position] use_pos: False -> use full page, otherwise use window_size
dataset = args.get('dataset', '')
use_position = args.get('use_position', False)
window_size = args.get('window_size', None)
rect = args.get('rect_dict', None)
if use_position:
if not attr_check(window_size, ()):
raise ValueError('window_size must be set when use_position is True')
if not attr_check(rect, {}):
raise ValueError('rect_dict must be set when use_position is True')
if not attr_check(rect, {}):
rect = {}
# [Label] for vimium is temp_clickable_label, otherwise keep all of it
label_attr = args.get('label_attr', '')
get_new_label = args.get('regenerate_label', False)
label_method = args.get('label_generator', None)
regen_label = not attr_check(label_method)
# [id] for mind2web is backend_node_id, for normal website use our method
id_attr = args.get('id_attr', '')
regen_id = not attr_check(id_attr)
if regen_id:
id_attr = 'temp_id'
# [attributes]
keep_attrs = args.get('attr_list', [])
if not attr_check(keep_attrs, []):
keep_attrs = []
# [Tags] for clickable elem, keep: must keep, obs: keep if follow specific rule
parent_chain = args.get('parent_chain', False)
keep_elem = args.get('keep_elem', [])
obs_elem = args.get('obs_elem', [])
# sanity check
self.set_args(use_position, window_size, rect, label_attr, id_attr, keep_attrs, keep_elem, obs_elem, parent_chain, get_new_label, dataset)
# [Prompt]
prompt = args.get('prompt', None)
self.prompt = HtmlPrompt(prompt)
# traverse and get special data
if regen_id or regen_label:
self.mark_id()
if get_new_label:
self.used_labels = {}
self.identifier = IdentifierTool(label_method, self.used_labels)
def set_args(self, use_position: bool=False, window_size: tuple=(), rect_dict: dict[str]={}, label_attr: str='',
id_attr: str='', keep_attrs: list[str]=[], keep_elem: list[str]=[], obs_elem: list[str]=[],
parent_chain: bool=False, get_new_label: bool=False, dataset: str='') -> None:
self.use_position = use_position
self.window_size = window_size
self.rect = rect_dict
self.label_attr = label_attr
self.id_attr = id_attr
self.keep_attrs = keep_attrs
self.keep = keep_elem
self.obs = obs_elem
self.parent_chain = parent_chain
self.get_new_label = get_new_label
self.dataset = dataset
def get_config(self):
config = {
'id_attr': self.id_attr,
'keep_attrs': self.keep_attrs[:5],
'label_attr': self.label_attr,
'use_position': self.use_position,
'window_size': self.window_size,
'rect': dict(list(self.rect.items())[:3]),
'keep_elem': self.keep[:5],
'obs_elem': self.obs[:5],
'parent_chain': self.parent_chain,
'prompt_name': self.prompt.name,
'identifier_name': self.identifier.name
}
return config, config_meta.format(**config)
def update_rect_dict(self, rect_dict: dict[str]={}) -> None:
self.rect = rect_dict
@staticmethod
def ctx2tree(ctx: str) -> html.HtmlElement:
# remove useless tags, eg. style and script
ctx = re.sub('<!--[\W\w]*?-->', '', ctx)
ctx = re.sub('<style[\W\w]*?>[\W\w]*?</style>', '', ctx)
ctx = re.sub('<script[\W\w]*?>[\W\w]*?</script>', '', ctx)
ctx = '' if ctx is None else re.sub(r'\s+', ' ', ctx).strip()
dom_tree = html.fromstring(ctx)
return dom_tree
@staticmethod
def get_root(tree: html.HtmlElement) -> html.HtmlElement:
node = tree.xpath('//*')[0]
while True:
parent = node.getparent()
if parent is None:
break
node = parent
return node
def get_node_by_bid(self, tree: html.HtmlElement, bid: str) -> html.HtmlElement:
nodes = tree.xpath(f'//*[@{self.id_attr}="{bid}"]')
if len(nodes) == 0:
return None
return nodes[0]
def id_label_converter(self, label: str) -> str:
return self.bids2label.get(label, '')
def id_xpath_converter(self, label: str) -> str:
return self.bids2xpath.get(label, '')
def mark_id(self) -> None:
root = self.get_root(self.dom_tree)
_, i2xpath, used_labels = get_xpath_top_down(root, self.id_attr, self.label_attr)
self.used_labels = used_labels
self.bids2xpath = i2xpath
def parse(self, root: html.HtmlElement, keep: list[str], obs: list[str], parent_chain: bool=False, get_new_label: bool=False) -> dict[str]:
def get_text(str: str) -> str:
return '' if str is None else str.strip()[:500]
def check_attr(attr: str, node: html.HtmlElement) -> bool:
tag = node.tag
if (
( attr == 'role' and node.attrib.get(attr, '') in ['presentation', 'none', 'link'] )
or ( attr == 'type' and node.attrib.get(attr, '') == 'hidden' )
# or ( attr == 'value' and tag in ['option'] )
):
return False
return True
def is_visible(node: html.HtmlElement, bid: str) -> bool:
if self.dataset == 'mind2web':
bound = node.attrib.get('bounding_box_rect', None)
self.rect[bid] = rect2tuple(bound)
if not self.use_position:
return True
rect = self.rect.get(bid, None)
if rect is None:
return False
if self.window_size is None:
return True
# get window size
wx, wy, ww, wh = self.window_size
x, y, w, h = rect
if x + w < wx or x > wx + ww or y + h < wy or y > wy + wh:
return False
return True
def _dfs(node: html.HtmlElement, keep: list[str]=[], obs: list[str]=[],
parent_chain: bool=False, get_new_label: bool=False, par_keep: bool=False) -> (str, dict[str]):
# basic information
bid = node.attrib.get(self.id_attr, '')
tag = node.tag
label = node.attrib.get(self.label_attr, '')
# element which is keeped equivalent to visible
visible = is_visible(node, bid)
in_keep_list = bid in keep
in_obs_list = (bid in obs or len(label) > 0) and visible
keep_element = in_keep_list or in_obs_list or visible or par_keep
# mark label
bids2label, labeled_elems = {}, []
have_label = False
if in_keep_list or in_obs_list:
if label is None or len(label) == 0 or get_new_label:
label = self.identifier.generate()
node.attrib[self.label_attr] = label
bids2label[bid] = label
bids2label[label] = bid
have_label = True
# get text or alt_text of current element
text = get_text(node.text)
classes = {}
# keep attributes if needed
keep_all_attrs = len(self.keep_attrs) == 0
keep_attrs = node.attrib.keys() if keep_all_attrs else self.keep_attrs
# traverse attributes
for attr in keep_attrs:
if attr not in node.attrib or not check_attr(attr, node):
continue
if attr in [self.id_attr, self.label_attr]:
continue
val = get_text(node.attrib[attr])
if len(val) > 0 or keep_all_attrs:
classes[attr] = val
have_text = len(text) > 0 or len(classes) > 0
parts = []
clickable_count = 0
children = node.getchildren()
for child in children:
cres, cmsg = _dfs(child, keep, obs, parent_chain, get_new_label)
clickable_count += 1 if cmsg.get('have_clickable', False) else 0
bids2label.update(cmsg.get('bids2label', {}))
labeled_elems.extend(cmsg.get('label_element', []))
if len(cres) != 0:
parts.append(cres)
dom = self.prompt.subtree_constructor(parts)
# remove <text|> if all children are text
keep_as_all_text = (dom.count('<') == dom.count('<text|')) and dom.count('<') > 0
if keep_as_all_text:
matches = re.findall(r'<text\| ([^>]+) >', dom)
dom = self.prompt.subtree_constructor(matches)
keep_element = keep_element and (clickable_count > 1 or have_text or have_label or keep_as_all_text)
keep_as_parent = len(dom) > 0 and parent_chain
if in_keep_list or keep_element or keep_as_parent:
dom = self.prompt.prompt_constructor(tag, label, text, dom, classes)
if have_label:
labeled_elems.append(bid)
control_msg = {
'have_clickable': bool(clickable_count or have_text),
'bids2label': bids2label,
'label_element': labeled_elems,
}
return dom, control_msg
dom, cmsg = _dfs(root, keep, obs, parent_chain, get_new_label)
return dom, cmsg
def parse_tree(self) -> dict[str]:
# start from here
stt = time.time()
root = self.get_root(self.dom_tree)
dom, cmsg = self.parse(root, self.keep, self.obs, self.parent_chain, self.get_new_label)
self.bids2label = cmsg.get('bids2label', {})
self.keep = list(set(self.keep + cmsg.get('label_element', [])))
obj = {
'html': dom,
'parse_time': time.time() - stt
}
return obj
# From mind2web, https://github.com/OSU-NLP-Group/Mind2Web/blob/main/src/data_utils/dom_utils.py
def get_keep_elements(self, tree: html.HtmlElement, keep: list[str], max_depth: int, max_children: int,
max_sibling: int, dfs_count: int=1, keep_parent: bool=False) -> list[str]:
def get_anscendants(node: html.HtmlElement, max_depth: int, current_depth: int=0) -> list[str]:
if current_depth > max_depth:
return []
anscendants = []
parent = node.getparent()
if parent is not None:
anscendants.append(parent)
anscendants.extend(get_anscendants(parent, max_depth, current_depth + 1))
return anscendants
def get_descendants(node: html.HtmlElement, max_depth: int, current_depth: int=0) -> list[str]:
if current_depth > max_depth:
return []
descendants = []
for child in node:
descendants.append(child)
descendants.extend(get_descendants(child, max_depth, current_depth + 1))
return descendants
to_keep = set(copy.deepcopy(keep))
nodes_to_keep = set()
for _ in range(max(1, dfs_count)):
for bid in to_keep:
candidate_node = self.get_node_by_bid(tree, bid)
if candidate_node is None:
continue
nodes_to_keep.add(candidate_node.attrib[self.id_attr])
# get all ancestors or with max depth
nodes_to_keep.update([x.attrib.get(self.id_attr, '') for x in get_anscendants(candidate_node, max_depth)])
# get descendants with max depth
nodes_to_keep.update([x.attrib.get(self.id_attr, '') for x in get_descendants(candidate_node, max_depth)][:max_children])
# get siblings within range
parent = candidate_node.getparent()
if parent is None:
continue
siblings = [x for x in parent.getchildren() if x.tag != 'text']
if candidate_node not in siblings:
continue
idx_in_sibling = siblings.index(candidate_node)
nodes_to_keep.update([x.attrib.get(self.id_attr, '')
for x in siblings[max(0, idx_in_sibling - max_sibling) : idx_in_sibling + max_sibling + 1]])
max_children = int(max_children * 0.5)
max_depth = int(max_depth * 0.5)
max_sibling = int(max_sibling * 0.7)
to_keep = copy.deepcopy(nodes_to_keep)
if keep_parent:
for bid in keep:
candidate_node = self.get_node_by_bid(tree, bid)
if candidate_node is None:
continue
nodes_to_keep.update([x.attrib.get(self.id_attr, '') for x in candidate_node.xpath("ancestor::*")])
return list(nodes_to_keep)
def prune(self, tree: html.HtmlElement, nodes_to_keep: list[str]) -> html.HtmlElement:
# remove nodes not in nodes_to_keep
for node in tree.xpath('//*')[::-1]:
if node.tag != 'text':
is_keep = node.attrib.get(self.id_attr, '') in nodes_to_keep
is_candidate = node.attrib.get(self.id_attr, '') in self.keep
else:
is_keep = (node.getparent().attrib.get(self.id_attr, '') in nodes_to_keep)
is_candidate = (node.getparent().attrib.get(self.id_attr, '') in self.keep)
if not is_keep and node.getparent() is not None:
# insert all children into parent
for child in node.getchildren():
node.addprevious(child)
node.getparent().remove(node)
else:
# if not is_candidate or node.tag == 'text':
# node.attrib.pop(self.id_attr, None)
if (
len(node.attrib) == 0
and not any([x.tag == 'text' for x in node.getchildren()])
and node.getparent() is not None
and node.tag != "text"
and len(node.getchildren()) <= 1
):
# insert all children into parent
for child in node.getchildren():
node.addprevious(child)
node.getparent().remove(node)
return tree
def prune_tree(self, dfs_count: int=1, max_depth: int=3, max_children: int=30,
max_sibling: int=3, keep_parent: bool=False) -> None:
# clone the tree
new_tree = copy.deepcopy(self.dom_tree)
nodes_to_keep = self.get_keep_elements(new_tree, self.keep, max_depth, max_children, max_sibling, dfs_count, keep_parent)
new_tree = self.prune(new_tree, nodes_to_keep)
self.dom_tree = new_tree
def get_segment(self, bid: str) -> str:
# clone the tree
new_tree = copy.deepcopy(self.dom_tree)
nodes_to_keep = self.get_keep_elements(new_tree, [bid], 0, 2, 1)
new_tree = self.prune(new_tree, nodes_to_keep)
dom, _ = self.parse(new_tree, self.keep, [], False)
return dom
def get_rect_data(self, bids: list[str]) -> list[dict[str]]:
res = []
for bid in bids:
label = self.bids2label.get(bid, '')
rect = self.rect.get(bid, None)
res.append({
'bid': bid,
'label': label,
'rect': rect
})
return res