webrl/VAB-WebArena-Lite/browser_env/html_tools/prompt.py
2025-04-23 17:01:18 +08:00

97 lines
3.8 KiB
Python
Executable File

from .configs import prompts
class HtmlPrompt:
def __init__(self, prompt: str='') -> None:
prompt = self.extract(prompt, 'xml')
if prompt not in prompts:
raise Exception('Unknown prompt: ' + prompt)
constructors = {
'refine': self.normal_prompt_constructor,
'xml': self.normal_prompt_constructor,
'new_data': self.new_data_prompt_constructor,
}
self.name = prompt
self.prompt = prompts[prompt]
self.constructor = constructors[prompt]
@staticmethod
def extract(data, default=''):
return data if data is not None else default
def subtree_constructor(self, subtree: list[str]=[]) -> str:
return self.prompt['subtree_splitter'].join(subtree)
def normal_prompt_constructor(self, tag: str='', label: str='', content: str='', subtree_str: str='', class_dict: dict[str]={}) -> str:
def add_prefix(data, prefix):
return prefix + data if len(data) > 0 else ''
tag = self.extract(tag)
label = self.extract(label)
content = self.extract(content)
subtree_str = self.extract(subtree_str, '')
class_dict = self.extract(class_dict, {})
label_str = ''
if len(label) > 0:
label_str = self.prompt['label'].format(label=label)
classes = []
values = set()
for key, val in class_dict.items():
if val in values:
continue
values.add(val)
classes.append(self.prompt['attr'].format(key=key, attr=val))
classes_str = self.prompt['attr_splitter'].join(classes)
content_splitter = ' ' if len(classes_str) == 0 else self.prompt['attr_splitter']
classes_str = add_prefix(classes_str, ' ')
content_str = add_prefix(content, content_splitter)
subtree_str = add_prefix(subtree_str, ' ')
return self.prompt['dom'].format(tag=tag, label=label_str, attr=classes_str, content=content_str, subtree=subtree_str)
def new_data_prompt_constructor(self, tag: str='', label: str='', content: str='', subtree_str: str='', class_dict: dict[str]={}) -> str:
def add_prefix(data, prefix):
return prefix + data if len(data) > 0 else ''
tag = self.extract(tag)
label = self.extract(label)
content = self.extract(content)
subtree_str = self.extract(subtree_str, '')
class_dict = self.extract(class_dict, {})
label_str = ''
if len(label) > 0:
label_str = self.prompt['label'].format(label=label)
classes = []
values = set()
message = []
for key, val in class_dict.items():
if val == '':
message.append(key)
continue
if val in values:
continue
values.add(val)
classes.append(self.prompt['attr'].format(key=key, attr=val))
if len(message) > 0:
message_str = ' '.join(message)
classes.append(self.prompt['attr'].format(key='message', attr=message_str))
classes_str = self.prompt['attr_splitter'].join(classes)
content_splitter = ' ' if len(classes_str) == 0 else self.prompt['attr_splitter']
classes_str = add_prefix(classes_str, ' ')
content_str = add_prefix(content, content_splitter)
subtree_str = add_prefix(subtree_str, ' ')
return self.prompt['dom'].format(tag=tag, label=label_str, attr=classes_str, content=content_str, subtree=subtree_str)
def prompt_constructor(self, tag: str='', label: str='', content: str='', subtree_str: str='', class_dict: dict[str]={}) -> str:
return self.constructor(tag, label, content, subtree_str, class_dict)