426 lines
17 KiB
Python
426 lines
17 KiB
Python
from lxml import html
|
|
import time, copy, random
|
|
import json, re, os
|
|
|
|
from .identifier import IdentifierTool
|
|
from .prompt import HtmlPrompt
|
|
from .configs import config_meta
|
|
from .utils import get_xpath_top_down, rect2tuple
|
|
|
|
class HtmlParser():
|
|
def __init__(self, ctx: str, args: dict[str]={}) -> None:
|
|
stt = time.time()
|
|
self.dom_tree = self.ctx2tree(ctx)
|
|
# tool related
|
|
self.bids2label = {}
|
|
self.bids2xpath = {}
|
|
self.used_labels = {}
|
|
|
|
# parse args
|
|
self.parse_args(args)
|
|
self.init_time = time.time() - stt
|
|
|
|
def parse_args(self, args: dict[str]={}) -> None:
|
|
def attr_check(attr, type_model='str'):
|
|
if attr is None:
|
|
return False
|
|
attr_type = type(attr)
|
|
if attr_type != type(type_model):
|
|
return False
|
|
if attr_type == type('str') and len(attr) == 0:
|
|
return False
|
|
return True
|
|
|
|
args = {} if args is None else args
|
|
|
|
# [Position] use_pos: False -> use full page, otherwise use window_size
|
|
dataset = args.get('dataset', '')
|
|
use_position = args.get('use_position', False)
|
|
window_size = args.get('window_size', None)
|
|
rect = args.get('rect_dict', None)
|
|
if use_position:
|
|
if not attr_check(window_size, ()):
|
|
raise ValueError('window_size must be set when use_position is True')
|
|
if not attr_check(rect, {}):
|
|
raise ValueError('rect_dict must be set when use_position is True')
|
|
|
|
if not attr_check(rect, {}):
|
|
rect = {}
|
|
|
|
# [Label] for vimium is temp_clickable_label, otherwise keep all of it
|
|
label_attr = args.get('label_attr', '')
|
|
get_new_label = args.get('regenerate_label', False)
|
|
label_method = args.get('label_generator', None)
|
|
regen_label = not attr_check(label_method)
|
|
|
|
# [id] for mind2web is backend_node_id, for normal website use our method
|
|
id_attr = args.get('id_attr', '')
|
|
regen_id = not attr_check(id_attr)
|
|
|
|
if regen_id:
|
|
id_attr = 'temp_id'
|
|
|
|
# [attributes]
|
|
keep_attrs = args.get('attr_list', [])
|
|
if not attr_check(keep_attrs, []):
|
|
keep_attrs = []
|
|
|
|
# [Tags] for clickable elem, keep: must keep, obs: keep if follow specific rule
|
|
parent_chain = args.get('parent_chain', False)
|
|
keep_elem = args.get('keep_elem', [])
|
|
obs_elem = args.get('obs_elem', [])
|
|
|
|
# sanity check
|
|
self.set_args(use_position, window_size, rect, label_attr, id_attr, keep_attrs, keep_elem, obs_elem, parent_chain, get_new_label, dataset)
|
|
|
|
# [Prompt]
|
|
prompt = args.get('prompt', None)
|
|
self.prompt = HtmlPrompt(prompt)
|
|
|
|
# traverse and get special data
|
|
if regen_id or regen_label:
|
|
self.mark_id()
|
|
|
|
if get_new_label:
|
|
self.used_labels = {}
|
|
|
|
self.identifier = IdentifierTool(label_method, self.used_labels)
|
|
|
|
def set_args(self, use_position: bool=False, window_size: tuple=(), rect_dict: dict[str]={}, label_attr: str='',
|
|
id_attr: str='', keep_attrs: list[str]=[], keep_elem: list[str]=[], obs_elem: list[str]=[],
|
|
parent_chain: bool=False, get_new_label: bool=False, dataset: str='') -> None:
|
|
|
|
self.use_position = use_position
|
|
self.window_size = window_size
|
|
self.rect = rect_dict
|
|
self.label_attr = label_attr
|
|
self.id_attr = id_attr
|
|
self.keep_attrs = keep_attrs
|
|
self.keep = keep_elem
|
|
self.obs = obs_elem
|
|
self.parent_chain = parent_chain
|
|
self.get_new_label = get_new_label
|
|
self.dataset = dataset
|
|
|
|
def get_config(self):
|
|
config = {
|
|
'id_attr': self.id_attr,
|
|
'keep_attrs': self.keep_attrs[:5],
|
|
'label_attr': self.label_attr,
|
|
'use_position': self.use_position,
|
|
'window_size': self.window_size,
|
|
'rect': dict(list(self.rect.items())[:3]),
|
|
'keep_elem': self.keep[:5],
|
|
'obs_elem': self.obs[:5],
|
|
'parent_chain': self.parent_chain,
|
|
'prompt_name': self.prompt.name,
|
|
'identifier_name': self.identifier.name
|
|
}
|
|
|
|
return config, config_meta.format(**config)
|
|
|
|
def update_rect_dict(self, rect_dict: dict[str]={}) -> None:
|
|
self.rect = rect_dict
|
|
|
|
@staticmethod
|
|
def ctx2tree(ctx: str) -> html.HtmlElement:
|
|
# remove useless tags, eg. style and script
|
|
ctx = re.sub('<!--[\W\w]*?-->', '', ctx)
|
|
ctx = re.sub('<style[\W\w]*?>[\W\w]*?</style>', '', ctx)
|
|
ctx = re.sub('<script[\W\w]*?>[\W\w]*?</script>', '', ctx)
|
|
ctx = '' if ctx is None else re.sub(r'\s+', ' ', ctx).strip()
|
|
dom_tree = html.fromstring(ctx)
|
|
return dom_tree
|
|
|
|
@staticmethod
|
|
def get_root(tree: html.HtmlElement) -> html.HtmlElement:
|
|
node = tree.xpath('//*')[0]
|
|
while True:
|
|
parent = node.getparent()
|
|
if parent is None:
|
|
break
|
|
node = parent
|
|
return node
|
|
|
|
def get_node_by_bid(self, tree: html.HtmlElement, bid: str) -> html.HtmlElement:
|
|
nodes = tree.xpath(f'//*[@{self.id_attr}="{bid}"]')
|
|
if len(nodes) == 0:
|
|
return None
|
|
return nodes[0]
|
|
|
|
def id_label_converter(self, label: str) -> str:
|
|
return self.bids2label.get(label, '')
|
|
|
|
def id_xpath_converter(self, label: str) -> str:
|
|
return self.bids2xpath.get(label, '')
|
|
|
|
def mark_id(self) -> None:
|
|
root = self.get_root(self.dom_tree)
|
|
_, i2xpath, used_labels = get_xpath_top_down(root, self.id_attr, self.label_attr)
|
|
self.used_labels = used_labels
|
|
self.bids2xpath = i2xpath
|
|
|
|
def parse(self, root: html.HtmlElement, keep: list[str], obs: list[str], parent_chain: bool=False, get_new_label: bool=False) -> dict[str]:
|
|
def get_text(str: str) -> str:
|
|
return '' if str is None else str.strip()[:500]
|
|
|
|
def check_attr(attr: str, node: html.HtmlElement) -> bool:
|
|
tag = node.tag
|
|
if (
|
|
( attr == 'role' and node.attrib.get(attr, '') in ['presentation', 'none', 'link'] )
|
|
or ( attr == 'type' and node.attrib.get(attr, '') == 'hidden' )
|
|
# or ( attr == 'value' and tag in ['option'] )
|
|
):
|
|
return False
|
|
return True
|
|
|
|
def is_visible(node: html.HtmlElement, bid: str) -> bool:
|
|
if self.dataset == 'mind2web':
|
|
bound = node.attrib.get('bounding_box_rect', None)
|
|
self.rect[bid] = rect2tuple(bound)
|
|
|
|
if not self.use_position:
|
|
return True
|
|
|
|
rect = self.rect.get(bid, None)
|
|
if rect is None:
|
|
return False
|
|
|
|
if self.window_size is None:
|
|
return True
|
|
|
|
# get window size
|
|
wx, wy, ww, wh = self.window_size
|
|
x, y, w, h = rect
|
|
if x + w < wx or x > wx + ww or y + h < wy or y > wy + wh:
|
|
return False
|
|
|
|
return True
|
|
|
|
def _dfs(node: html.HtmlElement, keep: list[str]=[], obs: list[str]=[],
|
|
parent_chain: bool=False, get_new_label: bool=False, par_keep: bool=False) -> (str, dict[str]):
|
|
# basic information
|
|
bid = node.attrib.get(self.id_attr, '')
|
|
tag = node.tag
|
|
label = node.attrib.get(self.label_attr, '')
|
|
|
|
# element which is keeped equivalent to visible
|
|
visible = is_visible(node, bid)
|
|
in_keep_list = bid in keep
|
|
in_obs_list = (bid in obs or len(label) > 0) and visible
|
|
keep_element = in_keep_list or in_obs_list or visible or par_keep
|
|
|
|
# mark label
|
|
bids2label, labeled_elems = {}, []
|
|
have_label = False
|
|
if in_keep_list or in_obs_list:
|
|
if label is None or len(label) == 0 or get_new_label:
|
|
label = self.identifier.generate()
|
|
node.attrib[self.label_attr] = label
|
|
bids2label[bid] = label
|
|
bids2label[label] = bid
|
|
have_label = True
|
|
|
|
# get text or alt_text of current element
|
|
text = get_text(node.text)
|
|
|
|
classes = {}
|
|
# keep attributes if needed
|
|
keep_all_attrs = len(self.keep_attrs) == 0
|
|
keep_attrs = node.attrib.keys() if keep_all_attrs else self.keep_attrs
|
|
|
|
# traverse attributes
|
|
for attr in keep_attrs:
|
|
if attr not in node.attrib or not check_attr(attr, node):
|
|
continue
|
|
if attr in [self.id_attr, self.label_attr]:
|
|
continue
|
|
val = get_text(node.attrib[attr])
|
|
if len(val) > 0 or keep_all_attrs:
|
|
classes[attr] = val
|
|
|
|
have_text = len(text) > 0 or len(classes) > 0
|
|
|
|
parts = []
|
|
clickable_count = 0
|
|
children = node.getchildren()
|
|
for child in children:
|
|
cres, cmsg = _dfs(child, keep, obs, parent_chain, get_new_label)
|
|
clickable_count += 1 if cmsg.get('have_clickable', False) else 0
|
|
bids2label.update(cmsg.get('bids2label', {}))
|
|
labeled_elems.extend(cmsg.get('label_element', []))
|
|
if len(cres) != 0:
|
|
parts.append(cres)
|
|
|
|
dom = self.prompt.subtree_constructor(parts)
|
|
|
|
# remove <text|> if all children are text
|
|
keep_as_all_text = (dom.count('<') == dom.count('<text|')) and dom.count('<') > 0
|
|
if keep_as_all_text:
|
|
matches = re.findall(r'<text\| ([^>]+) >', dom)
|
|
dom = self.prompt.subtree_constructor(matches)
|
|
|
|
keep_element = keep_element and (clickable_count > 1 or have_text or have_label or keep_as_all_text)
|
|
keep_as_parent = len(dom) > 0 and parent_chain
|
|
if in_keep_list or keep_element or keep_as_parent:
|
|
dom = self.prompt.prompt_constructor(tag, label, text, dom, classes)
|
|
|
|
if have_label:
|
|
labeled_elems.append(bid)
|
|
|
|
control_msg = {
|
|
'have_clickable': bool(clickable_count or have_text),
|
|
'bids2label': bids2label,
|
|
'label_element': labeled_elems,
|
|
}
|
|
|
|
return dom, control_msg
|
|
|
|
dom, cmsg = _dfs(root, keep, obs, parent_chain, get_new_label)
|
|
return dom, cmsg
|
|
|
|
def parse_tree(self) -> dict[str]:
|
|
# start from here
|
|
stt = time.time()
|
|
root = self.get_root(self.dom_tree)
|
|
dom, cmsg = self.parse(root, self.keep, self.obs, self.parent_chain, self.get_new_label)
|
|
self.bids2label = cmsg.get('bids2label', {})
|
|
self.keep = list(set(self.keep + cmsg.get('label_element', [])))
|
|
|
|
obj = {
|
|
'html': dom,
|
|
'parse_time': time.time() - stt
|
|
}
|
|
|
|
return obj
|
|
|
|
# From mind2web, https://github.com/OSU-NLP-Group/Mind2Web/blob/main/src/data_utils/dom_utils.py
|
|
def get_keep_elements(self, tree: html.HtmlElement, keep: list[str], max_depth: int, max_children: int,
|
|
max_sibling: int, dfs_count: int=1, keep_parent: bool=False) -> list[str]:
|
|
def get_anscendants(node: html.HtmlElement, max_depth: int, current_depth: int=0) -> list[str]:
|
|
if current_depth > max_depth:
|
|
return []
|
|
|
|
anscendants = []
|
|
parent = node.getparent()
|
|
if parent is not None:
|
|
anscendants.append(parent)
|
|
anscendants.extend(get_anscendants(parent, max_depth, current_depth + 1))
|
|
|
|
return anscendants
|
|
|
|
def get_descendants(node: html.HtmlElement, max_depth: int, current_depth: int=0) -> list[str]:
|
|
if current_depth > max_depth:
|
|
return []
|
|
|
|
descendants = []
|
|
for child in node:
|
|
descendants.append(child)
|
|
descendants.extend(get_descendants(child, max_depth, current_depth + 1))
|
|
|
|
return descendants
|
|
|
|
to_keep = set(copy.deepcopy(keep))
|
|
nodes_to_keep = set()
|
|
|
|
for _ in range(max(1, dfs_count)):
|
|
for bid in to_keep:
|
|
candidate_node = self.get_node_by_bid(tree, bid)
|
|
if candidate_node is None:
|
|
continue
|
|
|
|
nodes_to_keep.add(candidate_node.attrib[self.id_attr])
|
|
# get all ancestors or with max depth
|
|
nodes_to_keep.update([x.attrib.get(self.id_attr, '') for x in get_anscendants(candidate_node, max_depth)])
|
|
|
|
# get descendants with max depth
|
|
nodes_to_keep.update([x.attrib.get(self.id_attr, '') for x in get_descendants(candidate_node, max_depth)][:max_children])
|
|
# get siblings within range
|
|
parent = candidate_node.getparent()
|
|
if parent is None:
|
|
continue
|
|
|
|
siblings = [x for x in parent.getchildren() if x.tag != 'text']
|
|
if candidate_node not in siblings:
|
|
continue
|
|
|
|
idx_in_sibling = siblings.index(candidate_node)
|
|
nodes_to_keep.update([x.attrib.get(self.id_attr, '')
|
|
for x in siblings[max(0, idx_in_sibling - max_sibling) : idx_in_sibling + max_sibling + 1]])
|
|
|
|
max_children = int(max_children * 0.5)
|
|
max_depth = int(max_depth * 0.5)
|
|
max_sibling = int(max_sibling * 0.7)
|
|
|
|
to_keep = copy.deepcopy(nodes_to_keep)
|
|
|
|
if keep_parent:
|
|
for bid in keep:
|
|
candidate_node = self.get_node_by_bid(tree, bid)
|
|
if candidate_node is None:
|
|
continue
|
|
nodes_to_keep.update([x.attrib.get(self.id_attr, '') for x in candidate_node.xpath("ancestor::*")])
|
|
|
|
return list(nodes_to_keep)
|
|
|
|
def prune(self, tree: html.HtmlElement, nodes_to_keep: list[str]) -> html.HtmlElement:
|
|
# remove nodes not in nodes_to_keep
|
|
for node in tree.xpath('//*')[::-1]:
|
|
if node.tag != 'text':
|
|
is_keep = node.attrib.get(self.id_attr, '') in nodes_to_keep
|
|
is_candidate = node.attrib.get(self.id_attr, '') in self.keep
|
|
else:
|
|
is_keep = (node.getparent().attrib.get(self.id_attr, '') in nodes_to_keep)
|
|
is_candidate = (node.getparent().attrib.get(self.id_attr, '') in self.keep)
|
|
|
|
if not is_keep and node.getparent() is not None:
|
|
# insert all children into parent
|
|
for child in node.getchildren():
|
|
node.addprevious(child)
|
|
node.getparent().remove(node)
|
|
else:
|
|
# if not is_candidate or node.tag == 'text':
|
|
# node.attrib.pop(self.id_attr, None)
|
|
if (
|
|
len(node.attrib) == 0
|
|
and not any([x.tag == 'text' for x in node.getchildren()])
|
|
and node.getparent() is not None
|
|
and node.tag != "text"
|
|
and len(node.getchildren()) <= 1
|
|
):
|
|
# insert all children into parent
|
|
for child in node.getchildren():
|
|
node.addprevious(child)
|
|
node.getparent().remove(node)
|
|
|
|
return tree
|
|
|
|
def prune_tree(self, dfs_count: int=1, max_depth: int=3, max_children: int=30,
|
|
max_sibling: int=3, keep_parent: bool=False) -> None:
|
|
# clone the tree
|
|
new_tree = copy.deepcopy(self.dom_tree)
|
|
nodes_to_keep = self.get_keep_elements(new_tree, self.keep, max_depth, max_children, max_sibling, dfs_count, keep_parent)
|
|
new_tree = self.prune(new_tree, nodes_to_keep)
|
|
|
|
self.dom_tree = new_tree
|
|
|
|
def get_segment(self, bid: str) -> str:
|
|
# clone the tree
|
|
new_tree = copy.deepcopy(self.dom_tree)
|
|
nodes_to_keep = self.get_keep_elements(new_tree, [bid], 0, 2, 1)
|
|
new_tree = self.prune(new_tree, nodes_to_keep)
|
|
dom, _ = self.parse(new_tree, self.keep, [], False)
|
|
return dom
|
|
|
|
def get_rect_data(self, bids: list[str]) -> list[dict[str]]:
|
|
res = []
|
|
for bid in bids:
|
|
label = self.bids2label.get(bid, '')
|
|
rect = self.rect.get(bid, None)
|
|
res.append({
|
|
'bid': bid,
|
|
'label': label,
|
|
'rect': rect
|
|
})
|
|
return res
|
|
|