commit
5e3bc555d9
|
@ -162,6 +162,75 @@ After all parallel groupes finish, run `score.py` to get the pass rate
|
||||||
python score.py <your_result_dir>
|
python score.py <your_result_dir>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## 🚀 Evaluating in WebRL Setting (Text Modal)
|
||||||
|
|
||||||
|
[WebRL](https://github.com/THUDM/WebRL) is one of the top-performing models on WebArena-Lite. It uses a plain text modality as input. Additionally, we provide evaluation scripts that support this plain text modality.
|
||||||
|
|
||||||
|
### Evaluation of Finetuned Models
|
||||||
|
|
||||||
|
To run the finetuned model in WebRL setting, you can run evaluation with the following flags:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python run.py \
|
||||||
|
--instruction_path agent/prompts/jsons/p_webrl.json \
|
||||||
|
--test_start_idx 0 \
|
||||||
|
--test_end_idx 1 \
|
||||||
|
--result_dir <your_result_dir> \
|
||||||
|
--test_config_base_dir config_files/wa/test_webarena_lite \
|
||||||
|
--provider openai \
|
||||||
|
--mode completion \
|
||||||
|
--model <your_deployed_model_name> \
|
||||||
|
--planner_ip <your_deployed_model_ip> \
|
||||||
|
--stop_token "<|eot_id|>" \
|
||||||
|
--max_obs_length 0 \
|
||||||
|
--max_tokens 2048 \
|
||||||
|
--viewport_width 1280 \
|
||||||
|
--viewport_height 720 \
|
||||||
|
--action_set_tag webrl_id --observation_type webrl
|
||||||
|
```
|
||||||
|
|
||||||
|
You need to first use tools like vllm to deploy the finetuned model locally. Once deployed, the model can be accessed through the OpenAI API call method.
|
||||||
|
|
||||||
|
Ensure that the `--model` and `--planner_ip` fields are completed with the appropriate model name and the IP address of the deployed model instance.
|
||||||
|
|
||||||
|
We also provide the parallel script.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Remember to first launch a tmux session
|
||||||
|
tmux
|
||||||
|
bash wa_parallel_run_webrl.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
### Evaluation of Proprietary Models
|
||||||
|
|
||||||
|
To run the proprietary model in WebRL setting, you can run evaluation with the following flags:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python run.py \
|
||||||
|
--instruction_path agent/prompts/jsons/p_webrl_chat.json \
|
||||||
|
--test_start_idx 0 \
|
||||||
|
--test_end_idx 1 \
|
||||||
|
--result_dir <your_result_dir> \
|
||||||
|
--test_config_base_dir config_files/wa/test_webarena_lite \
|
||||||
|
--provider openai \
|
||||||
|
--model GPT-4o \
|
||||||
|
--mode chat \
|
||||||
|
--planner_ip '' \
|
||||||
|
--max_obs_length 0 \
|
||||||
|
--max_tokens 2048 \
|
||||||
|
--viewport_width 1280 \
|
||||||
|
--viewport_height 720 \
|
||||||
|
--action_set_tag webrl_id --observation_type webrl
|
||||||
|
```
|
||||||
|
|
||||||
|
You can switch the evaluation model by modifying `--model`. We also provide the parallel script.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Remember to first launch a tmux session
|
||||||
|
tmux
|
||||||
|
bash wa_parallel_run_webrl_chat.sh
|
||||||
|
```
|
||||||
|
|
||||||
### 🚨 Important: Refresh all websites before re-run another round of testing!
|
### 🚨 Important: Refresh all websites before re-run another round of testing!
|
||||||
Since tasks in WebArena may involve changing status and database of websites (e.g., posting comments on Reddit), if websites are not all refreshed before another round of evaluation, the results would be problematic.
|
Since tasks in WebArena may involve changing status and database of websites (e.g., posting comments on Reddit), if websites are not all refreshed before another round of evaluation, the results would be problematic.
|
||||||
|
|
||||||
|
|
2006
VAB-WebArena-Lite/new/actions.py
Normal file
2006
VAB-WebArena-Lite/new/actions.py
Normal file
File diff suppressed because it is too large
Load Diff
|
@ -14,6 +14,7 @@ from browser_env.actions import (
|
||||||
create_id_based_action,
|
create_id_based_action,
|
||||||
create_none_action,
|
create_none_action,
|
||||||
create_playwright_action,
|
create_playwright_action,
|
||||||
|
create_webrl_id_based_action
|
||||||
)
|
)
|
||||||
from browser_env.utils import Observation, StateInfo
|
from browser_env.utils import Observation, StateInfo
|
||||||
from llms import (
|
from llms import (
|
||||||
|
@ -108,12 +109,14 @@ class PromptAgent(Agent):
|
||||||
lm_config: lm_config.LMConfig,
|
lm_config: lm_config.LMConfig,
|
||||||
prompt_constructor: PromptConstructor,
|
prompt_constructor: PromptConstructor,
|
||||||
captioning_fn = None,
|
captioning_fn = None,
|
||||||
|
planner_ip = None
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.lm_config = lm_config
|
self.lm_config = lm_config
|
||||||
self.prompt_constructor = prompt_constructor
|
self.prompt_constructor = prompt_constructor
|
||||||
self.action_set_tag = action_set_tag
|
self.action_set_tag = action_set_tag
|
||||||
self.captioning_fn = captioning_fn
|
self.captioning_fn = captioning_fn
|
||||||
|
self.planner_ip = planner_ip
|
||||||
|
|
||||||
# Check if the model is multimodal.
|
# Check if the model is multimodal.
|
||||||
if ("gemini" in lm_config.model or "gpt-4" in lm_config.model and "vision" in lm_config.model or lm_config.provider in ["api", "finetune"]) and type(prompt_constructor) == MultimodalCoTPromptConstructor:
|
if ("gemini" in lm_config.model or "gpt-4" in lm_config.model and "vision" in lm_config.model or lm_config.provider in ["api", "finetune"]) and type(prompt_constructor) == MultimodalCoTPromptConstructor:
|
||||||
|
@ -165,7 +168,10 @@ class PromptAgent(Agent):
|
||||||
lm_config = self.lm_config
|
lm_config = self.lm_config
|
||||||
n = 0
|
n = 0
|
||||||
while True:
|
while True:
|
||||||
response = call_llm(lm_config, prompt)
|
if self.planner_ip is not None and self.planner_ip != "":
|
||||||
|
response = call_llm(lm_config, prompt, 'EMPTY', self.planner_ip)
|
||||||
|
else:
|
||||||
|
response = call_llm(lm_config, prompt)
|
||||||
force_prefix = self.prompt_constructor.instruction[
|
force_prefix = self.prompt_constructor.instruction[
|
||||||
"meta_data"
|
"meta_data"
|
||||||
].get("force_prefix", "")
|
].get("force_prefix", "")
|
||||||
|
@ -183,6 +189,8 @@ class PromptAgent(Agent):
|
||||||
action = create_playwright_action(parsed_response)
|
action = create_playwright_action(parsed_response)
|
||||||
elif self.action_set_tag == "som":
|
elif self.action_set_tag == "som":
|
||||||
action = create_id_based_action(parsed_response)
|
action = create_id_based_action(parsed_response)
|
||||||
|
elif self.action_set_tag == 'webrl_id':
|
||||||
|
action = create_webrl_id_based_action(parsed_response)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown action type {self.action_set_tag}"
|
f"Unknown action type {self.action_set_tag}"
|
||||||
|
@ -218,7 +226,8 @@ def construct_agent(args: argparse.Namespace, captioning_fn=None) -> Agent:
|
||||||
action_set_tag=args.action_set_tag,
|
action_set_tag=args.action_set_tag,
|
||||||
lm_config=llm_config,
|
lm_config=llm_config,
|
||||||
prompt_constructor=prompt_constructor,
|
prompt_constructor=prompt_constructor,
|
||||||
captioning_fn=captioning_fn
|
captioning_fn=captioning_fn,
|
||||||
|
planner_ip=args.planner_ip
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
|
319
VAB-WebArena-Lite/new/envs.py
Normal file
319
VAB-WebArena-Lite/new/envs.py
Normal file
|
@ -0,0 +1,319 @@
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
import requests
|
||||||
|
from beartype import beartype
|
||||||
|
from gymnasium import Env
|
||||||
|
from gymnasium.spaces import Box, Text
|
||||||
|
from playwright.sync_api import (
|
||||||
|
CDPSession,
|
||||||
|
Page,
|
||||||
|
Playwright,
|
||||||
|
ViewportSize,
|
||||||
|
expect,
|
||||||
|
sync_playwright,
|
||||||
|
)
|
||||||
|
|
||||||
|
DATASET = os.environ["DATASET"]
|
||||||
|
if DATASET == "visualwebarena":
|
||||||
|
from browser_env.env_config import (
|
||||||
|
CLASSIFIEDS,
|
||||||
|
CLASSIFIEDS_RESET_TOKEN,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .actions import Action, execute_action, get_action_space, execute_action_webrl
|
||||||
|
from .processors import ObservationHandler, ObservationMetadata
|
||||||
|
from .utils import (
|
||||||
|
AccessibilityTree,
|
||||||
|
DetachedPage,
|
||||||
|
Observation,
|
||||||
|
png_bytes_to_numpy,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PlaywrightScript:
|
||||||
|
function: str # goto, get_by_role
|
||||||
|
destination: str # https://www.google.com/, combobox
|
||||||
|
name: str | None = None # Search, Avatar 2009
|
||||||
|
operation: str | None = None # click, fill, press
|
||||||
|
value: str | None = None # avatar movie, Enter
|
||||||
|
|
||||||
|
|
||||||
|
def parse_action(action: str) -> PlaywrightScript:
|
||||||
|
splitted = action.strip().split(" ")
|
||||||
|
assert len(splitted) >= 2
|
||||||
|
match splitted[:2]:
|
||||||
|
case ["goto", url]:
|
||||||
|
assert len(splitted) == 2
|
||||||
|
return PlaywrightScript("goto", url)
|
||||||
|
case ["get_by_role", destination]:
|
||||||
|
assert len(splitted) >= 4
|
||||||
|
match splitted[2:]:
|
||||||
|
case [name, operation]:
|
||||||
|
return PlaywrightScript(
|
||||||
|
"get_by_role", destination, name, operation
|
||||||
|
)
|
||||||
|
case [name, operation, value]:
|
||||||
|
return PlaywrightScript(
|
||||||
|
"get_by_role", destination, name, operation, value
|
||||||
|
)
|
||||||
|
case _:
|
||||||
|
raise ValueError("Invalid action")
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Invalid action {action}")
|
||||||
|
|
||||||
|
|
||||||
|
class ScriptBrowserEnv(Env[dict[str, Observation], Action]):
|
||||||
|
"""
|
||||||
|
The goal of this environment is to produce a prototype of a browser environment.
|
||||||
|
In the end, we want to support a fully configurable browser environment with wide
|
||||||
|
range of action spaces and observation spaces, both structured and unstructured.
|
||||||
|
But in this prototype, we just support action space specified by Playwright script,
|
||||||
|
and observation space is the html content of the page.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_page_length: int = 8192,
|
||||||
|
headless: bool = True,
|
||||||
|
slow_mo: int = 0,
|
||||||
|
observation_type: str = "html",
|
||||||
|
current_viewport_only: bool = False,
|
||||||
|
viewport_size: ViewportSize = {"width": 1280, "height": 720},
|
||||||
|
save_trace_enabled: bool = False,
|
||||||
|
sleep_after_execution: float = 0.0,
|
||||||
|
captioning_fn=None,
|
||||||
|
):
|
||||||
|
# TODO: make Space[Action] = ActionSpace
|
||||||
|
self.action_space = get_action_space() # type: ignore[assignment]
|
||||||
|
self.headless = headless
|
||||||
|
self.slow_mo = slow_mo
|
||||||
|
self.current_viewport_only = current_viewport_only
|
||||||
|
self.reset_finished = False
|
||||||
|
self.viewport_size = viewport_size
|
||||||
|
self.save_trace_enabled = save_trace_enabled
|
||||||
|
self.sleep_after_execution = sleep_after_execution
|
||||||
|
|
||||||
|
match observation_type:
|
||||||
|
case "html" | "accessibility_tree" | "accessibility_tree_with_captioner" | "webrl":
|
||||||
|
self.text_observation_type = observation_type
|
||||||
|
self.image_observation_type = ""
|
||||||
|
self.main_observation_type = "text"
|
||||||
|
case "image":
|
||||||
|
self.image_observation_type = observation_type
|
||||||
|
self.text_observation_type = "" # type: ignore[assignment]
|
||||||
|
self.main_observation_type = "image"
|
||||||
|
case "image_som":
|
||||||
|
self.image_observation_type = observation_type
|
||||||
|
self.text_observation_type = observation_type # type: ignore[assignment]
|
||||||
|
self.main_observation_type = "image"
|
||||||
|
case _:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported observation type: {observation_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.observation_handler = ObservationHandler(
|
||||||
|
self.main_observation_type,
|
||||||
|
self.text_observation_type,
|
||||||
|
self.image_observation_type,
|
||||||
|
self.current_viewport_only,
|
||||||
|
self.viewport_size,
|
||||||
|
captioning_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.observation_space = (
|
||||||
|
self.observation_handler.get_observation_space()
|
||||||
|
)
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def setup(self, config_file: Path | None = None) -> None:
|
||||||
|
self.context_manager = sync_playwright()
|
||||||
|
self.playwright = self.context_manager.__enter__()
|
||||||
|
self.browser = self.playwright.chromium.launch(
|
||||||
|
headless=self.headless, slow_mo=self.slow_mo
|
||||||
|
)
|
||||||
|
|
||||||
|
if config_file:
|
||||||
|
with open(config_file, "r") as f:
|
||||||
|
instance_config = json.load(f)
|
||||||
|
else:
|
||||||
|
instance_config = {}
|
||||||
|
|
||||||
|
# Reset site if needed. Currently only supported for Classifieds.
|
||||||
|
# TODO(jykoh): Add reset functionality for Shopping/Reddit.
|
||||||
|
if instance_config.get("require_reset", False):
|
||||||
|
if "classifieds" in instance_config["sites"]:
|
||||||
|
# Send POST request to __CLASSIFIEDS__/index.php?page=reset with token=CLASSIFIEDS_TOKEN
|
||||||
|
response = requests.post(
|
||||||
|
f"{CLASSIFIEDS}/index.php?page=reset",
|
||||||
|
data={"token": CLASSIFIEDS_RESET_TOKEN},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if the request was successful
|
||||||
|
if response.status_code == 200:
|
||||||
|
print("Reset Classifieds site.")
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
"Failed to reset Classifieds site:",
|
||||||
|
response.status_code,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
"WARNING: Reset is not supported for this site. Please manually reset the site."
|
||||||
|
)
|
||||||
|
|
||||||
|
storage_state = instance_config.get("storage_state", None)
|
||||||
|
start_url = instance_config.get("start_url", None)
|
||||||
|
geolocation = instance_config.get("geolocation", None)
|
||||||
|
|
||||||
|
# Use custom viewport size if specified in the config, otherwise use the default.
|
||||||
|
viewport_size = self.viewport_size.copy()
|
||||||
|
viewport_size.update(instance_config.get("viewport_size", {}))
|
||||||
|
self.observation_handler.viewport_size = viewport_size
|
||||||
|
|
||||||
|
self.context = self.browser.new_context(
|
||||||
|
viewport=viewport_size,
|
||||||
|
storage_state=storage_state,
|
||||||
|
geolocation=geolocation,
|
||||||
|
device_scale_factor=1,
|
||||||
|
)
|
||||||
|
if self.save_trace_enabled:
|
||||||
|
self.context.tracing.start(screenshots=True, snapshots=True)
|
||||||
|
|
||||||
|
if start_url:
|
||||||
|
start_urls = start_url.split(" |AND| ")
|
||||||
|
for url in start_urls:
|
||||||
|
page = self.context.new_page()
|
||||||
|
if self.text_observation_type in [
|
||||||
|
"accessibility_tree",
|
||||||
|
"accessibility_tree_with_captioner",
|
||||||
|
]:
|
||||||
|
client = page.context.new_cdp_session(page)
|
||||||
|
client.send("Accessibility.enable")
|
||||||
|
client.detach()
|
||||||
|
page.goto(url)
|
||||||
|
# set the first page as the current page
|
||||||
|
self.page = self.context.pages[0]
|
||||||
|
self.page.bring_to_front()
|
||||||
|
else:
|
||||||
|
self.page = self.context.new_page()
|
||||||
|
if self.text_observation_type in [
|
||||||
|
"accessibility_tree",
|
||||||
|
"accessibility_tree_with_captioner",
|
||||||
|
]:
|
||||||
|
client = self.page.context.new_cdp_session(self.page)
|
||||||
|
client.send("Accessibility.enable")
|
||||||
|
client.detach()
|
||||||
|
|
||||||
|
def _get_obs(self) -> dict[str, Observation]:
|
||||||
|
obs = self.observation_handler.get_observation(self.page)
|
||||||
|
return obs
|
||||||
|
|
||||||
|
def _get_obs_metadata(self) -> dict[str, ObservationMetadata]:
|
||||||
|
metadata = self.observation_handler.get_observation_metadata()
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def reset(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
seed: int | None = None,
|
||||||
|
options: dict[str, str] | None = None,
|
||||||
|
) -> tuple[dict[str, Observation], dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Reset the environment.
|
||||||
|
:param options: options for the environment. The current supported options are:
|
||||||
|
- "storage_state": the storage state of the browser. It is a file path to a json file.
|
||||||
|
"""
|
||||||
|
super().reset(seed=seed, options=options)
|
||||||
|
if self.reset_finished:
|
||||||
|
self.context_manager.__exit__()
|
||||||
|
|
||||||
|
if options is not None and "config_file" in options:
|
||||||
|
config_file = Path(options["config_file"])
|
||||||
|
if config_file.exists():
|
||||||
|
self.setup(config_file=config_file)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Config file {config_file} does not exist.")
|
||||||
|
else:
|
||||||
|
self.setup()
|
||||||
|
self.reset_finished = True
|
||||||
|
timeout_in_ms = 120000
|
||||||
|
self.page.set_default_timeout(timeout_in_ms)
|
||||||
|
self.page.set_default_navigation_timeout(timeout_in_ms)
|
||||||
|
self.page.wait_for_timeout(int(self.sleep_after_execution * 1000))
|
||||||
|
|
||||||
|
observation = self._get_obs()
|
||||||
|
observation_metadata = self._get_obs_metadata()
|
||||||
|
info = {
|
||||||
|
"page": DetachedPage(self.page.url, ""),
|
||||||
|
"fail_error": "",
|
||||||
|
"observation_metadata": observation_metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
return (observation, info)
|
||||||
|
|
||||||
|
def save_trace(self, trace_path: str | Path) -> None:
|
||||||
|
if self.save_trace_enabled:
|
||||||
|
self.context.tracing.stop(path=trace_path)
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
if self.reset_finished:
|
||||||
|
self.context_manager.__exit__()
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self, action: Action
|
||||||
|
) -> tuple[dict[str, Observation], float, bool, bool, dict[str, Any]]:
|
||||||
|
if not self.reset_finished:
|
||||||
|
raise RuntimeError("Call reset first before calling step.")
|
||||||
|
|
||||||
|
success = False
|
||||||
|
fail_error = ""
|
||||||
|
try:
|
||||||
|
if self.text_observation_type == 'webrl':
|
||||||
|
self.page = execute_action_webrl(
|
||||||
|
action,
|
||||||
|
self.page,
|
||||||
|
self.context,
|
||||||
|
self.observation_handler.action_processor,
|
||||||
|
self.sleep_after_execution,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.page = execute_action(
|
||||||
|
action,
|
||||||
|
self.page,
|
||||||
|
self.context,
|
||||||
|
self.observation_handler.action_processor,
|
||||||
|
self.sleep_after_execution,
|
||||||
|
)
|
||||||
|
success = True
|
||||||
|
except Exception as e:
|
||||||
|
fail_error = str(e)
|
||||||
|
|
||||||
|
observation = self._get_obs()
|
||||||
|
observation_metadata = self._get_obs_metadata()
|
||||||
|
|
||||||
|
info = {
|
||||||
|
"page": DetachedPage(self.page.url, self.page.content()),
|
||||||
|
"fail_error": fail_error,
|
||||||
|
"observation_metadata": observation_metadata,
|
||||||
|
}
|
||||||
|
msg = (
|
||||||
|
observation,
|
||||||
|
float(success), # reward
|
||||||
|
False, # terminated
|
||||||
|
False, # truncated
|
||||||
|
info,
|
||||||
|
)
|
||||||
|
return msg
|
|
@ -172,6 +172,10 @@ class StringEvaluator(Evaluator):
|
||||||
# prevent false positive (e.g, 0)
|
# prevent false positive (e.g, 0)
|
||||||
if len(word_tokenize(clean_ref)) == 1:
|
if len(word_tokenize(clean_ref)) == 1:
|
||||||
tok_pred = word_tokenize(clean_pred)
|
tok_pred = word_tokenize(clean_pred)
|
||||||
|
for token in tok_pred:
|
||||||
|
if '/' in token:
|
||||||
|
sub_tokens = token.split('/')
|
||||||
|
tok_pred.extend(sub_tokens)
|
||||||
return float(clean_ref in tok_pred)
|
return float(clean_ref in tok_pred)
|
||||||
else:
|
else:
|
||||||
return float(clean_ref in clean_pred)
|
return float(clean_ref in clean_pred)
|
||||||
|
|
239
VAB-WebArena-Lite/new/helper_functions_browser.py
Normal file
239
VAB-WebArena-Lite/new/helper_functions_browser.py
Normal file
|
@ -0,0 +1,239 @@
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from agent.prompts import *
|
||||||
|
from browser_env import (
|
||||||
|
Action,
|
||||||
|
ActionTypes,
|
||||||
|
ObservationMetadata,
|
||||||
|
StateInfo,
|
||||||
|
action2str,
|
||||||
|
)
|
||||||
|
|
||||||
|
HTML_TEMPLATE = """
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<head>
|
||||||
|
<style>
|
||||||
|
pre {{
|
||||||
|
white-space: pre-wrap;
|
||||||
|
word-wrap: break-word;
|
||||||
|
}}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<html>
|
||||||
|
<body>
|
||||||
|
{body}
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_render_action(
|
||||||
|
action: Action,
|
||||||
|
observation_metadata: dict[str, ObservationMetadata],
|
||||||
|
action_set_tag: str,
|
||||||
|
) -> str:
|
||||||
|
"""Parse the predicted actions for rendering purpose. More comprehensive information"""
|
||||||
|
match action_set_tag:
|
||||||
|
case "id_accessibility_tree":
|
||||||
|
text_meta_data = observation_metadata["text"]
|
||||||
|
if action["element_id"] in text_meta_data["obs_nodes_info"]:
|
||||||
|
node_content = text_meta_data["obs_nodes_info"][
|
||||||
|
action["element_id"]
|
||||||
|
]["text"]
|
||||||
|
else:
|
||||||
|
node_content = "No match found"
|
||||||
|
|
||||||
|
action_str = f"<div class='raw_parsed_prediction' style='background-color:grey'><pre>{action['raw_prediction']}</pre></div>"
|
||||||
|
action_str += f"<div class='action_object' style='background-color:grey'><pre>{repr(action)}</pre></div>"
|
||||||
|
action_str += f"<div class='parsed_action' style='background-color:yellow'><pre>{action2str(action, action_set_tag, node_content)}</pre></div>"
|
||||||
|
|
||||||
|
case "som":
|
||||||
|
text_meta_data = observation_metadata["text"]
|
||||||
|
if action["element_id"] in text_meta_data["obs_nodes_info"]:
|
||||||
|
node_content = text_meta_data["obs_nodes_info"][
|
||||||
|
action["element_id"]
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
node_content = "No match found"
|
||||||
|
action_str = f"<div class='raw_parsed_prediction' style='background-color:grey'><pre>{action['raw_prediction']}</pre></div>"
|
||||||
|
action_str += f"<div class='action_object' style='background-color:grey'><pre>{repr(action)}</pre></div>"
|
||||||
|
action_str += f"<div class='parsed_action' style='background-color:yellow'><pre>{action2str(action, action_set_tag, node_content)}</pre></div>"
|
||||||
|
|
||||||
|
case "playwright":
|
||||||
|
action_str = action["pw_code"]
|
||||||
|
case "webrl_id":
|
||||||
|
action_str = action["raw_prediction"]
|
||||||
|
case _:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown action type {action['action_type'], action_set_tag}"
|
||||||
|
)
|
||||||
|
return action_str
|
||||||
|
|
||||||
|
|
||||||
|
def get_action_description(
|
||||||
|
action: Action,
|
||||||
|
observation_metadata: dict[str, ObservationMetadata],
|
||||||
|
action_set_tag: str,
|
||||||
|
prompt_constructor: PromptConstructor | None,
|
||||||
|
) -> str:
|
||||||
|
"""Generate the text version of the predicted actions to store in action history for prompt use.
|
||||||
|
May contain hint information to recover from the failures"""
|
||||||
|
|
||||||
|
match action_set_tag:
|
||||||
|
case "id_accessibility_tree":
|
||||||
|
text_meta_data = observation_metadata["text"]
|
||||||
|
if action["action_type"] in [
|
||||||
|
ActionTypes.CLICK,
|
||||||
|
ActionTypes.HOVER,
|
||||||
|
ActionTypes.TYPE,
|
||||||
|
]:
|
||||||
|
action_name = str(action["action_type"]).split(".")[1].lower()
|
||||||
|
if action["element_id"] in text_meta_data["obs_nodes_info"]:
|
||||||
|
node_content = text_meta_data["obs_nodes_info"][
|
||||||
|
action["element_id"]
|
||||||
|
]["text"]
|
||||||
|
node_content = " ".join(node_content.split()[1:])
|
||||||
|
action_str = action2str(
|
||||||
|
action, action_set_tag, node_content
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
action_str = f"Attempt to perfom \"{action_name}\" on element \"[{action['element_id']}]\" but no matching element found. Please check the observation more carefully."
|
||||||
|
else:
|
||||||
|
if (
|
||||||
|
action["action_type"] == ActionTypes.NONE
|
||||||
|
and prompt_constructor is not None
|
||||||
|
):
|
||||||
|
action_splitter = prompt_constructor.instruction[
|
||||||
|
"meta_data"
|
||||||
|
]["action_splitter"]
|
||||||
|
action_str = f'The previous prediction you issued was "{action["raw_prediction"]}". However, the format was incorrect. Ensure that the action is wrapped inside a pair of {action_splitter} and enclose arguments within [] as follows: {action_splitter}action [arg] ...{action_splitter}.'
|
||||||
|
else:
|
||||||
|
action_str = action2str(action, action_set_tag, "")
|
||||||
|
|
||||||
|
case "som":
|
||||||
|
text_meta_data = observation_metadata["image"]
|
||||||
|
if action["action_type"] in [
|
||||||
|
ActionTypes.CLICK,
|
||||||
|
ActionTypes.HOVER,
|
||||||
|
ActionTypes.TYPE,
|
||||||
|
]:
|
||||||
|
action_name = str(action["action_type"]).split(".")[1].lower()
|
||||||
|
if action["element_id"] in text_meta_data["obs_nodes_info"]:
|
||||||
|
action_str = action2str(action, action_set_tag, "")
|
||||||
|
else:
|
||||||
|
print(
|
||||||
|
'action["element_id"], text_meta_data["obs_nodes_info"]',
|
||||||
|
action["element_id"],
|
||||||
|
text_meta_data["obs_nodes_info"],
|
||||||
|
)
|
||||||
|
action_str = f"Attempt to perfom \"{action_name}\" on element \"[{action['element_id']}]\" but no matching element found. Please check the observation more carefully."
|
||||||
|
else:
|
||||||
|
if (
|
||||||
|
action["action_type"] == ActionTypes.NONE
|
||||||
|
and prompt_constructor is not None
|
||||||
|
):
|
||||||
|
action_splitter = prompt_constructor.instruction[
|
||||||
|
"meta_data"
|
||||||
|
]["action_splitter"]
|
||||||
|
action_str = f'The previous prediction you issued was "{action["raw_prediction"]}". However, the format was incorrect. Ensure that the action is wrapped inside a pair of {action_splitter} and enclose arguments within [] as follows: {action_splitter}action [arg] ...{action_splitter}.'
|
||||||
|
else:
|
||||||
|
action_str = action2str(action, action_set_tag, "")
|
||||||
|
|
||||||
|
case "playwright":
|
||||||
|
action_str = action["pw_code"]
|
||||||
|
|
||||||
|
case "webrl_id":
|
||||||
|
action_str = action["raw_prediction"]
|
||||||
|
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unknown action type {action['action_type']}")
|
||||||
|
|
||||||
|
return action_str
|
||||||
|
|
||||||
|
|
||||||
|
class RenderHelper(object):
|
||||||
|
"""Helper class to render text and image observations and meta data in the trajectory"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, config_file: str, result_dir: str, action_set_tag: str
|
||||||
|
) -> None:
|
||||||
|
with open(config_file, "r") as f:
|
||||||
|
_config = json.load(f)
|
||||||
|
_config_str = ""
|
||||||
|
for k, v in _config.items():
|
||||||
|
_config_str += f"{k}: {v}\n"
|
||||||
|
_config_str = f"<pre>{_config_str}</pre>\n"
|
||||||
|
task_id = _config["task_id"]
|
||||||
|
|
||||||
|
self.action_set_tag = action_set_tag
|
||||||
|
|
||||||
|
self.render_file = open(
|
||||||
|
Path(result_dir) / f"render_{task_id}.html", "a+", encoding="utf-8"
|
||||||
|
)
|
||||||
|
self.render_file.truncate(0)
|
||||||
|
# write init template
|
||||||
|
self.render_file.write(HTML_TEMPLATE.format(body=f"{_config_str}"))
|
||||||
|
self.render_file.read()
|
||||||
|
self.render_file.flush()
|
||||||
|
|
||||||
|
def render(
|
||||||
|
self,
|
||||||
|
action: Action,
|
||||||
|
state_info: StateInfo,
|
||||||
|
meta_data: dict[str, Any],
|
||||||
|
render_screenshot: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""Render the trajectory"""
|
||||||
|
# text observation
|
||||||
|
observation = state_info["observation"]
|
||||||
|
text_obs = observation["text"]
|
||||||
|
info = state_info["info"]
|
||||||
|
new_content = f"<h2>New Page</h2>\n"
|
||||||
|
new_content += f"<h3 class='url'><a href={state_info['info']['page'].url}>URL: {state_info['info']['page'].url}</a></h3>\n"
|
||||||
|
new_content += f"<div class='state_obv'><pre>{text_obs}</pre><div>\n"
|
||||||
|
|
||||||
|
if render_screenshot:
|
||||||
|
# image observation
|
||||||
|
img_obs = observation["image"]
|
||||||
|
image = Image.fromarray(img_obs)
|
||||||
|
byte_io = io.BytesIO()
|
||||||
|
image.save(byte_io, format="PNG")
|
||||||
|
byte_io.seek(0)
|
||||||
|
image_bytes = base64.b64encode(byte_io.read())
|
||||||
|
image_str = image_bytes.decode("utf-8")
|
||||||
|
new_content += f"<img src='data:image/png;base64,{image_str}' style='width:50vw; height:auto;'/>\n"
|
||||||
|
|
||||||
|
# meta data
|
||||||
|
new_content += f"<div class='prev_action' style='background-color:pink'>{meta_data['action_history'][-1]}</div>\n"
|
||||||
|
|
||||||
|
# action
|
||||||
|
action_str = get_render_action(
|
||||||
|
action,
|
||||||
|
info["observation_metadata"],
|
||||||
|
action_set_tag=self.action_set_tag,
|
||||||
|
)
|
||||||
|
# with yellow background
|
||||||
|
action_str = f"<div class='predict_action'>{action_str}</div>"
|
||||||
|
new_content += f"{action_str}\n"
|
||||||
|
|
||||||
|
# add new content
|
||||||
|
self.render_file.seek(0)
|
||||||
|
html = self.render_file.read()
|
||||||
|
html_body = re.findall(r"<body>(.*?)</body>", html, re.DOTALL)[0]
|
||||||
|
html_body += new_content
|
||||||
|
|
||||||
|
html = HTML_TEMPLATE.format(body=html_body)
|
||||||
|
self.render_file.seek(0)
|
||||||
|
self.render_file.truncate()
|
||||||
|
self.render_file.write(html)
|
||||||
|
self.render_file.flush()
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
self.render_file.close()
|
7
VAB-WebArena-Lite/new/html_tools/__init__.py
Executable file
7
VAB-WebArena-Lite/new/html_tools/__init__.py
Executable file
|
@ -0,0 +1,7 @@
|
||||||
|
from .identifier import IdentifierTool
|
||||||
|
from .prompt import HtmlPrompt
|
||||||
|
from .html_parser import HtmlParser
|
||||||
|
|
||||||
|
from .utils import print_html_object
|
||||||
|
from .configs import basic_attrs, mind2web_keep_attrs
|
||||||
|
from .fetch import get_parsed_html
|
3
VAB-WebArena-Lite/new/html_tools/configs/__init__.py
Executable file
3
VAB-WebArena-Lite/new/html_tools/configs/__init__.py
Executable file
|
@ -0,0 +1,3 @@
|
||||||
|
from .html_prompt import prompts
|
||||||
|
from .config import basic_attrs, mind2web_keep_attrs, miniwob_attrs
|
||||||
|
from .config import config_meta
|
56
VAB-WebArena-Lite/new/html_tools/configs/config.py
Executable file
56
VAB-WebArena-Lite/new/html_tools/configs/config.py
Executable file
|
@ -0,0 +1,56 @@
|
||||||
|
basic_attrs = [
|
||||||
|
'title',
|
||||||
|
'value',
|
||||||
|
'type',
|
||||||
|
'placeholder',
|
||||||
|
'selected',
|
||||||
|
'data-value',
|
||||||
|
'data-text',
|
||||||
|
'data-testid',
|
||||||
|
'data-label',
|
||||||
|
'data-bbox',
|
||||||
|
'data-status'
|
||||||
|
]
|
||||||
|
|
||||||
|
mind2web_keep_attrs = [
|
||||||
|
'alt',
|
||||||
|
'aria_description',
|
||||||
|
'aria_label',
|
||||||
|
'aria_role',
|
||||||
|
'input_checked',
|
||||||
|
'input_value',
|
||||||
|
'label',
|
||||||
|
'name',
|
||||||
|
'option_selected',
|
||||||
|
'placeholder',
|
||||||
|
'role',
|
||||||
|
'text_value',
|
||||||
|
'title',
|
||||||
|
'type',
|
||||||
|
'value',
|
||||||
|
]
|
||||||
|
|
||||||
|
miniwob_attrs = [
|
||||||
|
'id',
|
||||||
|
'type',
|
||||||
|
'value',
|
||||||
|
]
|
||||||
|
|
||||||
|
config_meta = """
|
||||||
|
======= Configs =======
|
||||||
|
Columns:
|
||||||
|
- id: {id_attr}
|
||||||
|
- label: {label_attr}
|
||||||
|
Position: {use_position}
|
||||||
|
- window: {window_size}
|
||||||
|
- rect_dict: {rect}
|
||||||
|
Keep:
|
||||||
|
- parents: {parent_chain}
|
||||||
|
- attrs: {keep_attrs}
|
||||||
|
- elems: {keep_elem}
|
||||||
|
- obs_elem: {obs_elem}
|
||||||
|
Generator:
|
||||||
|
- prompt: {prompt_name}
|
||||||
|
- label: {identifier_name}
|
||||||
|
========================
|
||||||
|
"""
|
22
VAB-WebArena-Lite/new/html_tools/configs/html_prompt.py
Executable file
22
VAB-WebArena-Lite/new/html_tools/configs/html_prompt.py
Executable file
|
@ -0,0 +1,22 @@
|
||||||
|
refine_prompt = {
|
||||||
|
'dom': '<{tag}{label}|{attr}{content}{subtree} >',
|
||||||
|
'label': '[{label}]',
|
||||||
|
'attr': '{attr}',
|
||||||
|
'attr_splitter': '; ',
|
||||||
|
'subtree_splitter': ' ',
|
||||||
|
}
|
||||||
|
|
||||||
|
xml_prompt = {
|
||||||
|
'dom': '<{tag}{label}{attr}>{content}{subtree} </{tag}>',
|
||||||
|
'label': ' id="{label}"',
|
||||||
|
'attr': '{key}="{attr}"',
|
||||||
|
'attr_splitter': ' ',
|
||||||
|
'subtree_splitter': ' ',
|
||||||
|
}
|
||||||
|
|
||||||
|
prompts = {
|
||||||
|
'refine': refine_prompt,
|
||||||
|
'xml': xml_prompt,
|
||||||
|
'new_data': refine_prompt,
|
||||||
|
}
|
||||||
|
|
108
VAB-WebArena-Lite/new/html_tools/fetch.py
Executable file
108
VAB-WebArena-Lite/new/html_tools/fetch.py
Executable file
|
@ -0,0 +1,108 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import base64
|
||||||
|
from .html_parser import HtmlParser
|
||||||
|
from .configs import basic_attrs
|
||||||
|
from .scripts import *
|
||||||
|
|
||||||
|
def get_window(page):
|
||||||
|
x = page.evaluate("window.scrollX")
|
||||||
|
y = page.evaluate("window.scrollY")
|
||||||
|
w = page.evaluate("window.innerWidth")
|
||||||
|
h = page.evaluate("window.innerHeight")
|
||||||
|
return (x, y, w, h)
|
||||||
|
|
||||||
|
def modify_page(page):
|
||||||
|
page.wait_for_timeout(500)
|
||||||
|
|
||||||
|
try:
|
||||||
|
page.evaluate(remove_id_script)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
packet = {
|
||||||
|
"raw_html": page.evaluate("document.documentElement.outerHTML"),
|
||||||
|
"window": get_window(page)
|
||||||
|
}
|
||||||
|
|
||||||
|
page.evaluate(prepare_script)
|
||||||
|
page.wait_for_timeout(100)
|
||||||
|
|
||||||
|
img_bytes = page.screenshot(path="debug_info/screenshot_raw.png")
|
||||||
|
raw_image = base64.b64encode(img_bytes).decode()
|
||||||
|
|
||||||
|
page.evaluate(clickable_checker_script)
|
||||||
|
page.wait_for_timeout(50)
|
||||||
|
|
||||||
|
# get all clickable elements
|
||||||
|
start_id = 0
|
||||||
|
items, start_id = page.evaluate(label_script, {
|
||||||
|
"selector": ".possible-clickable-element",
|
||||||
|
"startIndex": start_id
|
||||||
|
})
|
||||||
|
page.wait_for_timeout(50)
|
||||||
|
|
||||||
|
# mark our own labels and get the images
|
||||||
|
items = page.evaluate(label_marker_script, items)
|
||||||
|
page.wait_for_timeout(100)
|
||||||
|
img_bytes = page.screenshot(path="debug_info/marked.png")
|
||||||
|
marked_image = base64.b64encode(img_bytes).decode()
|
||||||
|
|
||||||
|
# remove markers on the page
|
||||||
|
page.evaluate(remove_label_mark_script)
|
||||||
|
|
||||||
|
packet.update({
|
||||||
|
"raw_image": raw_image,
|
||||||
|
"marked_image": marked_image,
|
||||||
|
"modified_html": page.evaluate("document.documentElement.outerHTML")
|
||||||
|
})
|
||||||
|
|
||||||
|
# element_info, include "all_elements" and "clickable_elements"
|
||||||
|
element_info = page.evaluate(element_info_script)
|
||||||
|
page.wait_for_timeout(100)
|
||||||
|
packet.update(element_info)
|
||||||
|
return packet
|
||||||
|
|
||||||
|
def save_debug_info(packet):
|
||||||
|
with open("debug_info/raw.html", "w") as f:
|
||||||
|
f.write(packet["modified_html"])
|
||||||
|
with open("debug_info/parsed.html", "w") as f:
|
||||||
|
f.write(packet["html"])
|
||||||
|
with open("debug_info/all_element.json", "w") as f:
|
||||||
|
f.write(json.dumps(packet["all_elements"]))
|
||||||
|
|
||||||
|
def get_parsed_html(page):
|
||||||
|
if not os.path.exists("debug_info"):
|
||||||
|
os.makedirs("debug_info")
|
||||||
|
|
||||||
|
print("parsing html...")
|
||||||
|
|
||||||
|
packet = modify_page(page)
|
||||||
|
raw_html = packet["modified_html"]
|
||||||
|
|
||||||
|
args = {
|
||||||
|
"use_position": True,
|
||||||
|
"rect_dict": {},
|
||||||
|
"window_size": packet["window"],
|
||||||
|
"id-attr": "data-backend-node-id",
|
||||||
|
"label_attr": "data-label-id",
|
||||||
|
"label_generator": "order",
|
||||||
|
"regenerate_label": False,
|
||||||
|
"attr_list": basic_attrs,
|
||||||
|
"prompt": "xml",
|
||||||
|
"dataset": "pipeline"
|
||||||
|
}
|
||||||
|
|
||||||
|
hp = HtmlParser(raw_html, args)
|
||||||
|
res = hp.parse_tree()
|
||||||
|
page_html = res.get("html", "")
|
||||||
|
|
||||||
|
packet["html"] = page_html
|
||||||
|
|
||||||
|
# for debug
|
||||||
|
save_debug_info(packet)
|
||||||
|
|
||||||
|
print("parsing finished.")
|
||||||
|
|
||||||
|
return packet
|
||||||
|
|
447
VAB-WebArena-Lite/new/html_tools/html_parser.py
Executable file
447
VAB-WebArena-Lite/new/html_tools/html_parser.py
Executable file
|
@ -0,0 +1,447 @@
|
||||||
|
from lxml import html
|
||||||
|
import time, copy, random
|
||||||
|
import json, re, os
|
||||||
|
|
||||||
|
from .identifier import IdentifierTool
|
||||||
|
from .prompt import HtmlPrompt
|
||||||
|
from .configs import config_meta
|
||||||
|
from .utils import get_xpath_top_down, rect2tuple
|
||||||
|
|
||||||
|
class HtmlParser():
|
||||||
|
def __init__(self, ctx: str, args: dict[str]={}) -> None:
|
||||||
|
stt = time.time()
|
||||||
|
self.dom_tree = self.ctx2tree(ctx)
|
||||||
|
# tool related
|
||||||
|
self.bids2label = {}
|
||||||
|
self.bids2xpath = {}
|
||||||
|
self.used_labels = {}
|
||||||
|
|
||||||
|
# parse args
|
||||||
|
self.parse_args(args)
|
||||||
|
self.init_time = time.time() - stt
|
||||||
|
|
||||||
|
def parse_args(self, args: dict[str]={}) -> None:
|
||||||
|
def attr_check(attr, type_model='str'):
|
||||||
|
if attr is None:
|
||||||
|
return False
|
||||||
|
attr_type = type(attr)
|
||||||
|
if attr_type != type(type_model):
|
||||||
|
return False
|
||||||
|
if attr_type == type('str') and len(attr) == 0:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
args = {} if args is None else args
|
||||||
|
|
||||||
|
# [Position] use_pos: False -> use full page, otherwise use window_size
|
||||||
|
dataset = args.get('dataset', '')
|
||||||
|
use_position = args.get('use_position', False)
|
||||||
|
window_size = args.get('window_size', None)
|
||||||
|
rect = args.get('rect_dict', None)
|
||||||
|
if use_position:
|
||||||
|
if not attr_check(window_size, ()):
|
||||||
|
raise ValueError('window_size must be set when use_position is True')
|
||||||
|
if not attr_check(rect, {}):
|
||||||
|
raise ValueError('rect_dict must be set when use_position is True')
|
||||||
|
|
||||||
|
if not attr_check(rect, {}):
|
||||||
|
rect = {}
|
||||||
|
|
||||||
|
# [Label] for vimium is temp_clickable_label, otherwise keep all of it
|
||||||
|
label_attr = args.get('label_attr', '')
|
||||||
|
get_new_label = args.get('regenerate_label', False)
|
||||||
|
label_method = args.get('label_generator', None)
|
||||||
|
regen_label = not attr_check(label_method)
|
||||||
|
|
||||||
|
# [id] for mind2web is backend_node_id, for normal website use our method
|
||||||
|
id_attr = args.get('id_attr', '')
|
||||||
|
regen_id = not attr_check(id_attr)
|
||||||
|
|
||||||
|
if regen_id:
|
||||||
|
id_attr = 'temp_id'
|
||||||
|
|
||||||
|
# [attributes]
|
||||||
|
keep_attrs = args.get('attr_list', [])
|
||||||
|
if not attr_check(keep_attrs, []):
|
||||||
|
keep_attrs = []
|
||||||
|
|
||||||
|
# [Tags] for clickable elem, keep: must keep, obs: keep if follow specific rule
|
||||||
|
parent_chain = args.get('parent_chain', False)
|
||||||
|
keep_elem = args.get('keep_elem', [])
|
||||||
|
obs_elem = args.get('obs_elem', [])
|
||||||
|
|
||||||
|
# sanity check
|
||||||
|
self.set_args(use_position, window_size, rect, label_attr, id_attr, keep_attrs, keep_elem, obs_elem, parent_chain, get_new_label, dataset)
|
||||||
|
|
||||||
|
# [Prompt]
|
||||||
|
prompt = args.get('prompt', None)
|
||||||
|
self.prompt = HtmlPrompt(prompt)
|
||||||
|
|
||||||
|
# traverse and get special data
|
||||||
|
if regen_id or regen_label:
|
||||||
|
self.mark_id()
|
||||||
|
|
||||||
|
if get_new_label:
|
||||||
|
self.used_labels = {}
|
||||||
|
|
||||||
|
self.identifier = IdentifierTool(label_method, self.used_labels)
|
||||||
|
|
||||||
|
def set_args(self, use_position: bool=False, window_size: tuple=(), rect_dict: dict[str]={}, label_attr: str='',
|
||||||
|
id_attr: str='', keep_attrs: list[str]=[], keep_elem: list[str]=[], obs_elem: list[str]=[],
|
||||||
|
parent_chain: bool=False, get_new_label: bool=False, dataset: str='') -> None:
|
||||||
|
|
||||||
|
self.use_position = use_position
|
||||||
|
self.window_size = window_size
|
||||||
|
self.rect = rect_dict
|
||||||
|
self.label_attr = label_attr
|
||||||
|
self.id_attr = id_attr
|
||||||
|
self.keep_attrs = keep_attrs
|
||||||
|
self.keep = keep_elem
|
||||||
|
self.obs = obs_elem
|
||||||
|
self.parent_chain = parent_chain
|
||||||
|
self.get_new_label = get_new_label
|
||||||
|
self.dataset = dataset
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
config = {
|
||||||
|
'id_attr': self.id_attr,
|
||||||
|
'keep_attrs': self.keep_attrs[:5],
|
||||||
|
'label_attr': self.label_attr,
|
||||||
|
'use_position': self.use_position,
|
||||||
|
'window_size': self.window_size,
|
||||||
|
'rect': dict(list(self.rect.items())[:3]),
|
||||||
|
'keep_elem': self.keep[:5],
|
||||||
|
'obs_elem': self.obs[:5],
|
||||||
|
'parent_chain': self.parent_chain,
|
||||||
|
'prompt_name': self.prompt.name,
|
||||||
|
'identifier_name': self.identifier.name
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, config_meta.format(**config)
|
||||||
|
|
||||||
|
def update_rect_dict(self, rect_dict: dict[str]={}) -> None:
|
||||||
|
self.rect = rect_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def ctx2tree(ctx: str) -> html.HtmlElement:
|
||||||
|
# remove useless tags, eg. style and script
|
||||||
|
ctx = re.sub('<!--[\W\w]*?-->', '', ctx)
|
||||||
|
ctx = re.sub('<style[\W\w]*?>[\W\w]*?</style>', '', ctx)
|
||||||
|
ctx = re.sub('<script[\W\w]*?>[\W\w]*?</script>', '', ctx)
|
||||||
|
ctx = '' if ctx is None else re.sub(r'\s+', ' ', ctx).strip()
|
||||||
|
dom_tree = html.fromstring(ctx.encode('utf-8'))
|
||||||
|
match = re.search('<meta charset="([^"]*)"', ctx)
|
||||||
|
if match:
|
||||||
|
charset = match.group(1)
|
||||||
|
ctx = ctx.encode(charset)
|
||||||
|
print(charset)
|
||||||
|
else:
|
||||||
|
print("Charset not found")
|
||||||
|
return dom_tree
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_root(tree: html.HtmlElement) -> html.HtmlElement:
|
||||||
|
node = tree.xpath('//*')[0]
|
||||||
|
while True:
|
||||||
|
parent = node.getparent()
|
||||||
|
if parent is None:
|
||||||
|
break
|
||||||
|
node = parent
|
||||||
|
return node
|
||||||
|
|
||||||
|
def get_node_by_bid(self, tree: html.HtmlElement, bid: str) -> html.HtmlElement:
|
||||||
|
nodes = tree.xpath(f'//*[@{self.id_attr}="{bid}"]')
|
||||||
|
if len(nodes) == 0:
|
||||||
|
return None
|
||||||
|
return nodes[0]
|
||||||
|
|
||||||
|
def id_label_converter(self, label: str) -> str:
|
||||||
|
return self.bids2label.get(label, '')
|
||||||
|
|
||||||
|
def id_xpath_converter(self, label: str) -> str:
|
||||||
|
return self.bids2xpath.get(label, '')
|
||||||
|
|
||||||
|
def mark_id(self) -> None:
|
||||||
|
root = self.get_root(self.dom_tree)
|
||||||
|
_, i2xpath, used_labels = get_xpath_top_down(root, self.id_attr, self.label_attr)
|
||||||
|
self.used_labels = used_labels
|
||||||
|
self.bids2xpath = i2xpath
|
||||||
|
|
||||||
|
def parse(self, root: html.HtmlElement, keep: list[str], obs: list[str], parent_chain: bool=False, get_new_label: bool=False) -> dict[str]:
|
||||||
|
def get_text(str: str) -> str:
|
||||||
|
return '' if str is None else str.strip()[:500]
|
||||||
|
|
||||||
|
def check_attr(attr: str, node: html.HtmlElement) -> bool:
|
||||||
|
tag = node.tag
|
||||||
|
if (
|
||||||
|
( attr == 'role' and node.attrib.get(attr, '') in ['presentation', 'none', 'link'] )
|
||||||
|
or ( attr == 'type' and node.attrib.get(attr, '') == 'hidden' )
|
||||||
|
or ( attr == 'aria-hidden' and node.attrib.get(attr, '') == 'true')
|
||||||
|
# or ( attr == 'value' and tag in ['option'] )
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def is_visible(node: html.HtmlElement, bid: str) -> bool:
|
||||||
|
if self.dataset == 'mind2web':
|
||||||
|
bound = node.attrib.get('bounding_box_rect', None)
|
||||||
|
self.rect[bid] = rect2tuple(bound)
|
||||||
|
|
||||||
|
if self.dataset == 'pipeline':
|
||||||
|
bound = node.attrib.get('data-bbox', None)
|
||||||
|
self.rect[bid] = rect2tuple(bound)
|
||||||
|
|
||||||
|
if node.attrib.get('aria-hidden', 'false') == 'true':
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not self.use_position:
|
||||||
|
return True
|
||||||
|
|
||||||
|
rect = self.rect.get(bid, None)
|
||||||
|
if rect is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self.window_size is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# get window size
|
||||||
|
wx, wy, ww, wh = self.window_size
|
||||||
|
wx, wy = 0, 0
|
||||||
|
x, y, w, h = rect
|
||||||
|
if x + w < wx or x > wx + ww or y + h < wy or y > wy + wh:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _dfs(node: html.HtmlElement, keep: list[str]=[], obs: list[str]=[],
|
||||||
|
parent_chain: bool=False, get_new_label: bool=False, par_keep: bool=False) -> (str, dict[str]):
|
||||||
|
# basic information
|
||||||
|
bid = node.attrib.get(self.id_attr, '')
|
||||||
|
tag = node.tag
|
||||||
|
label = node.attrib.get(self.label_attr, '')
|
||||||
|
|
||||||
|
# element which is keeped equivalent to visible
|
||||||
|
visible = is_visible(node, bid)
|
||||||
|
in_keep_list = bid in keep
|
||||||
|
in_obs_list = (bid in obs or len(label) > 0) and visible
|
||||||
|
par_keep = par_keep and tag == "option"
|
||||||
|
keep_element = in_keep_list or in_obs_list or visible or par_keep
|
||||||
|
|
||||||
|
if label:
|
||||||
|
keep_element = True
|
||||||
|
|
||||||
|
# mark label
|
||||||
|
bids2label, labeled_elems = {}, []
|
||||||
|
have_label = False
|
||||||
|
if in_keep_list or in_obs_list:
|
||||||
|
if label is None or len(label) == 0 or get_new_label:
|
||||||
|
label = self.identifier.generate()
|
||||||
|
node.attrib[self.label_attr] = label
|
||||||
|
bids2label[bid] = label
|
||||||
|
bids2label[label] = bid
|
||||||
|
have_label = True
|
||||||
|
|
||||||
|
# get text or alt_text of current element
|
||||||
|
text = get_text(node.text)
|
||||||
|
|
||||||
|
classes = {}
|
||||||
|
# keep attributes if needed
|
||||||
|
keep_all_attrs = len(self.keep_attrs) == 0
|
||||||
|
keep_attrs = node.attrib.keys() if keep_all_attrs else self.keep_attrs
|
||||||
|
|
||||||
|
# traverse attributes
|
||||||
|
for attr in keep_attrs:
|
||||||
|
if attr not in node.attrib or not check_attr(attr, node):
|
||||||
|
continue
|
||||||
|
if attr in [self.id_attr, self.label_attr]:
|
||||||
|
continue
|
||||||
|
val = get_text(node.attrib[attr])
|
||||||
|
if len(val) > 0 or keep_all_attrs:
|
||||||
|
classes[attr] = val
|
||||||
|
|
||||||
|
have_text = len(text) > 0 or len(classes) - (1 if 'data-bbox' in classes else 0) > 0
|
||||||
|
par_keep = keep_element and tag == 'select'
|
||||||
|
|
||||||
|
parts = []
|
||||||
|
clickable_count = 0
|
||||||
|
children = node.getchildren()
|
||||||
|
for child in children:
|
||||||
|
cres, cmsg = _dfs(child, keep, obs, parent_chain, get_new_label, par_keep)
|
||||||
|
clickable_count += 1 if cmsg.get('have_clickable', False) else 0
|
||||||
|
bids2label.update(cmsg.get('bids2label', {}))
|
||||||
|
labeled_elems.extend(cmsg.get('label_element', []))
|
||||||
|
if len(cres) != 0:
|
||||||
|
parts.append(cres)
|
||||||
|
|
||||||
|
dom = self.prompt.subtree_constructor(parts)
|
||||||
|
|
||||||
|
# remove <text|> if all children are text
|
||||||
|
keep_as_all_text = (dom.count('<') == dom.count('<text|')) and dom.count('<') > 0
|
||||||
|
if keep_as_all_text:
|
||||||
|
matches = re.findall(r'<text\| ([^>]+) >', dom)
|
||||||
|
dom = self.prompt.subtree_constructor(matches)
|
||||||
|
|
||||||
|
keep_element = keep_element and (clickable_count > 1 or have_text or have_label or keep_as_all_text)
|
||||||
|
keep_as_parent = len(dom) > 0 and parent_chain
|
||||||
|
if in_keep_list or keep_element or keep_as_parent:
|
||||||
|
dom = self.prompt.prompt_constructor(tag, label, text, dom, classes)
|
||||||
|
|
||||||
|
if have_label:
|
||||||
|
labeled_elems.append(bid)
|
||||||
|
|
||||||
|
control_msg = {
|
||||||
|
'have_clickable': bool(clickable_count or have_text),
|
||||||
|
'bids2label': bids2label,
|
||||||
|
'label_element': labeled_elems,
|
||||||
|
}
|
||||||
|
|
||||||
|
return dom, control_msg
|
||||||
|
|
||||||
|
dom, cmsg = _dfs(root, keep, obs, parent_chain, get_new_label)
|
||||||
|
return dom, cmsg
|
||||||
|
|
||||||
|
def parse_tree(self) -> dict[str]:
|
||||||
|
# start from here
|
||||||
|
stt = time.time()
|
||||||
|
root = self.get_root(self.dom_tree)
|
||||||
|
dom, cmsg = self.parse(root, self.keep, self.obs, self.parent_chain, self.get_new_label)
|
||||||
|
self.bids2label = cmsg.get('bids2label', {})
|
||||||
|
self.keep = list(set(self.keep + cmsg.get('label_element', [])))
|
||||||
|
|
||||||
|
obj = {
|
||||||
|
'html': dom,
|
||||||
|
'parse_time': time.time() - stt
|
||||||
|
}
|
||||||
|
|
||||||
|
return obj
|
||||||
|
|
||||||
|
# From mind2web, https://github.com/OSU-NLP-Group/Mind2Web/blob/main/src/data_utils/dom_utils.py
|
||||||
|
def get_keep_elements(self, tree: html.HtmlElement, keep: list[str], max_depth: int, max_children: int,
|
||||||
|
max_sibling: int, dfs_count: int=1, keep_parent: bool=False) -> list[str]:
|
||||||
|
def get_anscendants(node: html.HtmlElement, max_depth: int, current_depth: int=0) -> list[str]:
|
||||||
|
if current_depth > max_depth:
|
||||||
|
return []
|
||||||
|
|
||||||
|
anscendants = []
|
||||||
|
parent = node.getparent()
|
||||||
|
if parent is not None:
|
||||||
|
anscendants.append(parent)
|
||||||
|
anscendants.extend(get_anscendants(parent, max_depth, current_depth + 1))
|
||||||
|
|
||||||
|
return anscendants
|
||||||
|
|
||||||
|
def get_descendants(node: html.HtmlElement, max_depth: int, current_depth: int=0) -> list[str]:
|
||||||
|
if current_depth > max_depth:
|
||||||
|
return []
|
||||||
|
|
||||||
|
descendants = []
|
||||||
|
for child in node:
|
||||||
|
descendants.append(child)
|
||||||
|
descendants.extend(get_descendants(child, max_depth, current_depth + 1))
|
||||||
|
|
||||||
|
return descendants
|
||||||
|
|
||||||
|
to_keep = set(copy.deepcopy(keep))
|
||||||
|
nodes_to_keep = set()
|
||||||
|
|
||||||
|
for _ in range(max(1, dfs_count)):
|
||||||
|
for bid in to_keep:
|
||||||
|
candidate_node = self.get_node_by_bid(tree, bid)
|
||||||
|
if candidate_node is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
nodes_to_keep.add(candidate_node.attrib[self.id_attr])
|
||||||
|
# get all ancestors or with max depth
|
||||||
|
nodes_to_keep.update([x.attrib.get(self.id_attr, '') for x in get_anscendants(candidate_node, max_depth)])
|
||||||
|
|
||||||
|
# get descendants with max depth
|
||||||
|
nodes_to_keep.update([x.attrib.get(self.id_attr, '') for x in get_descendants(candidate_node, max_depth)][:max_children])
|
||||||
|
# get siblings within range
|
||||||
|
parent = candidate_node.getparent()
|
||||||
|
if parent is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
siblings = [x for x in parent.getchildren() if x.tag != 'text']
|
||||||
|
if candidate_node not in siblings:
|
||||||
|
continue
|
||||||
|
|
||||||
|
idx_in_sibling = siblings.index(candidate_node)
|
||||||
|
nodes_to_keep.update([x.attrib.get(self.id_attr, '')
|
||||||
|
for x in siblings[max(0, idx_in_sibling - max_sibling) : idx_in_sibling + max_sibling + 1]])
|
||||||
|
|
||||||
|
max_children = int(max_children * 0.5)
|
||||||
|
max_depth = int(max_depth * 0.5)
|
||||||
|
max_sibling = int(max_sibling * 0.7)
|
||||||
|
|
||||||
|
to_keep = copy.deepcopy(nodes_to_keep)
|
||||||
|
|
||||||
|
if keep_parent:
|
||||||
|
for bid in keep:
|
||||||
|
candidate_node = self.get_node_by_bid(tree, bid)
|
||||||
|
if candidate_node is None:
|
||||||
|
continue
|
||||||
|
nodes_to_keep.update([x.attrib.get(self.id_attr, '') for x in candidate_node.xpath("ancestor::*")])
|
||||||
|
|
||||||
|
return list(nodes_to_keep)
|
||||||
|
|
||||||
|
def prune(self, tree: html.HtmlElement, nodes_to_keep: list[str]) -> html.HtmlElement:
|
||||||
|
# remove nodes not in nodes_to_keep
|
||||||
|
for node in tree.xpath('//*')[::-1]:
|
||||||
|
if node.tag != 'text':
|
||||||
|
is_keep = node.attrib.get(self.id_attr, '') in nodes_to_keep
|
||||||
|
is_candidate = node.attrib.get(self.id_attr, '') in self.keep
|
||||||
|
else:
|
||||||
|
is_keep = (node.getparent().attrib.get(self.id_attr, '') in nodes_to_keep)
|
||||||
|
is_candidate = (node.getparent().attrib.get(self.id_attr, '') in self.keep)
|
||||||
|
|
||||||
|
if not is_keep and node.getparent() is not None:
|
||||||
|
# insert all children into parent
|
||||||
|
for child in node.getchildren():
|
||||||
|
node.addprevious(child)
|
||||||
|
node.getparent().remove(node)
|
||||||
|
else:
|
||||||
|
# if not is_candidate or node.tag == 'text':
|
||||||
|
# node.attrib.pop(self.id_attr, None)
|
||||||
|
if (
|
||||||
|
len(node.attrib) == 0
|
||||||
|
and not any([x.tag == 'text' for x in node.getchildren()])
|
||||||
|
and node.getparent() is not None
|
||||||
|
and node.tag != "text"
|
||||||
|
and len(node.getchildren()) <= 1
|
||||||
|
):
|
||||||
|
# insert all children into parent
|
||||||
|
for child in node.getchildren():
|
||||||
|
node.addprevious(child)
|
||||||
|
node.getparent().remove(node)
|
||||||
|
|
||||||
|
return tree
|
||||||
|
|
||||||
|
def prune_tree(self, dfs_count: int=1, max_depth: int=3, max_children: int=30,
|
||||||
|
max_sibling: int=3, keep_parent: bool=False) -> None:
|
||||||
|
# clone the tree
|
||||||
|
new_tree = copy.deepcopy(self.dom_tree)
|
||||||
|
nodes_to_keep = self.get_keep_elements(new_tree, self.keep, max_depth, max_children, max_sibling, dfs_count, keep_parent)
|
||||||
|
new_tree = self.prune(new_tree, nodes_to_keep)
|
||||||
|
|
||||||
|
self.dom_tree = new_tree
|
||||||
|
|
||||||
|
def get_segment(self, bid: str) -> str:
|
||||||
|
# clone the tree
|
||||||
|
new_tree = copy.deepcopy(self.dom_tree)
|
||||||
|
nodes_to_keep = self.get_keep_elements(new_tree, [bid], 0, 2, 1)
|
||||||
|
new_tree = self.prune(new_tree, nodes_to_keep)
|
||||||
|
dom, _ = self.parse(new_tree, self.keep, [], False)
|
||||||
|
return dom
|
||||||
|
|
||||||
|
def get_rect_data(self, bids: list[str]) -> list[dict[str]]:
|
||||||
|
res = []
|
||||||
|
for bid in bids:
|
||||||
|
label = self.bids2label.get(bid, '')
|
||||||
|
rect = self.rect.get(bid, None)
|
||||||
|
res.append({
|
||||||
|
'bid': bid,
|
||||||
|
'label': label,
|
||||||
|
'rect': rect
|
||||||
|
})
|
||||||
|
return res
|
||||||
|
|
64
VAB-WebArena-Lite/new/html_tools/identifier.py
Executable file
64
VAB-WebArena-Lite/new/html_tools/identifier.py
Executable file
|
@ -0,0 +1,64 @@
|
||||||
|
import secrets
|
||||||
|
|
||||||
|
class IdentifierTool:
|
||||||
|
def __init__(self, method: str='order', existing_labels: dict[str]={}) -> None:
|
||||||
|
self.methods = {
|
||||||
|
'order': self.get_identifier_in_order,
|
||||||
|
'random': self.get_random_identifier,
|
||||||
|
}
|
||||||
|
|
||||||
|
if method is None:
|
||||||
|
method = 'order'
|
||||||
|
|
||||||
|
self.func = self.methods.get(method, None)
|
||||||
|
self.name = method
|
||||||
|
if self.func is None:
|
||||||
|
raise ValueError(f'Invalid method for identifier: {method}')
|
||||||
|
|
||||||
|
self.reset(existing_labels)
|
||||||
|
|
||||||
|
def reset(self, exists: dict[str]={}) -> None:
|
||||||
|
self.identifier = -1
|
||||||
|
self.exists = {} if exists is None else exists
|
||||||
|
|
||||||
|
def get_identifier_in_order(self) -> str:
|
||||||
|
def id2str(id: int) -> str:
|
||||||
|
if id < 26:
|
||||||
|
return chr(id + 65)
|
||||||
|
id -= 26
|
||||||
|
c0 = id // 676
|
||||||
|
c1 = (id // 26) % 26
|
||||||
|
c2 = id % 26
|
||||||
|
label = f'{chr(c1 + 65)}{chr(c2 + 65)}'
|
||||||
|
return label if c0 == 0 else f'{chr(c0 + 64)}{label}'
|
||||||
|
|
||||||
|
self.identifier += 1
|
||||||
|
label = id2str(self.identifier)
|
||||||
|
|
||||||
|
while label in self.exists:
|
||||||
|
self.identifier += 1
|
||||||
|
label = id2str(self.identifier)
|
||||||
|
|
||||||
|
self.exists[label] = True
|
||||||
|
return label
|
||||||
|
|
||||||
|
def get_random_identifier(self) -> str:
|
||||||
|
secret_generator = secrets.SystemRandom()
|
||||||
|
|
||||||
|
def get_random_label(n: int=2) -> str:
|
||||||
|
tmp = ''
|
||||||
|
for _ in range(n):
|
||||||
|
tmp += chr(secret_generator.randint(65, 90))
|
||||||
|
return tmp
|
||||||
|
|
||||||
|
wc = 3 if len(self.exists) > 280 else 2
|
||||||
|
|
||||||
|
label = get_random_label(wc)
|
||||||
|
while label in self.exists:
|
||||||
|
label = get_random_label(wc)
|
||||||
|
|
||||||
|
self.exists[label] = True
|
||||||
|
return label
|
||||||
|
|
||||||
|
def generate(self):
|
||||||
|
return self.func()
|
97
VAB-WebArena-Lite/new/html_tools/prompt.py
Executable file
97
VAB-WebArena-Lite/new/html_tools/prompt.py
Executable file
|
@ -0,0 +1,97 @@
|
||||||
|
from .configs import prompts
|
||||||
|
|
||||||
|
class HtmlPrompt:
|
||||||
|
def __init__(self, prompt: str='') -> None:
|
||||||
|
prompt = self.extract(prompt, 'xml')
|
||||||
|
if prompt not in prompts:
|
||||||
|
raise Exception('Unknown prompt: ' + prompt)
|
||||||
|
|
||||||
|
constructors = {
|
||||||
|
'refine': self.normal_prompt_constructor,
|
||||||
|
'xml': self.normal_prompt_constructor,
|
||||||
|
'new_data': self.new_data_prompt_constructor,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.name = prompt
|
||||||
|
self.prompt = prompts[prompt]
|
||||||
|
self.constructor = constructors[prompt]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def extract(data, default=''):
|
||||||
|
return data if data is not None else default
|
||||||
|
|
||||||
|
def subtree_constructor(self, subtree: list[str]=[]) -> str:
|
||||||
|
return self.prompt['subtree_splitter'].join(subtree)
|
||||||
|
|
||||||
|
def normal_prompt_constructor(self, tag: str='', label: str='', content: str='', subtree_str: str='', class_dict: dict[str]={}) -> str:
|
||||||
|
def add_prefix(data, prefix):
|
||||||
|
return prefix + data if len(data) > 0 else ''
|
||||||
|
|
||||||
|
tag = self.extract(tag)
|
||||||
|
label = self.extract(label)
|
||||||
|
content = self.extract(content)
|
||||||
|
subtree_str = self.extract(subtree_str, '')
|
||||||
|
class_dict = self.extract(class_dict, {})
|
||||||
|
|
||||||
|
label_str = ''
|
||||||
|
if len(label) > 0:
|
||||||
|
label_str = self.prompt['label'].format(label=label)
|
||||||
|
|
||||||
|
classes = []
|
||||||
|
values = set()
|
||||||
|
for key, val in class_dict.items():
|
||||||
|
if val in values:
|
||||||
|
continue
|
||||||
|
values.add(val)
|
||||||
|
classes.append(self.prompt['attr'].format(key=key, attr=val))
|
||||||
|
classes_str = self.prompt['attr_splitter'].join(classes)
|
||||||
|
|
||||||
|
content_splitter = ' ' if len(classes_str) == 0 else self.prompt['attr_splitter']
|
||||||
|
classes_str = add_prefix(classes_str, ' ')
|
||||||
|
content_str = add_prefix(content, content_splitter)
|
||||||
|
subtree_str = add_prefix(subtree_str, ' ')
|
||||||
|
|
||||||
|
return self.prompt['dom'].format(tag=tag, label=label_str, attr=classes_str, content=content_str, subtree=subtree_str)
|
||||||
|
|
||||||
|
def new_data_prompt_constructor(self, tag: str='', label: str='', content: str='', subtree_str: str='', class_dict: dict[str]={}) -> str:
|
||||||
|
def add_prefix(data, prefix):
|
||||||
|
return prefix + data if len(data) > 0 else ''
|
||||||
|
|
||||||
|
tag = self.extract(tag)
|
||||||
|
label = self.extract(label)
|
||||||
|
content = self.extract(content)
|
||||||
|
subtree_str = self.extract(subtree_str, '')
|
||||||
|
class_dict = self.extract(class_dict, {})
|
||||||
|
|
||||||
|
label_str = ''
|
||||||
|
if len(label) > 0:
|
||||||
|
label_str = self.prompt['label'].format(label=label)
|
||||||
|
|
||||||
|
classes = []
|
||||||
|
values = set()
|
||||||
|
|
||||||
|
message = []
|
||||||
|
for key, val in class_dict.items():
|
||||||
|
if val == '':
|
||||||
|
message.append(key)
|
||||||
|
continue
|
||||||
|
if val in values:
|
||||||
|
continue
|
||||||
|
values.add(val)
|
||||||
|
classes.append(self.prompt['attr'].format(key=key, attr=val))
|
||||||
|
|
||||||
|
if len(message) > 0:
|
||||||
|
message_str = ' '.join(message)
|
||||||
|
classes.append(self.prompt['attr'].format(key='message', attr=message_str))
|
||||||
|
|
||||||
|
classes_str = self.prompt['attr_splitter'].join(classes)
|
||||||
|
|
||||||
|
content_splitter = ' ' if len(classes_str) == 0 else self.prompt['attr_splitter']
|
||||||
|
classes_str = add_prefix(classes_str, ' ')
|
||||||
|
content_str = add_prefix(content, content_splitter)
|
||||||
|
subtree_str = add_prefix(subtree_str, ' ')
|
||||||
|
|
||||||
|
return self.prompt['dom'].format(tag=tag, label=label_str, attr=classes_str, content=content_str, subtree=subtree_str)
|
||||||
|
|
||||||
|
def prompt_constructor(self, tag: str='', label: str='', content: str='', subtree_str: str='', class_dict: dict[str]={}) -> str:
|
||||||
|
return self.constructor(tag, label, content, subtree_str, class_dict)
|
43
VAB-WebArena-Lite/new/html_tools/scripts/__init__.py
Executable file
43
VAB-WebArena-Lite/new/html_tools/scripts/__init__.py
Executable file
|
@ -0,0 +1,43 @@
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
rootdir = Path(__file__).parent
|
||||||
|
|
||||||
|
with open(os.path.join(rootdir,'prepare.js'), 'r') as f:
|
||||||
|
prepare_script = f.read()
|
||||||
|
|
||||||
|
with open(os.path.join(rootdir, 'clickable_checker.js'), 'r') as f:
|
||||||
|
clickable_checker_script = f.read()
|
||||||
|
|
||||||
|
with open(os.path.join(rootdir, 'label.js'), 'r') as f:
|
||||||
|
label_script = f.read()
|
||||||
|
|
||||||
|
with open(os.path.join(rootdir, 'element_info.js'), 'r') as f:
|
||||||
|
element_info_script = f.read()
|
||||||
|
|
||||||
|
# draw label on page
|
||||||
|
with open(os.path.join(rootdir, 'label_marker.js'), 'r') as f:
|
||||||
|
label_marker_script = f.read()
|
||||||
|
|
||||||
|
# remove label draw on page
|
||||||
|
remove_label_mark_script = """
|
||||||
|
() => {
|
||||||
|
document.querySelectorAll(".our-dom-marker").forEach(item => {
|
||||||
|
document.body.removeChild(item);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
remove_id_script = """
|
||||||
|
() => {
|
||||||
|
Array.from(document.getElementsByClassName('possible-clickable-element')).forEach((element) => {
|
||||||
|
element.classList.remove('possible-clickable-element');
|
||||||
|
element.removeAttribute('data-value');
|
||||||
|
element.removeAttribute('data-text');
|
||||||
|
element.removeAttribute('data-label');
|
||||||
|
element.removeAttribute('data-bbox');
|
||||||
|
element.removeAttribute('data-status');
|
||||||
|
element.removeAttribute('data-backend-node-id');
|
||||||
|
element.removeAttribute('data-label-id');
|
||||||
|
});
|
||||||
|
}
|
||||||
|
"""
|
148
VAB-WebArena-Lite/new/html_tools/scripts/clickable_checker.js
Executable file
148
VAB-WebArena-Lite/new/html_tools/scripts/clickable_checker.js
Executable file
|
@ -0,0 +1,148 @@
|
||||||
|
() => {
|
||||||
|
var items = Array.prototype.slice.call(
|
||||||
|
document.querySelectorAll('*')
|
||||||
|
).map(function(element) {
|
||||||
|
var vw = Math.max(document.documentElement.clientWidth || 0, window.innerWidth || 0);
|
||||||
|
var vh = Math.max(document.documentElement.clientHeight || 0, window.innerHeight || 0);
|
||||||
|
|
||||||
|
var rects = [...element.getClientRects()].filter(bb => {
|
||||||
|
var center_x = bb.left + bb.width / 2;
|
||||||
|
var center_y = bb.top + bb.height / 2;
|
||||||
|
var elAtCenter = document.elementFromPoint(center_x, center_y);
|
||||||
|
|
||||||
|
if (!elAtCenter) return false;
|
||||||
|
return elAtCenter === element || element.contains(elAtCenter)
|
||||||
|
}).map(bb => {
|
||||||
|
const rect = {
|
||||||
|
left: Math.max(0, bb.left),
|
||||||
|
top: Math.max(0, bb.top),
|
||||||
|
right: Math.min(vw, bb.right),
|
||||||
|
bottom: Math.min(vh, bb.bottom)
|
||||||
|
};
|
||||||
|
return {
|
||||||
|
...rect,
|
||||||
|
width: rect.right - rect.left,
|
||||||
|
height: rect.bottom - rect.top
|
||||||
|
}
|
||||||
|
});
|
||||||
|
// var rects = [];
|
||||||
|
var area = rects.reduce((acc, rect) => acc + rect.width * rect.height, 0);
|
||||||
|
|
||||||
|
const tagName = element.tagName.toLowerCase?.() || "";
|
||||||
|
let isClickable = ((element.onclick != null) || window.getComputedStyle(element).cursor == "pointer");
|
||||||
|
|
||||||
|
// Insert area elements that provide click functionality to an img.
|
||||||
|
if (tagName === "img") {
|
||||||
|
let mapName = element.getAttribute("usemap");
|
||||||
|
if (mapName) {
|
||||||
|
const imgClientRects = element.getClientRects();
|
||||||
|
mapName = mapName.replace(/^#/, "").replace('"', '\\"');
|
||||||
|
const map = document.querySelector(`map[name=\"${mapName}\"]`);
|
||||||
|
if (map && (imgClientRects.length > 0)) isClickable = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isClickable) {
|
||||||
|
const role = element.getAttribute("role");
|
||||||
|
const clickableRoles = [
|
||||||
|
"button",
|
||||||
|
"tab",
|
||||||
|
"link",
|
||||||
|
"checkbox",
|
||||||
|
"menuitem",
|
||||||
|
"menuitemcheckbox",
|
||||||
|
"menuitemradio",
|
||||||
|
"radio",
|
||||||
|
];
|
||||||
|
if (role != null && clickableRoles.includes(role.toLowerCase())) {
|
||||||
|
isClickable = true;
|
||||||
|
} else {
|
||||||
|
const contentEditable = element.getAttribute("contentEditable");
|
||||||
|
if (
|
||||||
|
contentEditable != null &&
|
||||||
|
["", "contenteditable", "true"].includes(contentEditable.toLowerCase())
|
||||||
|
) {
|
||||||
|
isClickable = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for jsaction event listeners on the element.
|
||||||
|
if (!isClickable && element.hasAttribute("jsaction")) {
|
||||||
|
const jsactionRules = element.getAttribute("jsaction").split(";");
|
||||||
|
for (let jsactionRule of jsactionRules) {
|
||||||
|
const ruleSplit = jsactionRule.trim().split(":");
|
||||||
|
if ((ruleSplit.length >= 1) && (ruleSplit.length <= 2)) {
|
||||||
|
const [eventType, namespace, actionName] = ruleSplit.length === 1
|
||||||
|
? ["click", ...ruleSplit[0].trim().split("."), "_"]
|
||||||
|
: [ruleSplit[0], ...ruleSplit[1].trim().split("."), "_"];
|
||||||
|
if (!isClickable) {
|
||||||
|
isClickable = (eventType === "click") && (namespace !== "none") && (actionName !== "_");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isClickable) {
|
||||||
|
const clickableTags = [
|
||||||
|
"input",
|
||||||
|
"textarea",
|
||||||
|
"select",
|
||||||
|
"button",
|
||||||
|
"a",
|
||||||
|
"iframe",
|
||||||
|
"video",
|
||||||
|
"object",
|
||||||
|
"embed",
|
||||||
|
"details"
|
||||||
|
];
|
||||||
|
isClickable = clickableTags.includes(tagName);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isClickable) {
|
||||||
|
if (tagName === "label")
|
||||||
|
isClickable = (element.control != null) && !element.control.disabled;
|
||||||
|
else if (tagName === "img")
|
||||||
|
isClickable = ["zoom-in", "zoom-out"].includes(element.style.cursor);
|
||||||
|
}
|
||||||
|
|
||||||
|
// An element with a class name containing the text "button" might be clickable. However, real
|
||||||
|
// clickables are often wrapped in elements with such class names. So, when we find clickables
|
||||||
|
// based only on their class name, we mark them as unreliable.
|
||||||
|
const className = element.getAttribute("class");
|
||||||
|
if (!isClickable && className && className.toLowerCase().includes("button")) {
|
||||||
|
isClickable = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Elements with tabindex are sometimes useful, but usually not. We can treat them as second
|
||||||
|
// class citizens when it improves UX, so take special note of them.
|
||||||
|
const tabIndexValue = element.getAttribute("tabindex");
|
||||||
|
const tabIndex = tabIndexValue ? parseInt(tabIndexValue) : -1;
|
||||||
|
if (!isClickable && !(tabIndex < 0) && !isNaN(tabIndex)) {
|
||||||
|
isClickable = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
const idValue = element.getAttribute("id");
|
||||||
|
const id = idValue ? idValue.toLowerCase() : "";
|
||||||
|
if (isClickable && area == 0) {
|
||||||
|
const textValue = element.textContent.trim().replace(/\s{2,}/g, ' ');
|
||||||
|
clickable_msg = `${tagName}[id=${id}] ${isClickable} (${area}) ${textValue}`
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
element: element,
|
||||||
|
include: isClickable,
|
||||||
|
area,
|
||||||
|
rects,
|
||||||
|
text: element.textContent.trim().replace(/\s{2,}/g, ' ')
|
||||||
|
};
|
||||||
|
}).filter(item =>
|
||||||
|
item.include && (item.area >= 1)
|
||||||
|
);
|
||||||
|
|
||||||
|
items = items.filter(x => !items.some(y => x.element.contains(y.element) && !(x == y)))
|
||||||
|
|
||||||
|
items.forEach(item => {
|
||||||
|
item.element.classList.add('possible-clickable-element');
|
||||||
|
});
|
||||||
|
}
|
34
VAB-WebArena-Lite/new/html_tools/scripts/element_info.js
Executable file
34
VAB-WebArena-Lite/new/html_tools/scripts/element_info.js
Executable file
|
@ -0,0 +1,34 @@
|
||||||
|
() => {
|
||||||
|
function getElementInfo(element) {
|
||||||
|
return {
|
||||||
|
"bid": element.getAttribute("data-backend-node-id") || "",
|
||||||
|
"label": element.getAttribute("data-label-id") || "",
|
||||||
|
"tag": element.tagName.toLowerCase?.() || "",
|
||||||
|
"area": JSON.parse("[" + (element.getAttribute("data-bbox") || "") + "]"),
|
||||||
|
"text": element.innerText?.trim().replace(/\s{2,}/g, " ") || "",
|
||||||
|
"id": element.getAttribute("id") || "",
|
||||||
|
"role": element.getAttribute("role") || "",
|
||||||
|
"aria-label": element.getAttribute("aria-label") || "",
|
||||||
|
"href": element.getAttribute("href") || "",
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
var all_items = Array.prototype.slice.call(
|
||||||
|
document.querySelectorAll("*")
|
||||||
|
).map((element) => {
|
||||||
|
return getElementInfo(element);
|
||||||
|
});
|
||||||
|
|
||||||
|
var clickable_items = Array.prototype.slice.call(
|
||||||
|
document.querySelectorAll("*")
|
||||||
|
).filter(
|
||||||
|
element => element.getAttribute("data-label-id")
|
||||||
|
).map((element) => {
|
||||||
|
return getElementInfo(element);
|
||||||
|
});
|
||||||
|
|
||||||
|
return {
|
||||||
|
all_elements: all_items,
|
||||||
|
clickable_elements: clickable_items
|
||||||
|
};
|
||||||
|
}
|
74
VAB-WebArena-Lite/new/html_tools/scripts/label.js
Executable file
74
VAB-WebArena-Lite/new/html_tools/scripts/label.js
Executable file
|
@ -0,0 +1,74 @@
|
||||||
|
(packet) => {
|
||||||
|
function int2str(index) {
|
||||||
|
var str = "";
|
||||||
|
while (index >= 0) {
|
||||||
|
str = String.fromCharCode(65 + index % 26) + str;
|
||||||
|
index = Math.floor(index / 26) - 1;
|
||||||
|
}
|
||||||
|
return str;
|
||||||
|
};
|
||||||
|
|
||||||
|
selector = packet.selector
|
||||||
|
index = packet.startIndex
|
||||||
|
var items = Array.prototype.slice.call(
|
||||||
|
document.querySelectorAll(selector)
|
||||||
|
);
|
||||||
|
|
||||||
|
var vw = Math.max(document.documentElement.clientWidth || 0, window.innerWidth || 0);
|
||||||
|
var vh = Math.max(document.documentElement.clientHeight || 0, window.innerHeight || 0);
|
||||||
|
|
||||||
|
items = items.filter(
|
||||||
|
x => !items.some(y => x.contains(y) && !(x == y))
|
||||||
|
).map(element => {
|
||||||
|
var bb = element.getClientRects();
|
||||||
|
var rect = {
|
||||||
|
left: 0,
|
||||||
|
top: 0,
|
||||||
|
right: 0,
|
||||||
|
bottom: 0,
|
||||||
|
width: 0,
|
||||||
|
height: 0
|
||||||
|
};
|
||||||
|
var keep = false;
|
||||||
|
var text = "", id = -1;
|
||||||
|
if (bb.length > 0) {
|
||||||
|
bb = bb[0];
|
||||||
|
rect = {
|
||||||
|
left: Math.max(0, bb.left),
|
||||||
|
top: Math.max(0, bb.top),
|
||||||
|
right: Math.min(vw, bb.right),
|
||||||
|
bottom: Math.min(vh, bb.bottom)
|
||||||
|
};
|
||||||
|
rect = {
|
||||||
|
...rect,
|
||||||
|
width: rect.right - rect.left,
|
||||||
|
height: rect.bottom - rect.top
|
||||||
|
};
|
||||||
|
if (rect.width > 0 || rect.height > 0) {
|
||||||
|
keep = true;
|
||||||
|
if (index >= 0) {
|
||||||
|
// id = int2str(index++);
|
||||||
|
id = index++;
|
||||||
|
element.setAttribute("data-label-id", id);
|
||||||
|
}
|
||||||
|
var childNodes = element.childNodes;
|
||||||
|
|
||||||
|
for (var i = 0; i < childNodes.length; i++) {
|
||||||
|
if (childNodes[i].nodeType == Node.TEXT_NODE) {
|
||||||
|
text += childNodes[i].textContent;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
keep: true,
|
||||||
|
id,
|
||||||
|
rects: rect,
|
||||||
|
tag: element.tagName.toLowerCase?.() || "",
|
||||||
|
text,//: element.innerText?.trim().replace(/\s{2,}/g, " ") || ""
|
||||||
|
};
|
||||||
|
}).filter(x => x.keep);
|
||||||
|
|
||||||
|
return [items, index];
|
||||||
|
}
|
65
VAB-WebArena-Lite/new/html_tools/scripts/label_marker.js
Executable file
65
VAB-WebArena-Lite/new/html_tools/scripts/label_marker.js
Executable file
|
@ -0,0 +1,65 @@
|
||||||
|
(items) => {
|
||||||
|
function getRandomColor() {
|
||||||
|
var letters = '0123456789ABCDEF';
|
||||||
|
var color = '#';
|
||||||
|
for (var i = 0; i < 6; i++) {
|
||||||
|
color += letters[Math.floor(Math.random() * 16)];
|
||||||
|
}
|
||||||
|
return color;
|
||||||
|
}
|
||||||
|
|
||||||
|
items.filter(
|
||||||
|
item => item.id != ""
|
||||||
|
).forEach((item) => {
|
||||||
|
const bbox = item.rects;
|
||||||
|
const id_string = `dom-marker-id-${index}`;
|
||||||
|
|
||||||
|
index = item.id;
|
||||||
|
|
||||||
|
outerElement = document.createElement("div");
|
||||||
|
outerElement.classList.add("our-dom-marker");
|
||||||
|
// var borderColor = getRandomColor();
|
||||||
|
var borderColor = "#FFFF00";
|
||||||
|
outerElement.style.outline = `2px dashed ${borderColor}`;
|
||||||
|
outerElement.style.position = "fixed";
|
||||||
|
outerElement.style.left = bbox.left - 2 + "px";
|
||||||
|
outerElement.style.top = bbox.top - 2 + "px";
|
||||||
|
outerElement.style.width = bbox.width + 4 + "px";
|
||||||
|
outerElement.style.height = bbox.height + 4 + "px";
|
||||||
|
outerElement.style.pointerEvents = "none";
|
||||||
|
outerElement.style.boxSizing = "border-box";
|
||||||
|
outerElement.style.zIndex = 2147483647;
|
||||||
|
|
||||||
|
innerElement = document.createElement("div");
|
||||||
|
innerElement.classList.add("our-dom-marker");
|
||||||
|
innerElement.style.outline = `2px dashed #222288`;
|
||||||
|
innerElement.style.position = "fixed";
|
||||||
|
innerElement.style.left = bbox.left + "px";
|
||||||
|
innerElement.style.top = bbox.top + "px";
|
||||||
|
innerElement.style.width = bbox.width + "px";
|
||||||
|
innerElement.style.height = bbox.height + "px";
|
||||||
|
innerElement.style.pointerEvents = "none";
|
||||||
|
innerElement.style.boxSizing = "border-box";
|
||||||
|
innerElement.style.zIndex = 2147483647;
|
||||||
|
|
||||||
|
// Add floating label at the corner
|
||||||
|
var label = document.createElement("span");
|
||||||
|
var topPosition = 25;
|
||||||
|
if (bbox.top < 25) topPosition = bbox.top;
|
||||||
|
label.textContent = index;
|
||||||
|
label.style.position = "absolute";
|
||||||
|
label.style.top = `-${topPosition}px`;
|
||||||
|
label.style.left = "0px";
|
||||||
|
label.style.background = borderColor;
|
||||||
|
label.style.color = "black";
|
||||||
|
label.style.padding = "2px 4px";
|
||||||
|
label.style.fontSize = "16px";
|
||||||
|
label.style.borderRadius = "2px";
|
||||||
|
label.style.fontWeight = "bold";
|
||||||
|
outerElement.appendChild(label);
|
||||||
|
|
||||||
|
document.body.appendChild(outerElement);
|
||||||
|
document.body.appendChild(innerElement);
|
||||||
|
})
|
||||||
|
return items;
|
||||||
|
}
|
83
VAB-WebArena-Lite/new/html_tools/scripts/prepare.js
Executable file
83
VAB-WebArena-Lite/new/html_tools/scripts/prepare.js
Executable file
|
@ -0,0 +1,83 @@
|
||||||
|
() => {
|
||||||
|
// mark backend node id
|
||||||
|
var vw = Math.max(document.documentElement.clientWidth || 0, window.innerWidth || 0);
|
||||||
|
var vh = Math.max(document.documentElement.clientHeight || 0, window.innerHeight || 0);
|
||||||
|
|
||||||
|
var backendId = 0;
|
||||||
|
Array.prototype.slice.call(
|
||||||
|
document.querySelectorAll("*")
|
||||||
|
).forEach((element) => {
|
||||||
|
element.setAttribute("data-backend-node-id", backendId);
|
||||||
|
backendId++;
|
||||||
|
|
||||||
|
var tag = element.tagName.toLowerCase?.() || "";
|
||||||
|
var bb = element.getClientRects();
|
||||||
|
var rect = {
|
||||||
|
left: 0,
|
||||||
|
top: 0,
|
||||||
|
right: 0,
|
||||||
|
bottom: 0,
|
||||||
|
width: 0,
|
||||||
|
height: 0
|
||||||
|
};
|
||||||
|
|
||||||
|
if (bb.length > 0) {
|
||||||
|
bb = bb[0];
|
||||||
|
// rect = {
|
||||||
|
// left: Math.round(Math.max(0, bb.left) * 100) / 100,
|
||||||
|
// top: Math.round(Math.max(0, bb.top) * 100) / 100,
|
||||||
|
// right: Math.round(Math.min(vw, bb.right) * 100) / 100,
|
||||||
|
// bottom: Math.round(Math.min(vh, bb.bottom) * 100) / 100
|
||||||
|
// };
|
||||||
|
rect = {
|
||||||
|
left: (Math.round(bb.left) * 100) / 100,
|
||||||
|
top: (Math.round(bb.top) * 100) / 100,
|
||||||
|
right: (Math.round(bb.right) * 100) / 100,
|
||||||
|
bottom: (Math.round(bb.bottom) * 100) / 100
|
||||||
|
};
|
||||||
|
rect = {
|
||||||
|
...rect,
|
||||||
|
width: Math.round((rect.right - rect.left) * 100) / 100,
|
||||||
|
height: Math.round((rect.bottom - rect.top) * 100) / 100
|
||||||
|
};
|
||||||
|
|
||||||
|
element.setAttribute("data-bbox", `${rect.left},${rect.top},${rect.width},${rect.height}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (element.hasChildNodes()) {
|
||||||
|
let children = Array.prototype.slice.call(element.childNodes);
|
||||||
|
var texts = children.filter(
|
||||||
|
(node) => node.nodeType == Node.TEXT_NODE
|
||||||
|
).map(
|
||||||
|
(node) => node.textContent.trim().replace(/\s{2,}/g, " ") || ""
|
||||||
|
).filter(
|
||||||
|
(text) => text.length > 0
|
||||||
|
)
|
||||||
|
element.setAttribute("data-text", texts.join(","));
|
||||||
|
}
|
||||||
|
|
||||||
|
// fix select issue
|
||||||
|
if (tag == "select") {
|
||||||
|
var value = element.value;
|
||||||
|
var text = element.options[element.selectedIndex]?.text || "";
|
||||||
|
element.setAttribute("data-value", value);
|
||||||
|
element.setAttribute("data-text", text);
|
||||||
|
element.options[element.selectedIndex]?.setAttribute("data-status", "selected");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tag == "input") {
|
||||||
|
var input_type = element.getAttribute("type") || "";
|
||||||
|
if (input_type == "checkbox") {
|
||||||
|
var status = element.checked? "checked" : "not-checked";
|
||||||
|
element.setAttribute("data-status", status);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// fix input and textarea issue
|
||||||
|
Array.prototype.slice.call(
|
||||||
|
document.querySelectorAll("input, textarea")
|
||||||
|
).forEach(element => {
|
||||||
|
element.setAttribute("data-value", element.value);
|
||||||
|
});
|
||||||
|
}
|
101
VAB-WebArena-Lite/new/html_tools/utils.py
Executable file
101
VAB-WebArena-Lite/new/html_tools/utils.py
Executable file
|
@ -0,0 +1,101 @@
|
||||||
|
from lxml import html
|
||||||
|
def get_xpath_top_down(element: html.HtmlElement, id_column: str='temp_id', label_column: str='temp_clickable_label', path: str='', order: int=0,
|
||||||
|
in_svg: bool=False, temp_id: int=0) -> tuple[int, dict[str, str], dict[str]]:
|
||||||
|
used_labels, i2xpath = {}, {}
|
||||||
|
# path
|
||||||
|
tag = element.tag.lower()
|
||||||
|
in_svg = in_svg or (tag == 'svg')
|
||||||
|
|
||||||
|
if not in_svg and 'id' in element.attrib:
|
||||||
|
node_id = element.attrib['id']
|
||||||
|
path = f'//*[@id="{node_id}"]'
|
||||||
|
else:
|
||||||
|
suffix = f'[{order}]' if order > 0 else ''
|
||||||
|
prefix = f'*[name()="{tag}"]' if in_svg else tag
|
||||||
|
path = path + '/' + prefix + suffix
|
||||||
|
|
||||||
|
# add temp id
|
||||||
|
element.attrib[id_column] = str(temp_id)
|
||||||
|
ori_label = element.attrib.get(label_column, '')
|
||||||
|
if ori_label != '':
|
||||||
|
used_labels[ori_label] = True
|
||||||
|
|
||||||
|
bid = str(temp_id)
|
||||||
|
i2xpath[bid] = path
|
||||||
|
i2xpath[path] = bid
|
||||||
|
i2xpath[f'xpath/{path}'] = bid
|
||||||
|
i2xpath[f'xpath=/{path}'] = bid
|
||||||
|
|
||||||
|
temp_id += 1
|
||||||
|
|
||||||
|
# traverse node
|
||||||
|
children = element.getchildren()
|
||||||
|
tag_dict = {}
|
||||||
|
id_list = []
|
||||||
|
for child in children:
|
||||||
|
ctag = child.tag.lower()
|
||||||
|
if ctag not in tag_dict:
|
||||||
|
tag_dict[ctag] = 0
|
||||||
|
tag_dict[ctag] += 1
|
||||||
|
id_list.append(tag_dict[ctag])
|
||||||
|
|
||||||
|
for cid, child in zip(id_list, children):
|
||||||
|
ctag = child.tag.lower()
|
||||||
|
cod = cid if tag_dict[ctag] > 1 else 0
|
||||||
|
temp_id, i2x, ulabels = get_xpath_top_down(child, id_column, label_column, path, cod, in_svg, temp_id)
|
||||||
|
i2xpath.update(i2x)
|
||||||
|
used_labels.update(ulabels)
|
||||||
|
|
||||||
|
return temp_id, i2xpath, used_labels
|
||||||
|
|
||||||
|
def print_html_object(obj: str='') -> str:
|
||||||
|
tab_cnt = 0
|
||||||
|
result, content, sep = '', '', ''
|
||||||
|
last_is_left, last_is_right = False, False
|
||||||
|
for ch in obj:
|
||||||
|
if ch == '<':
|
||||||
|
result += '\n'
|
||||||
|
if len(content.strip()) > 0:
|
||||||
|
result += sep + content.strip() + '\n'
|
||||||
|
result += sep + '<'
|
||||||
|
|
||||||
|
tab_cnt += 1
|
||||||
|
sep = ' ' * tab_cnt
|
||||||
|
|
||||||
|
content = ''
|
||||||
|
last_is_right = False
|
||||||
|
last_is_left = True
|
||||||
|
elif ch == '>':
|
||||||
|
if last_is_left:
|
||||||
|
result += content
|
||||||
|
else:
|
||||||
|
if last_is_right:
|
||||||
|
result += '\n'
|
||||||
|
if len(content.strip()) > 0:
|
||||||
|
result += sep + content.strip() + '\n'
|
||||||
|
|
||||||
|
tab_cnt -= 1
|
||||||
|
sep = ' ' * tab_cnt
|
||||||
|
|
||||||
|
if not last_is_left:
|
||||||
|
result += sep
|
||||||
|
|
||||||
|
result += '>'
|
||||||
|
content = ''
|
||||||
|
|
||||||
|
last_is_right = True
|
||||||
|
last_is_left = False
|
||||||
|
else:
|
||||||
|
content += ch
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def rect2tuple(rect: str) -> tuple[int, int, int, int]:
|
||||||
|
if rect is None or type(rect) != type('str'):
|
||||||
|
return None
|
||||||
|
rect = rect.strip()
|
||||||
|
if rect.count(',') != 3:
|
||||||
|
return None
|
||||||
|
rect = rect.split(',')
|
||||||
|
rect = [float(r) for r in rect]
|
||||||
|
return tuple(rect)
|
|
@ -36,7 +36,6 @@ def retry_with_exponential_backoff( # type: ignore
|
||||||
# Initialize variables
|
# Initialize variables
|
||||||
num_retries = 0
|
num_retries = 0
|
||||||
delay = initial_delay
|
delay = initial_delay
|
||||||
|
|
||||||
# Loop until a successful response or max_retries is hit or an exception is raised
|
# Loop until a successful response or max_retries is hit or an exception is raised
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
@ -142,27 +141,32 @@ async def agenerate_from_openai_completion(
|
||||||
@retry_with_exponential_backoff
|
@retry_with_exponential_backoff
|
||||||
def generate_from_openai_completion(
|
def generate_from_openai_completion(
|
||||||
prompt: str,
|
prompt: str,
|
||||||
engine: str,
|
model: str,
|
||||||
temperature: float,
|
temperature: float,
|
||||||
max_tokens: int,
|
max_tokens: int,
|
||||||
top_p: float,
|
top_p: float,
|
||||||
context_length: int,
|
|
||||||
stop_token: str | None = None,
|
stop_token: str | None = None,
|
||||||
|
api_key: str | None = None,
|
||||||
|
base_url: str | None = None
|
||||||
) -> str:
|
) -> str:
|
||||||
if "OPENAI_API_KEY" not in os.environ:
|
if "OPENAI_API_KEY" not in os.environ and api_key is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
|
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
|
||||||
)
|
)
|
||||||
|
if api_key is not None:
|
||||||
|
client = OpenAI(api_key=api_key, base_url=base_url)
|
||||||
response = client.completions.create(
|
response = client.completions.create(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
engine=engine,
|
model=model,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
stop=[stop_token],
|
stop=[stop_token],
|
||||||
)
|
)
|
||||||
answer: str = response["choices"][0]["text"]
|
try:
|
||||||
|
answer: str = response["choices"][0]["text"]
|
||||||
|
except:
|
||||||
|
answer: str = response.choices[0].text
|
||||||
return answer
|
return answer
|
||||||
|
|
||||||
|
|
||||||
|
|
13
VAB-WebArena-Lite/new/p_webrl.json
Normal file
13
VAB-WebArena-Lite/new/p_webrl.json
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
{
|
||||||
|
"intro": "",
|
||||||
|
"examples": [],
|
||||||
|
"template": "",
|
||||||
|
"meta_data": {
|
||||||
|
"observation": "webrl",
|
||||||
|
"action_type": "webrl_id",
|
||||||
|
"keywords": [],
|
||||||
|
"prompt_constructor": "WebRLPromptConstructor",
|
||||||
|
"answer_phrase": "",
|
||||||
|
"action_splitter": ""
|
||||||
|
}
|
||||||
|
}
|
13
VAB-WebArena-Lite/new/p_webrl_chat.json
Normal file
13
VAB-WebArena-Lite/new/p_webrl_chat.json
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
{
|
||||||
|
"intro": "# Setup\nYou are a professional web browsing agent assistant that can fulfill user's high-level instructions. Given Simplified html of the browsed webpage at each step, you plan operations in python-style pseudo code using provided functions, or customize functions (if necessary) and then provide their implementations. \n# More details about the code\nYour code should be readable, simple, and only **ONE-LINE-OF-CODE** at a time, avoid using loop statement and only use if-else control if necessary. Predefined functions are as follow:\n\n```\ndef do(action, argument, element):\n\t\"\"\"A single browsing operation on the webpage.\n\tArgs:\n\t\t:param action: one of the actions from [\"Click\", \"Right Click\", \"Type\", \"Search\", \"Hover\", \"Scroll Up\", \"Scroll Down\", \"Press Enter\", \"Switch Tab\", \"Select Dropdown Option\", \"Wait\"].\n\t\t:param argument: optional. Only for \"Type\", \"Search\", \"Switch Page\", and \"Select Dropdown Option\", indicating the content to type in, page number(start from 0) to switch, or key to press.\n\t\t \"Search\" action is equivalent to \"Type\" action plus \"Enter\" key press.\n\t\t:param element: optional. Only for \"Click\", \"Right Click\", \"Type\", \"Search\", \"Select Dropdown Option\", and \"Hover\". Should be specific element id in the html.\n\tReturns:\n\t\tNone. The webpage will be updated after executing the action.\n\t\"\"\"\n\ndef exit(message):\n\t\"\"\"Ending the browsing process if the assistant think it has fulfilled the goal.\n\tArgs:\n\t\t:param message: optional. If user's instruction is a question, return assistant's answer in the message based on the browsing content.\n\tReturns:\n\t\tNone.\n\t\"\"\"\n\ndef go_backward():\n\t\"\"\"Go back to the previous page.\n\t\"\"\"\n\ndef go_forward():\n \"\"\"Go forward to the next page.\n \"\"\"\n```\n\nHere are some examples:\n- # Element: the 'REPORTS' section on the left sidebar\ndo(action=\"Click\", element=\"7\")\n- # Element: the 'Period' dropdown, middle center\ndo(action=\"Select Dropdown Option\", argument=\"Month\", element=\"20\")\n- # Element: the 'From' date picker input field, middle center\ndo(action=\"Type\", argument=\"01/01/2023\", element=\"22\")\n- do(action=\"Scroll Down\")\n- exit(message=\"The top-3 best-selling products in January 2023 are: 1\")\n- # Element: The search bar\ndo(action=\"Search\", argument=\"international airport near Carnegie Mellon University within a driving distance of 50 km\", element=\"13\")\n- # Note: Pittsburgh International Airport, Southern Beltway, Findlay Township, Allegheny County, 15231, United States\n# Element: The field labeled 'Pittsburgh International Airport' in the top left corner\ndo(action=\"Type\", argument=\"Cleveland Hopkins International Airport\", element=\"14\")\n\nREMEMBER: \n- only **ONE-LINE-OF-CODE** at a time\n- Don't generate an operation element that you do not see in the screenshot.\n- Use \"# Element\" to describe the element you choose in the html.\n- Use '# Note\" to record information useful to answer the instruction if needed.\n- If you find yourself fallen into some sort of loop, try to use another method or change your action.\n- If you think a page is still loading or still playing animation and you want to wait a while, use \"Wait\" action.\n- You are acting in a real world, try your best not to reject user's demand. Solve all the problem you encounter.\n- If you think you didn't get expected webpage, you should try using more precise and locative description of the element.\n- You must make sure the target element of `find_element*` exists on current screenshot, if not, you should navigate to the target place first.\n- You must identify potential errors or mistakes made by `find_element*` function and correct them. If the webpage is not as expected, you should try to re-do or un-do the operation.\n- You should **NEVER** try to use the browser's address bar at the top of the page to navigate.\n- Your answer shouldn't be in a code snippet format. Just write the function name and its arguments.\n- For quote, exit, go_backward, go_forward request, you should strictly obey the format of quote, exit, go_backward, go_forward functions, answers like do(\"Quote\", xxx, None) or do(\"quote\", xxx, None)are not allowed.\n- If you use do function to perform \"Click\", \"Right Click\", \"Type\", \"Search\", \"Select Dropdown Option\", and \"Hover\", the param element must not be None.\n",
|
||||||
|
"examples": [],
|
||||||
|
"template": "",
|
||||||
|
"meta_data": {
|
||||||
|
"observation": "webrl",
|
||||||
|
"action_type": "webrl_id",
|
||||||
|
"keywords": [],
|
||||||
|
"prompt_constructor": "WebRLChatPromptConstructor",
|
||||||
|
"answer_phrase": "",
|
||||||
|
"action_splitter": ""
|
||||||
|
}
|
||||||
|
}
|
1351
VAB-WebArena-Lite/new/processors.py
Normal file
1351
VAB-WebArena-Lite/new/processors.py
Normal file
File diff suppressed because it is too large
Load Diff
|
@ -499,3 +499,119 @@ class MultimodalCoTPromptConstructor(CoTPromptConstructor):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Provider {self.lm_config.provider} not implemented"
|
f"Provider {self.lm_config.provider} not implemented"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class WebRLPromptConstructor(PromptConstructor):
|
||||||
|
"""The agent will direct predict the action"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
instruction_path: str | Path,
|
||||||
|
lm_config: lm_config.LMConfig,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
):
|
||||||
|
super().__init__(instruction_path, lm_config, tokenizer)
|
||||||
|
|
||||||
|
def construct(
|
||||||
|
self,
|
||||||
|
trajectory: Trajectory,
|
||||||
|
intent: str,
|
||||||
|
meta_data: dict[str, Any] = {},
|
||||||
|
) -> APIInput:
|
||||||
|
"""Construct prompt given the trajectory"""
|
||||||
|
state_info: StateInfo = trajectory[-1] # type: ignore[assignment]
|
||||||
|
|
||||||
|
obs = state_info["observation"][self.obs_modality]
|
||||||
|
max_obs_length = self.lm_config.gen_config["max_obs_length"]
|
||||||
|
if max_obs_length:
|
||||||
|
if self.lm_config.provider == "google":
|
||||||
|
print("NOTE: This is a Gemini model, so we use characters instead of tokens for max_obs_length.")
|
||||||
|
obs = obs[:max_obs_length]
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
obs = self.tokenizer.decode(self.tokenizer.encode(obs)[:max_obs_length]) # type: ignore[arg-type]
|
||||||
|
except:
|
||||||
|
print("NOTE: There is no available tokenizer, so we use characters instead of tokens for max_obs_length.")
|
||||||
|
obs = obs[:max_obs_length]
|
||||||
|
|
||||||
|
turn_num = len(meta_data["action_history"])
|
||||||
|
if turn_num == 1:
|
||||||
|
previous_action_str = []
|
||||||
|
else:
|
||||||
|
previous_action_str = meta_data["action_history"][1:]
|
||||||
|
|
||||||
|
index = turn_num - 1
|
||||||
|
history = ""
|
||||||
|
for i in range(index - 1, -1, -1):
|
||||||
|
if i == 0:
|
||||||
|
history = f"Round {i}\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n{intent}\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n{previous_action_str[i]}\n\n" + history
|
||||||
|
else:
|
||||||
|
history = f"Round {i}\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n** Simplified html **\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n{previous_action_str[i]}\n\n" + history
|
||||||
|
if len(history) + len(obs) > (16384 - 512):
|
||||||
|
obs = obs[:(16384 - 512)-len(history)]
|
||||||
|
current_turn = f"Round {index}\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n{obs}\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
|
||||||
|
prompt = f"Task Instruction: {intent}\n\n{history}{current_turn}"
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def extract_action(self, response: str) -> str:
|
||||||
|
return response
|
||||||
|
|
||||||
|
class WebRLChatPromptConstructor(PromptConstructor):
|
||||||
|
"""The agent will direct predict the action"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
instruction_path: str | Path,
|
||||||
|
lm_config: lm_config.LMConfig,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
):
|
||||||
|
super().__init__(instruction_path, lm_config, tokenizer)
|
||||||
|
|
||||||
|
def construct(
|
||||||
|
self,
|
||||||
|
trajectory: Trajectory,
|
||||||
|
intent: str,
|
||||||
|
meta_data: dict[str, Any] = {},
|
||||||
|
) -> APIInput:
|
||||||
|
"""Construct prompt given the trajectory"""
|
||||||
|
state_info: StateInfo = trajectory[-1] # type: ignore[assignment]
|
||||||
|
|
||||||
|
obs = state_info["observation"][self.obs_modality]
|
||||||
|
max_obs_length = self.lm_config.gen_config["max_obs_length"]
|
||||||
|
if max_obs_length:
|
||||||
|
if self.lm_config.provider == "google":
|
||||||
|
print("NOTE: This is a Gemini model, so we use characters instead of tokens for max_obs_length.")
|
||||||
|
obs = obs[:max_obs_length]
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
obs = self.tokenizer.decode(self.tokenizer.encode(obs)[:max_obs_length]) # type: ignore[arg-type]
|
||||||
|
except:
|
||||||
|
print("NOTE: There is no available tokenizer, so we use characters instead of tokens for max_obs_length.")
|
||||||
|
obs = obs[:max_obs_length]
|
||||||
|
|
||||||
|
turn_num = len(meta_data["action_history"])
|
||||||
|
if turn_num == 1:
|
||||||
|
previous_action_str = []
|
||||||
|
else:
|
||||||
|
previous_action_str = meta_data["action_history"][1:]
|
||||||
|
|
||||||
|
index = turn_num - 1
|
||||||
|
conversations = []
|
||||||
|
for i in range(index - 1, -1, -1):
|
||||||
|
if i == 0:
|
||||||
|
content_user = f"Task Instruction: {intent}\n\nRound {i}\n{intent}"
|
||||||
|
content_assistant = f"{previous_action_str[i]}"
|
||||||
|
else:
|
||||||
|
content_user = f"Round {i}\n** Simplified html **"
|
||||||
|
content_assistant = f"{previous_action_str[i]}"
|
||||||
|
conversation = [{'role': 'user', 'content': content_user}, {'role': 'assistant', 'content': content_assistant}]
|
||||||
|
conversations = conversation + conversations
|
||||||
|
|
||||||
|
system_turn = [{'role': 'system', 'content': self.instruction['intro']}]
|
||||||
|
current_turn = [{'role': 'user', 'content': f'Round {index}\n\n{obs}'}]
|
||||||
|
conversations = system_turn + conversations + current_turn
|
||||||
|
|
||||||
|
return conversations
|
||||||
|
|
||||||
|
def extract_action(self, response: str) -> str:
|
||||||
|
return response
|
|
@ -13,11 +13,13 @@ import tempfile
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
import cv2
|
||||||
|
import shutil
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
from agent import (
|
from agent import (
|
||||||
PromptAgent,
|
PromptAgent,
|
||||||
|
@ -62,6 +64,34 @@ formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||||||
console_handler.setFormatter(formatter)
|
console_handler.setFormatter(formatter)
|
||||||
file_handler.setFormatter(formatter)
|
file_handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
def text_wrap(text, font, max_width):
|
||||||
|
lines = []
|
||||||
|
paragraphs = text.split('\n') # 按照 \n 分割文本为段落
|
||||||
|
for paragraph in paragraphs:
|
||||||
|
words = paragraph.split(' ')
|
||||||
|
line = ''
|
||||||
|
for word in words:
|
||||||
|
# 临时行
|
||||||
|
test_line = f"{line} {word}".strip()
|
||||||
|
# 获取临时行的宽度
|
||||||
|
test_line_bbox = font.getbbox(test_line)
|
||||||
|
test_line_width = test_line_bbox[2] - test_line_bbox[0]
|
||||||
|
if test_line_width <= max_width:
|
||||||
|
# 如果临时行的宽度不超过图片宽度,继续添加单词
|
||||||
|
line = test_line
|
||||||
|
else:
|
||||||
|
# 如果超过了最大宽度,保存当前行,开始新的一行
|
||||||
|
lines.append(line)
|
||||||
|
line = word
|
||||||
|
# 添加每段的最后一行
|
||||||
|
if line:
|
||||||
|
lines.append(line)
|
||||||
|
# 每个段落后添加一个空行,以保留段落的换行
|
||||||
|
lines.append('')
|
||||||
|
# 移除最后一个空行(不需要额外的空行)
|
||||||
|
if lines[-1] == '':
|
||||||
|
lines.pop()
|
||||||
|
return lines
|
||||||
|
|
||||||
def config() -> argparse.Namespace:
|
def config() -> argparse.Namespace:
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
|
@ -88,6 +118,7 @@ def config() -> argparse.Namespace:
|
||||||
"html",
|
"html",
|
||||||
"image",
|
"image",
|
||||||
"image_som",
|
"image_som",
|
||||||
|
"webrl",
|
||||||
],
|
],
|
||||||
default="accessibility_tree",
|
default="accessibility_tree",
|
||||||
help="Observation type",
|
help="Observation type",
|
||||||
|
@ -176,6 +207,9 @@ def config() -> argparse.Namespace:
|
||||||
|
|
||||||
# logging related
|
# logging related
|
||||||
parser.add_argument("--result_dir", type=str, default="")
|
parser.add_argument("--result_dir", type=str, default="")
|
||||||
|
|
||||||
|
# if use self-deployed model
|
||||||
|
parser.add_argument("--planner_ip", type=str, default=None)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# check the whether the action space is compatible with the observation space
|
# check the whether the action space is compatible with the observation space
|
||||||
|
@ -196,7 +230,7 @@ def config() -> argparse.Namespace:
|
||||||
|
|
||||||
|
|
||||||
def early_stop(
|
def early_stop(
|
||||||
trajectory: Trajectory, max_steps: int, thresholds: dict[str, int]
|
trajectory: Trajectory, max_steps: int, thresholds: dict[str, int], actions=None
|
||||||
) -> tuple[bool, str]:
|
) -> tuple[bool, str]:
|
||||||
"""Check whether need to stop early"""
|
"""Check whether need to stop early"""
|
||||||
|
|
||||||
|
@ -228,28 +262,38 @@ def early_stop(
|
||||||
if len(action_seq) == 0:
|
if len(action_seq) == 0:
|
||||||
return False, ""
|
return False, ""
|
||||||
|
|
||||||
last_action: Action = action_seq[-1]
|
if actions is None:
|
||||||
|
last_action: Action = action_seq[-1]
|
||||||
|
if last_action["action_type"] != ActionTypes.TYPE:
|
||||||
|
if len(last_k_actions) >= k:
|
||||||
|
if all(
|
||||||
|
[
|
||||||
|
is_equivalent(action, last_action)
|
||||||
|
for action in last_k_actions
|
||||||
|
]
|
||||||
|
):
|
||||||
|
return True, f"Same action for {k} times"
|
||||||
|
else:
|
||||||
|
# check the action sequence
|
||||||
|
if (
|
||||||
|
sum([is_equivalent(action, last_action) for action in action_seq])
|
||||||
|
>= k
|
||||||
|
):
|
||||||
|
return True, f"Same typing action for {k} times"
|
||||||
|
return False, ""
|
||||||
|
|
||||||
if last_action["action_type"] != ActionTypes.TYPE:
|
else:
|
||||||
|
last_k_actions = actions[-k:]
|
||||||
|
last_action = actions[-1]
|
||||||
if len(last_k_actions) >= k:
|
if len(last_k_actions) >= k:
|
||||||
if all(
|
if all(
|
||||||
[
|
[
|
||||||
is_equivalent(action, last_action)
|
action == last_action
|
||||||
for action in last_k_actions
|
for action in last_k_actions
|
||||||
]
|
]
|
||||||
):
|
):
|
||||||
return True, f"Same action for {k} times"
|
return True, f"Same action for {k} times"
|
||||||
|
return False, ""
|
||||||
else:
|
|
||||||
# check the action sequence
|
|
||||||
if (
|
|
||||||
sum([is_equivalent(action, last_action) for action in action_seq])
|
|
||||||
>= k
|
|
||||||
):
|
|
||||||
return True, f"Same typing action for {k} times"
|
|
||||||
|
|
||||||
return False, ""
|
|
||||||
|
|
||||||
|
|
||||||
def update_action_history(path: str, task_id: int, actions: List[str], score: float=-0.1):
|
def update_action_history(path: str, task_id: int, actions: List[str], score: float=-0.1):
|
||||||
obj = {
|
obj = {
|
||||||
|
@ -387,16 +431,20 @@ def test(
|
||||||
obs, info = env.reset(options={"config_file": config_file})
|
obs, info = env.reset(options={"config_file": config_file})
|
||||||
state_info: StateInfo = {"observation": obs, "info": info}
|
state_info: StateInfo = {"observation": obs, "info": info}
|
||||||
trajectory.append(state_info)
|
trajectory.append(state_info)
|
||||||
|
|
||||||
meta_data = {"action_history": ["None"]}
|
meta_data = {"action_history": ["None"]}
|
||||||
out_path = os.path.join(args.result_dir, "actions", f"{task_id}.json")
|
out_path = os.path.join(args.result_dir, "actions", f"{task_id}.json")
|
||||||
actions = []
|
actions = []
|
||||||
|
|
||||||
|
os.makedirs(os.path.join(args.result_dir, 'screehshots'), exist_ok=True)
|
||||||
|
if os.path.exists(os.path.join(args.result_dir, 'screehshots', f"{task_id}")):
|
||||||
|
shutil.rmtree(os.path.join(args.result_dir, 'screehshots', f"{task_id}"))
|
||||||
|
os.makedirs(os.path.join(args.result_dir, 'screehshots', f"{task_id}"))
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
update_action_history(out_path, task_id, actions=actions)
|
update_action_history(out_path, task_id, actions=actions)
|
||||||
|
# If no actions variable is passed, the behavior of early_stop is the same as the original one.
|
||||||
early_stop_flag, stop_info = early_stop(
|
early_stop_flag, stop_info = early_stop(
|
||||||
trajectory, max_steps, early_stop_thresholds
|
trajectory, max_steps, early_stop_thresholds, actions
|
||||||
)
|
)
|
||||||
|
|
||||||
if early_stop_flag:
|
if early_stop_flag:
|
||||||
|
@ -407,7 +455,7 @@ def test(
|
||||||
trajectory,
|
trajectory,
|
||||||
intent,
|
intent,
|
||||||
images=images,
|
images=images,
|
||||||
meta_data=meta_data,
|
meta_data=meta_data
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# get the error message
|
# get the error message
|
||||||
|
@ -426,9 +474,25 @@ def test(
|
||||||
render_helper.render(
|
render_helper.render(
|
||||||
action, state_info, meta_data, args.render_screenshot
|
action, state_info, meta_data, args.render_screenshot
|
||||||
)
|
)
|
||||||
|
|
||||||
|
current_screenshot = os.path.join(args.result_dir, 'screehshots', f"{task_id}", f"{len(actions)}.png")
|
||||||
|
_ = env.page.viewport_size
|
||||||
|
env.page.screenshot(path="/dev/null")
|
||||||
|
env.page.screenshot(path=current_screenshot)
|
||||||
|
element_id = action["element_id"]
|
||||||
|
if element_id != "":
|
||||||
|
element = env.page.query_selector(f"[data-label-id='{element_id}']")
|
||||||
|
if element:
|
||||||
|
bbox = element.bounding_box()
|
||||||
|
bbox = [int(bbox['x']), int(bbox['y']), int(bbox['width']),int(bbox['height'])]
|
||||||
|
image = cv2.imread(current_screenshot)
|
||||||
|
cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[0] + bbox[2], bbox[1] + bbox[3]), (0, 255, 0), 2)
|
||||||
|
cv2.circle(image, (int(bbox[0] + bbox[2] / 2), int(bbox[1] + bbox[3] / 2)), radius=0, color=(0, 255, 0), thickness=2)
|
||||||
|
cv2.imwrite(current_screenshot, image)
|
||||||
|
|
||||||
meta_data["action_history"].append(action_str)
|
meta_data["action_history"].append(action_str)
|
||||||
actions.append(action_str)
|
actions.append(action_str)
|
||||||
print(action_str)
|
print('Action String: ', action_str)
|
||||||
|
|
||||||
if action["action_type"] == ActionTypes.STOP:
|
if action["action_type"] == ActionTypes.STOP:
|
||||||
break
|
break
|
||||||
|
@ -540,7 +604,7 @@ if __name__ == "__main__":
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
args = config()
|
args = config()
|
||||||
args.sleep_after_execution = 2.5
|
args.sleep_after_execution = 3.0
|
||||||
prepare(args)
|
prepare(args)
|
||||||
|
|
||||||
test_config_base_dir = args.test_config_base_dir
|
test_config_base_dir = args.test_config_base_dir
|
||||||
|
|
111
VAB-WebArena-Lite/new/score.py
Normal file
111
VAB-WebArena-Lite/new/score.py
Normal file
|
@ -0,0 +1,111 @@
|
||||||
|
import os, json, sys, copy
|
||||||
|
|
||||||
|
USE_TASKS = [i for i in range(165)]
|
||||||
|
|
||||||
|
def get_result(res_dict, src="all"):
|
||||||
|
if len(res_dict) == 0:
|
||||||
|
return ''
|
||||||
|
|
||||||
|
success_id = [k for k, v in res_dict.items() if v >= 1.0]
|
||||||
|
score = len(success_id)
|
||||||
|
finish_count = len(res_dict)
|
||||||
|
pacc, acc = score / finish_count * 100, score / TASKS * 100
|
||||||
|
|
||||||
|
print(sorted(success_id))
|
||||||
|
|
||||||
|
meta = """
|
||||||
|
--------
|
||||||
|
src file: {}
|
||||||
|
successed: {:3} / {:4} (812)
|
||||||
|
partial accuracy: {:7}
|
||||||
|
overall accuracy: {:7}
|
||||||
|
--------
|
||||||
|
""".format(src, int(score), finish_count, round(pacc, 2), round(acc, 2))
|
||||||
|
|
||||||
|
print(meta)
|
||||||
|
|
||||||
|
def export_result(res_dict, src=".", note=["1.0", "0.0"], show_all=False):
|
||||||
|
out_string = ""
|
||||||
|
for id in USE_TASKS:
|
||||||
|
# with open(f"Pipeline/config_files/{id}.json", "r") as f:
|
||||||
|
# jd = json.load(f)
|
||||||
|
|
||||||
|
# if "map" in jd["sites"]:
|
||||||
|
# continue
|
||||||
|
if id in res_dict:
|
||||||
|
if res_dict[id] >= 1.0:
|
||||||
|
out_string += note[0]
|
||||||
|
else:
|
||||||
|
out_string += note[1]
|
||||||
|
elif show_all:
|
||||||
|
out_string += note[1]
|
||||||
|
out_string += "\n"
|
||||||
|
|
||||||
|
with open(os.path.join(src, 'export.txt'), 'w') as f:
|
||||||
|
f.write(out_string)
|
||||||
|
|
||||||
|
TASKS = 165
|
||||||
|
|
||||||
|
files = sys.argv[1]
|
||||||
|
file_list = files.split(',')
|
||||||
|
|
||||||
|
all_result = {}
|
||||||
|
|
||||||
|
for src in file_list:
|
||||||
|
path = os.path.join(src, 'actions')
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
finished = os.listdir(path)
|
||||||
|
|
||||||
|
for file in finished:
|
||||||
|
if not file.endswith('.json'):
|
||||||
|
continue
|
||||||
|
with open(os.path.join(path, file), 'r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
continue
|
||||||
|
|
||||||
|
task_id = data.get('task_id', 1000)
|
||||||
|
# if task_id >= TASKS:
|
||||||
|
# continue
|
||||||
|
|
||||||
|
task_score = data.get('score', 0)
|
||||||
|
if task_score < 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
result[task_id] = task_score
|
||||||
|
if task_id not in all_result or task_score > all_result[task_id]:
|
||||||
|
all_result[task_id] = task_score
|
||||||
|
|
||||||
|
get_result(result, src)
|
||||||
|
export_result(result, src=src)
|
||||||
|
|
||||||
|
if len(file_list) > 1:
|
||||||
|
get_result(all_result)
|
||||||
|
export_result(all_result, show_all=True)
|
||||||
|
|
||||||
|
with open('./config_files/wa/test_webarena_lite.raw.json') as fp:
|
||||||
|
configs = json.load(fp)
|
||||||
|
sub_results = {}
|
||||||
|
sub_ids = {}
|
||||||
|
for item in configs:
|
||||||
|
web = tuple(item['sites'])
|
||||||
|
task_id = int(item['task_id'])
|
||||||
|
old_task_id = int(item['old_task_id'])
|
||||||
|
if web not in sub_results:
|
||||||
|
sub_results[web] = []
|
||||||
|
if web not in sub_ids:
|
||||||
|
sub_ids[web] = []
|
||||||
|
if task_id in all_result:
|
||||||
|
sub_results[web].append(all_result[task_id])
|
||||||
|
if all_result[task_id] == 1:
|
||||||
|
sub_ids[web].append(old_task_id)
|
||||||
|
else:
|
||||||
|
sub_results[web].append(0)
|
||||||
|
for web in sub_results:
|
||||||
|
print(web, round(sum(sub_results[web]) / len(sub_results[web]) * 100, 1))
|
||||||
|
|
||||||
|
print('\n\n')
|
||||||
|
for web in sub_ids:
|
||||||
|
print(web, sorted(sub_ids[web]), len(sub_ids[web]))
|
|
@ -187,7 +187,7 @@
|
||||||
"string_match"
|
"string_match"
|
||||||
],
|
],
|
||||||
"reference_answers": {
|
"reference_answers": {
|
||||||
"must_include": [
|
"fuzzy_match": [
|
||||||
"0"
|
"0"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -774,8 +774,8 @@
|
||||||
"string_match"
|
"string_match"
|
||||||
],
|
],
|
||||||
"reference_answers": {
|
"reference_answers": {
|
||||||
"fuzzy_match": [
|
"must_include": [
|
||||||
"914km"
|
"914"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"reference_url": "",
|
"reference_url": "",
|
||||||
|
@ -1139,7 +1139,7 @@
|
||||||
"string_match"
|
"string_match"
|
||||||
],
|
],
|
||||||
"reference_answers": {
|
"reference_answers": {
|
||||||
"must_include": [
|
"fuzzy_match": [
|
||||||
"0"
|
"0"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -1679,7 +1679,7 @@
|
||||||
"string_match"
|
"string_match"
|
||||||
],
|
],
|
||||||
"reference_answers": {
|
"reference_answers": {
|
||||||
"exact_match": "N/A"
|
"fuzzy_match": ["N/A"]
|
||||||
},
|
},
|
||||||
"reference_url": "",
|
"reference_url": "",
|
||||||
"program_html": [],
|
"program_html": [],
|
||||||
|
@ -2605,7 +2605,7 @@
|
||||||
"string_match"
|
"string_match"
|
||||||
],
|
],
|
||||||
"reference_answers": {
|
"reference_answers": {
|
||||||
"exact_match": "yjlou"
|
"must_include": "yjlou"
|
||||||
},
|
},
|
||||||
"reference_url": "",
|
"reference_url": "",
|
||||||
"program_html": [],
|
"program_html": [],
|
||||||
|
@ -3712,7 +3712,7 @@
|
||||||
"string_match"
|
"string_match"
|
||||||
],
|
],
|
||||||
"reference_answers": {
|
"reference_answers": {
|
||||||
"exact_match": "N/A"
|
"fuzzy_match": "N/A"
|
||||||
},
|
},
|
||||||
"reference_url": "",
|
"reference_url": "",
|
||||||
"program_html": [],
|
"program_html": [],
|
||||||
|
|
|
@ -1,15 +1,18 @@
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from transformers import LlamaTokenizer # type: ignore
|
from transformers import LlamaTokenizer, AutoTokenizer # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class Tokenizer(object):
|
class Tokenizer(object):
|
||||||
def __init__(self, provider: str, model_name: str) -> None:
|
def __init__(self, provider: str, model_name: str) -> None:
|
||||||
if provider == "openai":
|
if provider == "openai":
|
||||||
self.tokenizer = tiktoken.encoding_for_model(model_name)
|
try:
|
||||||
|
self.tokenizer = tiktoken.encoding_for_model(model_name)
|
||||||
|
except: # The provider is in openai format but the model is a finetuned model
|
||||||
|
self.tokenizer = None
|
||||||
elif provider == "huggingface":
|
elif provider == "huggingface":
|
||||||
self.tokenizer = LlamaTokenizer.from_pretrained(model_name)
|
self.tokenizer = LlamaTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||||
# turn off adding special tokens automatically
|
# turn off adding special tokens automatically
|
||||||
self.tokenizer.add_special_tokens = False # type: ignore[attr-defined]
|
self.tokenizer.add_special_tokens = False # type: ignore[attr-defined]
|
||||||
self.tokenizer.add_bos_token = False # type: ignore[attr-defined]
|
self.tokenizer.add_bos_token = False # type: ignore[attr-defined]
|
||||||
|
|
|
@ -21,6 +21,8 @@ APIInput = str | list[Any] | dict[str, Any]
|
||||||
def call_llm(
|
def call_llm(
|
||||||
lm_config: lm_config.LMConfig,
|
lm_config: lm_config.LMConfig,
|
||||||
prompt: APIInput,
|
prompt: APIInput,
|
||||||
|
api_key = None,
|
||||||
|
base_url = None
|
||||||
) -> str:
|
) -> str:
|
||||||
response: str
|
response: str
|
||||||
if lm_config.provider == "openai":
|
if lm_config.provider == "openai":
|
||||||
|
@ -39,11 +41,13 @@ def call_llm(
|
||||||
assert isinstance(prompt, str)
|
assert isinstance(prompt, str)
|
||||||
response = generate_from_openai_completion(
|
response = generate_from_openai_completion(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
engine=lm_config.model,
|
model=lm_config.model,
|
||||||
temperature=lm_config.gen_config["temperature"],
|
temperature=lm_config.gen_config["temperature"],
|
||||||
max_tokens=lm_config.gen_config["max_tokens"],
|
max_tokens=lm_config.gen_config["max_tokens"],
|
||||||
top_p=lm_config.gen_config["top_p"],
|
top_p=lm_config.gen_config["top_p"],
|
||||||
stop_token=lm_config.gen_config["stop_token"],
|
stop_token=lm_config.gen_config["stop_token"],
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|
97
VAB-WebArena-Lite/new/wa_parallel_run_webrl.sh
Normal file
97
VAB-WebArena-Lite/new/wa_parallel_run_webrl.sh
Normal file
|
@ -0,0 +1,97 @@
|
||||||
|
#!/bin/bash
|
||||||
|
DATASET='webarena' # TODO: select from ['webarena', 'visualwebarena']
|
||||||
|
result_dir='' # TODO: set your result_dir
|
||||||
|
provider='openai' # TODO: select from ['openai', 'finetune', ...]
|
||||||
|
model='' # TODO: assign model name, which is used for action generation
|
||||||
|
planner_ip='' # TODO: ip address of the model you are deploying (only if you are deploying your own model using e.g. vllm)
|
||||||
|
instruction_path='agent/prompts/jsons/p_webrl.json' # e.g., agent/prompts/jsons/p_cot_id_actree_2s.json
|
||||||
|
test_config_base_dir='config_files/wa/test_webarena_lite' # e.g., config_files/wa/test_webarena_lite
|
||||||
|
temperature=0.0
|
||||||
|
|
||||||
|
SERVER='' # TODO: your server address
|
||||||
|
MAP_SERVER='' # TODO: the server address for MAP tasks
|
||||||
|
OPENAI_API_KEY='' # TODO: if you test OpenAI APIs
|
||||||
|
OPENAI_ORGANIZATION=''
|
||||||
|
CONDA_ENV_NAME='' # TODO: the name of your conda environment for testing WebArena
|
||||||
|
|
||||||
|
ENV_VARIABLES="export DATASET=${DATASET}; export SHOPPING='http://${SERVER}:7770';export SHOPPING_ADMIN='http://${SERVER}:7780/admin';export REDDIT='http://${SERVER}:9999';export GITLAB='http://${SERVER}:8023';export MAP='http://${MAP_SERVER}:3000';export WIKIPEDIA='http://${SERVER}:8888/wikipedia_en_all_maxi_2022-05/A/User:The_other_Kiwix_guy/Landing';export HOMEPAGE='http://${SERVER}:4399';export OPENAI_API_KEY=${OPENAI_API_KEY};export OPENAI_ORGANIZATION=${OPENAI_ORGANIZATION}"
|
||||||
|
|
||||||
|
# get the number of tmux panes
|
||||||
|
num_panes=$(tmux list-panes | wc -l)
|
||||||
|
|
||||||
|
# calculate how many panes need to be created
|
||||||
|
let "panes_to_create = 7 - num_panes"
|
||||||
|
|
||||||
|
# array of tmux commands to create each pane
|
||||||
|
tmux_commands=(
|
||||||
|
'tmux split-window -h'
|
||||||
|
'tmux split-window -v'
|
||||||
|
'tmux select-pane -t 0; tmux split-window -v'
|
||||||
|
'tmux split-window -v'
|
||||||
|
'tmux select-pane -t 3; tmux split-window -v'
|
||||||
|
'tmux select-pane -t 5; tmux split-window -v'
|
||||||
|
)
|
||||||
|
|
||||||
|
# create panes up to 7
|
||||||
|
for ((i=0; i<$panes_to_create; i++)); do
|
||||||
|
eval ${tmux_commands[$i]}
|
||||||
|
done
|
||||||
|
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Function to run a job
|
||||||
|
run_job() {
|
||||||
|
tmux select-pane -t $1
|
||||||
|
COMMAND="python run.py \
|
||||||
|
--instruction_path ${instruction_path} \
|
||||||
|
--test_start_idx $2 \
|
||||||
|
--test_end_idx $3 \
|
||||||
|
--result_dir ${result_dir} \
|
||||||
|
--test_config_base_dir ${test_config_base_dir} \
|
||||||
|
--provider ${provider} \
|
||||||
|
--model ${model} \
|
||||||
|
--mode completion \
|
||||||
|
--planner_ip ${planner_ip} \
|
||||||
|
--stop_token \"<|eot_id|>\" \
|
||||||
|
--temperature ${temperature} \
|
||||||
|
--max_obs_length 0 \
|
||||||
|
--max_tokens 2048 \
|
||||||
|
--viewport_width 1280 \
|
||||||
|
--viewport_height 720 \
|
||||||
|
--parsing_failure_th 5 \
|
||||||
|
--repeating_action_failure_th 5 \
|
||||||
|
--action_set_tag webrl_id --observation_type webrl"
|
||||||
|
tmux send-keys "tmux set mouse on; conda activate ${CONDA_ENV_NAME}; ${ENV_VARIABLES}; until ${COMMAND}; do echo 'crashed' >&2; sleep 1; done" C-m
|
||||||
|
sleep 3
|
||||||
|
}
|
||||||
|
|
||||||
|
TOLERANCE=2
|
||||||
|
run_batch() {
|
||||||
|
args=("$@") # save all arguments in an array
|
||||||
|
num_jobs=${#args[@]} # get number of arguments
|
||||||
|
|
||||||
|
for ((i=1; i<$num_jobs; i++)); do
|
||||||
|
run_job $i ${args[i-1]} ${args[i]}
|
||||||
|
done
|
||||||
|
|
||||||
|
# Wait for all jobs to finish
|
||||||
|
while tmux list-panes -F "#{pane_pid} #{pane_current_command}" | grep -q python; do
|
||||||
|
sleep 100 # wait for 10 seconds before checking again
|
||||||
|
done
|
||||||
|
|
||||||
|
# Run checker
|
||||||
|
while ! python scripts/check_error_runs.py ${result_dir} --delete_errors --tolerance ${TOLERANCE}; do
|
||||||
|
echo "Check failed, rerunning jobs..."
|
||||||
|
for ((i=1; i<$num_jobs; i++)); do
|
||||||
|
run_job $i ${args[i-1]} ${args[i]}
|
||||||
|
done
|
||||||
|
|
||||||
|
# Wait for all jobs to finish
|
||||||
|
while tmux list-panes -F "#{pane_pid} #{pane_current_command}" | grep -q python; do
|
||||||
|
sleep 100 # wait for 10 seconds before checking again
|
||||||
|
done
|
||||||
|
done
|
||||||
|
|
||||||
|
}
|
||||||
|
run_batch 0 28 56 84 112 140 165
|
||||||
|
|
97
VAB-WebArena-Lite/new/wa_parallel_run_webrl_chat.sh
Normal file
97
VAB-WebArena-Lite/new/wa_parallel_run_webrl_chat.sh
Normal file
|
@ -0,0 +1,97 @@
|
||||||
|
#!/bin/bash
|
||||||
|
DATASET='webarena' # TODO: select from ['webarena', 'visualwebarena']
|
||||||
|
result_dir='' # TODO: set your result_dir
|
||||||
|
provider='openai' # TODO: select from ['openai', 'finetune', ...]
|
||||||
|
model='' # TODO: assign model name, which is used for action generation
|
||||||
|
planner_ip='' # TODO: ip address of the model you are deploying (only if you are deploying your own model using e.g. vllm)
|
||||||
|
instruction_path='agent/prompts/jsons/p_webrl_chat.json' # e.g., agent/prompts/jsons/p_cot_id_actree_2s.json
|
||||||
|
test_config_base_dir='config_files/wa/test_webarena_lite' # e.g., config_files/wa/test_webarena_lite
|
||||||
|
temperature=0.0
|
||||||
|
|
||||||
|
SERVER='' # TODO: your server address
|
||||||
|
MAP_SERVER='' # TODO: the server address for MAP tasks
|
||||||
|
OPENAI_API_KEY='' # TODO: if you test OpenAI APIs
|
||||||
|
OPENAI_ORGANIZATION=''
|
||||||
|
CONDA_ENV_NAME='' # TODO: the name of your conda environment for testing WebArena
|
||||||
|
|
||||||
|
ENV_VARIABLES="export DATASET=${DATASET}; export SHOPPING='http://${SERVER}:7770';export SHOPPING_ADMIN='http://${SERVER}:7780/admin';export REDDIT='http://${SERVER}:9999';export GITLAB='http://${SERVER}:8023';export MAP='http://${MAP_SERVER}:3000';export WIKIPEDIA='http://${SERVER}:8888/wikipedia_en_all_maxi_2022-05/A/User:The_other_Kiwix_guy/Landing';export HOMEPAGE='http://${SERVER}:4399';export OPENAI_API_KEY=${OPENAI_API_KEY};export OPENAI_ORGANIZATION=${OPENAI_ORGANIZATION}"
|
||||||
|
|
||||||
|
# get the number of tmux panes
|
||||||
|
num_panes=$(tmux list-panes | wc -l)
|
||||||
|
|
||||||
|
# calculate how many panes need to be created
|
||||||
|
let "panes_to_create = 7 - num_panes"
|
||||||
|
|
||||||
|
# array of tmux commands to create each pane
|
||||||
|
tmux_commands=(
|
||||||
|
'tmux split-window -h'
|
||||||
|
'tmux split-window -v'
|
||||||
|
'tmux select-pane -t 0; tmux split-window -v'
|
||||||
|
'tmux split-window -v'
|
||||||
|
'tmux select-pane -t 3; tmux split-window -v'
|
||||||
|
'tmux select-pane -t 5; tmux split-window -v'
|
||||||
|
)
|
||||||
|
|
||||||
|
# create panes up to 7
|
||||||
|
for ((i=0; i<$panes_to_create; i++)); do
|
||||||
|
eval ${tmux_commands[$i]}
|
||||||
|
done
|
||||||
|
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Function to run a job
|
||||||
|
run_job() {
|
||||||
|
tmux select-pane -t $1
|
||||||
|
COMMAND="python run.py \
|
||||||
|
--instruction_path ${instruction_path} \
|
||||||
|
--test_start_idx $2 \
|
||||||
|
--test_end_idx $3 \
|
||||||
|
--result_dir ${result_dir} \
|
||||||
|
--test_config_base_dir ${test_config_base_dir} \
|
||||||
|
--provider ${provider} \
|
||||||
|
--model ${model} \
|
||||||
|
--mode chat \
|
||||||
|
--planner_ip ${planner_ip} \
|
||||||
|
--stop_token \"<|eot_id|>\" \
|
||||||
|
--temperature ${temperature} \
|
||||||
|
--max_obs_length 0 \
|
||||||
|
--max_tokens 2048 \
|
||||||
|
--viewport_width 1280 \
|
||||||
|
--viewport_height 720 \
|
||||||
|
--parsing_failure_th 5 \
|
||||||
|
--repeating_action_failure_th 5 \
|
||||||
|
--action_set_tag webrl_id --observation_type webrl"
|
||||||
|
tmux send-keys "tmux set mouse on; conda activate ${CONDA_ENV_NAME}; ${ENV_VARIABLES}; until ${COMMAND}; do echo 'crashed' >&2; sleep 1; done" C-m
|
||||||
|
sleep 3
|
||||||
|
}
|
||||||
|
|
||||||
|
TOLERANCE=2
|
||||||
|
run_batch() {
|
||||||
|
args=("$@") # save all arguments in an array
|
||||||
|
num_jobs=${#args[@]} # get number of arguments
|
||||||
|
|
||||||
|
for ((i=1; i<$num_jobs; i++)); do
|
||||||
|
run_job $i ${args[i-1]} ${args[i]}
|
||||||
|
done
|
||||||
|
|
||||||
|
# Wait for all jobs to finish
|
||||||
|
while tmux list-panes -F "#{pane_pid} #{pane_current_command}" | grep -q python; do
|
||||||
|
sleep 100 # wait for 10 seconds before checking again
|
||||||
|
done
|
||||||
|
|
||||||
|
# Run checker
|
||||||
|
while ! python scripts/check_error_runs.py ${result_dir} --delete_errors --tolerance ${TOLERANCE}; do
|
||||||
|
echo "Check failed, rerunning jobs..."
|
||||||
|
for ((i=1; i<$num_jobs; i++)); do
|
||||||
|
run_job $i ${args[i-1]} ${args[i]}
|
||||||
|
done
|
||||||
|
|
||||||
|
# Wait for all jobs to finish
|
||||||
|
while tmux list-panes -F "#{pane_pid} #{pane_current_command}" | grep -q python; do
|
||||||
|
sleep 100 # wait for 10 seconds before checking again
|
||||||
|
done
|
||||||
|
done
|
||||||
|
|
||||||
|
}
|
||||||
|
run_batch 0 28 56 84 112 140 165
|
||||||
|
|
|
@ -14,6 +14,15 @@ cp -f new/generate_test_data.py visualwebarena/scripts/generate_test_data.py
|
||||||
cp -f new/run.py visualwebarena/run.py
|
cp -f new/run.py visualwebarena/run.py
|
||||||
cp -f new/agent.py visualwebarena/agent/agent.py
|
cp -f new/agent.py visualwebarena/agent/agent.py
|
||||||
cp -f new/prompt_constructor.py visualwebarena/agent/prompts/prompt_constructor.py
|
cp -f new/prompt_constructor.py visualwebarena/agent/prompts/prompt_constructor.py
|
||||||
|
cp -f new/p_webrl.json visualwebarena/agent/prompts/jsons/p_webrl.json
|
||||||
|
cp -f new/p_webrl_chat.json visualwebarena/agent/prompts/jsons/p_webrl_chat.json
|
||||||
|
|
||||||
|
# browser_env
|
||||||
|
cp -f new/actions.py visualwebarena/browser_env/actions.py
|
||||||
|
cp -f new/envs.py visualwebarena/browser_env/envs.py
|
||||||
|
cp -f new/helper_functions_browser.py visualwebarena/browser_env/helper_functions.py
|
||||||
|
cp -f new/processors.py visualwebarena/browser_env/processors.py
|
||||||
|
cp -rf new/html_tools visualwebarena/browser_env/
|
||||||
|
|
||||||
# llms
|
# llms
|
||||||
cp -f new/utils.py visualwebarena/llms/utils.py
|
cp -f new/utils.py visualwebarena/llms/utils.py
|
||||||
|
@ -22,15 +31,20 @@ cp -f new/lm_config.py visualwebarena/llms/lm_config.py
|
||||||
cp -f new/tokenizers.py visualwebarena/llms/tokenizers.py
|
cp -f new/tokenizers.py visualwebarena/llms/tokenizers.py
|
||||||
cp -f new/api_utils.py visualwebarena/llms/providers/api_utils.py
|
cp -f new/api_utils.py visualwebarena/llms/providers/api_utils.py
|
||||||
cp -f new/openai_utils.py visualwebarena/llms/providers/openai_utils.py
|
cp -f new/openai_utils.py visualwebarena/llms/providers/openai_utils.py
|
||||||
|
cp -f new/utils.py visualwebarena/llms/utils.py
|
||||||
|
|
||||||
# eval
|
# eval
|
||||||
cp -f new/evaluators.py visualwebarena/evaluation_harness/evaluators.py
|
cp -f new/evaluators.py visualwebarena/evaluation_harness/evaluators.py
|
||||||
cp -f new/helper_functions.py visualwebarena/evaluation_harness/helper_functions.py
|
cp -f new/helper_functions_eval.py visualwebarena/evaluation_harness/helper_functions.py
|
||||||
|
|
||||||
# misc
|
# misc
|
||||||
cp -f README.md visualwebarena/README.md
|
cp -f README.md visualwebarena/README.md
|
||||||
cp -f new/wa_parallel_run.sh visualwebarena/wa_parallel_run.sh
|
cp -f new/wa_parallel_run.sh visualwebarena/wa_parallel_run.sh
|
||||||
|
|
||||||
|
cp -f new/score.py visualwebarena/score.py
|
||||||
|
cp -f new/wa_parallel_run_webrl.sh visualwebarena/wa_parallel_run_webrl.sh
|
||||||
|
cp -f new/wa_parallel_run_webrl_chat.sh visualwebarena/wa_parallel_run_webrl_chat.sh
|
||||||
|
|
||||||
# 3. remove temporary files
|
# 3. remove temporary files
|
||||||
mv visualwebarena/* .
|
mv visualwebarena/* .
|
||||||
rm -rf new
|
rm -rf new
|
||||||
|
|
Loading…
Reference in New Issue
Block a user