Add documentation for VAB-WebArena-Lite
This commit is contained in:
parent
9b7a654aaa
commit
0c6f549214
50
README.md
50
README.md
|
@ -26,32 +26,30 @@ Compared to its predecessor [AgentBench](https://github.com/THUDM/AgentBench), V
|
||||||
|
|
||||||
## Table of Contents
|
## Table of Contents
|
||||||
|
|
||||||
|
- [Quick Start](#quick-start)
|
||||||
- [Dataset Summary](#dataset-summary)
|
- [Dataset Summary](#dataset-summary)
|
||||||
- [Leaderboard](#leaderboard)
|
- [Leaderboard](#leaderboard)
|
||||||
- [Quick Start](#quick-start)
|
- [Quick Start](#quick-start)
|
||||||
- [Acknowledgement](#acknowledgement)
|
- [Acknowledgement](#acknowledgement)
|
||||||
- [Citation](#citation)
|
- [Citation](#citation)
|
||||||
|
|
||||||
## Dataset Summary
|
|
||||||
|
|
||||||
We offer two splits for each dataset: Testing and Training. Different from its predecessor [AgentBench](https://github.com/THUDM/AgentBench), VAB is accompanied with a trajectory training set for behavior cloning (BC) training, which allows development of more potent visual foundation agents with emerging open LMMs.
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
## Leaderboard
|
|
||||||
|
|
||||||
Here is the scores on test set results of VAB. All metrics are task Success Rate (SR). Noted that proprietary LMMs are tested with mere **Prompting**, and open LMMs are tested after **Multitask Finetuning** on VAB training set, as they usually fail to follow complicated agent task instructions.
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
This section will guide you on how to use `gpt-4o-2024-05-13` as an agent to launch 4 concurrent `VAB-Minecraft` tasks.
|
This section will first give you an overview to the use and architecture of VAB.
|
||||||
For the specific framework structure, please refer to AgentBench's [Framework Introduction](https://github.com/THUDM/AgentBench/blob/main/docs/Introduction_en.md).
|
Next, it will guide you on how to use `gpt-4o-2024-05-13` as an exemplar agent to launch 4 concurrent `VAB-Minecraft` tasks.
|
||||||
|
|
||||||
|
### Overview on VAB Framework
|
||||||
|
|
||||||
|
To allow fast evaluation over agent tasks, we leverage AgentBench's framework as the backbone (currently for VAB-OmniGibson, VAB-Minecraft, and VAB-CSS).
|
||||||
|
If you are interested in its detailed implementation, please refer to AgentBench's [Framework Introduction](https://github.com/THUDM/AgentBench/blob/main/docs/Introduction_en.md) (which may not be necessary).
|
||||||
|
Basically, the framework calls all LLM/LMM in API formats via `Agent-Controller`, and accesses to environments via `Task-Controller`.
|
||||||
|
The `Assigner` will automatically assign evaluation tasks by pairing `Agent-Controller` and `Task-Controller` to optimize the overall evaluation speed.
|
||||||
For more detailed configuration and launch methods, please check [Configuration Guide](docs/Config_en.md)
|
For more detailed configuration and launch methods, please check [Configuration Guide](docs/Config_en.md)
|
||||||
and [Program Entrance Guide](docs/Entrance_en.md).
|
and [Program Entrance Guide](docs/Entrance_en.md).
|
||||||
|
|
||||||
### Step 1. Prerequisites
|

|
||||||
|
|
||||||
|
### Step 1. Prerequisites for All Environments
|
||||||
|
|
||||||
Clone this repo and install the dependencies.
|
Clone this repo and install the dependencies.
|
||||||
|
|
||||||
|
@ -68,7 +66,14 @@ Ensure that [Docker](https://www.docker.com/) is properly installed.
|
||||||
docker ps
|
docker ps
|
||||||
```
|
```
|
||||||
|
|
||||||
For specific environments, please refer to their respective prerequisites: [VAB-OmniGibson](docs/README_setup.md#Setup-for-VAB-OmniGibson), [VAB-Minecraft](docs/README_setup.md#Setup-for-VAB-Minecraft), [VAB-CSS](docs/README_setup.md#Setup-for-VAB-CSS).
|
For specific environments, please refer to their additional prerequisites respectively.
|
||||||
|
For VAB-WebArena-Lite, it is based on [WebArena](https://github.com/webarena-x/webarena) with some modifications, so please read its individual setup carefully.
|
||||||
|
|
||||||
|
* [VAB-OmniGibson Setup](docs/detailed_setups/VAB-OmniGibson.md)
|
||||||
|
* [VAB-Minecraft Setup](docs/detailed_setups/VAB-Minecraft.md)
|
||||||
|
* VAB-Mobile: Ongoing
|
||||||
|
* [VAB-WebArena-Lite Setup](VAB-WebArena-Lite/README.md) (Separate installation and evaluation method)
|
||||||
|
* VAB-CSS: Ongoing
|
||||||
|
|
||||||
### Step 2. Configure the Agent
|
### Step 2. Configure the Agent
|
||||||
|
|
||||||
|
@ -117,6 +122,19 @@ python -m src.assigner --auto-retry --config configs/assignments/omnigibson.yaml
|
||||||
|
|
||||||
You can modify the config files to launch other tasks or change task concurrency.
|
You can modify the config files to launch other tasks or change task concurrency.
|
||||||
|
|
||||||
|
## Dataset Summary
|
||||||
|
|
||||||
|
We offer two splits for each dataset: Testing and Training. Different from its predecessor [AgentBench](https://github.com/THUDM/AgentBench), VAB is accompanied with a trajectory training set for behavior cloning (BC) training, which allows development of more potent visual foundation agents with emerging open LMMs.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
## Leaderboard
|
||||||
|
|
||||||
|
Here is the scores on test set results of VAB. All metrics are task Success Rate (SR). Noted that proprietary LMMs are tested with mere **Prompting**, and open LMMs are tested after **Multitask Finetuning** on VAB training set, as they usually fail to follow complicated agent task instructions.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
## Acknowledgement
|
## Acknowledgement
|
||||||
This project is heavily built upon the following repositories (to be updated):
|
This project is heavily built upon the following repositories (to be updated):
|
||||||
|
|
||||||
|
|
218
VAB-WebArena-Lite/README.md
Normal file
218
VAB-WebArena-Lite/README.md
Normal file
|
@ -0,0 +1,218 @@
|
||||||
|
# Setup for VAB-WebArena-Lite
|
||||||
|
|
||||||
|
## Brief Introduction
|
||||||
|
|
||||||
|
VAB-WebArena-Lite is a 165-task refined subset from <a href="https://webarena.dev/" target="_blank">WebArena</a>.
|
||||||
|
The purpose of building this subset is to manually ensure task correctness & feasibility, and speed up testing (original 812-task WebArena usually takes more than 6h to run through, while VAB-WebArena-Lite takes around 40m in practice).
|
||||||
|
The modified version of the test cases can be found in `config_files/wa/test_webarena_lite.raw.json`.
|
||||||
|
|
||||||
|
|
||||||
|
## Install
|
||||||
|
|
||||||
|
First, you should clone the official repository of <a href="https://github.com/web-arena-x/visualwebarena">VisualWebArena</a> to this directory
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Assume you have cloned VAB and is now in the `VAB-WebArena-Lite` directory
|
||||||
|
git clone https://github.com/web-arena-x/visualwebarena.git visualwebarena
|
||||||
|
cd visualwebarena
|
||||||
|
git reset --hard ad57aae4dad71531504726900b80db02e0526158
|
||||||
|
cd ..
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, you should substitute the file with the commands below:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash replace.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
After that, you should install the dependencies for VAB-WebArena-Lite (recommend using a independent conda environment to VAB):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Python 3.10 (or 3.11, but not 3.12 cause 3.12 deprecated distutils needed here)
|
||||||
|
python -m wal wal
|
||||||
|
source venv/bin/activate
|
||||||
|
pip install -r requirements.txt
|
||||||
|
playwright install
|
||||||
|
pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also run the unit tests to ensure that WebArena-Lite is installed correctly:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pytest -x
|
||||||
|
```
|
||||||
|
|
||||||
|
## Setup WebArena-Lite Environments
|
||||||
|
|
||||||
|
1. Setup the standalone environments.
|
||||||
|
Please check out [this page](https://github.com/web-arena-x/webarena/tree/main/environment_docker) for details.
|
||||||
|
|
||||||
|
2. Configurate the urls for each website.
|
||||||
|
First, export the `DATASET` to be `webarena`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export DATASET=webarena
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, set the URL for the websites
|
||||||
|
(🚨 Notice: check if default ports of websites below correspond to those you setup in the first step)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Actually, the CLASSIFIEDS environment is not included in the WebArena-Lite evaluation, we keep the environment variables here just for consistency.
|
||||||
|
export CLASSIFIEDS="<your_classifieds_domain>:9980"
|
||||||
|
export CLASSIFIEDS_RESET_TOKEN="4b61655535e7ed388f0d40a93600254c"
|
||||||
|
|
||||||
|
# Below are the variables you should set for the evaluation.
|
||||||
|
export SHOPPING="<your_shopping_site_domain>:7770"
|
||||||
|
export REDDIT="<your_reddit_domain>:9999"
|
||||||
|
export SHOPPING_ADMIN="<your_e_commerce_cms_domain>:7780/admin"
|
||||||
|
export GITLAB="<your_gitlab_domain>:8023"
|
||||||
|
export MAP="<your_map_domain>:3000"
|
||||||
|
export WIKIPEDIA="<your_wikipedia_domain>:8888"
|
||||||
|
export HOMEPAGE="<your_homepage_domain>:4399"
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Generate config files for each test example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/generate_test_data.py
|
||||||
|
```
|
||||||
|
|
||||||
|
You will see `*.json` files generated in the [config_files](./config_files) folder. Each file contains the configuration for one test example.
|
||||||
|
|
||||||
|
4. Obtain and save the auto-login cookies for all websites:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
bash prepare.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
5. Set up API keys.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export OPENAI_API_KEY=your_key
|
||||||
|
|
||||||
|
# Optional: if you use a different OpenAI model source
|
||||||
|
export OPENAI_API_URL=your_url
|
||||||
|
|
||||||
|
# Optional: you can set the following variables to evaluate the preset model in llms/providers/api_utils.py
|
||||||
|
export GEMENI_API_KEY=your_key
|
||||||
|
export QWEN_API_KEY=your_key
|
||||||
|
export CLAUDE_API_KEY=your_key
|
||||||
|
|
||||||
|
# Optional: if you have trained your model, we recommend deploying it as an API service, where you can set a FINETUNED_URL to evaluate it.
|
||||||
|
export FINETUNED_URL=your_url
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
If using Gemini, first install the [gcloud CLI](https://cloud.google.com/sdk/docs/install). Configure the API key by authenticating with Google Cloud:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
gcloud auth login
|
||||||
|
gcloud config set project <your_project_name>
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🖼️ Evaluating in VAB Standard Setting with SoM (Set-of-Marks) Visual Agents
|
||||||
|
|
||||||
|
### 👎 Run Single Agent For Evalution (Slow, but please read to understand meaning of arguments)
|
||||||
|
|
||||||
|
To run your own model with SoM visual agent, you can run evaluation with the following flags:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python run.py \
|
||||||
|
--instruction_path agent/prompts/jsons/p_som_cot_id_actree_3s.json \
|
||||||
|
--test_start_idx 0 \
|
||||||
|
--test_end_idx 1 \
|
||||||
|
--result_dir <your_result_dir> \
|
||||||
|
--test_config_base_dir config_files/wa/test_webarena_lite \
|
||||||
|
--provider api \
|
||||||
|
--model openai_gpt-4-vision-preview \
|
||||||
|
--action_set_tag som --observation_type image_som
|
||||||
|
```
|
||||||
|
|
||||||
|
Besides the original model providers (OpenAI, Google), you can also add your models in `llms/providers/api_utils.py`. Remember to set `--provider` to:
|
||||||
|
|
||||||
|
- `api`: Keep the same input style as WebArena, suitable for regular API calls
|
||||||
|
- `finetune`: This is required for models trained with the data we provide.
|
||||||
|
|
||||||
|
For the `--model` variable, we use the format `<source>_<model-name>` .
|
||||||
|
|
||||||
|
- If there is no more optional models under source, you can set it to just `source`.
|
||||||
|
- Remember that the source name here should be added in the init function of `APIModel` in `llms/providers/api_utils.py`.
|
||||||
|
- For example, if you want to use the openai model "gpt-4o", you can set the flag like this: `--model openai_gpt-4o`.
|
||||||
|
|
||||||
|
Finally, run `score.py` to get the pass rate
|
||||||
|
```bash
|
||||||
|
python score.py <your_result_dir>
|
||||||
|
```
|
||||||
|
|
||||||
|
### 👍 Run Parallel Agent For Evaluation (Recommended)
|
||||||
|
|
||||||
|
To run the tests in parallel, you can first configure `wa_parallel_run.sh`, then run it. We default split the test set to 5 parallel-group for evaluation in VAB.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Remember to first launch a tmux session
|
||||||
|
tmux
|
||||||
|
bash wa_parallel_run.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
The script is enabled with auto-resuming if it is interrupted or met unexpected error. Please feel free to rerun the above command until all tasks finish.
|
||||||
|
|
||||||
|
After all parallel groupes finish, run `score.py` to get the pass rate
|
||||||
|
```bash
|
||||||
|
python score.py <your_result_dir>
|
||||||
|
```
|
||||||
|
|
||||||
|
### 🚨 Important: Refresh all websites before re-run another round of testing!
|
||||||
|
Since tasks in WebArena may involve changing status and database of websites (e.g., posting comments on Reddit), if websites are not all refreshed before another round of evaluation, the results would be problematic.
|
||||||
|
|
||||||
|
Please remember to run following command (assume you are hosting WebArena websites on your own) to restart and refresh all website dockers to avoid potential contamination.
|
||||||
|
The process usually takes 3-5 minites.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Make sure the script is executed on the machine that you run those website dockers
|
||||||
|
bash refresh_website_docker.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
You may need to change some contents in the script (e.g. configured ports of websites, names of dockers, etc.).
|
||||||
|
|
||||||
|
## Run Visualized Demostration
|
||||||
|
Original WebArena have also prepared a demo for you to run the agents on your own task on an arbitrary webpage. An example is shown above where the agent is tasked to find the best Thai restaurant in Pittsburgh.
|
||||||
|
|
||||||
|
After following the setup instructions above and setting the OpenAI API key (the other environment variables for website URLs aren't really used, so you should be able to set them to some dummy variable), you can run the GPT-4V + SoM agent with the following command:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python run_demo.py \
|
||||||
|
--instruction_path agent/prompts/jsons/p_som_cot_id_actree_3s.json \
|
||||||
|
--start_url "https://www.amazon.com" \
|
||||||
|
--image "https://media.npr.org/assets/img/2023/01/14/this-is-fine_wide-0077dc0607062e15b476fb7f3bd99c5f340af356-s1400-c100.jpg" \
|
||||||
|
--intent "Help me navigate to a shirt that has this on it." \
|
||||||
|
--result_dir demo_test_amazon \
|
||||||
|
--model gpt-4-vision-preview \
|
||||||
|
--action_set_tag som --observation_type image_som \
|
||||||
|
--render
|
||||||
|
```
|
||||||
|
|
||||||
|
This tasks the agent to find a shirt that looks like the provided image (the "This is fine" dog) from Amazon. Have fun!
|
||||||
|
|
||||||
|
## Acknowledgements
|
||||||
|
|
||||||
|
Our code is heavily based off the <a href="https://github.com/web-arena-x/webarena">WebArena codebase</a> and <a href="https://github.com/web-arena-x/visualwebarena">VisualWebArena codebase</a>.
|
||||||
|
|
||||||
|
If you find this environment useful, please consider citing <a href="https://jykoh.com/vwa" target="_blank">VisualWebArena</a> as well as <a href="https://webarena.dev/" target="_blank">WebArena</a>:
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
@article{koh2024visualwebarena,
|
||||||
|
title={VisualWebArena: Evaluating Multimodal Agents on Realistic Visual Web Tasks},
|
||||||
|
author={Koh, Jing Yu and Lo, Robert and Jang, Lawrence and Duvvur, Vikram and Lim, Ming Chong and Huang, Po-Yu and Neubig, Graham and Zhou, Shuyan and Salakhutdinov, Ruslan and Fried, Daniel},
|
||||||
|
journal={arXiv preprint arXiv:2401.13649},
|
||||||
|
year={2024}
|
||||||
|
}
|
||||||
|
|
||||||
|
@article{zhou2024webarena,
|
||||||
|
title={WebArena: A Realistic Web Environment for Building Autonomous Agents},
|
||||||
|
author={Zhou, Shuyan and Xu, Frank F and Zhu, Hao and Zhou, Xuhui and Lo, Robert and Sridhar, Abishek and Cheng, Xianyi and Bisk, Yonatan and Fried, Daniel and Alon, Uri and others},
|
||||||
|
journal={ICLR},
|
||||||
|
year={2024}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
227
VAB-WebArena-Lite/new/agent.py
Normal file
227
VAB-WebArena-Lite/new/agent.py
Normal file
|
@ -0,0 +1,227 @@
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
|
from beartype import beartype
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from agent.prompts import *
|
||||||
|
from browser_env import Trajectory
|
||||||
|
from browser_env.actions import (
|
||||||
|
Action,
|
||||||
|
ActionParsingError,
|
||||||
|
create_id_based_action,
|
||||||
|
create_none_action,
|
||||||
|
create_playwright_action,
|
||||||
|
)
|
||||||
|
from browser_env.utils import Observation, StateInfo
|
||||||
|
from llms import (
|
||||||
|
call_llm,
|
||||||
|
generate_from_huggingface_completion,
|
||||||
|
generate_from_openai_chat_completion,
|
||||||
|
generate_from_openai_completion,
|
||||||
|
lm_config,
|
||||||
|
)
|
||||||
|
from llms.tokenizers import Tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class Agent:
|
||||||
|
"""Base class for the agent"""
|
||||||
|
|
||||||
|
def __init__(self, *args: Any) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def next_action(
|
||||||
|
self, trajectory: Trajectory, intent: str, meta_data: Any
|
||||||
|
) -> Action:
|
||||||
|
"""Predict the next action given the observation"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self,
|
||||||
|
test_config_file: str,
|
||||||
|
) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class TeacherForcingAgent(Agent):
|
||||||
|
"""Agent that follows a pre-defined action sequence"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def set_action_set_tag(self, tag: str) -> None:
|
||||||
|
self.action_set_tag = tag
|
||||||
|
|
||||||
|
def set_actions(self, action_seq: str | list[str]) -> None:
|
||||||
|
if isinstance(action_seq, str):
|
||||||
|
action_strs = action_seq.strip().split("\n")
|
||||||
|
else:
|
||||||
|
action_strs = action_seq
|
||||||
|
action_strs = [a.strip() for a in action_strs]
|
||||||
|
|
||||||
|
actions = []
|
||||||
|
for a_str in action_strs:
|
||||||
|
try:
|
||||||
|
if self.action_set_tag == "playwright":
|
||||||
|
cur_action = create_playwright_action(a_str)
|
||||||
|
elif self.action_set_tag == "id_accessibility_tree":
|
||||||
|
cur_action = create_id_based_action(a_str)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown action type {self.action_set_tag}"
|
||||||
|
)
|
||||||
|
except ActionParsingError as e:
|
||||||
|
cur_action = create_none_action()
|
||||||
|
|
||||||
|
cur_action["raw_prediction"] = a_str
|
||||||
|
actions.append(cur_action)
|
||||||
|
|
||||||
|
self.actions: list[Action] = actions
|
||||||
|
|
||||||
|
def next_action(
|
||||||
|
self, trajectory: Trajectory, intent: str, meta_data: Any
|
||||||
|
) -> Action:
|
||||||
|
"""Predict the next action given the observation"""
|
||||||
|
return self.actions.pop(0)
|
||||||
|
|
||||||
|
def reset(
|
||||||
|
self,
|
||||||
|
test_config_file: str,
|
||||||
|
) -> None:
|
||||||
|
with open(test_config_file) as f:
|
||||||
|
ref_actions = json.load(f)["reference_action_sequence"]
|
||||||
|
tag = ref_actions["action_set_tag"]
|
||||||
|
action_seq = ref_actions["action_sequence"]
|
||||||
|
self.set_action_set_tag(tag)
|
||||||
|
self.set_actions(action_seq)
|
||||||
|
|
||||||
|
|
||||||
|
class PromptAgent(Agent):
|
||||||
|
"""prompt-based agent that emits action given the history"""
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
action_set_tag: str,
|
||||||
|
lm_config: lm_config.LMConfig,
|
||||||
|
prompt_constructor: PromptConstructor,
|
||||||
|
captioning_fn = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.lm_config = lm_config
|
||||||
|
self.prompt_constructor = prompt_constructor
|
||||||
|
self.action_set_tag = action_set_tag
|
||||||
|
self.captioning_fn = captioning_fn
|
||||||
|
|
||||||
|
# Check if the model is multimodal.
|
||||||
|
if ("gemini" in lm_config.model or "gpt-4" in lm_config.model and "vision" in lm_config.model or lm_config.provider in ["api", "finetune"]) and type(prompt_constructor) == MultimodalCoTPromptConstructor:
|
||||||
|
self.multimodal_inputs = True
|
||||||
|
else:
|
||||||
|
self.multimodal_inputs = False
|
||||||
|
|
||||||
|
def set_action_set_tag(self, tag: str) -> None:
|
||||||
|
self.action_set_tag = tag
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def next_action(
|
||||||
|
self, trajectory: Trajectory, intent: str, meta_data: dict[str, Any], images: Optional[list[Image.Image]] = None,
|
||||||
|
output_response: bool = False
|
||||||
|
) -> Action:
|
||||||
|
# Create page screenshot image for multimodal models.
|
||||||
|
if self.multimodal_inputs:
|
||||||
|
page_screenshot_arr = trajectory[-1]["observation"]["image"]
|
||||||
|
page_screenshot_img = Image.fromarray(
|
||||||
|
page_screenshot_arr
|
||||||
|
) # size = (viewport_width, viewport_width)
|
||||||
|
|
||||||
|
# Caption the input image, if provided.
|
||||||
|
if images is not None and len(images) > 0:
|
||||||
|
if self.captioning_fn is not None:
|
||||||
|
image_input_caption = ""
|
||||||
|
for image_i, image in enumerate(images):
|
||||||
|
if image_i == 0:
|
||||||
|
image_input_caption += f'Input image {image_i+1}: "{self.captioning_fn([image])[0]}"'
|
||||||
|
else:
|
||||||
|
image_input_caption += f'input image {image_i+1}: "{self.captioning_fn([image])[0]}"'
|
||||||
|
if len(images) > 1:
|
||||||
|
image_input_caption += ", "
|
||||||
|
# Update intent to include captions of input images.
|
||||||
|
intent = f"{image_input_caption}\nIntent: {intent}"
|
||||||
|
elif not self.multimodal_inputs:
|
||||||
|
print(
|
||||||
|
"WARNING: Input image provided but no image captioner available."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.multimodal_inputs:
|
||||||
|
prompt = self.prompt_constructor.construct(
|
||||||
|
trajectory, intent, page_screenshot_img, images, meta_data
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prompt = self.prompt_constructor.construct(
|
||||||
|
trajectory, intent, meta_data
|
||||||
|
)
|
||||||
|
lm_config = self.lm_config
|
||||||
|
n = 0
|
||||||
|
while True:
|
||||||
|
response = call_llm(lm_config, prompt)
|
||||||
|
force_prefix = self.prompt_constructor.instruction[
|
||||||
|
"meta_data"
|
||||||
|
].get("force_prefix", "")
|
||||||
|
response = f"{force_prefix}{response}"
|
||||||
|
if output_response:
|
||||||
|
print(f'Agent: {response}', flush=True)
|
||||||
|
n += 1
|
||||||
|
try:
|
||||||
|
parsed_response = self.prompt_constructor.extract_action(
|
||||||
|
response
|
||||||
|
)
|
||||||
|
if self.action_set_tag == "id_accessibility_tree":
|
||||||
|
action = create_id_based_action(parsed_response)
|
||||||
|
elif self.action_set_tag == "playwright":
|
||||||
|
action = create_playwright_action(parsed_response)
|
||||||
|
elif self.action_set_tag == "som":
|
||||||
|
action = create_id_based_action(parsed_response)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown action type {self.action_set_tag}"
|
||||||
|
)
|
||||||
|
action["raw_prediction"] = response
|
||||||
|
break
|
||||||
|
except ActionParsingError as e:
|
||||||
|
if n >= lm_config.gen_config["max_retry"]:
|
||||||
|
action = create_none_action()
|
||||||
|
action["raw_prediction"] = response
|
||||||
|
break
|
||||||
|
|
||||||
|
return action
|
||||||
|
|
||||||
|
def reset(self, test_config_file: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def construct_agent(args: argparse.Namespace, captioning_fn=None) -> Agent:
|
||||||
|
llm_config = lm_config.construct_llm_config(args)
|
||||||
|
|
||||||
|
agent: Agent
|
||||||
|
if args.agent_type == "teacher_forcing":
|
||||||
|
agent = TeacherForcingAgent()
|
||||||
|
elif args.agent_type == "prompt":
|
||||||
|
with open(args.instruction_path) as f:
|
||||||
|
constructor_type = json.load(f)["meta_data"]["prompt_constructor"]
|
||||||
|
tokenizer = Tokenizer(args.provider, args.model)
|
||||||
|
prompt_constructor = eval(constructor_type)(
|
||||||
|
args.instruction_path, lm_config=llm_config, tokenizer=tokenizer
|
||||||
|
)
|
||||||
|
agent = PromptAgent(
|
||||||
|
action_set_tag=args.action_set_tag,
|
||||||
|
lm_config=llm_config,
|
||||||
|
prompt_constructor=prompt_constructor,
|
||||||
|
captioning_fn=captioning_fn
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"agent type {args.agent_type} not implemented"
|
||||||
|
)
|
||||||
|
return agent
|
682
VAB-WebArena-Lite/new/api_utils.py
Normal file
682
VAB-WebArena-Lite/new/api_utils.py
Normal file
|
@ -0,0 +1,682 @@
|
||||||
|
import os
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import base64
|
||||||
|
import shutil
|
||||||
|
import requests
|
||||||
|
|
||||||
|
import dashscope
|
||||||
|
import http.client
|
||||||
|
import anthropic
|
||||||
|
|
||||||
|
import google.auth
|
||||||
|
from google.oauth2 import service_account
|
||||||
|
from google.auth.transport.requests import Request
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
from typing import List, Tuple, Dict
|
||||||
|
from http import HTTPStatus
|
||||||
|
from PIL import Image
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
PROXIES = { # gemini
|
||||||
|
"http": "http://127.0.0.1:7890",
|
||||||
|
"https": "http://127.0.0.1:7890"
|
||||||
|
}
|
||||||
|
|
||||||
|
SEED = int(os.environ.get("SEED", 42))
|
||||||
|
|
||||||
|
GCLOUD_KEY_FILE_PATH = "" # path to the google cloud project json file
|
||||||
|
GCLOUD_REGIONAL_CODE = "asia-east1"
|
||||||
|
OPENAI_API_URL = os.environ.get("OPENAI_API_URL")
|
||||||
|
FINETUNED_URL = os.environ.get("FINETUNED_URL") # finetuned model url
|
||||||
|
|
||||||
|
OPENAI_API_KEY = os.environ["OPENAI_API_KEY"] # you should alway setup openai api key for evaluation
|
||||||
|
GEMINI_API_KEY = os.environ.get("GEMENI_API_KEY", "") # no need when using google cloud
|
||||||
|
QWEN_API_KEY = os.environ.get("QWEN_API_KEY" , "")
|
||||||
|
CLAUDE_API_KEY = os.environ.get("CLAUDE_API_KEY", "")
|
||||||
|
|
||||||
|
class BasicModel(object):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
# make temp dir here
|
||||||
|
file_path = os.path.dirname(__file__)
|
||||||
|
self.base_dir = os.path.join(file_path, "temp", f"{int(time.time())}")
|
||||||
|
os.makedirs(self.base_dir, exist_ok=True)
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
# remove temp dir
|
||||||
|
shutil.rmtree(self.base_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
def prompt_construct(self, messages: List[Dict]) -> List[Dict]:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def process_system_prompt(messages: List[Dict]) -> List[Dict]:
|
||||||
|
if messages[0]["role"] != "system":
|
||||||
|
return messages
|
||||||
|
|
||||||
|
new_messages = copy.deepcopy(messages[1:])
|
||||||
|
system_prompt = messages[0]["content"]
|
||||||
|
|
||||||
|
# Search for first user message and add system prompt to it
|
||||||
|
for item in new_messages:
|
||||||
|
if item.get("role") != "user":
|
||||||
|
continue
|
||||||
|
|
||||||
|
for ct in item["content"]:
|
||||||
|
# Case 1: directly appended to the text
|
||||||
|
if ct["type"] == "text":
|
||||||
|
ct["text"] = system_prompt + "\n" + ct["text"]
|
||||||
|
return new_messages
|
||||||
|
|
||||||
|
# Case 2: create a new text item
|
||||||
|
item["content"].insert(0, {
|
||||||
|
"type": "text",
|
||||||
|
"text": system_prompt
|
||||||
|
})
|
||||||
|
return new_messages
|
||||||
|
|
||||||
|
# Case 3: no user message found, add a new user message
|
||||||
|
new_messages.insert(0, {
|
||||||
|
"role": "user",
|
||||||
|
"content": [{
|
||||||
|
"type": "text",
|
||||||
|
"text": system_prompt
|
||||||
|
}]
|
||||||
|
})
|
||||||
|
return new_messages
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def pil_to_b64(img: Image.Image) -> str:
|
||||||
|
with BytesIO() as image_buffer:
|
||||||
|
img.save(image_buffer, format="PNG")
|
||||||
|
byte_data = image_buffer.getvalue()
|
||||||
|
img_b64 = base64.b64encode(byte_data).decode("utf-8")
|
||||||
|
img_b64 = "data:image/png;base64," + img_b64
|
||||||
|
return img_b64
|
||||||
|
|
||||||
|
# save base64 image and return filename
|
||||||
|
def b64_to_image(self, base64_data: str) -> str:
|
||||||
|
base64_data = base64_data.split(",")[1]
|
||||||
|
image_data = base64.b64decode(base64_data)
|
||||||
|
|
||||||
|
filename = os.path.join(self.base_dir, f"{int(time.time())}.png")
|
||||||
|
with open(filename, "wb") as f:
|
||||||
|
f.write(image_data)
|
||||||
|
|
||||||
|
return filename
|
||||||
|
|
||||||
|
def get_model_response(self, messages: List[Dict], model_name: str, **args) -> Tuple[bool, str]:
|
||||||
|
raise NotImplementedError("Subclasses must implement this method")
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIModel(BasicModel):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.client = OpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_API_URL)
|
||||||
|
|
||||||
|
def prompt_construct(self, messages: List[Dict]) -> List[Dict]:
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def get_model_response(self, messages: List[Dict], model_name: str, **args) -> Tuple[bool, str]:
|
||||||
|
if "OPENAI_API_KEY" not in os.environ:
|
||||||
|
raise ValueError(
|
||||||
|
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
|
||||||
|
)
|
||||||
|
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
temperature=args.get("temperature", 0.0),
|
||||||
|
max_tokens=args.get("max_tokens", 1024),
|
||||||
|
top_p=args.get("top_p", 1.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
answer: str = response.choices[0].message.content
|
||||||
|
return True, answer
|
||||||
|
except:
|
||||||
|
return False, str(response.error)
|
||||||
|
|
||||||
|
|
||||||
|
class FinetuneModel(BasicModel):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.url = FINETUNED_URL # inference api
|
||||||
|
|
||||||
|
def prompt_construct(self, messages: List[Dict]) -> List[Dict]:
|
||||||
|
dialog, images = "", []
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] == "system":
|
||||||
|
dialog += f"<|system|>\n{message['content']}\n\n"
|
||||||
|
continue
|
||||||
|
elif message["role"] == "assistant":
|
||||||
|
dialog += f"<|assistant|>\n{message['content']}\n\n"
|
||||||
|
continue
|
||||||
|
|
||||||
|
dialog += f"<|user|>\n"
|
||||||
|
for content in message["content"]:
|
||||||
|
if content["type"] == "text":
|
||||||
|
dialog += f"{content['text']}\n"
|
||||||
|
else:
|
||||||
|
# TODO: we use filename as image url here
|
||||||
|
images.append(self.b64_to_image(content["image_url"]["url"]))
|
||||||
|
dialog += "\n\n"
|
||||||
|
|
||||||
|
dialog += "<|assistant|>\n"
|
||||||
|
images = [open(image, "rb") for image in images]
|
||||||
|
new_messages = [
|
||||||
|
{"image": images[0]},
|
||||||
|
{"prompt": dialog}
|
||||||
|
]
|
||||||
|
|
||||||
|
return new_messages
|
||||||
|
|
||||||
|
def get_model_response(self, messages: List[Dict], model_name: str, **args) -> Tuple[bool, str]:
|
||||||
|
try:
|
||||||
|
response = requests.post(self.url, files=messages[0], data=messages[1], timeout=40)
|
||||||
|
response = response.json()
|
||||||
|
except Exception as e:
|
||||||
|
return False, str(e)
|
||||||
|
|
||||||
|
if "error" in response:
|
||||||
|
return False, response["error"]["message"]
|
||||||
|
|
||||||
|
# TODO: you should change the response format here
|
||||||
|
resp = f'```\n{response["response"].split("<|end_of_text|>")[0]}\n```'
|
||||||
|
return True, resp
|
||||||
|
|
||||||
|
|
||||||
|
class QwenModel(BasicModel):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
dashscope.api_key = QWEN_API_KEY
|
||||||
|
self.seed = SEED
|
||||||
|
|
||||||
|
def prompt_construct(self, messages: List[Dict]) -> List[Dict]:
|
||||||
|
messages = self.process_system_prompt(messages)
|
||||||
|
|
||||||
|
new_messages = []
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] != "user":
|
||||||
|
new_messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"text": message["content"]}]
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_content = []
|
||||||
|
for content in message["content"]:
|
||||||
|
if content["type"] == "text":
|
||||||
|
new_content.append({
|
||||||
|
"text": content["text"],
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
filename = self.b64_to_image(content["image_url"]["url"])
|
||||||
|
new_content.append({
|
||||||
|
"image": f"file://{filename}"
|
||||||
|
})
|
||||||
|
new_messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": new_content
|
||||||
|
})
|
||||||
|
|
||||||
|
return new_messages
|
||||||
|
|
||||||
|
def get_model_response(self, messages: List[Dict], model_name: str, **args) -> Tuple[bool, str]:
|
||||||
|
if "QWEN_API_KEY" not in os.environ:
|
||||||
|
raise ValueError(
|
||||||
|
"QWEN_API_KEY environment variable must be set when using Qwen API."
|
||||||
|
)
|
||||||
|
|
||||||
|
response = dashscope.MultiModalConversation.call(
|
||||||
|
model=model_name,
|
||||||
|
messages=messages,
|
||||||
|
top_k=args.get("top_k"),
|
||||||
|
seed=self.seed
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == HTTPStatus.OK:
|
||||||
|
return True, response.output.choices[0].message.content[0]['text']
|
||||||
|
else:
|
||||||
|
return False, response.message
|
||||||
|
|
||||||
|
|
||||||
|
class ClaudeModel(BasicModel):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.client = anthropic.Anthropic(api_key=CLAUDE_API_KEY)
|
||||||
|
|
||||||
|
def prompt_construct(self, messages: List[Dict]) -> List[Dict]:
|
||||||
|
new_messages = []
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] in ["system", "assistant"]:
|
||||||
|
new_messages.append(message)
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_content = []
|
||||||
|
for content in message["content"]:
|
||||||
|
if content["type"] == "text":
|
||||||
|
new_content.append(content)
|
||||||
|
continue
|
||||||
|
|
||||||
|
hdr, idata = content["image_url"]["url"].split(";base64,")
|
||||||
|
new_content.append({
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": hdr.split("data:")[1],
|
||||||
|
"data": idata
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
new_messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": new_content
|
||||||
|
})
|
||||||
|
|
||||||
|
return new_messages
|
||||||
|
|
||||||
|
def get_model_response(self, messages: List[Dict], model_name: str, **args) -> Tuple[bool, str]:
|
||||||
|
try:
|
||||||
|
if messages[0]["role"] == "system":
|
||||||
|
system_prompt = messages[0]["content"]
|
||||||
|
messages = messages[1:]
|
||||||
|
response = self.client.messages.create(
|
||||||
|
model="claude-3-5-sonnet-20240620",
|
||||||
|
max_tokens=args.get("max_tokens"),
|
||||||
|
temperature=args.get("temperature"),
|
||||||
|
system=system_prompt,
|
||||||
|
messages=messages
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
response = self.client.messages.create(
|
||||||
|
model="claude-3-5-sonnet-20240620",
|
||||||
|
max_tokens=args.get("max_tokens"),
|
||||||
|
temperature=args.get("temperature"),
|
||||||
|
messages=messages
|
||||||
|
)
|
||||||
|
|
||||||
|
usage = response.usage
|
||||||
|
prompt_tokens = usage.input_tokens
|
||||||
|
completion_tokens = usage.output_tokens
|
||||||
|
|
||||||
|
# print(response)
|
||||||
|
print(response.content)
|
||||||
|
print(f"Prompt Tokens: {prompt_tokens}\nCompletion Tokens: {completion_tokens}\n")
|
||||||
|
return True, response.content
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return False, str(e)
|
||||||
|
|
||||||
|
def get_model_response_thirdapi(self, messages) -> Tuple[bool, str]:
|
||||||
|
|
||||||
|
conn = http.client.HTTPSConnection("cn2us02.opapi.win", timeout=900)
|
||||||
|
|
||||||
|
system_prompt = None
|
||||||
|
if messages[0]["role"] == "system":
|
||||||
|
system_prompt = messages[0]["content"]
|
||||||
|
messages = messages[1:]
|
||||||
|
payload = {
|
||||||
|
"model": "claude-3-opus",
|
||||||
|
"stream": False,
|
||||||
|
"system": system_prompt,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"max_tokens": self.max_tokens
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
payload = {
|
||||||
|
"model": "claude-3-opus",
|
||||||
|
"stream": False,
|
||||||
|
"messages": messages,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"max_tokens": self.max_tokens
|
||||||
|
}
|
||||||
|
payload = json.dumps(payload)
|
||||||
|
headers = {
|
||||||
|
'Accept': 'application/json',
|
||||||
|
'Authorization': f'Bearer {CLAUDE_API_KEY}',
|
||||||
|
'User-Agent': 'Apifox/1.0.0 (https://apifox.com)',
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
conn.request("POST", "/v1/messages", payload, headers)
|
||||||
|
res = conn.getresponse()
|
||||||
|
data = res.read()
|
||||||
|
response = json.loads(data.decode("utf-8"))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return False, str(e)
|
||||||
|
|
||||||
|
if "statusCode" in response and response["statusCode"] != 200:
|
||||||
|
return False, response["message"]
|
||||||
|
|
||||||
|
usage = response["usage"]
|
||||||
|
prompt_tokens = usage["input_tokens"]
|
||||||
|
completion_tokens = usage["output_tokens"]
|
||||||
|
print(f"Prompt Tokens: {prompt_tokens}\nCompletion Tokens: {completion_tokens}\n")
|
||||||
|
|
||||||
|
return True, response["content"][0]["text"]
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiModel(BasicModel):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.api_key = GEMINI_API_KEY
|
||||||
|
|
||||||
|
def prompt_construct(self, messages: List[Dict]) -> List[Dict]:
|
||||||
|
parts = []
|
||||||
|
dialog = ""
|
||||||
|
sep = "\n\n###\n\n"
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] == "system":
|
||||||
|
dialog += f"SYSTEM:\n{message['content']}{sep}"
|
||||||
|
elif message["role"] == "assistant":
|
||||||
|
dialog += f"ASSISTANT:\n{message['content']}{sep}"
|
||||||
|
elif message["role"] == "user":
|
||||||
|
dialog += "USER:\n"
|
||||||
|
for content in message["content"]:
|
||||||
|
if content["type"] == "text":
|
||||||
|
dialog += content["text"] + "\n"
|
||||||
|
continue
|
||||||
|
|
||||||
|
assert content["type"] == "image_url"
|
||||||
|
|
||||||
|
# save text
|
||||||
|
parts.append({ "text": dialog })
|
||||||
|
dialog = ""
|
||||||
|
|
||||||
|
# new content type for image
|
||||||
|
hdr, idata = content["image_url"]["url"].split(";base64,")
|
||||||
|
parts.append({
|
||||||
|
"inline_data": {
|
||||||
|
"mime_type": hdr.split("data:")[1],
|
||||||
|
"data": idata
|
||||||
|
}
|
||||||
|
})
|
||||||
|
dialog += sep
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid role: {message['role']}")
|
||||||
|
|
||||||
|
parts.append({
|
||||||
|
"text": dialog + "ASSISTANT:\n"
|
||||||
|
})
|
||||||
|
|
||||||
|
new_messages = [{
|
||||||
|
"parts": parts,
|
||||||
|
"role": "user"
|
||||||
|
}]
|
||||||
|
|
||||||
|
return new_messages
|
||||||
|
|
||||||
|
def get_model_response(self, messages: List[Dict], model_name: str, **args) -> Tuple[bool, str]:
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
proxies = PROXIES
|
||||||
|
url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}-latest:generateContent?key={self.api_key}"
|
||||||
|
|
||||||
|
generation_config = {
|
||||||
|
"temperature": args.get('temperature'),
|
||||||
|
"maxOutputTokens": args.get('max_tokens'),
|
||||||
|
"stopSequences": ["\n\n###\n\n"]
|
||||||
|
}
|
||||||
|
|
||||||
|
safety_settings = [
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||||
|
"threshold": "BLOCK_ONLY_HIGH"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HARASSMENT",
|
||||||
|
"threshold": "BLOCK_ONLY_HIGH"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||||
|
"threshold": "BLOCK_ONLY_HIGH"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||||
|
"threshold": "BLOCK_ONLY_HIGH"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"contents": messages,
|
||||||
|
"generationConfig": generation_config,
|
||||||
|
"safetySettings": safety_settings
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(url, headers=headers, json=payload, proxies=proxies, timeout=30)
|
||||||
|
response = response.json()
|
||||||
|
except Exception as e:
|
||||||
|
return False, str(e)
|
||||||
|
|
||||||
|
if "error" in response:
|
||||||
|
return False, response["error"]["message"]
|
||||||
|
if "content" not in response['candidates'][0]:
|
||||||
|
self.generation_config['maxOutputTokens'] *= 2
|
||||||
|
return False, "No content generated."
|
||||||
|
return True, response['candidates'][0]['content']['parts'][0]['text']
|
||||||
|
|
||||||
|
|
||||||
|
class VertexGeminiModel(BasicModel):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def prompt_construct(self, messages: List[Dict]) -> List[Dict]:
|
||||||
|
parts = []
|
||||||
|
dialog = ""
|
||||||
|
sep = "\n\n###\n\n"
|
||||||
|
for message in messages:
|
||||||
|
if message["role"] == "system":
|
||||||
|
dialog += f"SYSTEM:\n{message['content']}{sep}"
|
||||||
|
elif message["role"] == "assistant":
|
||||||
|
dialog += f"ASSISTANT:\n{message['content']}{sep}"
|
||||||
|
elif message["role"] == "user":
|
||||||
|
dialog += "USER:\n"
|
||||||
|
for content in message["content"]:
|
||||||
|
if content["type"] == "text":
|
||||||
|
dialog += content["text"] + "\n"
|
||||||
|
continue
|
||||||
|
|
||||||
|
assert content["type"] == "image_url"
|
||||||
|
# save text
|
||||||
|
parts.append({ "text": dialog })
|
||||||
|
dialog = ""
|
||||||
|
# new content type for image
|
||||||
|
hdr, idata = content["image_url"]["url"].split(";base64,")
|
||||||
|
parts.append({
|
||||||
|
"inline_data": {
|
||||||
|
"mime_type": hdr.split("data:")[1],
|
||||||
|
"data": idata
|
||||||
|
}
|
||||||
|
})
|
||||||
|
dialog += sep
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid role: {message['role']}")
|
||||||
|
|
||||||
|
parts.append({
|
||||||
|
"text": dialog + "ASSISTANT:\n"
|
||||||
|
})
|
||||||
|
|
||||||
|
new_messages = [{
|
||||||
|
"parts": parts,
|
||||||
|
"role": "user"
|
||||||
|
}]
|
||||||
|
|
||||||
|
return new_messages
|
||||||
|
|
||||||
|
def get_model_response(self, messages: List[Dict], model_name: str, **args) -> Tuple[bool, str]:
|
||||||
|
def get_gcloud_token():
|
||||||
|
def get_token():
|
||||||
|
try:
|
||||||
|
# Load the credentials from the key file
|
||||||
|
creds = service_account.Credentials.from_service_account_file(
|
||||||
|
GCLOUD_KEY_FILE_PATH,
|
||||||
|
# You can list multiple scopes if needed
|
||||||
|
scopes=['https://www.googleapis.com/auth/cloud-platform']
|
||||||
|
)
|
||||||
|
# Refresh the token (this is needed even for the first time)
|
||||||
|
creds.refresh(Request())
|
||||||
|
return creds.token
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"An error occurred while trying to fetch the gcloud token: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
os.environ['HTTP_PROXY'] = PROXIES['http']
|
||||||
|
os.environ['HTTPS_PROXY'] = PROXIES['https']
|
||||||
|
|
||||||
|
fail_time = 0
|
||||||
|
while not api_key and fail_time < 10:
|
||||||
|
time.sleep(5)
|
||||||
|
api_key = get_token()
|
||||||
|
fail_time += 1
|
||||||
|
|
||||||
|
return api_key
|
||||||
|
|
||||||
|
def get_url(model_name: str) -> str:
|
||||||
|
region_code = GCLOUD_REGIONAL_CODE
|
||||||
|
model_id = f"{model_name}:generateContent"
|
||||||
|
with open(GCLOUD_KEY_FILE_PATH, "r") as f:
|
||||||
|
project_id = json.load(f)["project_id"]
|
||||||
|
return f"https://{region_code}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{region_code}/publishers/google/models/{model_id}"
|
||||||
|
|
||||||
|
url = get_url(model_name)
|
||||||
|
api_key = get_gcloud_token()
|
||||||
|
|
||||||
|
if not api_key:
|
||||||
|
return False, "Failed to fetch gcloud token."
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {api_key}"
|
||||||
|
}
|
||||||
|
|
||||||
|
proxies = PROXIES
|
||||||
|
safety_settings = [
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||||
|
"threshold": "BLOCK_ONLY_HIGH"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HARASSMENT",
|
||||||
|
"threshold": "BLOCK_ONLY_HIGH"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||||
|
"threshold": "BLOCK_ONLY_HIGH"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||||
|
"threshold": "BLOCK_ONLY_HIGH"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
generation_config = {
|
||||||
|
"temperature": args.get('temperature'),
|
||||||
|
"maxOutputTokens": args.get('max_tokens'),
|
||||||
|
"stopSequences": ["\n\n###\n\n"]
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"contents": messages,
|
||||||
|
"generationConfig": generation_config,
|
||||||
|
"safetySettings": safety_settings
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(url, headers=headers, json=payload, proxies=proxies, timeout=120)
|
||||||
|
response = response.json()
|
||||||
|
except Exception as e:
|
||||||
|
return False, str(e)
|
||||||
|
|
||||||
|
if "error" in response:
|
||||||
|
return False, response["error"]["message"]
|
||||||
|
if "content" not in response['candidates'][0]:
|
||||||
|
self.generation_config['maxOutputTokens'] *= 2
|
||||||
|
return False, "No content generated."
|
||||||
|
|
||||||
|
return True, response['candidates'][0]['content']['parts'][0]['text']
|
||||||
|
|
||||||
|
|
||||||
|
class APIModel(object):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.models = {
|
||||||
|
"openai": OpenAIModel(),
|
||||||
|
"gemini": VertexGeminiModel(),
|
||||||
|
"qwen": QwenModel(),
|
||||||
|
"finetuned": FinetuneModel(),
|
||||||
|
"claude": ClaudeModel()
|
||||||
|
}
|
||||||
|
|
||||||
|
def inference(self, model_id: str, messages: List[Dict], args: Dict) -> Tuple[bool, str]:
|
||||||
|
model_provider, model_name = model_id.split("_")[:2] if "_" in model_id else (model_id, model_id) # eg. "openai_gpt4o"
|
||||||
|
if model_provider not in self.models:
|
||||||
|
return False, f"Unsupported model: {model_provider} ({model_name})"
|
||||||
|
|
||||||
|
model = self.models[model_provider]
|
||||||
|
prompt = model.prompt_construct(messages)
|
||||||
|
resp = model.get_model_response(prompt, model_name, **args)
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
model = APIModel()
|
||||||
|
|
||||||
|
def generate_with_api(prompt: List[dict], model_id: str, args: Dict) -> str:
|
||||||
|
success, response = model.inference(model_id, prompt, args)
|
||||||
|
return response
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
path_to_image = "../../coco_images/000000000285.jpg"
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
image = Image.open(path_to_image)
|
||||||
|
img_str = BasicModel.pil_to_b64(image)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a helpful assistant. Please response concisely."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{
|
||||||
|
"type": "text",
|
||||||
|
"text": "what's annotated in this image? Image: Omitted."
|
||||||
|
}]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Only 5.cart is annotated in this image."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [{
|
||||||
|
"type": "text",
|
||||||
|
"text": "What can you see?"
|
||||||
|
},{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": img_str,
|
||||||
|
"detail": "high"
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
response = generate_with_api(messages, "openai", {
|
||||||
|
"temperature": 0.5,
|
||||||
|
"max_tokens": 1024,
|
||||||
|
"top_p": 0.9,
|
||||||
|
"n": 1,
|
||||||
|
})
|
649
VAB-WebArena-Lite/new/evaluators.py
Normal file
649
VAB-WebArena-Lite/new/evaluators.py
Normal file
|
@ -0,0 +1,649 @@
|
||||||
|
"""base class for evaluation"""
|
||||||
|
# answer string match
|
||||||
|
import importlib
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
import urllib
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Optional, Tuple, Union
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
|
import evaluate # type: ignore[import]
|
||||||
|
import requests
|
||||||
|
from beartype import beartype
|
||||||
|
from beartype.door import is_bearable
|
||||||
|
from nltk.tokenize import word_tokenize # type: ignore
|
||||||
|
from PIL import Image
|
||||||
|
from playwright.sync_api import CDPSession, Page
|
||||||
|
|
||||||
|
from browser_env.actions import Action
|
||||||
|
from browser_env.utils import StateInfo
|
||||||
|
from evaluation_harness import image_utils
|
||||||
|
from evaluation_harness.helper_functions import (
|
||||||
|
PseudoPage,
|
||||||
|
get_query_text,
|
||||||
|
get_query_text_lowercase,
|
||||||
|
gitlab_get_project_memeber_role,
|
||||||
|
llm_fuzzy_match,
|
||||||
|
llm_ua_match,
|
||||||
|
reddit_get_latest_comment_content_by_username,
|
||||||
|
reddit_get_latest_comment_obj_by_username,
|
||||||
|
reddit_get_parent_comment_username_of_latest_comment_by_username,
|
||||||
|
reddit_get_post_url,
|
||||||
|
shopping_get_latest_order_url,
|
||||||
|
shopping_get_num_reviews,
|
||||||
|
shopping_get_order_product_name_list,
|
||||||
|
shopping_get_order_product_option,
|
||||||
|
shopping_get_order_product_quantity,
|
||||||
|
shopping_get_product_attributes,
|
||||||
|
shopping_get_product_price,
|
||||||
|
shopping_get_rating_as_percentage,
|
||||||
|
shopping_get_sku_latest_review_author,
|
||||||
|
shopping_get_sku_latest_review_rating,
|
||||||
|
shopping_get_sku_latest_review_text,
|
||||||
|
)
|
||||||
|
|
||||||
|
Trajectory = list[Union[Action, StateInfo]]
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
class Evaluator(object):
|
||||||
|
def __init__(self, eval_tag: str = "") -> None:
|
||||||
|
self.eval_tag = eval_tag
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
trajectory: Trajectory,
|
||||||
|
config_file: Path | str,
|
||||||
|
page: Page | PseudoPage
|
||||||
|
) -> float:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_last_action(trajectory: Trajectory) -> Action:
|
||||||
|
try:
|
||||||
|
is_bearable(trajectory[-1], Action)
|
||||||
|
last_action = trajectory[-1]
|
||||||
|
except Exception:
|
||||||
|
raise ValueError(
|
||||||
|
"The last element of trajectory should be an action, add a fake stop action if needed"
|
||||||
|
)
|
||||||
|
|
||||||
|
return last_action # type: ignore[return-value]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_last_state(trajectory: Trajectory) -> StateInfo:
|
||||||
|
try:
|
||||||
|
is_bearable(trajectory[-2], StateInfo)
|
||||||
|
last_state = trajectory[-2]
|
||||||
|
except Exception:
|
||||||
|
raise ValueError(
|
||||||
|
"The second last element of trajectory should be a state, add a fake stop action if needed"
|
||||||
|
)
|
||||||
|
|
||||||
|
return last_state # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
class NumericEvaluator(Evaluator):
|
||||||
|
"""Check if the numerical relationship is correct"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@beartype
|
||||||
|
def str_2_int(s: str) -> Optional[int]:
|
||||||
|
try:
|
||||||
|
s = s.strip()
|
||||||
|
if "," in s:
|
||||||
|
s = s.replace(",", "")
|
||||||
|
|
||||||
|
return int(s)
|
||||||
|
except ValueError:
|
||||||
|
# Return None if the string cannot be converted to int
|
||||||
|
print(f"[NumericEvaluator error]: Cannot convert {s} to int")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@beartype
|
||||||
|
def compare_inequality(
|
||||||
|
value: Union[int, float], inequality: str, tol: float = 1e-8
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Compare a value (int or float) against an inequality string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- value (int/float): The value to be compared.
|
||||||
|
- inequality (str): Inequality in the form of "< 700", ">= 300", etc.
|
||||||
|
- tol (float): Tolerance for floating point comparisons.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- bool: True if the value satisfies the inequality, False otherwise.
|
||||||
|
"""
|
||||||
|
# Extract the operator and the number from the inequality string
|
||||||
|
ops = {
|
||||||
|
"<=": lambda x, y: x <= y + tol,
|
||||||
|
">=": lambda x, y: x >= y - tol,
|
||||||
|
"==": lambda x, y: abs(x - y) <= tol,
|
||||||
|
"<": lambda x, y: x < y + tol,
|
||||||
|
">": lambda x, y: x > y - tol,
|
||||||
|
}
|
||||||
|
|
||||||
|
for op, func in ops.items():
|
||||||
|
if op in inequality:
|
||||||
|
_, num = inequality.split(op)
|
||||||
|
return func(value, float(num.strip()))
|
||||||
|
|
||||||
|
raise ValueError(f"Invalid inequality string: {inequality}")
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
class StringEvaluator(Evaluator):
|
||||||
|
"""Check whether the answer is correct with:
|
||||||
|
exact match: the answer is exactly the same as the reference answer
|
||||||
|
must include: each phrase in the reference answer must be included in the answer
|
||||||
|
fuzzy match: the answer is similar to the reference answer, using LLM judge
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@beartype
|
||||||
|
def clean_answer(answer: str) -> str:
|
||||||
|
if answer.startswith("'") and answer.endswith("'"):
|
||||||
|
answer = answer[1:-1]
|
||||||
|
elif answer.startswith('"') and answer.endswith('"'):
|
||||||
|
answer = answer[1:-1]
|
||||||
|
return answer.lower()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@beartype
|
||||||
|
def exact_match(ref: str, pred: Union[str, int]) -> float:
|
||||||
|
if isinstance(pred, int):
|
||||||
|
pred = str(pred)
|
||||||
|
return float(
|
||||||
|
StringEvaluator.clean_answer(pred)
|
||||||
|
== StringEvaluator.clean_answer(ref)
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@beartype
|
||||||
|
def must_include(ref: str, pred: str) -> float:
|
||||||
|
clean_ref = StringEvaluator.clean_answer(ref)
|
||||||
|
clean_pred = StringEvaluator.clean_answer(pred)
|
||||||
|
# tokenize the answer if the ref is a single word
|
||||||
|
# prevent false positive (e.g, 0)
|
||||||
|
if len(word_tokenize(clean_ref)) == 1:
|
||||||
|
tok_pred = word_tokenize(clean_pred)
|
||||||
|
return float(clean_ref in tok_pred)
|
||||||
|
else:
|
||||||
|
return float(clean_ref in clean_pred)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@beartype
|
||||||
|
def must_exclude(ref: str, pred: str) -> float:
|
||||||
|
"""Returns 1 if pred is not in ref, and 0 otherwise"""
|
||||||
|
clean_ref = StringEvaluator.clean_answer(ref)
|
||||||
|
clean_pred = StringEvaluator.clean_answer(pred)
|
||||||
|
# tokenize the answer if the ref is a single word
|
||||||
|
# prevent false positive (e.g, 0)
|
||||||
|
if len(word_tokenize(clean_ref)) == 1:
|
||||||
|
tok_pred = word_tokenize(clean_pred)
|
||||||
|
return float(clean_ref not in tok_pred)
|
||||||
|
else:
|
||||||
|
return float(clean_ref not in clean_pred)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@beartype
|
||||||
|
def fuzzy_match(ref: str, pred: str, intent: str) -> float:
|
||||||
|
return llm_fuzzy_match(pred, ref, intent)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@beartype
|
||||||
|
def ua_match(ref: str, pred: str, intent: str) -> float:
|
||||||
|
return llm_ua_match(pred, ref, intent)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
trajectory: Trajectory,
|
||||||
|
config_file: Path | str,
|
||||||
|
page: Page | PseudoPage | None = None
|
||||||
|
) -> float:
|
||||||
|
with open(config_file, "r") as f:
|
||||||
|
configs = json.load(f)
|
||||||
|
|
||||||
|
last_action = self.get_last_action(trajectory)
|
||||||
|
pred = self.clean_answer(last_action["answer"])
|
||||||
|
|
||||||
|
score = 1.0
|
||||||
|
for approach, value in configs["eval"]["reference_answers"].items():
|
||||||
|
match approach:
|
||||||
|
case "exact_match":
|
||||||
|
score *= self.exact_match(ref=value, pred=pred)
|
||||||
|
case "required_values":
|
||||||
|
required_values = value
|
||||||
|
assert isinstance(required_values, list)
|
||||||
|
pred = NumericEvaluator.str_2_int(pred)
|
||||||
|
if pred is None:
|
||||||
|
score = 0.0
|
||||||
|
else:
|
||||||
|
for v in required_values:
|
||||||
|
value_or = v.split(" |OR| ")
|
||||||
|
score *= any(
|
||||||
|
[
|
||||||
|
NumericEvaluator.compare_inequality(
|
||||||
|
pred, value
|
||||||
|
)
|
||||||
|
for value in value_or
|
||||||
|
]
|
||||||
|
)
|
||||||
|
case "must_include":
|
||||||
|
assert isinstance(value, list)
|
||||||
|
for must_value in value:
|
||||||
|
value_or = must_value.split(" |OR| ")
|
||||||
|
score *= any([self.must_include(ref=v, pred=pred) for v in value_or])
|
||||||
|
case "must_exclude":
|
||||||
|
assert isinstance(value, list)
|
||||||
|
for must_excl_value in value:
|
||||||
|
score *= self.must_exclude(
|
||||||
|
ref=must_excl_value, pred=pred
|
||||||
|
)
|
||||||
|
case "one_of":
|
||||||
|
assert isinstance(value, list)
|
||||||
|
found = False
|
||||||
|
for one_of_value in value:
|
||||||
|
one_of_value = self.clean_answer(one_of_value)
|
||||||
|
if one_of_value in pred:
|
||||||
|
found = True
|
||||||
|
break
|
||||||
|
score = score * found
|
||||||
|
case "fuzzy_match":
|
||||||
|
intent = configs["intent"]
|
||||||
|
if value == "N/A":
|
||||||
|
# if the instruction only asks the model to generate N/A when encountering an unachievable task
|
||||||
|
# without more concrete reasons
|
||||||
|
score *= self.exact_match(ref=value, pred=pred)
|
||||||
|
# if the instruction also asks the model to generate the reason why the task is unachievable
|
||||||
|
# this should be the default as it will prevent false positive N/A`
|
||||||
|
if score != 1:
|
||||||
|
score = 1.0 * self.ua_match(
|
||||||
|
intent=configs["intent"],
|
||||||
|
ref=configs["eval"]["string_note"],
|
||||||
|
pred=pred,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert isinstance(value, list)
|
||||||
|
reference = ', '.join(value)
|
||||||
|
score *= self.fuzzy_match(
|
||||||
|
ref=reference, pred=pred, intent=intent
|
||||||
|
)
|
||||||
|
return score
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
class StringSoftEvaluator(Evaluator):
|
||||||
|
"""Use text generation metrics such as BLEU, ROUGE, etc. to evaluate the answer"""
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
trajectory: Trajectory,
|
||||||
|
config_file: Path | str,
|
||||||
|
page: Page | PseudoPage | None = None
|
||||||
|
) -> float:
|
||||||
|
with open(config_file, "r") as f:
|
||||||
|
configs = json.load(f)
|
||||||
|
|
||||||
|
last_action = self.get_last_action(trajectory)
|
||||||
|
pred = last_action["answer"]
|
||||||
|
ref = configs["eval"]["reference_answers"]
|
||||||
|
# rouge
|
||||||
|
m = evaluate.load("rouge")
|
||||||
|
rouge = m.compute(predictions=[pred], references=[ref])
|
||||||
|
return float(rouge["rouge1"])
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
class URLExactEvaluator(Evaluator):
|
||||||
|
"""Check whether the URL is exactly the same as of the reference URLs"""
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
trajectory: Trajectory,
|
||||||
|
config_file: Path | str,
|
||||||
|
page: Page | PseudoPage
|
||||||
|
) -> float:
|
||||||
|
with open(config_file, "r") as f:
|
||||||
|
configs = json.load(f)
|
||||||
|
|
||||||
|
def clean_url(url: str) -> str:
|
||||||
|
url = str(url)
|
||||||
|
# Replace http://localhost with http://127.0.0.1 to keep things consistent across evals.
|
||||||
|
url = url.replace("localhost", "127.0.0.1")
|
||||||
|
if url.endswith("/"):
|
||||||
|
url = url[:-1]
|
||||||
|
return url
|
||||||
|
|
||||||
|
pred = clean_url(page.url)
|
||||||
|
ref_urls = configs["eval"]["reference_url"].split(" |OR| ")
|
||||||
|
ref_urls = [clean_url(url) for url in ref_urls]
|
||||||
|
matching_rule = configs["eval"].get("url_note", "EXACT")
|
||||||
|
if matching_rule == "EXACT":
|
||||||
|
if pred in ref_urls:
|
||||||
|
return 1.0
|
||||||
|
else:
|
||||||
|
return 0.0
|
||||||
|
elif matching_rule == "GOLD in PRED":
|
||||||
|
if any([ref in pred for ref in ref_urls]):
|
||||||
|
return 1.0
|
||||||
|
else:
|
||||||
|
return 0.0
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown matching rule: {matching_rule}")
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
class HTMLContentExactEvaluator(Evaluator):
|
||||||
|
"""Check whether the contents appear in the page"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@beartype
|
||||||
|
def fuzzy_match(ref: str, pred: str, intent: str) -> float:
|
||||||
|
return llm_fuzzy_match(pred, ref, intent)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
trajectory: Trajectory,
|
||||||
|
config_file: Path | str,
|
||||||
|
page: Page | PseudoPage
|
||||||
|
) -> float:
|
||||||
|
with open(config_file, "r") as f:
|
||||||
|
configs = json.load(f)
|
||||||
|
|
||||||
|
targets = configs["eval"]["program_html"]
|
||||||
|
|
||||||
|
score = 1.0
|
||||||
|
for target in targets:
|
||||||
|
target_url: str = target["url"] # which url to check
|
||||||
|
if target_url.startswith("func"):
|
||||||
|
func = target_url.split("func:")[1]
|
||||||
|
func = func.replace("__last_url__", page.url)
|
||||||
|
target_url = eval(func)
|
||||||
|
|
||||||
|
locator: str = target["locator"] # js element locator
|
||||||
|
|
||||||
|
# navigate to that url
|
||||||
|
if target_url != "last":
|
||||||
|
page.goto(target_url)
|
||||||
|
time.sleep(3) # TODO [shuyanzh]: fix this hard-coded sleep
|
||||||
|
|
||||||
|
# empty, use the full page
|
||||||
|
if not locator.strip():
|
||||||
|
selected_element = page.content()
|
||||||
|
# use JS to select the element
|
||||||
|
elif locator.startswith("document.") or locator.startswith(
|
||||||
|
"[...document."
|
||||||
|
):
|
||||||
|
if "prep_actions" in target:
|
||||||
|
try:
|
||||||
|
for prep_action in target["prep_actions"]:
|
||||||
|
page.evaluate(f"() => {prep_action}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
selected_element = str(page.evaluate(f"() => {locator}"))
|
||||||
|
if not selected_element:
|
||||||
|
selected_element = ""
|
||||||
|
except Exception:
|
||||||
|
# the page is wrong, return empty
|
||||||
|
selected_element = ""
|
||||||
|
elif locator.startswith("lambda:"):
|
||||||
|
try:
|
||||||
|
locator = locator.lstrip("lambda:")
|
||||||
|
selected_element = page.evaluate(locator)
|
||||||
|
if not selected_element:
|
||||||
|
selected_element = None
|
||||||
|
except Exception:
|
||||||
|
# the page is wrong, return empty
|
||||||
|
selected_element = None
|
||||||
|
# run program to call API
|
||||||
|
elif locator.startswith("func:"): # a helper function
|
||||||
|
func = locator.split("func:")[1]
|
||||||
|
func = func.replace("__page__", "page")
|
||||||
|
selected_element = eval(func)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown locator: {locator}")
|
||||||
|
|
||||||
|
# If the selected element is None, then the page is wrong
|
||||||
|
if selected_element is None:
|
||||||
|
score = 0.0
|
||||||
|
break
|
||||||
|
|
||||||
|
if "exact_match" in target["required_contents"]:
|
||||||
|
required_contents = target["required_contents"]["exact_match"]
|
||||||
|
score *= StringEvaluator.exact_match(
|
||||||
|
ref=required_contents, pred=selected_element
|
||||||
|
)
|
||||||
|
elif "must_include" in target["required_contents"]:
|
||||||
|
required_contents = target["required_contents"]["must_include"]
|
||||||
|
assert isinstance(required_contents, list)
|
||||||
|
for content in required_contents:
|
||||||
|
content_or = content.split(" |OR| ")
|
||||||
|
score *= any(
|
||||||
|
[
|
||||||
|
StringEvaluator.must_include(
|
||||||
|
ref=content, pred=selected_element
|
||||||
|
)
|
||||||
|
for content in content_or
|
||||||
|
]
|
||||||
|
)
|
||||||
|
elif "must_exclude" in target["required_contents"]:
|
||||||
|
required_contents = target["required_contents"]["must_exclude"]
|
||||||
|
assert isinstance(required_contents, list)
|
||||||
|
for content in required_contents:
|
||||||
|
assert " |OR| " not in content
|
||||||
|
score *= StringEvaluator.must_exclude(
|
||||||
|
content, pred=selected_element
|
||||||
|
)
|
||||||
|
elif "required_values" in target["required_contents"]:
|
||||||
|
required_values = target["required_contents"][
|
||||||
|
"required_values"
|
||||||
|
]
|
||||||
|
assert isinstance(required_values, list)
|
||||||
|
if isinstance(selected_element, str):
|
||||||
|
selected_element = NumericEvaluator.str_2_int(
|
||||||
|
selected_element
|
||||||
|
)
|
||||||
|
if selected_element is None:
|
||||||
|
score = 0.0
|
||||||
|
else:
|
||||||
|
for value in required_values:
|
||||||
|
value_or = value.split(" |OR| ")
|
||||||
|
score *= any(
|
||||||
|
[
|
||||||
|
NumericEvaluator.compare_inequality(
|
||||||
|
selected_element, value
|
||||||
|
)
|
||||||
|
for value in value_or
|
||||||
|
]
|
||||||
|
)
|
||||||
|
elif "fuzzy_match" in target["required_contents"]:
|
||||||
|
required_contents = target["required_contents"]["fuzzy_match"]
|
||||||
|
intent = configs["intent"]
|
||||||
|
|
||||||
|
assert isinstance(required_contents, list)
|
||||||
|
reference = ', '.join(required_contents)
|
||||||
|
score *= self.fuzzy_match(
|
||||||
|
ref=reference, pred=selected_element, intent=intent
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown required_contents: {target['required_contents'].keys()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return score
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
class PageImageEvaluator(Evaluator):
|
||||||
|
"""Check whether the answer is correct by querying a vision model."""
|
||||||
|
|
||||||
|
def __init__(self, captioning_fn):
|
||||||
|
self.captioning_fn = captioning_fn
|
||||||
|
# Default to 0.8 as the threshold for similarity to account for compression, resizing, etc
|
||||||
|
# This might be too generous but we bias towards minimizing false negatives.
|
||||||
|
self.ssim_threshold = 0.8
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
trajectory: Trajectory,
|
||||||
|
config_file: Path | str,
|
||||||
|
page: Page | PseudoPage | None = None
|
||||||
|
) -> float:
|
||||||
|
with open(config_file, "r") as f:
|
||||||
|
configs = json.load(f)
|
||||||
|
|
||||||
|
for query in configs["eval"]["page_image_query"]:
|
||||||
|
locator: str = query["eval_image_class"]
|
||||||
|
target_url: str = query["eval_image_url"]
|
||||||
|
if target_url.startswith("func"):
|
||||||
|
func = target_url.split("func:")[1]
|
||||||
|
func = func.replace("__last_url__", page.url)
|
||||||
|
target_url = eval(func)
|
||||||
|
|
||||||
|
# navigate to that url
|
||||||
|
if target_url != "last":
|
||||||
|
page.goto(target_url)
|
||||||
|
time.sleep(3) # TODO(jykoh): fix this hard-coded sleep
|
||||||
|
|
||||||
|
# empty, use the full page
|
||||||
|
if not locator.strip():
|
||||||
|
images = page.get_by_role("img").all()
|
||||||
|
# use JS to select the element
|
||||||
|
elif locator.startswith("."):
|
||||||
|
# Get all img children under the locator
|
||||||
|
elements = page.query_selector_all(locator)
|
||||||
|
images = []
|
||||||
|
for element in elements:
|
||||||
|
is_img = element.evaluate(
|
||||||
|
'element => element.tagName === "IMG"'
|
||||||
|
)
|
||||||
|
if is_img:
|
||||||
|
images.append(element)
|
||||||
|
else:
|
||||||
|
images.extend(element.query_selector_all("img"))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown locator: {locator}")
|
||||||
|
|
||||||
|
if images == []:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
all_image_pixels = []
|
||||||
|
for image in images:
|
||||||
|
try:
|
||||||
|
# Get image from URL.
|
||||||
|
image_url = image.get_attribute("src")
|
||||||
|
if not image_url.startswith(
|
||||||
|
("http://", "https://", "www.")
|
||||||
|
):
|
||||||
|
image_url = urljoin(page.url, image_url)
|
||||||
|
image = Image.open(
|
||||||
|
requests.get(image_url, stream=True).raw
|
||||||
|
)
|
||||||
|
all_image_pixels.append(image)
|
||||||
|
except Exception as e:
|
||||||
|
print("[WARNING]: ", e)
|
||||||
|
|
||||||
|
score = 1.0
|
||||||
|
if all_image_pixels == []:
|
||||||
|
return 0.0
|
||||||
|
else:
|
||||||
|
# Run the VQA eval on the image elements.
|
||||||
|
eval_vqas = query.get("eval_vqa", [])
|
||||||
|
assert (
|
||||||
|
len(eval_vqas) > 0 or "eval_fuzzy_image_match" in query
|
||||||
|
), "eval_vqa must have at least 2 questions or eval_fuzzy_image_match must be True"
|
||||||
|
for qa in eval_vqas:
|
||||||
|
question, answer = qa["question"], qa["answer"]
|
||||||
|
prompt = f"Q: {question} A:"
|
||||||
|
pred_ans = self.captioning_fn(
|
||||||
|
all_image_pixels, [prompt] * len(all_image_pixels)
|
||||||
|
)
|
||||||
|
score *= float(
|
||||||
|
any(
|
||||||
|
[answer.lower() in ans.lower() for ans in pred_ans]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if "eval_fuzzy_image_match" in query:
|
||||||
|
ssim_threshold = query.get(
|
||||||
|
"ssim_threshold", self.ssim_threshold
|
||||||
|
)
|
||||||
|
exact_match_imgs = query["eval_fuzzy_image_match"].split(
|
||||||
|
" |OR| "
|
||||||
|
)
|
||||||
|
all_exact_match_pixels = []
|
||||||
|
|
||||||
|
for exact_match_img in exact_match_imgs:
|
||||||
|
if exact_match_img.startswith("http"):
|
||||||
|
exact_match_pixels = Image.open(
|
||||||
|
requests.get(exact_match_img, stream=True).raw
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
exact_match_pixels = Image.open(exact_match_img)
|
||||||
|
all_exact_match_pixels.append(exact_match_pixels)
|
||||||
|
|
||||||
|
# Check if any of the images on the page match
|
||||||
|
found_exact_match = False
|
||||||
|
for exact_match_pixels in all_exact_match_pixels:
|
||||||
|
for image_pixels in all_image_pixels:
|
||||||
|
ssim = image_utils.get_image_ssim(
|
||||||
|
image_pixels, exact_match_pixels
|
||||||
|
)
|
||||||
|
if ssim > ssim_threshold:
|
||||||
|
found_exact_match = True
|
||||||
|
break
|
||||||
|
score *= float(found_exact_match)
|
||||||
|
|
||||||
|
return score
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluatorComb:
|
||||||
|
def __init__(self, evaluators: list[Evaluator]) -> None:
|
||||||
|
self.evaluators = evaluators
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
trajectory: Trajectory,
|
||||||
|
config_file: Path | str,
|
||||||
|
page: Page | PseudoPage
|
||||||
|
) -> float:
|
||||||
|
|
||||||
|
score = 1.0
|
||||||
|
for evaluator in self.evaluators:
|
||||||
|
cur_score = evaluator(trajectory, config_file, page)
|
||||||
|
score *= cur_score
|
||||||
|
|
||||||
|
return score
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def evaluator_router(
|
||||||
|
config_file: Path | str, captioning_fn=None
|
||||||
|
) -> EvaluatorComb:
|
||||||
|
"""Router to get the evaluator class"""
|
||||||
|
with open(config_file, "r") as f:
|
||||||
|
configs = json.load(f)
|
||||||
|
|
||||||
|
eval_types = configs["eval"]["eval_types"]
|
||||||
|
evaluators: list[Evaluator | EvaluatorPartial] = []
|
||||||
|
for eval_type in eval_types:
|
||||||
|
match eval_type:
|
||||||
|
case "string_match":
|
||||||
|
evaluators.append(StringEvaluator())
|
||||||
|
case "url_match":
|
||||||
|
evaluators.append(URLExactEvaluator())
|
||||||
|
case "program_html":
|
||||||
|
evaluators.append(HTMLContentExactEvaluator())
|
||||||
|
case "page_image_query":
|
||||||
|
evaluators.append(PageImageEvaluator(captioning_fn))
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"eval_type {eval_type} is not supported")
|
||||||
|
|
||||||
|
return EvaluatorComb(evaluators)
|
63
VAB-WebArena-Lite/new/generate_test_data.py
Normal file
63
VAB-WebArena-Lite/new/generate_test_data.py
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
"""Replace the website placeholders with website domains from env_config
|
||||||
|
Generate the test data"""
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
from browser_env.env_config import *
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
DATASET = os.environ["DATASET"]
|
||||||
|
if DATASET == "webarena":
|
||||||
|
print("DATASET: webarena")
|
||||||
|
print(f"REDDIT: {REDDIT}")
|
||||||
|
print(f"SHOPPING: {SHOPPING}")
|
||||||
|
print(f"SHOPPING_ADMIN: {SHOPPING_ADMIN}")
|
||||||
|
print(f"GITLAB: {GITLAB}")
|
||||||
|
print(f"WIKIPEDIA: {WIKIPEDIA}")
|
||||||
|
print(f"MAP: {MAP}")
|
||||||
|
|
||||||
|
inp_paths = ["config_files/wa/test_webarena.raw.json", "config_files/wa/test_webarena_lite.raw.json"]
|
||||||
|
replace_map = {
|
||||||
|
"__REDDIT__": REDDIT,
|
||||||
|
"__SHOPPING__": SHOPPING,
|
||||||
|
"__SHOPPING_ADMIN__": SHOPPING_ADMIN,
|
||||||
|
"__GITLAB__": GITLAB,
|
||||||
|
"__WIKIPEDIA__": WIKIPEDIA,
|
||||||
|
"__MAP__": MAP,
|
||||||
|
}
|
||||||
|
elif DATASET == "visualwebarena":
|
||||||
|
print("DATASET: visualwebarena")
|
||||||
|
print(f"CLASSIFIEDS: {CLASSIFIEDS}")
|
||||||
|
print(f"REDDIT: {REDDIT}")
|
||||||
|
print(f"SHOPPING: {SHOPPING}")
|
||||||
|
inp_paths = [
|
||||||
|
"config_files/vwa/test_classifieds.raw.json", "config_files/vwa/test_shopping.raw.json", "config_files/vwa/test_reddit.raw.json",
|
||||||
|
]
|
||||||
|
replace_map = {
|
||||||
|
"__REDDIT__": REDDIT,
|
||||||
|
"__SHOPPING__": SHOPPING,
|
||||||
|
"__WIKIPEDIA__": WIKIPEDIA,
|
||||||
|
"__CLASSIFIEDS__": CLASSIFIEDS,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Dataset not implemented: {DATASET}")
|
||||||
|
|
||||||
|
for inp_path in inp_paths:
|
||||||
|
output_dir = inp_path.replace('.raw.json', '')
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
with open(inp_path, "r") as f:
|
||||||
|
raw = f.read()
|
||||||
|
for k, v in replace_map.items():
|
||||||
|
raw = raw.replace(k, v)
|
||||||
|
|
||||||
|
with open(inp_path.replace(".raw", ""), "w") as f:
|
||||||
|
f.write(raw)
|
||||||
|
data = json.loads(raw)
|
||||||
|
for idx, item in enumerate(data):
|
||||||
|
with open(os.path.join(output_dir, f"{idx}.json"), "w") as f:
|
||||||
|
json.dump(item, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
647
VAB-WebArena-Lite/new/helper_functions.py
Normal file
647
VAB-WebArena-Lite/new/helper_functions.py
Normal file
|
@ -0,0 +1,647 @@
|
||||||
|
"""Implements helper functions to assist evaluation cases where other evaluators are not suitable."""
|
||||||
|
import json
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any, Union
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from beartype import beartype
|
||||||
|
from beartype.typing import Dict, List
|
||||||
|
from playwright.sync_api import CDPSession, Page
|
||||||
|
|
||||||
|
from browser_env.env_config import (
|
||||||
|
ACCOUNTS,
|
||||||
|
REDDIT,
|
||||||
|
SHOPPING,
|
||||||
|
WIKIPEDIA,
|
||||||
|
)
|
||||||
|
from llms.providers.openai_utils import (
|
||||||
|
generate_from_openai_chat_completion,
|
||||||
|
)
|
||||||
|
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger("logger")
|
||||||
|
|
||||||
|
class PseudoPage:
|
||||||
|
def __init__(self, original_page: Page, url: str):
|
||||||
|
self.url = url
|
||||||
|
self.original_page = original_page
|
||||||
|
|
||||||
|
def __getattr__(self, attr: str) -> Any:
|
||||||
|
# Delegate attribute access to the original page object
|
||||||
|
if attr not in ["url"]:
|
||||||
|
return getattr(self.original_page, attr)
|
||||||
|
else:
|
||||||
|
return getattr(self, attr)
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def shopping_get_auth_token() -> str:
|
||||||
|
response = requests.post(
|
||||||
|
url=f"{SHOPPING}/rest/default/V1/integration/admin/token",
|
||||||
|
headers={"content-type": "application/json"},
|
||||||
|
data=json.dumps(
|
||||||
|
{
|
||||||
|
"username": ACCOUNTS["shopping_site_admin"]["username"],
|
||||||
|
"password": ACCOUNTS["shopping_site_admin"]["password"],
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
token: str = response.json()
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def shopping_get_latest_order_url() -> str:
|
||||||
|
"""Get the latest order url from the shopping website."""
|
||||||
|
|
||||||
|
header = {
|
||||||
|
"Authorization": f"Bearer {shopping_get_auth_token()}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"searchCriteria[sortOrders][0][field]": "created_at",
|
||||||
|
"searchCriteria[sortOrders][0][direction]": "DESC",
|
||||||
|
"searchCriteria[pageSize]": "1",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.get(
|
||||||
|
f"{SHOPPING}/rest/V1/orders", params=params, headers=header
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
response_obj = response.json()["items"][0]
|
||||||
|
order_id = int(response_obj["increment_id"])
|
||||||
|
order_url = f"{SHOPPING}/sales/order/view/order_id/{order_id}/"
|
||||||
|
return order_url
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def shopping_get_sku_latest_review_author(sku: str) -> str:
|
||||||
|
"""Get the latest review for shopping admin."""
|
||||||
|
header = {
|
||||||
|
"Authorization": f"Bearer {shopping_get_auth_token()}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
response = requests.get(
|
||||||
|
f"{SHOPPING}/rest/V1/products/{sku}/reviews", headers=header
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
response_obj = response.json()
|
||||||
|
if len(response_obj) == 0:
|
||||||
|
return ""
|
||||||
|
author: str = response_obj[-1]["nickname"]
|
||||||
|
return author
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def shopping_get_sku_latest_review_rating(sku: str) -> str:
|
||||||
|
"""Get the latest review for shopping admin."""
|
||||||
|
header = {
|
||||||
|
"Authorization": f"Bearer {shopping_get_auth_token()}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
response = requests.get(
|
||||||
|
f"{SHOPPING}/rest/V1/products/{sku}/reviews", headers=header
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
response_obj = response.json()
|
||||||
|
if len(response_obj) == 0:
|
||||||
|
return ""
|
||||||
|
assert response_obj[0]["ratings"][0]["rating_name"] == "Rating"
|
||||||
|
rating: str = str(response_obj[-1]["ratings"][0]["percent"])
|
||||||
|
return rating
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def shopping_get_sku_latest_review_text(sku: str) -> str:
|
||||||
|
"""Get the latest review text for shopping admin."""
|
||||||
|
header = {
|
||||||
|
"Authorization": f"Bearer {shopping_get_auth_token()}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
response = requests.get(
|
||||||
|
f"{SHOPPING}/rest/V1/products/{sku}/reviews", headers=header
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
response_obj = response.json()
|
||||||
|
if len(response_obj) == 0:
|
||||||
|
return ""
|
||||||
|
text: str = response_obj[-1]["detail"]
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def shopping_get_sku_latest_review_title(sku: str) -> str:
|
||||||
|
"""Get the latest review title for shopping admin."""
|
||||||
|
header = {
|
||||||
|
"Authorization": f"Bearer {shopping_get_auth_token()}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
response = requests.get(
|
||||||
|
f"{SHOPPING}/rest/V1/products/{sku}/reviews", headers=header
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
response_obj = response.json()
|
||||||
|
if len(response_obj) == 0:
|
||||||
|
return ""
|
||||||
|
title: str = response_obj[-1]["title"]
|
||||||
|
return title
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def shopping_get_sku_product_page_url(sku: str) -> str:
|
||||||
|
"""Get product page url from sku"""
|
||||||
|
header = {
|
||||||
|
"Authorization": f"Bearer {shopping_get_auth_token()}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
response = requests.get(
|
||||||
|
f"{SHOPPING}/rest/V1/products/{sku}", headers=header
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
response_obj = response.json()
|
||||||
|
if len(response_obj) == 0:
|
||||||
|
return ""
|
||||||
|
for custom_attributes in response_obj["custom_attributes"]:
|
||||||
|
if custom_attributes["attribute_code"] == "url_key":
|
||||||
|
return f"{SHOPPING}/{custom_attributes['value']}.html"
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def shopping_get_all_product_order(
|
||||||
|
page: Page | PseudoPage,
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
"""
|
||||||
|
Get info of all product in a given order page.
|
||||||
|
|
||||||
|
Example output:
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"name": "Kellogg's Special K Protein Bars, Meal Replacement, Protein Snacks, Value Size, Strawberry, 19oz Box (12 Bars)\nSize\n12 Count (Pack of 1)",
|
||||||
|
"options": {
|
||||||
|
"Size": "12 Count (Pack of 1)"
|
||||||
|
},
|
||||||
|
"sku": "B00MXUFL0E",
|
||||||
|
"price": "$24.50",
|
||||||
|
"qty": "Ordered2",
|
||||||
|
"subtotal": "$49.00"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Kellogg's Special K Protein Bars, Meal Replacement, Protein Snacks, Value Size, Chocolatey Chip Cookie Dough, 19oz Box (12 Bars)",
|
||||||
|
"sku": "B07ZD2PB9F",
|
||||||
|
"price": "$42.30",
|
||||||
|
"qty": "Ordered2",
|
||||||
|
"subtotal": "$84.60"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = page.evaluate(
|
||||||
|
f"""
|
||||||
|
(() => {{
|
||||||
|
try {{
|
||||||
|
const products = [...document.querySelector("#my-orders-table").getElementsByTagName('tbody')].map(
|
||||||
|
(x) => {{
|
||||||
|
return [...x.getElementsByTagName('td')].reduce(function(obj, y) {{
|
||||||
|
const key = y.className.split(' ')[1];
|
||||||
|
obj[key] = y.outerText;
|
||||||
|
// check if options exist
|
||||||
|
if (key === 'name' && y.querySelector('dl')) {{
|
||||||
|
var option_dict = {{}}
|
||||||
|
const options = [...y.querySelector('dl').children];
|
||||||
|
for (let i = 0; i < options.length; i += 2) {{
|
||||||
|
option_dict[options[i].outerText] = options[i+1].outerText;
|
||||||
|
}}
|
||||||
|
obj['options'] = option_dict;
|
||||||
|
}}
|
||||||
|
return obj;
|
||||||
|
}}, {{}})
|
||||||
|
}}
|
||||||
|
);
|
||||||
|
return products;
|
||||||
|
}} catch (e) {{
|
||||||
|
// If any errors are caught, return an empty string
|
||||||
|
return e;
|
||||||
|
return [];
|
||||||
|
}}
|
||||||
|
}})();
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
result = []
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def shopping_get_order_product_name_list(page: Page | PseudoPage) -> str:
|
||||||
|
try:
|
||||||
|
products = shopping_get_all_product_order(page)
|
||||||
|
|
||||||
|
return " |OR| ".join([p["name"] for p in products])
|
||||||
|
except Exception:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def shopping_get_order_product_quantity(
|
||||||
|
page: Page | PseudoPage, sku: str
|
||||||
|
) -> int:
|
||||||
|
try:
|
||||||
|
if "|OR|" in sku:
|
||||||
|
skus = sku.split(" |OR| ")
|
||||||
|
else:
|
||||||
|
skus = [sku]
|
||||||
|
|
||||||
|
products = shopping_get_all_product_order(page)
|
||||||
|
for product in products:
|
||||||
|
if product["sku"].strip() in skus:
|
||||||
|
# Ordered{qty}
|
||||||
|
return int(product["qty"][7:])
|
||||||
|
return 0
|
||||||
|
except Exception:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def shopping_get_order_product_option(
|
||||||
|
page: Page | PseudoPage, sku: str, option_name: str
|
||||||
|
) -> str:
|
||||||
|
try:
|
||||||
|
products = shopping_get_all_product_order(page)
|
||||||
|
for product in products:
|
||||||
|
if product["sku"].strip() == sku:
|
||||||
|
# Ordered{qty}
|
||||||
|
return product["options"][option_name]
|
||||||
|
return ""
|
||||||
|
except Exception as e:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def shopping_get_product_attributes(
|
||||||
|
page: Page | PseudoPage, attribute: str
|
||||||
|
) -> str:
|
||||||
|
# Get the values of all cells in the table for the given attribute
|
||||||
|
try:
|
||||||
|
result = page.evaluate(
|
||||||
|
f"""
|
||||||
|
(() => {{
|
||||||
|
try {{
|
||||||
|
// Create an array of search terms, splitting the string by ' |OR| '
|
||||||
|
const searchTerms = '{attribute}'.toLowerCase().split(' |or| ');
|
||||||
|
// Convert the children of the tbody inside the element with the given ID into an array
|
||||||
|
return Array.from(
|
||||||
|
document.querySelector('#productDetails_detailBullets_sections1 > tbody').children
|
||||||
|
)
|
||||||
|
// Filter the array to only include elements where the first child's text includes any of the search terms
|
||||||
|
.filter(x =>
|
||||||
|
searchTerms.some(term => x.children[0].outerText.toLowerCase().includes(term))
|
||||||
|
)
|
||||||
|
// Map over the filtered elements to get the outerText of their second child
|
||||||
|
.map(x => x.children[1].outerText)
|
||||||
|
// Join all the resulting strings with a comma and a space
|
||||||
|
.join(', ')
|
||||||
|
}} catch (e) {{
|
||||||
|
// If any errors are caught, return an empty string
|
||||||
|
return ''
|
||||||
|
}}
|
||||||
|
}})();
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
result = ""
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def shopping_get_product_price(page: Page | PseudoPage) -> Union[float, int]:
|
||||||
|
"""Get the price of the product on the shopping website."""
|
||||||
|
try:
|
||||||
|
result = page.evaluate(
|
||||||
|
"""
|
||||||
|
(() => {{
|
||||||
|
res = parseFloat(document.querySelector(\"#maincontent > div.columns > div > div.product-info-main > div.product-info-price > div.price-box.price-final_price > span > span\")
|
||||||
|
.outerText.substr(1));
|
||||||
|
return res ? res : 0;
|
||||||
|
}})();
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
result = 0
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def shopping_get_num_reviews(page: Page | PseudoPage) -> int:
|
||||||
|
"""Get the price of the product on the shopping website."""
|
||||||
|
try:
|
||||||
|
result = page.evaluate(
|
||||||
|
"""
|
||||||
|
(() => {{
|
||||||
|
res = parseInt(document.querySelector(\"#tab-label-reviews-title\")
|
||||||
|
.outerText.split(' ')[1]);
|
||||||
|
return res ? res : 0; }}
|
||||||
|
)();
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
result = 0
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def shopping_get_rating_as_percentage(page: Page | PseudoPage) -> int:
|
||||||
|
"""Get the rating of the product on the shopping website as a percentage out of 100."""
|
||||||
|
try:
|
||||||
|
rating = page.evaluate(
|
||||||
|
"""
|
||||||
|
(() => {{
|
||||||
|
ratingPercentage = parseFloat(document.querySelector('.rating-result').title.replace('%', ''));
|
||||||
|
return ratingPercentage ? ratingPercentage : 0;
|
||||||
|
}})();
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
rating = 0
|
||||||
|
|
||||||
|
return rating
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def get_query_text(page: Page | PseudoPage, selector: str) -> str:
|
||||||
|
"""Get the text content of the element matching the given selector.
|
||||||
|
|
||||||
|
Note that this function DOES NOT perform downcasing.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = page.evaluate(
|
||||||
|
f"""
|
||||||
|
(() => {{
|
||||||
|
try {{
|
||||||
|
return document.querySelector('{selector}').textContent;
|
||||||
|
}} catch (e) {{
|
||||||
|
return '';
|
||||||
|
}}
|
||||||
|
}})();
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
result = ""
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def get_query_text_lowercase(page: Page | PseudoPage, selector: str) -> str:
|
||||||
|
"""Get the lowercase text content of the element matching the given selector."""
|
||||||
|
return get_query_text(page, selector).lower()
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def reddit_get_post_url(url: str) -> str:
|
||||||
|
"""Get the post url"""
|
||||||
|
# Url is http://domain/f/subreddit/post_id/...
|
||||||
|
# get domain, subreddit, post_id
|
||||||
|
domain = urlparse(url).netloc
|
||||||
|
tok_url = urlparse(url).path.split("/")
|
||||||
|
# not a valid post/comment url, return the url as is
|
||||||
|
if len(tok_url) < 4:
|
||||||
|
return url
|
||||||
|
if tok_url[1] != "f":
|
||||||
|
return url
|
||||||
|
subreddit = urlparse(url).path.split("/")[2]
|
||||||
|
post_id = urlparse(url).path.split("/")[3]
|
||||||
|
scheme = urlparse(url).scheme
|
||||||
|
post_url = f"{scheme}://{domain}/f/{subreddit}/{post_id}/"
|
||||||
|
return post_url
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def reddit_get_post_comment_tree(page: Page | PseudoPage) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
comment_tree = page.evaluate(
|
||||||
|
f"""(function buildCommentTree(node, data_level) {{
|
||||||
|
let tree = {{
|
||||||
|
"username": node.querySelector(".fg-inherit").outerText,
|
||||||
|
"net_score": parseInt(node.querySelector(".vote__net-score").outerText),
|
||||||
|
"content": node.querySelector(".comment__content").outerText,
|
||||||
|
"time": new Date(node.querySelector('.comment__main > header > h1 > span > time').dateTime),
|
||||||
|
"children": []
|
||||||
|
}};
|
||||||
|
node.querySelectorAll(".comment").forEach((child) => {{
|
||||||
|
if (parseInt(child.getAttribute('data-level')) === data_level+1) {{
|
||||||
|
tree['children'].push(buildCommentTree(child, data_level+1));
|
||||||
|
}}
|
||||||
|
}})
|
||||||
|
|
||||||
|
return tree;
|
||||||
|
}})(document.querySelector("#main"), 0)"""
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
comment_tree = {}
|
||||||
|
|
||||||
|
return comment_tree
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def reddit_get_latest_comment_obj_by_username(
|
||||||
|
page: Page | PseudoPage, username: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
comment_tree = reddit_get_post_comment_tree(page)
|
||||||
|
latest_time = datetime.min.replace(tzinfo=timezone.utc)
|
||||||
|
comment = {}
|
||||||
|
|
||||||
|
def dfs(node):
|
||||||
|
nonlocal latest_time
|
||||||
|
nonlocal comment
|
||||||
|
if node["username"] == username:
|
||||||
|
if node["time"] > latest_time:
|
||||||
|
comment = {
|
||||||
|
"username": node["username"],
|
||||||
|
"net_score": node["net_score"],
|
||||||
|
"content": node["content"],
|
||||||
|
"time": node["time"],
|
||||||
|
}
|
||||||
|
latest_time = node["time"]
|
||||||
|
|
||||||
|
for child in node["children"]:
|
||||||
|
dfs(child)
|
||||||
|
|
||||||
|
dfs(comment_tree)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
comment = {}
|
||||||
|
return comment
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def reddit_get_latest_comment_content_by_username(
|
||||||
|
page: Page | PseudoPage, username: str
|
||||||
|
) -> str:
|
||||||
|
try:
|
||||||
|
comment = reddit_get_latest_comment_obj_by_username(page, username)
|
||||||
|
content = comment["content"]
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
content = ""
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def reddit_get_parent_comment_obj_of_latest_comment_by_username(
|
||||||
|
page: Page | PseudoPage, username: str
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
comment_tree = reddit_get_post_comment_tree(page)
|
||||||
|
latest_time = datetime.min.replace(tzinfo=timezone.utc)
|
||||||
|
comment = {}
|
||||||
|
|
||||||
|
def dfs(node):
|
||||||
|
nonlocal latest_time
|
||||||
|
nonlocal comment
|
||||||
|
for child in node["children"]:
|
||||||
|
if child["username"] == username:
|
||||||
|
if child["time"] > latest_time:
|
||||||
|
comment = {
|
||||||
|
"username": node["username"],
|
||||||
|
"net_score": node["net_score"],
|
||||||
|
"content": node["content"],
|
||||||
|
"time": node["time"],
|
||||||
|
}
|
||||||
|
latest_time = child["time"]
|
||||||
|
else:
|
||||||
|
dfs(child)
|
||||||
|
|
||||||
|
dfs(comment_tree)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
comment = {}
|
||||||
|
return comment
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def reddit_get_parent_comment_username_of_latest_comment_by_username(
|
||||||
|
page: Page | PseudoPage, username: str
|
||||||
|
) -> str:
|
||||||
|
try:
|
||||||
|
comment = reddit_get_parent_comment_obj_of_latest_comment_by_username(
|
||||||
|
page, username
|
||||||
|
)
|
||||||
|
username = comment["username"]
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
username = ""
|
||||||
|
|
||||||
|
return username
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def gitlab_get_project_memeber_role(
|
||||||
|
page: Page | PseudoPage, account_name: str
|
||||||
|
) -> str:
|
||||||
|
# get the account index
|
||||||
|
try:
|
||||||
|
account_idx = page.evaluate(
|
||||||
|
f"""(() => {{
|
||||||
|
const elements = document.querySelectorAll("td[data-label='Account'] span.gl-avatar-labeled-sublabel");
|
||||||
|
let index = -1; // Default value if not found
|
||||||
|
|
||||||
|
for(let i = 0; i < elements.length; i++) {{
|
||||||
|
if(elements[i].outerText === '@{account_name}') {{
|
||||||
|
index = i;
|
||||||
|
break;
|
||||||
|
}}
|
||||||
|
}}
|
||||||
|
|
||||||
|
return index;
|
||||||
|
}})()"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the role
|
||||||
|
role: str = page.evaluate(
|
||||||
|
f"""(() => {{
|
||||||
|
return document.querySelectorAll("td.col-max-role span")[{account_idx}].outerText;
|
||||||
|
}})()"""
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
role = ""
|
||||||
|
|
||||||
|
return role
|
||||||
|
|
||||||
|
|
||||||
|
@beartype
|
||||||
|
def llm_fuzzy_match(pred: str, reference: str, question: str) -> float:
|
||||||
|
"""Check whether the prediction matches the reference with GPT-4-turbo"""
|
||||||
|
messages: list[dict[str, Any]] = []
|
||||||
|
# construct the question to ask
|
||||||
|
message = "Help a teacher to grade the answer of a student given a question. Keep in mind that the student may use different phrasing or wording to answer the question. The goal is to evaluate whether the answer is semantically equivalent to the reference answer.\n"
|
||||||
|
message += f"question: {question}\n"
|
||||||
|
message += f"reference answer: {reference}\n"
|
||||||
|
message += "all the string 'N/A' that you see is a special sequence that means 'not achievable'\n"
|
||||||
|
message += f"student answer: {pred}\n"
|
||||||
|
message += "Conclude the judgement by 'correct', 'incorrect', or 'partially correct'. Only output one of these options, and nothing else."
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant"},
|
||||||
|
{"role": "user", "content": message},
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.info(f'[R] {reference}')
|
||||||
|
logger.info(f'[P] {pred}')
|
||||||
|
|
||||||
|
response = generate_from_openai_chat_completion(
|
||||||
|
model="gpt-4-1106-preview",
|
||||||
|
messages=messages,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=768,
|
||||||
|
top_p=1.0,
|
||||||
|
context_length=0,
|
||||||
|
).lower()
|
||||||
|
if "partially correct" in response or "incorrect" in response:
|
||||||
|
return 0.0
|
||||||
|
else:
|
||||||
|
assert "correct" in response, response
|
||||||
|
return 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def llm_ua_match(pred: str, reference: str, question: str) -> float:
|
||||||
|
"""Check whether the prediction matches the reference with GPT-4-turbo"""
|
||||||
|
messages: list[dict[str, Any]] = []
|
||||||
|
# construct the question to ask
|
||||||
|
message = ""
|
||||||
|
message += f"task: {question}\n"
|
||||||
|
message += f"actual unachievable reason: {reference}\n"
|
||||||
|
message += f"reported unachievable reason: {pred}\n"
|
||||||
|
message += (
|
||||||
|
"The task described above is inherently unachievable due to the reason specified under 'actual unachievable reason'. "
|
||||||
|
"An individual previously attempted this task and was unable to complete it. They provided a reason for their failure, "
|
||||||
|
"which is listed under 'reported unachievable reason'. Your role is to review both the actual and reported reasons. "
|
||||||
|
"Determine if the reported reason aligns with the actual reason, even if implicitly. "
|
||||||
|
"If the stated reason is in line with the actual reason, respond with 'same'. Otherwise, respond with 'different'."
|
||||||
|
)
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant"},
|
||||||
|
{"role": "user", "content": message},
|
||||||
|
]
|
||||||
|
|
||||||
|
response = generate_from_openai_chat_completion(
|
||||||
|
model="gpt-4-1106-preview",
|
||||||
|
messages=messages,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=768,
|
||||||
|
top_p=1.0,
|
||||||
|
context_length=0,
|
||||||
|
).lower()
|
||||||
|
if "different" in response:
|
||||||
|
return 0.0
|
||||||
|
else:
|
||||||
|
assert "same" in response
|
||||||
|
return 1.0
|
23
VAB-WebArena-Lite/new/llms_init.py
Normal file
23
VAB-WebArena-Lite/new/llms_init.py
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
"""This module is adapt from https://github.com/zeno-ml/zeno-build"""
|
||||||
|
try:
|
||||||
|
from .providers.gemini_utils import generate_from_gemini_completion
|
||||||
|
except:
|
||||||
|
print('Google Cloud not set up, skipping import of providers.gemini_utils.generate_from_gemini_completion')
|
||||||
|
|
||||||
|
from .providers.hf_utils import generate_from_huggingface_completion
|
||||||
|
from .providers.openai_utils import (
|
||||||
|
generate_from_openai_chat_completion,
|
||||||
|
generate_from_openai_completion,
|
||||||
|
)
|
||||||
|
from .providers.api_utils import (
|
||||||
|
generate_with_api,
|
||||||
|
)
|
||||||
|
from .utils import call_llm
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"generate_from_openai_completion",
|
||||||
|
"generate_from_openai_chat_completion",
|
||||||
|
"generate_from_huggingface_completion",
|
||||||
|
"generate_from_gemini_completion",
|
||||||
|
"call_llm",
|
||||||
|
]
|
57
VAB-WebArena-Lite/new/lm_config.py
Normal file
57
VAB-WebArena-Lite/new/lm_config.py
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
"""Config for language models."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import dataclasses
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class LMConfig:
|
||||||
|
"""A config for a language model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
provider: The name of the API provider.
|
||||||
|
model: The name of the model.
|
||||||
|
model_cls: The Python class corresponding to the model, mostly for
|
||||||
|
Hugging Face transformers.
|
||||||
|
tokenizer_cls: The Python class corresponding to the tokenizer, mostly
|
||||||
|
for Hugging Face transformers.
|
||||||
|
mode: The mode of the API calls, e.g., "chat" or "generation".
|
||||||
|
"""
|
||||||
|
|
||||||
|
provider: str
|
||||||
|
model: str
|
||||||
|
model_cls: type | None = None
|
||||||
|
tokenizer_cls: type | None = None
|
||||||
|
mode: str | None = None
|
||||||
|
gen_config: dict[str, Any] = dataclasses.field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
def construct_llm_config(args: argparse.Namespace) -> LMConfig:
|
||||||
|
llm_config = LMConfig(
|
||||||
|
provider=args.provider, model=args.model, mode=args.mode
|
||||||
|
)
|
||||||
|
if args.provider in ["openai", "google", "api", "finetune"]:
|
||||||
|
llm_config.gen_config["temperature"] = args.temperature
|
||||||
|
llm_config.gen_config["top_p"] = args.top_p
|
||||||
|
llm_config.gen_config["context_length"] = args.context_length
|
||||||
|
llm_config.gen_config["max_tokens"] = args.max_tokens
|
||||||
|
llm_config.gen_config["stop_token"] = args.stop_token
|
||||||
|
llm_config.gen_config["max_obs_length"] = args.max_obs_length
|
||||||
|
llm_config.gen_config["max_retry"] = args.max_retry
|
||||||
|
elif args.provider == "huggingface":
|
||||||
|
llm_config.gen_config["temperature"] = args.temperature
|
||||||
|
llm_config.gen_config["top_p"] = args.top_p
|
||||||
|
llm_config.gen_config["max_new_tokens"] = args.max_tokens
|
||||||
|
llm_config.gen_config["stop_sequences"] = (
|
||||||
|
[args.stop_token] if args.stop_token else None
|
||||||
|
)
|
||||||
|
llm_config.gen_config["max_obs_length"] = args.max_obs_length
|
||||||
|
llm_config.gen_config["model_endpoint"] = args.model_endpoint
|
||||||
|
llm_config.gen_config["max_retry"] = args.max_retry
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"provider {args.provider} not implemented")
|
||||||
|
return llm_config
|
286
VAB-WebArena-Lite/new/openai_utils.py
Normal file
286
VAB-WebArena-Lite/new/openai_utils.py
Normal file
|
@ -0,0 +1,286 @@
|
||||||
|
"""Tools to generate from OpenAI prompts.
|
||||||
|
Adopted from https://github.com/zeno-ml/zeno-build/"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import aiolimiter
|
||||||
|
import openai
|
||||||
|
from openai import AsyncOpenAI, OpenAI
|
||||||
|
|
||||||
|
base_url = os.environ.get("OPENAI_API_URL")
|
||||||
|
client = OpenAI(api_key=os.environ["OPENAI_API_KEY"], base_url=base_url)
|
||||||
|
aclient = AsyncOpenAI(api_key=os.environ["OPENAI_API_KEY"], base_url=base_url)
|
||||||
|
from tqdm.asyncio import tqdm_asyncio
|
||||||
|
|
||||||
|
|
||||||
|
def retry_with_exponential_backoff( # type: ignore
|
||||||
|
func,
|
||||||
|
initial_delay: float = 1,
|
||||||
|
exponential_base: float = 2,
|
||||||
|
jitter: bool = True,
|
||||||
|
max_retries: int = 3,
|
||||||
|
errors: tuple[Any] = (
|
||||||
|
openai.RateLimitError,
|
||||||
|
openai.BadRequestError,
|
||||||
|
openai.InternalServerError,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""Retry a function with exponential backoff."""
|
||||||
|
|
||||||
|
def wrapper(*args, **kwargs): # type: ignore
|
||||||
|
# Initialize variables
|
||||||
|
num_retries = 0
|
||||||
|
delay = initial_delay
|
||||||
|
|
||||||
|
# Loop until a successful response or max_retries is hit or an exception is raised
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
# Retry on specified errors
|
||||||
|
except errors as e:
|
||||||
|
# Increment retries
|
||||||
|
num_retries += 1
|
||||||
|
|
||||||
|
# Check if max retries has been reached
|
||||||
|
if num_retries > max_retries:
|
||||||
|
raise Exception(
|
||||||
|
f"Maximum number of retries ({max_retries}) exceeded."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Increment the delay
|
||||||
|
delay *= exponential_base * (1 + jitter * random.random())
|
||||||
|
|
||||||
|
# Sleep for the delay
|
||||||
|
time.sleep(delay)
|
||||||
|
|
||||||
|
# Raise exceptions for any errors not specified
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
async def _throttled_openai_completion_acreate(
|
||||||
|
engine: str,
|
||||||
|
prompt: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int,
|
||||||
|
top_p: float,
|
||||||
|
limiter: aiolimiter.AsyncLimiter,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
async with limiter:
|
||||||
|
for _ in range(3):
|
||||||
|
try:
|
||||||
|
return await aclient.completions.create(
|
||||||
|
engine=engine,
|
||||||
|
prompt=prompt,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
top_p=top_p,
|
||||||
|
)
|
||||||
|
except openai.RateLimitError:
|
||||||
|
logging.warning(
|
||||||
|
"OpenAI API rate limit exceeded. Sleeping for 10 seconds."
|
||||||
|
)
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
except openai.APIError as e:
|
||||||
|
logging.warning(f"OpenAI API error: {e}")
|
||||||
|
break
|
||||||
|
return {"choices": [{"message": {"content": ""}}]}
|
||||||
|
|
||||||
|
|
||||||
|
async def agenerate_from_openai_completion(
|
||||||
|
prompts: list[str],
|
||||||
|
engine: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int,
|
||||||
|
top_p: float,
|
||||||
|
context_length: int,
|
||||||
|
requests_per_minute: int = 300,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Generate from OpenAI Completion API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompts: list of prompts
|
||||||
|
temperature: Temperature to use.
|
||||||
|
max_tokens: Maximum number of tokens to generate.
|
||||||
|
top_p: Top p to use.
|
||||||
|
context_length: Length of context to use.
|
||||||
|
requests_per_minute: Number of requests per minute to allow.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of generated responses.
|
||||||
|
"""
|
||||||
|
if "OPENAI_API_KEY" not in os.environ:
|
||||||
|
raise ValueError(
|
||||||
|
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
|
||||||
|
)
|
||||||
|
|
||||||
|
limiter = aiolimiter.AsyncLimiter(requests_per_minute)
|
||||||
|
async_responses = [
|
||||||
|
_throttled_openai_completion_acreate(
|
||||||
|
engine=engine,
|
||||||
|
prompt=prompt,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
top_p=top_p,
|
||||||
|
limiter=limiter,
|
||||||
|
)
|
||||||
|
for prompt in prompts
|
||||||
|
]
|
||||||
|
responses = await tqdm_asyncio.gather(*async_responses)
|
||||||
|
return [x["choices"][0]["text"] for x in responses]
|
||||||
|
|
||||||
|
|
||||||
|
@retry_with_exponential_backoff
|
||||||
|
def generate_from_openai_completion(
|
||||||
|
prompt: str,
|
||||||
|
engine: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int,
|
||||||
|
top_p: float,
|
||||||
|
context_length: int,
|
||||||
|
stop_token: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
if "OPENAI_API_KEY" not in os.environ:
|
||||||
|
raise ValueError(
|
||||||
|
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.completions.create(
|
||||||
|
prompt=prompt,
|
||||||
|
engine=engine,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
top_p=top_p,
|
||||||
|
stop=[stop_token],
|
||||||
|
)
|
||||||
|
answer: str = response["choices"][0]["text"]
|
||||||
|
return answer
|
||||||
|
|
||||||
|
|
||||||
|
async def _throttled_openai_chat_completion_acreate(
|
||||||
|
model: str,
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int,
|
||||||
|
top_p: float,
|
||||||
|
limiter: aiolimiter.AsyncLimiter,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
async with limiter:
|
||||||
|
for _ in range(3):
|
||||||
|
try:
|
||||||
|
return await aclient.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
top_p=top_p,
|
||||||
|
)
|
||||||
|
except openai.RateLimitError:
|
||||||
|
logging.warning(
|
||||||
|
"OpenAI API rate limit exceeded. Sleeping for 10 seconds."
|
||||||
|
)
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
except asyncio.exceptions.TimeoutError:
|
||||||
|
logging.warning("OpenAI API timeout. Sleeping for 10 seconds.")
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
except openai.APIError as e:
|
||||||
|
logging.warning(f"OpenAI API error: {e}")
|
||||||
|
break
|
||||||
|
return {"choices": [{"message": {"content": ""}}]}
|
||||||
|
|
||||||
|
|
||||||
|
async def agenerate_from_openai_chat_completion(
|
||||||
|
messages_list: list[list[dict[str, str]]],
|
||||||
|
engine: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int,
|
||||||
|
top_p: float,
|
||||||
|
context_length: int,
|
||||||
|
requests_per_minute: int = 300,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Generate from OpenAI Chat Completion API.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages_list: list of message list
|
||||||
|
temperature: Temperature to use.
|
||||||
|
max_tokens: Maximum number of tokens to generate.
|
||||||
|
top_p: Top p to use.
|
||||||
|
context_length: Length of context to use.
|
||||||
|
requests_per_minute: Number of requests per minute to allow.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of generated responses.
|
||||||
|
"""
|
||||||
|
if "OPENAI_API_KEY" not in os.environ:
|
||||||
|
raise ValueError(
|
||||||
|
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
|
||||||
|
)
|
||||||
|
|
||||||
|
limiter = aiolimiter.AsyncLimiter(requests_per_minute)
|
||||||
|
async_responses = [
|
||||||
|
_throttled_openai_chat_completion_acreate(
|
||||||
|
model=engine,
|
||||||
|
messages=message,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
top_p=top_p,
|
||||||
|
limiter=limiter,
|
||||||
|
)
|
||||||
|
for message in messages_list
|
||||||
|
]
|
||||||
|
responses = await tqdm_asyncio.gather(*async_responses)
|
||||||
|
return [x["choices"][0]["message"]["content"] for x in responses]
|
||||||
|
|
||||||
|
|
||||||
|
@retry_with_exponential_backoff
|
||||||
|
def generate_from_openai_chat_completion(
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
model: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int,
|
||||||
|
top_p: float,
|
||||||
|
context_length: int,
|
||||||
|
stop_token: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
if "OPENAI_API_KEY" not in os.environ:
|
||||||
|
raise ValueError(
|
||||||
|
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
|
||||||
|
)
|
||||||
|
response = client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
top_p=top_p,
|
||||||
|
)
|
||||||
|
answer: str = response.choices[0].message.content
|
||||||
|
return answer
|
||||||
|
|
||||||
|
|
||||||
|
@retry_with_exponential_backoff
|
||||||
|
# debug only
|
||||||
|
def fake_generate_from_openai_chat_completion(
|
||||||
|
messages: list[dict[str, str]],
|
||||||
|
model: str,
|
||||||
|
temperature: float,
|
||||||
|
max_tokens: int,
|
||||||
|
top_p: float,
|
||||||
|
context_length: int,
|
||||||
|
stop_token: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
if "OPENAI_API_KEY" not in os.environ:
|
||||||
|
raise ValueError(
|
||||||
|
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
|
||||||
|
)
|
||||||
|
|
||||||
|
answer = "Let's think step-by-step. This page shows a list of links and buttons. There is a search box with the label 'Search query'. I will click on the search box to type the query. So the action I will perform is \"click [60]\"."
|
||||||
|
return answer
|
501
VAB-WebArena-Lite/new/prompt_constructor.py
Normal file
501
VAB-WebArena-Lite/new/prompt_constructor.py
Normal file
|
@ -0,0 +1,501 @@
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, TypedDict
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from browser_env import Action, ActionParsingError, Trajectory
|
||||||
|
from browser_env.env_config import URL_MAPPINGS
|
||||||
|
from browser_env.utils import StateInfo, pil_to_b64, pil_to_vertex
|
||||||
|
from llms import lm_config
|
||||||
|
from llms.tokenizers import Tokenizer
|
||||||
|
from llms.utils import APIInput
|
||||||
|
|
||||||
|
|
||||||
|
class Instruction(TypedDict):
|
||||||
|
"""Instruction for constructing prompt"""
|
||||||
|
|
||||||
|
intro: str
|
||||||
|
examples: list[tuple[str, str]]
|
||||||
|
template: str
|
||||||
|
meta_data: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class PromptConstructor(object):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
instruction_path: str | Path,
|
||||||
|
lm_config: lm_config.LMConfig,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
):
|
||||||
|
self.instruction_path = Path(instruction_path)
|
||||||
|
self.obs_modality = "text"
|
||||||
|
self.lm_config = lm_config
|
||||||
|
instruction = json.load(open(self.instruction_path))
|
||||||
|
instruction["examples"] = [tuple(e) for e in instruction["examples"]]
|
||||||
|
self.instruction: Instruction = instruction
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
def get_lm_api_input(
|
||||||
|
self, intro: str, examples: list[tuple[str, str]], current: str
|
||||||
|
) -> APIInput:
|
||||||
|
|
||||||
|
"""Return the require format for an API"""
|
||||||
|
message: list[dict[str, str]] | str
|
||||||
|
if "openai" in self.lm_config.provider:
|
||||||
|
if self.lm_config.mode == "chat":
|
||||||
|
message = [{"role": "system", "content": intro}]
|
||||||
|
for (x, y) in examples:
|
||||||
|
message.append(
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"name": "example_user",
|
||||||
|
"content": x,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
message.append(
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"name": "example_assistant",
|
||||||
|
"content": y,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
message.append({"role": "user", "content": current})
|
||||||
|
return message
|
||||||
|
elif self.lm_config.mode == "completion":
|
||||||
|
message = f"{intro}\n\n"
|
||||||
|
message += "Here are a few examples:\n"
|
||||||
|
for example in examples:
|
||||||
|
message += f"Observation\n:{example[0]}\n\n"
|
||||||
|
message += f"Action: {example[1]}\n\n"
|
||||||
|
message += "Now make prediction given the observation\n\n"
|
||||||
|
message += f"Observation\n:{current}\n\n"
|
||||||
|
message += "Action:"
|
||||||
|
return message
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"OpenAI models do not support mode {self.lm_config.mode}"
|
||||||
|
)
|
||||||
|
elif "huggingface" in self.lm_config.provider:
|
||||||
|
# https://huggingface.co/blog/llama2#how-to-prompt-llama-2
|
||||||
|
# https://github.com/facebookresearch/llama/blob/main/llama/generation.py#L320
|
||||||
|
if "Llama-2" in self.lm_config.model:
|
||||||
|
if self.lm_config.mode == "chat":
|
||||||
|
B_INST, E_INST = "[INST]", "[/INST]"
|
||||||
|
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
|
||||||
|
BOS, EOS = "<s>", "</s>"
|
||||||
|
# adding the system message to be the starting of the first example
|
||||||
|
examples = [
|
||||||
|
(
|
||||||
|
B_SYS + intro + E_SYS + examples[0][0],
|
||||||
|
examples[0][1],
|
||||||
|
)
|
||||||
|
] + examples[1:]
|
||||||
|
message = "".join(
|
||||||
|
[
|
||||||
|
f"{BOS}{B_INST} {x.strip()} {E_INST} {y.strip()} {EOS}"
|
||||||
|
for (x, y) in examples
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# add the current observation
|
||||||
|
message += f"{BOS}{B_INST} {current.strip()} {E_INST} {self.instruction['meta_data'].get('force_prefix', '')}"
|
||||||
|
|
||||||
|
return message
|
||||||
|
else:
|
||||||
|
raise ValueError("Only chat mode is supported for Llama-2")
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Huggingface models do not support model_tag {self.lm_config.gen_config['model_tag']}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Provider {self.lm_config.provider} not implemented"
|
||||||
|
)
|
||||||
|
|
||||||
|
def construct(
|
||||||
|
self,
|
||||||
|
trajectory: Trajectory,
|
||||||
|
intent: str,
|
||||||
|
meta_data: dict[str, Any] = {},
|
||||||
|
) -> APIInput:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def map_url_to_real(self, url: str) -> str:
|
||||||
|
"""Map the urls to their real world counterparts"""
|
||||||
|
for i, j in URL_MAPPINGS.items():
|
||||||
|
if i in url:
|
||||||
|
url = url.replace(i, j)
|
||||||
|
return url
|
||||||
|
|
||||||
|
def map_url_to_local(self, url: str) -> str:
|
||||||
|
"""Map the urls to their local counterparts"""
|
||||||
|
for i, j in URL_MAPPINGS.items():
|
||||||
|
if j in url:
|
||||||
|
url = url.replace(j, i)
|
||||||
|
# https
|
||||||
|
if j.replace("http", "https") in url:
|
||||||
|
url = url.replace(j.replace("http", "https"), i)
|
||||||
|
return url
|
||||||
|
|
||||||
|
def _extract_action(self, response: str) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def extract_action(self, response: str) -> str:
|
||||||
|
response = self._extract_action(response)
|
||||||
|
response = self.map_url_to_local(response)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
class DirectPromptConstructor(PromptConstructor):
|
||||||
|
"""The agent will direct predict the action"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
instruction_path: str | Path,
|
||||||
|
lm_config: lm_config.LMConfig,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
):
|
||||||
|
super().__init__(instruction_path, lm_config, tokenizer)
|
||||||
|
|
||||||
|
def construct(
|
||||||
|
self,
|
||||||
|
trajectory: Trajectory,
|
||||||
|
intent: str,
|
||||||
|
meta_data: dict[str, Any] = {},
|
||||||
|
) -> APIInput:
|
||||||
|
"""Construct prompt given the trajectory"""
|
||||||
|
intro = self.instruction["intro"]
|
||||||
|
examples = self.instruction["examples"]
|
||||||
|
template = self.instruction["template"]
|
||||||
|
keywords = self.instruction["meta_data"]["keywords"]
|
||||||
|
state_info: StateInfo = trajectory[-1] # type: ignore[assignment]
|
||||||
|
|
||||||
|
obs = state_info["observation"][self.obs_modality]
|
||||||
|
max_obs_length = self.lm_config.gen_config["max_obs_length"]
|
||||||
|
if max_obs_length:
|
||||||
|
if self.lm_config.provider == "google":
|
||||||
|
print("NOTE: This is a Gemini model, so we use characters instead of tokens for max_obs_length.")
|
||||||
|
obs = obs[:max_obs_length]
|
||||||
|
else:
|
||||||
|
obs = self.tokenizer.decode(self.tokenizer.encode(obs)[:max_obs_length]) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
page = state_info["info"]["page"]
|
||||||
|
url = page.url
|
||||||
|
previous_action_str = meta_data["action_history"][-1]
|
||||||
|
|
||||||
|
# input x
|
||||||
|
current = template.format(
|
||||||
|
objective=intent,
|
||||||
|
url=self.map_url_to_real(url),
|
||||||
|
observation=obs,
|
||||||
|
previous_action=previous_action_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
# make sure all keywords are replaced
|
||||||
|
assert all([f"{{k}}" not in current for k in keywords])
|
||||||
|
prompt = self.get_lm_api_input(intro, examples, current)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def _extract_action(self, response: str) -> str:
|
||||||
|
action_splitter = self.instruction["meta_data"]["action_splitter"]
|
||||||
|
pattern = rf"{action_splitter}((.|\n)*?){action_splitter}"
|
||||||
|
match = re.search(pattern, response)
|
||||||
|
if match:
|
||||||
|
return match.group(1).strip()
|
||||||
|
else:
|
||||||
|
raise ActionParsingError(
|
||||||
|
f"Cannot parse action from response {response}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CoTPromptConstructor(PromptConstructor):
|
||||||
|
"""The agent will perform step-by-step reasoning before the answer"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
instruction_path: str | Path,
|
||||||
|
lm_config: lm_config.LMConfig,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
):
|
||||||
|
super().__init__(instruction_path, lm_config, tokenizer)
|
||||||
|
self.answer_phrase = self.instruction["meta_data"]["answer_phrase"]
|
||||||
|
|
||||||
|
def construct(
|
||||||
|
self,
|
||||||
|
trajectory: Trajectory,
|
||||||
|
intent: str,
|
||||||
|
meta_data: dict[str, Any] = {},
|
||||||
|
) -> APIInput:
|
||||||
|
intro = self.instruction["intro"]
|
||||||
|
examples = self.instruction["examples"]
|
||||||
|
template = self.instruction["template"]
|
||||||
|
keywords = self.instruction["meta_data"]["keywords"]
|
||||||
|
state_info: StateInfo = trajectory[-1] # type: ignore[assignment]
|
||||||
|
|
||||||
|
obs = state_info["observation"][self.obs_modality]
|
||||||
|
max_obs_length = self.lm_config.gen_config["max_obs_length"]
|
||||||
|
if max_obs_length:
|
||||||
|
if self.lm_config.provider == "google":
|
||||||
|
print("NOTE: This is a Gemini model, so we use characters instead of tokens for max_obs_length.")
|
||||||
|
obs = obs[:max_obs_length]
|
||||||
|
else:
|
||||||
|
obs = self.tokenizer.decode(self.tokenizer.encode(obs)[:max_obs_length]) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
page = state_info["info"]["page"]
|
||||||
|
url = page.url
|
||||||
|
previous_action_str = meta_data["action_history"][-1]
|
||||||
|
current = template.format(
|
||||||
|
objective=intent,
|
||||||
|
url=self.map_url_to_real(url),
|
||||||
|
observation=obs,
|
||||||
|
previous_action=previous_action_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert all([f"{{k}}" not in current for k in keywords])
|
||||||
|
|
||||||
|
prompt = self.get_lm_api_input(intro, examples, current)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def _extract_action(self, response: str) -> str:
|
||||||
|
# find the first occurence of action
|
||||||
|
action_splitter = self.instruction["meta_data"]["action_splitter"]
|
||||||
|
pattern = rf"{action_splitter}((.|\n)*?){action_splitter}"
|
||||||
|
match = re.search(pattern, response)
|
||||||
|
if match:
|
||||||
|
return match.group(1).strip()
|
||||||
|
else:
|
||||||
|
raise ActionParsingError(
|
||||||
|
f'Cannot find the answer phrase "{self.answer_phrase}" in "{response}"'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MultimodalCoTPromptConstructor(CoTPromptConstructor):
|
||||||
|
"""The agent will perform step-by-step reasoning before the answer"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
instruction_path: str | Path,
|
||||||
|
lm_config: lm_config.LMConfig,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
):
|
||||||
|
super().__init__(instruction_path, lm_config, tokenizer)
|
||||||
|
self.answer_phrase = self.instruction["meta_data"]["answer_phrase"]
|
||||||
|
|
||||||
|
def construct(
|
||||||
|
self,
|
||||||
|
trajectory: Trajectory,
|
||||||
|
intent: str,
|
||||||
|
page_screenshot_img: Image.Image,
|
||||||
|
images: list[Image.Image],
|
||||||
|
meta_data: dict[str, Any] = {},
|
||||||
|
) -> APIInput:
|
||||||
|
intro = self.instruction["intro"]
|
||||||
|
examples = self.instruction["examples"]
|
||||||
|
template = self.instruction["template"]
|
||||||
|
keywords = self.instruction["meta_data"]["keywords"]
|
||||||
|
state_info: StateInfo = trajectory[-1] # type: ignore[assignment]
|
||||||
|
|
||||||
|
obs = state_info["observation"][self.obs_modality]
|
||||||
|
max_obs_length = self.lm_config.gen_config["max_obs_length"]
|
||||||
|
if max_obs_length:
|
||||||
|
if self.lm_config.provider in ["google", "api", "finetune"]:
|
||||||
|
print("NOTE: This is a Gemini / API model, so we use characters instead of tokens for max_obs_length.")
|
||||||
|
obs = obs[:max_obs_length]
|
||||||
|
else:
|
||||||
|
obs = self.tokenizer.decode(self.tokenizer.encode(obs)[:max_obs_length]) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
page = state_info["info"]["page"]
|
||||||
|
url = page.url
|
||||||
|
previous_action_str = meta_data["action_history"][-1]
|
||||||
|
current = template.format(
|
||||||
|
objective=intent,
|
||||||
|
url=self.map_url_to_real(url),
|
||||||
|
observation=obs,
|
||||||
|
previous_action=previous_action_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert all([f"{{k}}" not in current for k in keywords])
|
||||||
|
|
||||||
|
# TODO: for your finetune model, you can config you prompt here
|
||||||
|
if self.lm_config.provider == "finetune":
|
||||||
|
current = ""
|
||||||
|
traj = trajectory[1::2]
|
||||||
|
for rnd, tra in enumerate(traj):
|
||||||
|
tar = '** screenshot **' if rnd > 0 else intent
|
||||||
|
raw = tra["raw_prediction"]
|
||||||
|
current += f"Round {rnd}\n\n<|user|>\n\n** node_info **\n\n{tar}\n\n<|assistant|>\n{raw}\n\n"""
|
||||||
|
|
||||||
|
current += f"Round {len(traj)}\n\n<|user|>\n\n{obs}\n\n{'** screenshot **' if len(traj) > 0 else intent}\n"
|
||||||
|
|
||||||
|
prompt = self.get_lm_api_input(
|
||||||
|
intro, examples, current, page_screenshot_img, images
|
||||||
|
)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def get_lm_api_input(
|
||||||
|
self,
|
||||||
|
intro: str,
|
||||||
|
examples: list[tuple[str, str, str]],
|
||||||
|
current: str,
|
||||||
|
page_screenshot_img: Image.Image,
|
||||||
|
images: list[Image.Image],
|
||||||
|
) -> APIInput:
|
||||||
|
"""Return the require format for an API"""
|
||||||
|
message: list[dict[str, str]] | str | list[str | Image.Image]
|
||||||
|
if "openai" in self.lm_config.provider:
|
||||||
|
if self.lm_config.mode == "chat":
|
||||||
|
message = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": [{"type": "text", "text": intro}],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
for (x, y, z) in examples:
|
||||||
|
example_img = Image.open(z)
|
||||||
|
message.append(
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"name": "example_user",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": x},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "IMAGES: (1) current page screenshot",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": pil_to_b64(example_img)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
message.append(
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"name": "example_assistant",
|
||||||
|
"content": [{"type": "text", "text": y}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Encode images and page_screenshot_img as base64 strings.
|
||||||
|
current_prompt = current
|
||||||
|
content = [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "IMAGES: (1) current page screenshot",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": pil_to_b64(page_screenshot_img)},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
for image_i, image in enumerate(images):
|
||||||
|
content.extend(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": f"({image_i+2}) input image {image_i+1}",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": pil_to_b64(image)},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
content = [{"type": "text", "text": current_prompt}] + content
|
||||||
|
|
||||||
|
message.append({"role": "user", "content": content})
|
||||||
|
return message
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"GPT-4V models do not support mode {self.lm_config.mode}"
|
||||||
|
)
|
||||||
|
elif "google" in self.lm_config.provider:
|
||||||
|
if self.lm_config.mode == "completion":
|
||||||
|
message = [
|
||||||
|
intro,
|
||||||
|
"Here are a few examples:",
|
||||||
|
]
|
||||||
|
for (x, y, z) in examples:
|
||||||
|
example_img = Image.open(z)
|
||||||
|
message.append(f"Observation\n:{x}\n")
|
||||||
|
message.extend(
|
||||||
|
[
|
||||||
|
"IMAGES:",
|
||||||
|
"(1) current page screenshot:",
|
||||||
|
pil_to_vertex(example_img),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
message.append(f"Action: {y}")
|
||||||
|
message.append("Now make prediction given the observation")
|
||||||
|
message.append(f"Observation\n:{current}\n")
|
||||||
|
message.extend(
|
||||||
|
[
|
||||||
|
"IMAGES:",
|
||||||
|
"(1) current page screenshot:",
|
||||||
|
pil_to_vertex(page_screenshot_img),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
for image_i, image in enumerate(images):
|
||||||
|
message.extend(
|
||||||
|
[
|
||||||
|
f"({image_i+2}) input image {image_i+1}",
|
||||||
|
pil_to_vertex(image),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
message.append("Action:")
|
||||||
|
return message
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Gemini models do not support mode {self.lm_config.mode}"
|
||||||
|
)
|
||||||
|
elif self.lm_config.provider in ["api", "finetune"]:
|
||||||
|
message = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": intro,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# we keep few-shot here, but remove the image corresponding to the current page.
|
||||||
|
for (x, y, _) in examples:
|
||||||
|
message.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{ "type": "text", "text": x },
|
||||||
|
{ "type": "text", "text": "IMAGES: (1) current page screenshot\n\n** Screenshot **\n" },
|
||||||
|
],
|
||||||
|
})
|
||||||
|
message.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": y,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Encode images and page_screenshot_img as base64 strings, we only keep screenshot of current page.
|
||||||
|
current_prompt = current
|
||||||
|
content = []
|
||||||
|
|
||||||
|
if self.lm_config.provider != "finetune":
|
||||||
|
content.append({
|
||||||
|
"type": "text",
|
||||||
|
"text": "IMAGES: (1) current page screenshot",
|
||||||
|
})
|
||||||
|
|
||||||
|
if "text" not in self.lm_config.model:
|
||||||
|
content.append({
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": pil_to_b64(page_screenshot_img)},
|
||||||
|
})
|
||||||
|
|
||||||
|
content = [{"type": "text", "text": current_prompt}] + content
|
||||||
|
|
||||||
|
message.append({"role": "user", "content": content})
|
||||||
|
return message
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Provider {self.lm_config.provider} not implemented"
|
||||||
|
)
|
562
VAB-WebArena-Lite/new/run.py
Normal file
562
VAB-WebArena-Lite/new/run.py
Normal file
|
@ -0,0 +1,562 @@
|
||||||
|
"""Script to run end-to-end evaluation on the benchmark.
|
||||||
|
|
||||||
|
Modified from https://github.com/web-arena-x/webarena/blob/main/run.py.
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import requests
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from agent import (
|
||||||
|
PromptAgent,
|
||||||
|
construct_agent,
|
||||||
|
)
|
||||||
|
from agent.prompts import *
|
||||||
|
from browser_env import (
|
||||||
|
Action,
|
||||||
|
ActionTypes,
|
||||||
|
ScriptBrowserEnv,
|
||||||
|
StateInfo,
|
||||||
|
Trajectory,
|
||||||
|
create_stop_action,
|
||||||
|
)
|
||||||
|
from browser_env.actions import is_equivalent
|
||||||
|
from browser_env.auto_login import get_site_comb_from_filepath
|
||||||
|
from browser_env.helper_functions import (
|
||||||
|
RenderHelper,
|
||||||
|
get_action_description,
|
||||||
|
)
|
||||||
|
from evaluation_harness import evaluator_router, image_utils
|
||||||
|
|
||||||
|
DATASET = os.environ["DATASET"]
|
||||||
|
|
||||||
|
LOG_FOLDER = "log_files"
|
||||||
|
Path(LOG_FOLDER).mkdir(parents=True, exist_ok=True)
|
||||||
|
LOG_FILE_NAME = f"{LOG_FOLDER}/log_{time.strftime('%Y%m%d%H%M%S', time.localtime())}_{random.randint(0, 10000)}.log"
|
||||||
|
|
||||||
|
logger = logging.getLogger("logger")
|
||||||
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
|
console_handler = logging.StreamHandler()
|
||||||
|
console_handler.setLevel(logging.DEBUG)
|
||||||
|
logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
file_handler = logging.FileHandler(LOG_FILE_NAME)
|
||||||
|
file_handler.setLevel(logging.DEBUG)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
|
||||||
|
# Set the log format
|
||||||
|
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
|
||||||
|
console_handler.setFormatter(formatter)
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
|
||||||
|
def config() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run end-to-end evaluation on the benchmark"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--render", action="store_true", help="Render the browser"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--slow_mo",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Slow down the browser by the specified amount",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--action_set_tag", default="id_accessibility_tree", help="Action type"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--observation_type",
|
||||||
|
choices=[
|
||||||
|
"accessibility_tree",
|
||||||
|
"accessibility_tree_with_captioner",
|
||||||
|
"html",
|
||||||
|
"image",
|
||||||
|
"image_som",
|
||||||
|
],
|
||||||
|
default="accessibility_tree",
|
||||||
|
help="Observation type",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--current_viewport_only",
|
||||||
|
action="store_true",
|
||||||
|
help="Only use the current viewport for the observation",
|
||||||
|
)
|
||||||
|
parser.add_argument("--viewport_width", type=int, default=1280)
|
||||||
|
parser.add_argument("--viewport_height", type=int, default=2048)
|
||||||
|
parser.add_argument("--save_trace_enabled", action="store_true")
|
||||||
|
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
|
||||||
|
|
||||||
|
parser.add_argument("--max_steps", type=int, default=30)
|
||||||
|
|
||||||
|
# agent config
|
||||||
|
parser.add_argument("--agent_type", type=str, default="prompt")
|
||||||
|
parser.add_argument(
|
||||||
|
"--instruction_path",
|
||||||
|
type=str,
|
||||||
|
default="agents/prompts/state_action_agent.json",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--parsing_failure_th",
|
||||||
|
help="When consecutive parsing failures exceed this threshold, the agent will terminate early.",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--repeating_action_failure_th",
|
||||||
|
help="When consecutive repeated actions exceed this threshold, the agent will terminate early.",
|
||||||
|
type=int,
|
||||||
|
default=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument("--test_config_base_dir", type=str)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--eval_captioning_model_device",
|
||||||
|
type=str,
|
||||||
|
default="cpu",
|
||||||
|
choices=["cpu", "cuda"],
|
||||||
|
help="Device to run eval captioning model on. By default, runs it on CPU.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--eval_captioning_model",
|
||||||
|
type=str,
|
||||||
|
default="Salesforce/blip2-flan-t5-xl",
|
||||||
|
choices=["Salesforce/blip2-flan-t5-xl"],
|
||||||
|
help="Captioning backbone for VQA-type evals.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--captioning_model",
|
||||||
|
type=str,
|
||||||
|
default="Salesforce/blip2-flan-t5-xl",
|
||||||
|
choices=["Salesforce/blip2-flan-t5-xl", "llava-hf/llava-1.5-7b-hf"],
|
||||||
|
help="Captioning backbone for accessibility tree alt text.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# lm config
|
||||||
|
parser.add_argument("--provider", type=str, default="openai")
|
||||||
|
parser.add_argument("--model", type=str, default="gpt-3.5-turbo-0613")
|
||||||
|
parser.add_argument("--mode", type=str, default="chat")
|
||||||
|
parser.add_argument("--temperature", type=float, default=1.0)
|
||||||
|
parser.add_argument("--top_p", type=float, default=0.9)
|
||||||
|
parser.add_argument("--context_length", type=int, default=0)
|
||||||
|
parser.add_argument("--max_tokens", type=int, default=384)
|
||||||
|
parser.add_argument("--stop_token", type=str, default=None)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_retry",
|
||||||
|
type=int,
|
||||||
|
help="max retry times to perform generations when parsing fails",
|
||||||
|
default=1,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max_obs_length",
|
||||||
|
type=int,
|
||||||
|
help="when not zero, will truncate the observation to this length before feeding to the model",
|
||||||
|
default=3840,
|
||||||
|
)
|
||||||
|
|
||||||
|
# example config
|
||||||
|
parser.add_argument("--test_start_idx", type=int, default=0)
|
||||||
|
parser.add_argument("--test_end_idx", type=int, default=910)
|
||||||
|
|
||||||
|
# logging related
|
||||||
|
parser.add_argument("--result_dir", type=str, default="")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# check the whether the action space is compatible with the observation space
|
||||||
|
if (
|
||||||
|
args.action_set_tag == "id_accessibility_tree"
|
||||||
|
and args.observation_type
|
||||||
|
not in [
|
||||||
|
"accessibility_tree",
|
||||||
|
"accessibility_tree_with_captioner",
|
||||||
|
"image_som",
|
||||||
|
]
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"Action type {args.action_set_tag} is incompatible with the observation type {args.observation_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def early_stop(
|
||||||
|
trajectory: Trajectory, max_steps: int, thresholds: dict[str, int]
|
||||||
|
) -> tuple[bool, str]:
|
||||||
|
"""Check whether need to stop early"""
|
||||||
|
|
||||||
|
# reach the max step
|
||||||
|
num_steps = (len(trajectory) - 1) / 2
|
||||||
|
if num_steps >= max_steps:
|
||||||
|
return True, f"Reach max steps {max_steps}"
|
||||||
|
|
||||||
|
last_k_actions: list[Action]
|
||||||
|
action_seq: list[Action]
|
||||||
|
|
||||||
|
# Case: parsing failure for k times
|
||||||
|
k = thresholds["parsing_failure"]
|
||||||
|
last_k_actions = trajectory[1::2][-k:] # type: ignore[assignment]
|
||||||
|
if len(last_k_actions) >= k:
|
||||||
|
if all(
|
||||||
|
[
|
||||||
|
action["action_type"] == ActionTypes.NONE
|
||||||
|
for action in last_k_actions
|
||||||
|
]
|
||||||
|
):
|
||||||
|
return True, f"Failed to parse actions for {k} times"
|
||||||
|
|
||||||
|
# Case: same action for k times
|
||||||
|
k = thresholds["repeating_action"]
|
||||||
|
last_k_actions = trajectory[1::2][-k:] # type: ignore[assignment]
|
||||||
|
action_seq = trajectory[1::2] # type: ignore[assignment]
|
||||||
|
|
||||||
|
if len(action_seq) == 0:
|
||||||
|
return False, ""
|
||||||
|
|
||||||
|
last_action: Action = action_seq[-1]
|
||||||
|
|
||||||
|
if last_action["action_type"] != ActionTypes.TYPE:
|
||||||
|
if len(last_k_actions) >= k:
|
||||||
|
if all(
|
||||||
|
[
|
||||||
|
is_equivalent(action, last_action)
|
||||||
|
for action in last_k_actions
|
||||||
|
]
|
||||||
|
):
|
||||||
|
return True, f"Same action for {k} times"
|
||||||
|
|
||||||
|
else:
|
||||||
|
# check the action sequence
|
||||||
|
if (
|
||||||
|
sum([is_equivalent(action, last_action) for action in action_seq])
|
||||||
|
>= k
|
||||||
|
):
|
||||||
|
return True, f"Same typing action for {k} times"
|
||||||
|
|
||||||
|
return False, ""
|
||||||
|
|
||||||
|
|
||||||
|
def update_action_history(path: str, task_id: int, actions: List[str], score: float=-0.1):
|
||||||
|
obj = {
|
||||||
|
"task_id": task_id,
|
||||||
|
"score": score,
|
||||||
|
"actions": actions
|
||||||
|
}
|
||||||
|
json.dump(obj, open(path, "w"), indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def test(
|
||||||
|
args: argparse.Namespace,
|
||||||
|
config_file_list: list[str]
|
||||||
|
) -> None:
|
||||||
|
scores = []
|
||||||
|
max_steps = args.max_steps
|
||||||
|
|
||||||
|
early_stop_thresholds = {
|
||||||
|
"parsing_failure": args.parsing_failure_th,
|
||||||
|
"repeating_action": args.repeating_action_failure_th,
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.observation_type in [
|
||||||
|
"accessibility_tree_with_captioner",
|
||||||
|
# "image_som",
|
||||||
|
]:
|
||||||
|
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
|
||||||
|
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
||||||
|
caption_image_fn = image_utils.get_captioning_fn(
|
||||||
|
device, dtype, args.captioning_model
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
caption_image_fn = None
|
||||||
|
|
||||||
|
# Load a (possibly different) captioning model for running VQA evals.
|
||||||
|
if DATASET == 'visualwebarena':
|
||||||
|
if (
|
||||||
|
caption_image_fn
|
||||||
|
and args.eval_captioning_model == args.captioning_model
|
||||||
|
):
|
||||||
|
eval_caption_image_fn = caption_image_fn
|
||||||
|
else:
|
||||||
|
eval_caption_image_fn = image_utils.get_captioning_fn(
|
||||||
|
args.eval_captioning_model_device,
|
||||||
|
torch.float16
|
||||||
|
if (
|
||||||
|
torch.cuda.is_available()
|
||||||
|
and args.eval_captioning_model_device == "cuda"
|
||||||
|
)
|
||||||
|
else torch.float32,
|
||||||
|
args.eval_captioning_model,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
caption_image_fn = None
|
||||||
|
eval_caption_image_fn = None
|
||||||
|
|
||||||
|
agent = construct_agent(
|
||||||
|
args,
|
||||||
|
captioning_fn=caption_image_fn
|
||||||
|
if args.observation_type == "accessibility_tree_with_captioner"
|
||||||
|
else None,
|
||||||
|
) # NOTE: captioning_fn here is used for captioning input images.
|
||||||
|
|
||||||
|
env = ScriptBrowserEnv(
|
||||||
|
headless=not args.render,
|
||||||
|
slow_mo=args.slow_mo,
|
||||||
|
observation_type=args.observation_type,
|
||||||
|
current_viewport_only=args.current_viewport_only,
|
||||||
|
viewport_size={
|
||||||
|
"width": args.viewport_width,
|
||||||
|
"height": args.viewport_height,
|
||||||
|
},
|
||||||
|
save_trace_enabled=args.save_trace_enabled,
|
||||||
|
sleep_after_execution=args.sleep_after_execution,
|
||||||
|
# NOTE: captioning_fn here is used for LLM + captioning baselines.
|
||||||
|
# This can be different from the captioning model used for evals.
|
||||||
|
captioning_fn=caption_image_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
for config_file in config_file_list:
|
||||||
|
try:
|
||||||
|
render_helper = RenderHelper(
|
||||||
|
config_file, args.result_dir, args.action_set_tag
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load task.
|
||||||
|
with open(config_file) as f:
|
||||||
|
_c = json.load(f)
|
||||||
|
intent = _c["intent"]
|
||||||
|
task_id = _c["task_id"]
|
||||||
|
image_paths = _c.get("image", None)
|
||||||
|
images = []
|
||||||
|
|
||||||
|
# automatically login
|
||||||
|
if _c["storage_state"]:
|
||||||
|
cookie_file_name = os.path.basename(_c["storage_state"])
|
||||||
|
comb = get_site_comb_from_filepath(cookie_file_name)
|
||||||
|
temp_dir = tempfile.mkdtemp()
|
||||||
|
# subprocess to renew the cookie
|
||||||
|
subprocess.run(
|
||||||
|
[
|
||||||
|
"python",
|
||||||
|
"browser_env/auto_login.py",
|
||||||
|
"--auth_folder",
|
||||||
|
temp_dir,
|
||||||
|
"--site_list",
|
||||||
|
*comb,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
_c["storage_state"] = f"{temp_dir}/{cookie_file_name}"
|
||||||
|
assert os.path.exists(_c["storage_state"])
|
||||||
|
# update the config file
|
||||||
|
config_file = f"{temp_dir}/{os.path.basename(config_file)}"
|
||||||
|
with open(config_file, "w") as f:
|
||||||
|
json.dump(_c, f)
|
||||||
|
|
||||||
|
# Load input images for the task, if any.
|
||||||
|
if image_paths is not None:
|
||||||
|
if isinstance(image_paths, str):
|
||||||
|
image_paths = [image_paths]
|
||||||
|
for image_path in image_paths:
|
||||||
|
# Load image either from the web or from a local path.
|
||||||
|
if image_path.startswith("http"):
|
||||||
|
input_image = Image.open(requests.get(image_path, stream=True).raw)
|
||||||
|
else:
|
||||||
|
input_image = Image.open(image_path)
|
||||||
|
|
||||||
|
images.append(input_image)
|
||||||
|
|
||||||
|
logger.info(f"[Config file]: {config_file}")
|
||||||
|
logger.info(f"[Intent]: {intent}")
|
||||||
|
|
||||||
|
agent.reset(config_file)
|
||||||
|
trajectory: Trajectory = []
|
||||||
|
obs, info = env.reset(options={"config_file": config_file})
|
||||||
|
state_info: StateInfo = {"observation": obs, "info": info}
|
||||||
|
trajectory.append(state_info)
|
||||||
|
|
||||||
|
meta_data = {"action_history": ["None"]}
|
||||||
|
out_path = os.path.join(args.result_dir, "actions", f"{task_id}.json")
|
||||||
|
actions = []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
update_action_history(out_path, task_id, actions=actions)
|
||||||
|
|
||||||
|
early_stop_flag, stop_info = early_stop(
|
||||||
|
trajectory, max_steps, early_stop_thresholds
|
||||||
|
)
|
||||||
|
|
||||||
|
if early_stop_flag:
|
||||||
|
action = create_stop_action(f"Early stop: {stop_info}")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
action = agent.next_action(
|
||||||
|
trajectory,
|
||||||
|
intent,
|
||||||
|
images=images,
|
||||||
|
meta_data=meta_data,
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
# get the error message
|
||||||
|
action = create_stop_action(f"ERROR: {str(e)}")
|
||||||
|
|
||||||
|
trajectory.append(action)
|
||||||
|
|
||||||
|
action_str = get_action_description(
|
||||||
|
action,
|
||||||
|
state_info["info"]["observation_metadata"],
|
||||||
|
action_set_tag=args.action_set_tag,
|
||||||
|
prompt_constructor=agent.prompt_constructor
|
||||||
|
if isinstance(agent, PromptAgent)
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
render_helper.render(
|
||||||
|
action, state_info, meta_data, args.render_screenshot
|
||||||
|
)
|
||||||
|
meta_data["action_history"].append(action_str)
|
||||||
|
actions.append(action_str)
|
||||||
|
print(action_str)
|
||||||
|
|
||||||
|
if action["action_type"] == ActionTypes.STOP:
|
||||||
|
break
|
||||||
|
|
||||||
|
obs, _, terminated, _, info = env.step(action)
|
||||||
|
state_info = {"observation": obs, "info": info}
|
||||||
|
trajectory.append(state_info)
|
||||||
|
|
||||||
|
if terminated:
|
||||||
|
# add a action place holder
|
||||||
|
trajectory.append(create_stop_action(""))
|
||||||
|
break
|
||||||
|
|
||||||
|
# NOTE: eval_caption_image_fn is used for running eval_vqa functions.
|
||||||
|
evaluator = evaluator_router(
|
||||||
|
config_file, captioning_fn=eval_caption_image_fn
|
||||||
|
)
|
||||||
|
score = evaluator(
|
||||||
|
trajectory=trajectory,
|
||||||
|
config_file=config_file,
|
||||||
|
page=env.page
|
||||||
|
)
|
||||||
|
|
||||||
|
update_action_history(out_path, task_id, actions=actions, score=score)
|
||||||
|
scores.append(score)
|
||||||
|
|
||||||
|
if score == 1:
|
||||||
|
logger.info(f"[Result] (PASS) {config_file}")
|
||||||
|
else:
|
||||||
|
logger.info(f"[Result] (FAIL) {config_file}")
|
||||||
|
|
||||||
|
if args.save_trace_enabled:
|
||||||
|
env.save_trace(
|
||||||
|
Path(args.result_dir) / "traces" / f"{task_id}.zip"
|
||||||
|
)
|
||||||
|
except openai.OpenAIError as e:
|
||||||
|
logger.info(f"[OpenAI Error] {repr(e)}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(f"[Unhandled Error] {repr(e)}]")
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
# write to error file
|
||||||
|
with open(Path(args.result_dir) / "error.txt", "a") as f:
|
||||||
|
f.write(f"[Config file]: {config_file}\n")
|
||||||
|
f.write(f"[Unhandled Error] {repr(e)}\n")
|
||||||
|
f.write(traceback.format_exc()) # write stack trace to file
|
||||||
|
|
||||||
|
render_helper.close()
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
if len(scores):
|
||||||
|
logger.info(f"Average score: {sum(scores) / len(scores)}")
|
||||||
|
|
||||||
|
|
||||||
|
def prepare(args: argparse.Namespace) -> None:
|
||||||
|
# convert prompt python files to json
|
||||||
|
from agent.prompts import to_json
|
||||||
|
|
||||||
|
to_json.run()
|
||||||
|
|
||||||
|
# prepare result dir
|
||||||
|
result_dir = args.result_dir
|
||||||
|
if not result_dir:
|
||||||
|
result_dir = (
|
||||||
|
f"cache/results_{time.strftime('%Y%m%d%H%M%S', time.localtime())}"
|
||||||
|
)
|
||||||
|
if not Path(result_dir).exists():
|
||||||
|
Path(result_dir).mkdir(parents=True, exist_ok=True)
|
||||||
|
args.result_dir = result_dir
|
||||||
|
logger.info(f"Create result dir: {result_dir}")
|
||||||
|
|
||||||
|
if not (Path(result_dir) / "traces").exists():
|
||||||
|
(Path(result_dir) / "traces").mkdir(parents=True)
|
||||||
|
|
||||||
|
os.makedirs(os.path.join(result_dir, "actions"), exist_ok=True)
|
||||||
|
|
||||||
|
# log the log file
|
||||||
|
with open(os.path.join(result_dir, "log_files.txt"), "a+") as f:
|
||||||
|
f.write(f"{LOG_FILE_NAME}\n")
|
||||||
|
|
||||||
|
|
||||||
|
def get_unfinished(config_files: list[str], result_dir: str) -> list[str]:
|
||||||
|
result_files = glob.glob(f"{result_dir}/*.html")
|
||||||
|
task_ids = [
|
||||||
|
os.path.basename(f).split(".")[0].split("_")[1] for f in result_files
|
||||||
|
]
|
||||||
|
unfinished_configs = []
|
||||||
|
for config_file in config_files:
|
||||||
|
task_id = os.path.basename(config_file).split(".")[0]
|
||||||
|
try:
|
||||||
|
with open(f"{result_dir}/actions/{task_id}.json", "r") as f:
|
||||||
|
jd = json.load(f)
|
||||||
|
except:
|
||||||
|
jd = {}
|
||||||
|
if task_id not in task_ids or jd.get('score', -1) < 0:
|
||||||
|
unfinished_configs.append(config_file)
|
||||||
|
return unfinished_configs
|
||||||
|
|
||||||
|
|
||||||
|
def dump_config(args: argparse.Namespace) -> None:
|
||||||
|
config_file = Path(args.result_dir) / "config.json"
|
||||||
|
if not config_file.exists():
|
||||||
|
with open(config_file, "w") as f:
|
||||||
|
json.dump(vars(args), f, indent=4)
|
||||||
|
logger.info(f"Dump config to {config_file}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
args = config()
|
||||||
|
args.sleep_after_execution = 2.5
|
||||||
|
prepare(args)
|
||||||
|
|
||||||
|
test_config_base_dir = args.test_config_base_dir
|
||||||
|
|
||||||
|
test_file_list = []
|
||||||
|
st_idx = args.test_start_idx
|
||||||
|
ed_idx = args.test_end_idx
|
||||||
|
for i in range(st_idx, ed_idx):
|
||||||
|
test_file_list.append(os.path.join(test_config_base_dir, f"{i}.json"))
|
||||||
|
test_file_list = get_unfinished(test_file_list, args.result_dir)
|
||||||
|
print(f"Total {len(test_file_list)} tasks left")
|
||||||
|
args.render = False
|
||||||
|
args.render_screenshot = True
|
||||||
|
args.save_trace_enabled = True
|
||||||
|
|
||||||
|
args.current_viewport_only = True
|
||||||
|
dump_config(args)
|
||||||
|
|
||||||
|
test(args, test_file_list)
|
5838
VAB-WebArena-Lite/new/test_webarena_lite.raw.json
Normal file
5838
VAB-WebArena-Lite/new/test_webarena_lite.raw.json
Normal file
File diff suppressed because it is too large
Load Diff
29
VAB-WebArena-Lite/new/tokenizers.py
Normal file
29
VAB-WebArena-Lite/new/tokenizers.py
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
|
from transformers import LlamaTokenizer # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class Tokenizer(object):
|
||||||
|
def __init__(self, provider: str, model_name: str) -> None:
|
||||||
|
if provider == "openai":
|
||||||
|
self.tokenizer = tiktoken.encoding_for_model(model_name)
|
||||||
|
elif provider == "huggingface":
|
||||||
|
self.tokenizer = LlamaTokenizer.from_pretrained(model_name)
|
||||||
|
# turn off adding special tokens automatically
|
||||||
|
self.tokenizer.add_special_tokens = False # type: ignore[attr-defined]
|
||||||
|
self.tokenizer.add_bos_token = False # type: ignore[attr-defined]
|
||||||
|
self.tokenizer.add_eos_token = False # type: ignore[attr-defined]
|
||||||
|
elif provider in ["google", "api", "finetune"]:
|
||||||
|
self.tokenizer = None # Not used for input length computation, as Gemini is based on characters
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def encode(self, text: str) -> list[int]:
|
||||||
|
return self.tokenizer.encode(text)
|
||||||
|
|
||||||
|
def decode(self, ids: list[int]) -> str:
|
||||||
|
return self.tokenizer.decode(ids)
|
||||||
|
|
||||||
|
def __call__(self, text: str) -> list[int]:
|
||||||
|
return self.tokenizer.encode(text)
|
87
VAB-WebArena-Lite/new/utils.py
Normal file
87
VAB-WebArena-Lite/new/utils.py
Normal file
|
@ -0,0 +1,87 @@
|
||||||
|
import argparse
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
try:
|
||||||
|
from vertexai.preview.generative_models import Image
|
||||||
|
from llms import generate_from_gemini_completion
|
||||||
|
except:
|
||||||
|
print('Google Cloud not set up, skipping import of vertexai.preview.generative_models.Image and llms.generate_from_gemini_completion')
|
||||||
|
|
||||||
|
from llms import (
|
||||||
|
generate_from_huggingface_completion,
|
||||||
|
generate_from_openai_chat_completion,
|
||||||
|
generate_from_openai_completion,
|
||||||
|
generate_with_api,
|
||||||
|
lm_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
APIInput = str | list[Any] | dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
def call_llm(
|
||||||
|
lm_config: lm_config.LMConfig,
|
||||||
|
prompt: APIInput,
|
||||||
|
) -> str:
|
||||||
|
response: str
|
||||||
|
if lm_config.provider == "openai":
|
||||||
|
if lm_config.mode == "chat":
|
||||||
|
assert isinstance(prompt, list)
|
||||||
|
response = generate_from_openai_chat_completion(
|
||||||
|
messages=prompt,
|
||||||
|
model=lm_config.model,
|
||||||
|
temperature=lm_config.gen_config["temperature"],
|
||||||
|
top_p=lm_config.gen_config["top_p"],
|
||||||
|
context_length=lm_config.gen_config["context_length"],
|
||||||
|
max_tokens=lm_config.gen_config["max_tokens"],
|
||||||
|
stop_token=None,
|
||||||
|
)
|
||||||
|
elif lm_config.mode == "completion":
|
||||||
|
assert isinstance(prompt, str)
|
||||||
|
response = generate_from_openai_completion(
|
||||||
|
prompt=prompt,
|
||||||
|
engine=lm_config.model,
|
||||||
|
temperature=lm_config.gen_config["temperature"],
|
||||||
|
max_tokens=lm_config.gen_config["max_tokens"],
|
||||||
|
top_p=lm_config.gen_config["top_p"],
|
||||||
|
stop_token=lm_config.gen_config["stop_token"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"OpenAI models do not support mode {lm_config.mode}"
|
||||||
|
)
|
||||||
|
elif lm_config.provider == "huggingface":
|
||||||
|
assert isinstance(prompt, str)
|
||||||
|
response = generate_from_huggingface_completion(
|
||||||
|
prompt=prompt,
|
||||||
|
model_endpoint=lm_config.gen_config["model_endpoint"],
|
||||||
|
temperature=lm_config.gen_config["temperature"],
|
||||||
|
top_p=lm_config.gen_config["top_p"],
|
||||||
|
stop_sequences=lm_config.gen_config["stop_sequences"],
|
||||||
|
max_new_tokens=lm_config.gen_config["max_new_tokens"],
|
||||||
|
)
|
||||||
|
elif lm_config.provider == "google":
|
||||||
|
assert isinstance(prompt, list)
|
||||||
|
assert all(
|
||||||
|
[isinstance(p, str) or isinstance(p, Image) for p in prompt]
|
||||||
|
)
|
||||||
|
response = generate_from_gemini_completion(
|
||||||
|
prompt=prompt,
|
||||||
|
engine=lm_config.model,
|
||||||
|
temperature=lm_config.gen_config["temperature"],
|
||||||
|
max_tokens=lm_config.gen_config["max_tokens"],
|
||||||
|
top_p=lm_config.gen_config["top_p"],
|
||||||
|
)
|
||||||
|
elif lm_config.provider in ["api", "finetune"]:
|
||||||
|
args = {
|
||||||
|
"temperature": lm_config.gen_config["temperature"], # openai, gemini, claude
|
||||||
|
"max_tokens": lm_config.gen_config["max_tokens"], # openai, gemini, claude
|
||||||
|
"top_k": lm_config.gen_config["top_p"], # qwen
|
||||||
|
}
|
||||||
|
response = generate_with_api(prompt, lm_config.model, args)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Provider {lm_config.provider} not implemented"
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
78
VAB-WebArena-Lite/new/wa_parallel_run.sh
Normal file
78
VAB-WebArena-Lite/new/wa_parallel_run.sh
Normal file
|
@ -0,0 +1,78 @@
|
||||||
|
#!/bin/bash
|
||||||
|
DATASET='webarena' # webarena, visualwebarena
|
||||||
|
result_dir=''
|
||||||
|
provider=''
|
||||||
|
model=''
|
||||||
|
instruction_path='agent/prompts/jsons/p_som_cot_id_actree_3s.json' # e.g., agent/prompts/jsons/p_cot_id_actree_2s.json
|
||||||
|
test_config_base_dir='config_files/wa/test_webarena_lite' # e.g., config_files/wa/test_webarena_lite
|
||||||
|
temperature=0.0
|
||||||
|
|
||||||
|
SERVER='' # your server address
|
||||||
|
MAP_SERVER='' # the same as SERVER
|
||||||
|
OPENAI_API_KEY=''
|
||||||
|
OPENAI_ORGANIZATION=''
|
||||||
|
CONDA_ENV_NAME='' # the name of your conda environment
|
||||||
|
|
||||||
|
ENV_VARIABLES="export DATASET=${DATASET}; export SHOPPING='http://${SERVER}:7770';export SHOPPING_ADMIN='http://${SERVER}:7780/admin';export REDDIT='http://${SERVER}:9999';export GITLAB='http://${SERVER}:8023';export MAP='http://${MAP_SERVER}:3000';export WIKIPEDIA='http://${SERVER}:8888/wikipedia_en_all_maxi_2022-05/A/User:The_other_Kiwix_guy/Landing';export HOMEPAGE='http://${SERVER}:4399';export OPENAI_API_KEY=${OPENAI_API_KEY};export OPENAI_ORGANIZATION=${OPENAI_ORGANIZATION}"
|
||||||
|
|
||||||
|
# get the number of tmux panes
|
||||||
|
num_panes=$(tmux list-panes | wc -l)
|
||||||
|
|
||||||
|
# calculate how many panes need to be created
|
||||||
|
let "panes_to_create = 8 - num_panes"
|
||||||
|
|
||||||
|
# array of tmux commands to create each pane
|
||||||
|
tmux_commands=(
|
||||||
|
'tmux split-window -h'
|
||||||
|
'tmux split-window -v'
|
||||||
|
'tmux select-pane -t 0; tmux split-window -v'
|
||||||
|
'tmux select-pane -t 0; tmux split-window -v'
|
||||||
|
'tmux select-pane -t 2; tmux split-window -v'
|
||||||
|
'tmux select-pane -t 4; tmux split-window -v'
|
||||||
|
'tmux select-pane -t 6; tmux split-window -v'
|
||||||
|
)
|
||||||
|
|
||||||
|
# create panes up to 8
|
||||||
|
for ((i=0; i<$panes_to_create; i++)); do
|
||||||
|
eval ${tmux_commands[$i]}
|
||||||
|
done
|
||||||
|
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Function to run a job
|
||||||
|
run_job() {
|
||||||
|
tmux select-pane -t $1
|
||||||
|
tmux send-keys "tmux set mouse on; conda activate ${CONDA_ENV_NAME}; ${ENV_VARIABLES}; until python run.py --viewport_width 1280 --viewport_height 720 --test_start_idx $2 --test_end_idx $3 --provider ${provider} --model ${model} --instruction_path ${instruction_path} --temperature ${temperature} --test_config_base_dir ${test_config_base_dir} --result_dir ${result_dir}; do echo 'crashed' >&2; sleep 1; done" C-m
|
||||||
|
sleep 3
|
||||||
|
}
|
||||||
|
|
||||||
|
TOLERANCE=2
|
||||||
|
run_batch() {
|
||||||
|
args=("$@") # save all arguments in an array
|
||||||
|
num_jobs=${#args[@]} # get number of arguments
|
||||||
|
|
||||||
|
for ((i=1; i<$num_jobs; i++)); do
|
||||||
|
run_job $i ${args[i-1]} ${args[i]}
|
||||||
|
done
|
||||||
|
|
||||||
|
# Wait for all jobs to finish
|
||||||
|
while tmux list-panes -F "#{pane_pid} #{pane_current_command}" | grep -q python; do
|
||||||
|
sleep 100 # wait for 10 seconds before checking again
|
||||||
|
done
|
||||||
|
|
||||||
|
# Run checker
|
||||||
|
while ! python scripts/check_error_runs.py ${result_dir} --delete_errors --tolerance ${TOLERANCE}; do
|
||||||
|
echo "Check failed, rerunning jobs..."
|
||||||
|
for ((i=1; i<$num_jobs; i++)); do
|
||||||
|
run_job $i ${args[i-1]} ${args[i]}
|
||||||
|
done
|
||||||
|
|
||||||
|
# Wait for all jobs to finish
|
||||||
|
while tmux list-panes -F "#{pane_pid} #{pane_current_command}" | grep -q python; do
|
||||||
|
sleep 100 # wait for 10 seconds before checking again
|
||||||
|
done
|
||||||
|
done
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
run_batch 0 24 48 72 96 120 143 165
|
71
VAB-WebArena-Lite/refresh_website_dockers.sh
Normal file
71
VAB-WebArena-Lite/refresh_website_dockers.sh
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
docker stop shopping
|
||||||
|
docker stop shopping_admin
|
||||||
|
docker stop forum
|
||||||
|
docker stop gitlab
|
||||||
|
|
||||||
|
docker rm shopping
|
||||||
|
docker rm shopping_admin
|
||||||
|
docker rm forum
|
||||||
|
docker rm gitlab
|
||||||
|
|
||||||
|
# One Stop Shop
|
||||||
|
# docker load --input shopping_final_0712.tar
|
||||||
|
docker run --name shopping -p 7770:80 -d shopping_final_0712
|
||||||
|
|
||||||
|
# CMS
|
||||||
|
# docker load --input shopping_admin_final_0719.tar
|
||||||
|
docker run --name shopping_admin -p 7780:80 -d shopping_admin_final_0719
|
||||||
|
|
||||||
|
# Reddit
|
||||||
|
# docker load --input postmill-populated-exposed-withimg.tar
|
||||||
|
docker run --name forum -p 9999:80 -d postmill-populated-exposed-withimg
|
||||||
|
|
||||||
|
# GitLab
|
||||||
|
# docker load --input gitlab-populated-final-port8023.tar
|
||||||
|
docker run --name gitlab --shm-size="10g" -d -p 8023:8023 gitlab-populated-final-port8023 /opt/gitlab/embedded/bin/runsvdir-start
|
||||||
|
|
||||||
|
sleep 60
|
||||||
|
|
||||||
|
# Define your actual server hostname
|
||||||
|
YOUR_ACTUAL_HOSTNAME="http://localhost"
|
||||||
|
# Remove trailing / if it exists
|
||||||
|
YOUR_ACTUAL_HOSTNAME=${YOUR_ACTUAL_HOSTNAME%/}
|
||||||
|
|
||||||
|
# OSS
|
||||||
|
docker exec shopping /var/www/magento2/bin/magento setup:store-config:set --base-url="${YOUR_ACTUAL_HOSTNAME}:7770"
|
||||||
|
docker exec shopping mysql -u magentouser -pMyPassword magentodb -e "UPDATE core_config_data SET value='${YOUR_ACTUAL_HOSTNAME}:7775/' WHERE path = 'web/secure/base_url';"
|
||||||
|
docker exec shopping /var/www/magento2/bin/magento cache:flush
|
||||||
|
|
||||||
|
# Disable re-indexing of products
|
||||||
|
docker exec shopping /var/www/magento2/bin/magento indexer:set-mode schedule catalogrule_product
|
||||||
|
docker exec shopping /var/www/magento2/bin/magento indexer:set-mode schedule catalogrule_rule
|
||||||
|
docker exec shopping /var/www/magento2/bin/magento indexer:set-mode schedule catalogsearch_fulltext
|
||||||
|
docker exec shopping /var/www/magento2/bin/magento indexer:set-mode schedule catalog_category_product
|
||||||
|
docker exec shopping /var/www/magento2/bin/magento indexer:set-mode schedule customer_grid
|
||||||
|
docker exec shopping /var/www/magento2/bin/magento indexer:set-mode schedule design_config_grid
|
||||||
|
docker exec shopping /var/www/magento2/bin/magento indexer:set-mode schedule inventory
|
||||||
|
docker exec shopping /var/www/magento2/bin/magento indexer:set-mode schedule catalog_product_category
|
||||||
|
docker exec shopping /var/www/magento2/bin/magento indexer:set-mode schedule catalog_product_attribute
|
||||||
|
docker exec shopping /var/www/magento2/bin/magento indexer:set-mode schedule catalog_product_price
|
||||||
|
docker exec shopping /var/www/magento2/bin/magento indexer:set-mode schedule cataloginventory_stock
|
||||||
|
|
||||||
|
# CMS
|
||||||
|
docker exec shopping_admin /var/www/magento2/bin/magento setup:store-config:set --base-url="${YOUR_ACTUAL_HOSTNAME}:7780"
|
||||||
|
docker exec shopping_admin mysql -u magentouser -pMyPassword magentodb -e "UPDATE core_config_data SET value='${YOUR_ACTUAL_HOSTNAME}:7780/' WHERE path = 'web/secure/base_url';"
|
||||||
|
docker exec shopping_admin /var/www/magento2/bin/magento cache:flush
|
||||||
|
|
||||||
|
# Forum
|
||||||
|
docker exec forum sed -i '/@RateLimit/,/)/d' /var/www/html/src/DataObject/CommentData.php
|
||||||
|
docker exec forum sed -i '/@RateLimit/,/)/d' /var/www/html/src/DataObject/SubmissionData.php
|
||||||
|
docker exec forum sed -i '/@RateLimit/,/)/d' /var/www/html/src/DataObject/UserData.php
|
||||||
|
docker exec forum bin/console cache:clear --env=prod
|
||||||
|
|
||||||
|
sleep 60
|
||||||
|
|
||||||
|
# Gitlab
|
||||||
|
docker exec gitlab sed -i "s|^external_url.*|external_url '${YOUR_ACTUAL_HOSTNAME}:8023'|" /etc/gitlab/gitlab.rb
|
||||||
|
docker exec gitlab sed -i "s/.*postgresql\['max_connections'.*/postgresql\['max_connections'\] = 2000/g" /etc/gitlab/gitlab.rb
|
||||||
|
docker exec gitlab gitlab-ctl reconfigure
|
||||||
|
docker exec gitlab gitlab-ctl restart
|
37
VAB-WebArena-Lite/replace.sh
Normal file
37
VAB-WebArena-Lite/replace.sh
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# 1. reset the original environment
|
||||||
|
cd visualwebarena
|
||||||
|
rm -rf .git
|
||||||
|
cd ..
|
||||||
|
|
||||||
|
# 2. replace files
|
||||||
|
# webarena-lite
|
||||||
|
cp -f new/test_webarena_lite.raw.json visualwebarena/config_files/wa/test_webarena_lite.raw.json
|
||||||
|
cp -f new/generate_test_data.py visualwebarena/scripts/generate_test_data.py
|
||||||
|
|
||||||
|
# agent
|
||||||
|
cp -f new/run.py visualwebarena/run.py
|
||||||
|
cp -f new/agent.py visualwebarena/agent/agent.py
|
||||||
|
cp -f new/prompt_constructor.py visualwebarena/agent/prompts/prompt_constructor.py
|
||||||
|
|
||||||
|
# llms
|
||||||
|
cp -f new/utils.py visualwebarena/llms/utils.py
|
||||||
|
cp -f new/llms_init.py visualwebarena/llms/__init__.py
|
||||||
|
cp -f new/lm_config.py visualwebarena/llms/lm_config.py
|
||||||
|
cp -f new/tokenizers.py visualwebarena/llms/tokenizers.py
|
||||||
|
cp -f new/api_utils.py visualwebarena/llms/providers/api_utils.py
|
||||||
|
cp -f new/openai_utils.py visualwebarena/llms/providers/openai_utils.py
|
||||||
|
|
||||||
|
# eval
|
||||||
|
cp -f new/evaluators.py visualwebarena/evaluation_harness/evaluators.py
|
||||||
|
cp -f new/helper_functions.py visualwebarena/evaluation_harness/helper_functions.py
|
||||||
|
|
||||||
|
# misc
|
||||||
|
cp -f README.md visualwebarena/README.md
|
||||||
|
cp -f new/wa_parallel_run.sh visualwebarena/wa_parallel_run.sh
|
||||||
|
|
||||||
|
# 3. remove temporary files
|
||||||
|
mv visualwebarena/* .
|
||||||
|
rm -rf new
|
||||||
|
rm -rf visualwebarena
|
BIN
assets/framework.png
Normal file
BIN
assets/framework.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 407 KiB |
49
docs/detailed_setups/VAB-Minecraft.md
Normal file
49
docs/detailed_setups/VAB-Minecraft.md
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
# Setup for VAB-Minecraft
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
1. We have tested on Ubuntu. VAB-Minecraft requires at least 4 GB NVIDIA GPU and NVIDIA GPU driver version >= 530.30.02.
|
||||||
|
|
||||||
|
2. Besides [docker](https://www.docker.com/), install [NVIDIA container toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) on your machine.
|
||||||
|
|
||||||
|
3. Get pre-built docker image.
|
||||||
|
|
||||||
|
- If you have access to docker hub:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker pull tianjiezhang/vab_minecraft:latest
|
||||||
|
```
|
||||||
|
|
||||||
|
- Or you can download from ModelScope.
|
||||||
|
|
||||||
|
1. Make sure `git-lfs` is installed.
|
||||||
|
|
||||||
|
2. Download from ModelScope:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git lfs install
|
||||||
|
|
||||||
|
git clone https://www.modelscope.cn/datasets/VisualAgentBench/VAB-Minecraft.git
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Load the docker image from ModelScope dataset.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker load -i VAB-Minecraft/vab_minecraft.tar
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Download weights of Steve-1 to `data/minecraft`. Please make sure you have access to google drive.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python scripts/minecraft_download.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Get Started
|
||||||
|
|
||||||
|
According to your hardware equipment, fill `available_ports` and `available_devices` in the task configuration file `configs/tasks/minecraft.yaml`.
|
||||||
|
|
||||||
|
- `available_ports`: Please fill in available ports in your machine. Each concurrent docker container requires 1 port for communication with the task server. Ensure that you provide enough ports to accommodate the expected concurrency.
|
||||||
|
|
||||||
|
- `available_devices`: Please fill in GPU IDs and their corresponding capability of concurrency. Each concurrent docker container occupies about **3.3 GB** memory. Ensure that you provide enough GPU memory to accommodate the expected concurrency.
|
||||||
|
|
||||||
|
**Note: If you manually shut down the task server and assigner, please ensure you also stop the Minecraft containers to free up the ports!**
|
|
@ -57,57 +57,3 @@
|
||||||
2. Load the new value: `sudo sysctl -p`.
|
2. Load the new value: `sudo sysctl -p`.
|
||||||
|
|
||||||
**Note: If you manually shut down the task server and assigner, please ensure you also stop the OmniGibson containers to free up the ports!**
|
**Note: If you manually shut down the task server and assigner, please ensure you also stop the OmniGibson containers to free up the ports!**
|
||||||
|
|
||||||
# Setup for VAB-Minecraft
|
|
||||||
|
|
||||||
## Installation
|
|
||||||
|
|
||||||
1. We have tested on Ubuntu. VAB-Minecraft requires at least 4 GB NVIDIA GPU and NVIDIA GPU driver version >= 530.30.02.
|
|
||||||
|
|
||||||
2. Besides [docker](https://www.docker.com/), install [NVIDIA container toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) on your machine.
|
|
||||||
|
|
||||||
3. Get pre-built docker image.
|
|
||||||
|
|
||||||
- If you have access to docker hub:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
docker pull tianjiezhang/vab_minecraft:latest
|
|
||||||
```
|
|
||||||
|
|
||||||
- Or you can download from ModelScope.
|
|
||||||
|
|
||||||
1. Make sure `git-lfs` is installed.
|
|
||||||
|
|
||||||
2. Download from ModelScope:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git lfs install
|
|
||||||
|
|
||||||
git clone https://www.modelscope.cn/datasets/VisualAgentBench/VAB-Minecraft.git
|
|
||||||
```
|
|
||||||
|
|
||||||
3. Load the docker image from ModelScope dataset.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
docker load -i VAB-Minecraft/vab_minecraft.tar
|
|
||||||
```
|
|
||||||
|
|
||||||
4. Download weights of Steve-1 to `data/minecraft`. Please make sure you have access to google drive.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python scripts/minecraft_download.py
|
|
||||||
```
|
|
||||||
|
|
||||||
## Get Started
|
|
||||||
|
|
||||||
According to your hardware equipment, fill `available_ports` and `available_devices` in the task configuration file `configs/tasks/minecraft.yaml`.
|
|
||||||
|
|
||||||
- `available_ports`: Please fill in available ports in your machine. Each concurrent docker container requires 1 port for communication with the task server. Ensure that you provide enough ports to accommodate the expected concurrency.
|
|
||||||
|
|
||||||
- `available_devices`: Please fill in GPU IDs and their corresponding capability of concurrency. Each concurrent docker container occupies about **3.3 GB** memory. Ensure that you provide enough GPU memory to accommodate the expected concurrency.
|
|
||||||
|
|
||||||
**Note: If you manually shut down the task server and assigner, please ensure you also stop the Minecraft containers to free up the ports!**
|
|
||||||
|
|
||||||
# Setup for VAB-CSS
|
|
||||||
|
|
||||||
TODO
|
|
Loading…
Reference in New Issue
Block a user