Add documentation for VAB-WebArena-Lite

This commit is contained in:
xiao9905 2024-10-20 00:10:34 +08:00
parent 9b7a654aaa
commit 0c6f549214
21 changed files with 10138 additions and 70 deletions

View File

@ -26,32 +26,30 @@ Compared to its predecessor [AgentBench](https://github.com/THUDM/AgentBench), V
## Table of Contents
- [Quick Start](#quick-start)
- [Dataset Summary](#dataset-summary)
- [Leaderboard](#leaderboard)
- [Quick Start](#quick-start)
- [Acknowledgement](#acknowledgement)
- [Citation](#citation)
## Dataset Summary
We offer two splits for each dataset: Testing and Training. Different from its predecessor [AgentBench](https://github.com/THUDM/AgentBench), VAB is accompanied with a trajectory training set for behavior cloning (BC) training, which allows development of more potent visual foundation agents with emerging open LMMs.
![](./assets/statistics.png)
## Leaderboard
Here is the scores on test set results of VAB. All metrics are task Success Rate (SR). Noted that proprietary LMMs are tested with mere **Prompting**, and open LMMs are tested after **Multitask Finetuning** on VAB training set, as they usually fail to follow complicated agent task instructions.
![](./assets/leaderboard.png)
## Quick Start
This section will guide you on how to use `gpt-4o-2024-05-13` as an agent to launch 4 concurrent `VAB-Minecraft` tasks.
For the specific framework structure, please refer to AgentBench's [Framework Introduction](https://github.com/THUDM/AgentBench/blob/main/docs/Introduction_en.md).
This section will first give you an overview to the use and architecture of VAB.
Next, it will guide you on how to use `gpt-4o-2024-05-13` as an exemplar agent to launch 4 concurrent `VAB-Minecraft` tasks.
### Overview on VAB Framework
To allow fast evaluation over agent tasks, we leverage AgentBench's framework as the backbone (currently for VAB-OmniGibson, VAB-Minecraft, and VAB-CSS).
If you are interested in its detailed implementation, please refer to AgentBench's [Framework Introduction](https://github.com/THUDM/AgentBench/blob/main/docs/Introduction_en.md) (which may not be necessary).
Basically, the framework calls all LLM/LMM in API formats via `Agent-Controller`, and accesses to environments via `Task-Controller`.
The `Assigner` will automatically assign evaluation tasks by pairing `Agent-Controller` and `Task-Controller` to optimize the overall evaluation speed.
For more detailed configuration and launch methods, please check [Configuration Guide](docs/Config_en.md)
and [Program Entrance Guide](docs/Entrance_en.md).
### Step 1. Prerequisites
![](./assets/framework.png)
### Step 1. Prerequisites for All Environments
Clone this repo and install the dependencies.
@ -68,7 +66,14 @@ Ensure that [Docker](https://www.docker.com/) is properly installed.
docker ps
```
For specific environments, please refer to their respective prerequisites: [VAB-OmniGibson](docs/README_setup.md#Setup-for-VAB-OmniGibson), [VAB-Minecraft](docs/README_setup.md#Setup-for-VAB-Minecraft), [VAB-CSS](docs/README_setup.md#Setup-for-VAB-CSS).
For specific environments, please refer to their additional prerequisites respectively.
For VAB-WebArena-Lite, it is based on [WebArena](https://github.com/webarena-x/webarena) with some modifications, so please read its individual setup carefully.
* [VAB-OmniGibson Setup](docs/detailed_setups/VAB-OmniGibson.md)
* [VAB-Minecraft Setup](docs/detailed_setups/VAB-Minecraft.md)
* VAB-Mobile: Ongoing
* [VAB-WebArena-Lite Setup](VAB-WebArena-Lite/README.md) (Separate installation and evaluation method)
* VAB-CSS: Ongoing
### Step 2. Configure the Agent
@ -117,6 +122,19 @@ python -m src.assigner --auto-retry --config configs/assignments/omnigibson.yaml
You can modify the config files to launch other tasks or change task concurrency.
## Dataset Summary
We offer two splits for each dataset: Testing and Training. Different from its predecessor [AgentBench](https://github.com/THUDM/AgentBench), VAB is accompanied with a trajectory training set for behavior cloning (BC) training, which allows development of more potent visual foundation agents with emerging open LMMs.
![](./assets/statistics.png)
## Leaderboard
Here is the scores on test set results of VAB. All metrics are task Success Rate (SR). Noted that proprietary LMMs are tested with mere **Prompting**, and open LMMs are tested after **Multitask Finetuning** on VAB training set, as they usually fail to follow complicated agent task instructions.
![](./assets/leaderboard.png)
## Acknowledgement
This project is heavily built upon the following repositories (to be updated):

218
VAB-WebArena-Lite/README.md Normal file
View File

@ -0,0 +1,218 @@
# Setup for VAB-WebArena-Lite
## Brief Introduction
VAB-WebArena-Lite is a 165-task refined subset from <a href="https://webarena.dev/" target="_blank">WebArena</a>.
The purpose of building this subset is to manually ensure task correctness & feasibility, and speed up testing (original 812-task WebArena usually takes more than 6h to run through, while VAB-WebArena-Lite takes around 40m in practice).
The modified version of the test cases can be found in `config_files/wa/test_webarena_lite.raw.json`.
## Install
First, you should clone the official repository of <a href="https://github.com/web-arena-x/visualwebarena">VisualWebArena</a> to this directory
```bash
# Assume you have cloned VAB and is now in the `VAB-WebArena-Lite` directory
git clone https://github.com/web-arena-x/visualwebarena.git visualwebarena
cd visualwebarena
git reset --hard ad57aae4dad71531504726900b80db02e0526158
cd ..
```
Then, you should substitute the file with the commands below:
```bash
bash replace.sh
```
After that, you should install the dependencies for VAB-WebArena-Lite (recommend using a independent conda environment to VAB):
```bash
# Python 3.10 (or 3.11, but not 3.12 cause 3.12 deprecated distutils needed here)
python -m wal wal
source venv/bin/activate
pip install -r requirements.txt
playwright install
pip install -e .
```
You can also run the unit tests to ensure that WebArena-Lite is installed correctly:
```bash
pytest -x
```
## Setup WebArena-Lite Environments
1. Setup the standalone environments.
Please check out [this page](https://github.com/web-arena-x/webarena/tree/main/environment_docker) for details.
2. Configurate the urls for each website.
First, export the `DATASET` to be `webarena`:
```bash
export DATASET=webarena
```
Then, set the URL for the websites
(🚨 Notice: check if default ports of websites below correspond to those you setup in the first step)
```bash
# Actually, the CLASSIFIEDS environment is not included in the WebArena-Lite evaluation, we keep the environment variables here just for consistency.
export CLASSIFIEDS="<your_classifieds_domain>:9980"
export CLASSIFIEDS_RESET_TOKEN="4b61655535e7ed388f0d40a93600254c"
# Below are the variables you should set for the evaluation.
export SHOPPING="<your_shopping_site_domain>:7770"
export REDDIT="<your_reddit_domain>:9999"
export SHOPPING_ADMIN="<your_e_commerce_cms_domain>:7780/admin"
export GITLAB="<your_gitlab_domain>:8023"
export MAP="<your_map_domain>:3000"
export WIKIPEDIA="<your_wikipedia_domain>:8888"
export HOMEPAGE="<your_homepage_domain>:4399"
```
3. Generate config files for each test example:
```bash
python scripts/generate_test_data.py
```
You will see `*.json` files generated in the [config_files](./config_files) folder. Each file contains the configuration for one test example.
4. Obtain and save the auto-login cookies for all websites:
```bash
bash prepare.sh
```
5. Set up API keys.
```bash
export OPENAI_API_KEY=your_key
# Optional: if you use a different OpenAI model source
export OPENAI_API_URL=your_url
# Optional: you can set the following variables to evaluate the preset model in llms/providers/api_utils.py
export GEMENI_API_KEY=your_key
export QWEN_API_KEY=your_key
export CLAUDE_API_KEY=your_key
# Optional: if you have trained your model, we recommend deploying it as an API service, where you can set a FINETUNED_URL to evaluate it.
export FINETUNED_URL=your_url
```
If using Gemini, first install the [gcloud CLI](https://cloud.google.com/sdk/docs/install). Configure the API key by authenticating with Google Cloud:
```bash
gcloud auth login
gcloud config set project <your_project_name>
```
## 🖼️ Evaluating in VAB Standard Setting with SoM (Set-of-Marks) Visual Agents
### 👎 Run Single Agent For Evalution (Slow, but please read to understand meaning of arguments)
To run your own model with SoM visual agent, you can run evaluation with the following flags:
```bash
python run.py \
--instruction_path agent/prompts/jsons/p_som_cot_id_actree_3s.json \
--test_start_idx 0 \
--test_end_idx 1 \
--result_dir <your_result_dir> \
--test_config_base_dir config_files/wa/test_webarena_lite \
--provider api \
--model openai_gpt-4-vision-preview \
--action_set_tag som --observation_type image_som
```
Besides the original model providers (OpenAI, Google), you can also add your models in `llms/providers/api_utils.py`. Remember to set `--provider` to:
- `api`: Keep the same input style as WebArena, suitable for regular API calls
- `finetune`: This is required for models trained with the data we provide.
For the `--model` variable, we use the format `<source>_<model-name>` .
- If there is no more optional models under source, you can set it to just `source`.
- Remember that the source name here should be added in the init function of `APIModel` in `llms/providers/api_utils.py`.
- For example, if you want to use the openai model "gpt-4o", you can set the flag like this: `--model openai_gpt-4o`.
Finally, run `score.py` to get the pass rate
```bash
python score.py <your_result_dir>
```
### 👍 Run Parallel Agent For Evaluation (Recommended)
To run the tests in parallel, you can first configure `wa_parallel_run.sh`, then run it. We default split the test set to 5 parallel-group for evaluation in VAB.
```bash
# Remember to first launch a tmux session
tmux
bash wa_parallel_run.sh
```
The script is enabled with auto-resuming if it is interrupted or met unexpected error. Please feel free to rerun the above command until all tasks finish.
After all parallel groupes finish, run `score.py` to get the pass rate
```bash
python score.py <your_result_dir>
```
### 🚨 Important: Refresh all websites before re-run another round of testing!
Since tasks in WebArena may involve changing status and database of websites (e.g., posting comments on Reddit), if websites are not all refreshed before another round of evaluation, the results would be problematic.
Please remember to run following command (assume you are hosting WebArena websites on your own) to restart and refresh all website dockers to avoid potential contamination.
The process usually takes 3-5 minites.
```bash
# Make sure the script is executed on the machine that you run those website dockers
bash refresh_website_docker.sh
```
You may need to change some contents in the script (e.g. configured ports of websites, names of dockers, etc.).
## Run Visualized Demostration
Original WebArena have also prepared a demo for you to run the agents on your own task on an arbitrary webpage. An example is shown above where the agent is tasked to find the best Thai restaurant in Pittsburgh.
After following the setup instructions above and setting the OpenAI API key (the other environment variables for website URLs aren't really used, so you should be able to set them to some dummy variable), you can run the GPT-4V + SoM agent with the following command:
```bash
python run_demo.py \
--instruction_path agent/prompts/jsons/p_som_cot_id_actree_3s.json \
--start_url "https://www.amazon.com" \
--image "https://media.npr.org/assets/img/2023/01/14/this-is-fine_wide-0077dc0607062e15b476fb7f3bd99c5f340af356-s1400-c100.jpg" \
--intent "Help me navigate to a shirt that has this on it." \
--result_dir demo_test_amazon \
--model gpt-4-vision-preview \
--action_set_tag som --observation_type image_som \
--render
```
This tasks the agent to find a shirt that looks like the provided image (the "This is fine" dog) from Amazon. Have fun!
## Acknowledgements
Our code is heavily based off the <a href="https://github.com/web-arena-x/webarena">WebArena codebase</a> and <a href="https://github.com/web-arena-x/visualwebarena">VisualWebArena codebase</a>.
If you find this environment useful, please consider citing <a href="https://jykoh.com/vwa" target="_blank">VisualWebArena</a> as well as <a href="https://webarena.dev/" target="_blank">WebArena</a>:
```bibtex
@article{koh2024visualwebarena,
title={VisualWebArena: Evaluating Multimodal Agents on Realistic Visual Web Tasks},
author={Koh, Jing Yu and Lo, Robert and Jang, Lawrence and Duvvur, Vikram and Lim, Ming Chong and Huang, Po-Yu and Neubig, Graham and Zhou, Shuyan and Salakhutdinov, Ruslan and Fried, Daniel},
journal={arXiv preprint arXiv:2401.13649},
year={2024}
}
@article{zhou2024webarena,
title={WebArena: A Realistic Web Environment for Building Autonomous Agents},
author={Zhou, Shuyan and Xu, Frank F and Zhu, Hao and Zhou, Xuhui and Lo, Robert and Sridhar, Abishek and Cheng, Xianyi and Bisk, Yonatan and Fried, Daniel and Alon, Uri and others},
journal={ICLR},
year={2024}
}
```

View File

@ -0,0 +1,227 @@
import argparse
import json
from typing import Any, Optional
import tiktoken
from beartype import beartype
from PIL import Image
from agent.prompts import *
from browser_env import Trajectory
from browser_env.actions import (
Action,
ActionParsingError,
create_id_based_action,
create_none_action,
create_playwright_action,
)
from browser_env.utils import Observation, StateInfo
from llms import (
call_llm,
generate_from_huggingface_completion,
generate_from_openai_chat_completion,
generate_from_openai_completion,
lm_config,
)
from llms.tokenizers import Tokenizer
class Agent:
"""Base class for the agent"""
def __init__(self, *args: Any) -> None:
pass
def next_action(
self, trajectory: Trajectory, intent: str, meta_data: Any
) -> Action:
"""Predict the next action given the observation"""
raise NotImplementedError
def reset(
self,
test_config_file: str,
) -> None:
raise NotImplementedError
class TeacherForcingAgent(Agent):
"""Agent that follows a pre-defined action sequence"""
def __init__(self) -> None:
super().__init__()
def set_action_set_tag(self, tag: str) -> None:
self.action_set_tag = tag
def set_actions(self, action_seq: str | list[str]) -> None:
if isinstance(action_seq, str):
action_strs = action_seq.strip().split("\n")
else:
action_strs = action_seq
action_strs = [a.strip() for a in action_strs]
actions = []
for a_str in action_strs:
try:
if self.action_set_tag == "playwright":
cur_action = create_playwright_action(a_str)
elif self.action_set_tag == "id_accessibility_tree":
cur_action = create_id_based_action(a_str)
else:
raise ValueError(
f"Unknown action type {self.action_set_tag}"
)
except ActionParsingError as e:
cur_action = create_none_action()
cur_action["raw_prediction"] = a_str
actions.append(cur_action)
self.actions: list[Action] = actions
def next_action(
self, trajectory: Trajectory, intent: str, meta_data: Any
) -> Action:
"""Predict the next action given the observation"""
return self.actions.pop(0)
def reset(
self,
test_config_file: str,
) -> None:
with open(test_config_file) as f:
ref_actions = json.load(f)["reference_action_sequence"]
tag = ref_actions["action_set_tag"]
action_seq = ref_actions["action_sequence"]
self.set_action_set_tag(tag)
self.set_actions(action_seq)
class PromptAgent(Agent):
"""prompt-based agent that emits action given the history"""
@beartype
def __init__(
self,
action_set_tag: str,
lm_config: lm_config.LMConfig,
prompt_constructor: PromptConstructor,
captioning_fn = None,
) -> None:
super().__init__()
self.lm_config = lm_config
self.prompt_constructor = prompt_constructor
self.action_set_tag = action_set_tag
self.captioning_fn = captioning_fn
# Check if the model is multimodal.
if ("gemini" in lm_config.model or "gpt-4" in lm_config.model and "vision" in lm_config.model or lm_config.provider in ["api", "finetune"]) and type(prompt_constructor) == MultimodalCoTPromptConstructor:
self.multimodal_inputs = True
else:
self.multimodal_inputs = False
def set_action_set_tag(self, tag: str) -> None:
self.action_set_tag = tag
@beartype
def next_action(
self, trajectory: Trajectory, intent: str, meta_data: dict[str, Any], images: Optional[list[Image.Image]] = None,
output_response: bool = False
) -> Action:
# Create page screenshot image for multimodal models.
if self.multimodal_inputs:
page_screenshot_arr = trajectory[-1]["observation"]["image"]
page_screenshot_img = Image.fromarray(
page_screenshot_arr
) # size = (viewport_width, viewport_width)
# Caption the input image, if provided.
if images is not None and len(images) > 0:
if self.captioning_fn is not None:
image_input_caption = ""
for image_i, image in enumerate(images):
if image_i == 0:
image_input_caption += f'Input image {image_i+1}: "{self.captioning_fn([image])[0]}"'
else:
image_input_caption += f'input image {image_i+1}: "{self.captioning_fn([image])[0]}"'
if len(images) > 1:
image_input_caption += ", "
# Update intent to include captions of input images.
intent = f"{image_input_caption}\nIntent: {intent}"
elif not self.multimodal_inputs:
print(
"WARNING: Input image provided but no image captioner available."
)
if self.multimodal_inputs:
prompt = self.prompt_constructor.construct(
trajectory, intent, page_screenshot_img, images, meta_data
)
else:
prompt = self.prompt_constructor.construct(
trajectory, intent, meta_data
)
lm_config = self.lm_config
n = 0
while True:
response = call_llm(lm_config, prompt)
force_prefix = self.prompt_constructor.instruction[
"meta_data"
].get("force_prefix", "")
response = f"{force_prefix}{response}"
if output_response:
print(f'Agent: {response}', flush=True)
n += 1
try:
parsed_response = self.prompt_constructor.extract_action(
response
)
if self.action_set_tag == "id_accessibility_tree":
action = create_id_based_action(parsed_response)
elif self.action_set_tag == "playwright":
action = create_playwright_action(parsed_response)
elif self.action_set_tag == "som":
action = create_id_based_action(parsed_response)
else:
raise ValueError(
f"Unknown action type {self.action_set_tag}"
)
action["raw_prediction"] = response
break
except ActionParsingError as e:
if n >= lm_config.gen_config["max_retry"]:
action = create_none_action()
action["raw_prediction"] = response
break
return action
def reset(self, test_config_file: str) -> None:
pass
def construct_agent(args: argparse.Namespace, captioning_fn=None) -> Agent:
llm_config = lm_config.construct_llm_config(args)
agent: Agent
if args.agent_type == "teacher_forcing":
agent = TeacherForcingAgent()
elif args.agent_type == "prompt":
with open(args.instruction_path) as f:
constructor_type = json.load(f)["meta_data"]["prompt_constructor"]
tokenizer = Tokenizer(args.provider, args.model)
prompt_constructor = eval(constructor_type)(
args.instruction_path, lm_config=llm_config, tokenizer=tokenizer
)
agent = PromptAgent(
action_set_tag=args.action_set_tag,
lm_config=llm_config,
prompt_constructor=prompt_constructor,
captioning_fn=captioning_fn
)
else:
raise NotImplementedError(
f"agent type {args.agent_type} not implemented"
)
return agent

View File

@ -0,0 +1,682 @@
import os
import copy
import json
import time
import base64
import shutil
import requests
import dashscope
import http.client
import anthropic
import google.auth
from google.oauth2 import service_account
from google.auth.transport.requests import Request
from openai import OpenAI
from typing import List, Tuple, Dict
from http import HTTPStatus
from PIL import Image
from io import BytesIO
PROXIES = { # gemini
"http": "http://127.0.0.1:7890",
"https": "http://127.0.0.1:7890"
}
SEED = int(os.environ.get("SEED", 42))
GCLOUD_KEY_FILE_PATH = "" # path to the google cloud project json file
GCLOUD_REGIONAL_CODE = "asia-east1"
OPENAI_API_URL = os.environ.get("OPENAI_API_URL")
FINETUNED_URL = os.environ.get("FINETUNED_URL") # finetuned model url
OPENAI_API_KEY = os.environ["OPENAI_API_KEY"] # you should alway setup openai api key for evaluation
GEMINI_API_KEY = os.environ.get("GEMENI_API_KEY", "") # no need when using google cloud
QWEN_API_KEY = os.environ.get("QWEN_API_KEY" , "")
CLAUDE_API_KEY = os.environ.get("CLAUDE_API_KEY", "")
class BasicModel(object):
def __init__(self):
super().__init__()
# make temp dir here
file_path = os.path.dirname(__file__)
self.base_dir = os.path.join(file_path, "temp", f"{int(time.time())}")
os.makedirs(self.base_dir, exist_ok=True)
def __del__(self):
# remove temp dir
shutil.rmtree(self.base_dir, ignore_errors=True)
def prompt_construct(self, messages: List[Dict]) -> List[Dict]:
return messages
@staticmethod
def process_system_prompt(messages: List[Dict]) -> List[Dict]:
if messages[0]["role"] != "system":
return messages
new_messages = copy.deepcopy(messages[1:])
system_prompt = messages[0]["content"]
# Search for first user message and add system prompt to it
for item in new_messages:
if item.get("role") != "user":
continue
for ct in item["content"]:
# Case 1: directly appended to the text
if ct["type"] == "text":
ct["text"] = system_prompt + "\n" + ct["text"]
return new_messages
# Case 2: create a new text item
item["content"].insert(0, {
"type": "text",
"text": system_prompt
})
return new_messages
# Case 3: no user message found, add a new user message
new_messages.insert(0, {
"role": "user",
"content": [{
"type": "text",
"text": system_prompt
}]
})
return new_messages
@staticmethod
def pil_to_b64(img: Image.Image) -> str:
with BytesIO() as image_buffer:
img.save(image_buffer, format="PNG")
byte_data = image_buffer.getvalue()
img_b64 = base64.b64encode(byte_data).decode("utf-8")
img_b64 = "data:image/png;base64," + img_b64
return img_b64
# save base64 image and return filename
def b64_to_image(self, base64_data: str) -> str:
base64_data = base64_data.split(",")[1]
image_data = base64.b64decode(base64_data)
filename = os.path.join(self.base_dir, f"{int(time.time())}.png")
with open(filename, "wb") as f:
f.write(image_data)
return filename
def get_model_response(self, messages: List[Dict], model_name: str, **args) -> Tuple[bool, str]:
raise NotImplementedError("Subclasses must implement this method")
class OpenAIModel(BasicModel):
def __init__(self):
super().__init__()
self.client = OpenAI(api_key=OPENAI_API_KEY, base_url=OPENAI_API_URL)
def prompt_construct(self, messages: List[Dict]) -> List[Dict]:
return messages
def get_model_response(self, messages: List[Dict], model_name: str, **args) -> Tuple[bool, str]:
if "OPENAI_API_KEY" not in os.environ:
raise ValueError(
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
)
response = self.client.chat.completions.create(
model=model_name,
messages=messages,
temperature=args.get("temperature", 0.0),
max_tokens=args.get("max_tokens", 1024),
top_p=args.get("top_p", 1.0),
)
try:
answer: str = response.choices[0].message.content
return True, answer
except:
return False, str(response.error)
class FinetuneModel(BasicModel):
def __init__(self):
super().__init__()
self.url = FINETUNED_URL # inference api
def prompt_construct(self, messages: List[Dict]) -> List[Dict]:
dialog, images = "", []
for message in messages:
if message["role"] == "system":
dialog += f"<|system|>\n{message['content']}\n\n"
continue
elif message["role"] == "assistant":
dialog += f"<|assistant|>\n{message['content']}\n\n"
continue
dialog += f"<|user|>\n"
for content in message["content"]:
if content["type"] == "text":
dialog += f"{content['text']}\n"
else:
# TODO: we use filename as image url here
images.append(self.b64_to_image(content["image_url"]["url"]))
dialog += "\n\n"
dialog += "<|assistant|>\n"
images = [open(image, "rb") for image in images]
new_messages = [
{"image": images[0]},
{"prompt": dialog}
]
return new_messages
def get_model_response(self, messages: List[Dict], model_name: str, **args) -> Tuple[bool, str]:
try:
response = requests.post(self.url, files=messages[0], data=messages[1], timeout=40)
response = response.json()
except Exception as e:
return False, str(e)
if "error" in response:
return False, response["error"]["message"]
# TODO: you should change the response format here
resp = f'```\n{response["response"].split("<|end_of_text|>")[0]}\n```'
return True, resp
class QwenModel(BasicModel):
def __init__(self):
super().__init__()
dashscope.api_key = QWEN_API_KEY
self.seed = SEED
def prompt_construct(self, messages: List[Dict]) -> List[Dict]:
messages = self.process_system_prompt(messages)
new_messages = []
for message in messages:
if message["role"] != "user":
new_messages.append({
"role": "assistant",
"content": [{"text": message["content"]}]
})
continue
new_content = []
for content in message["content"]:
if content["type"] == "text":
new_content.append({
"text": content["text"],
})
else:
filename = self.b64_to_image(content["image_url"]["url"])
new_content.append({
"image": f"file://{filename}"
})
new_messages.append({
"role": "user",
"content": new_content
})
return new_messages
def get_model_response(self, messages: List[Dict], model_name: str, **args) -> Tuple[bool, str]:
if "QWEN_API_KEY" not in os.environ:
raise ValueError(
"QWEN_API_KEY environment variable must be set when using Qwen API."
)
response = dashscope.MultiModalConversation.call(
model=model_name,
messages=messages,
top_k=args.get("top_k"),
seed=self.seed
)
if response.status_code == HTTPStatus.OK:
return True, response.output.choices[0].message.content[0]['text']
else:
return False, response.message
class ClaudeModel(BasicModel):
def __init__(self):
super().__init__()
self.client = anthropic.Anthropic(api_key=CLAUDE_API_KEY)
def prompt_construct(self, messages: List[Dict]) -> List[Dict]:
new_messages = []
for message in messages:
if message["role"] in ["system", "assistant"]:
new_messages.append(message)
continue
new_content = []
for content in message["content"]:
if content["type"] == "text":
new_content.append(content)
continue
hdr, idata = content["image_url"]["url"].split(";base64,")
new_content.append({
"type": "image",
"source": {
"type": "base64",
"media_type": hdr.split("data:")[1],
"data": idata
}
})
new_messages.append({
"role": "user",
"content": new_content
})
return new_messages
def get_model_response(self, messages: List[Dict], model_name: str, **args) -> Tuple[bool, str]:
try:
if messages[0]["role"] == "system":
system_prompt = messages[0]["content"]
messages = messages[1:]
response = self.client.messages.create(
model="claude-3-5-sonnet-20240620",
max_tokens=args.get("max_tokens"),
temperature=args.get("temperature"),
system=system_prompt,
messages=messages
)
else:
response = self.client.messages.create(
model="claude-3-5-sonnet-20240620",
max_tokens=args.get("max_tokens"),
temperature=args.get("temperature"),
messages=messages
)
usage = response.usage
prompt_tokens = usage.input_tokens
completion_tokens = usage.output_tokens
# print(response)
print(response.content)
print(f"Prompt Tokens: {prompt_tokens}\nCompletion Tokens: {completion_tokens}\n")
return True, response.content
except Exception as e:
return False, str(e)
def get_model_response_thirdapi(self, messages) -> Tuple[bool, str]:
conn = http.client.HTTPSConnection("cn2us02.opapi.win", timeout=900)
system_prompt = None
if messages[0]["role"] == "system":
system_prompt = messages[0]["content"]
messages = messages[1:]
payload = {
"model": "claude-3-opus",
"stream": False,
"system": system_prompt,
"messages": messages,
"temperature": self.temperature,
"max_tokens": self.max_tokens
}
else:
payload = {
"model": "claude-3-opus",
"stream": False,
"messages": messages,
"temperature": self.temperature,
"max_tokens": self.max_tokens
}
payload = json.dumps(payload)
headers = {
'Accept': 'application/json',
'Authorization': f'Bearer {CLAUDE_API_KEY}',
'User-Agent': 'Apifox/1.0.0 (https://apifox.com)',
'Content-Type': 'application/json'
}
try:
conn.request("POST", "/v1/messages", payload, headers)
res = conn.getresponse()
data = res.read()
response = json.loads(data.decode("utf-8"))
except Exception as e:
return False, str(e)
if "statusCode" in response and response["statusCode"] != 200:
return False, response["message"]
usage = response["usage"]
prompt_tokens = usage["input_tokens"]
completion_tokens = usage["output_tokens"]
print(f"Prompt Tokens: {prompt_tokens}\nCompletion Tokens: {completion_tokens}\n")
return True, response["content"][0]["text"]
class GeminiModel(BasicModel):
def __init__(self):
super().__init__()
self.api_key = GEMINI_API_KEY
def prompt_construct(self, messages: List[Dict]) -> List[Dict]:
parts = []
dialog = ""
sep = "\n\n###\n\n"
for message in messages:
if message["role"] == "system":
dialog += f"SYSTEM:\n{message['content']}{sep}"
elif message["role"] == "assistant":
dialog += f"ASSISTANT:\n{message['content']}{sep}"
elif message["role"] == "user":
dialog += "USER:\n"
for content in message["content"]:
if content["type"] == "text":
dialog += content["text"] + "\n"
continue
assert content["type"] == "image_url"
# save text
parts.append({ "text": dialog })
dialog = ""
# new content type for image
hdr, idata = content["image_url"]["url"].split(";base64,")
parts.append({
"inline_data": {
"mime_type": hdr.split("data:")[1],
"data": idata
}
})
dialog += sep
else:
raise ValueError(f"Invalid role: {message['role']}")
parts.append({
"text": dialog + "ASSISTANT:\n"
})
new_messages = [{
"parts": parts,
"role": "user"
}]
return new_messages
def get_model_response(self, messages: List[Dict], model_name: str, **args) -> Tuple[bool, str]:
headers = {
"Content-Type": "application/json"
}
proxies = PROXIES
url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}-latest:generateContent?key={self.api_key}"
generation_config = {
"temperature": args.get('temperature'),
"maxOutputTokens": args.get('max_tokens'),
"stopSequences": ["\n\n###\n\n"]
}
safety_settings = [
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_ONLY_HIGH"
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_ONLY_HIGH"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_ONLY_HIGH"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_ONLY_HIGH"
}
]
payload = {
"contents": messages,
"generationConfig": generation_config,
"safetySettings": safety_settings
}
try:
response = requests.post(url, headers=headers, json=payload, proxies=proxies, timeout=30)
response = response.json()
except Exception as e:
return False, str(e)
if "error" in response:
return False, response["error"]["message"]
if "content" not in response['candidates'][0]:
self.generation_config['maxOutputTokens'] *= 2
return False, "No content generated."
return True, response['candidates'][0]['content']['parts'][0]['text']
class VertexGeminiModel(BasicModel):
def __init__(self):
super().__init__()
def prompt_construct(self, messages: List[Dict]) -> List[Dict]:
parts = []
dialog = ""
sep = "\n\n###\n\n"
for message in messages:
if message["role"] == "system":
dialog += f"SYSTEM:\n{message['content']}{sep}"
elif message["role"] == "assistant":
dialog += f"ASSISTANT:\n{message['content']}{sep}"
elif message["role"] == "user":
dialog += "USER:\n"
for content in message["content"]:
if content["type"] == "text":
dialog += content["text"] + "\n"
continue
assert content["type"] == "image_url"
# save text
parts.append({ "text": dialog })
dialog = ""
# new content type for image
hdr, idata = content["image_url"]["url"].split(";base64,")
parts.append({
"inline_data": {
"mime_type": hdr.split("data:")[1],
"data": idata
}
})
dialog += sep
else:
raise ValueError(f"Invalid role: {message['role']}")
parts.append({
"text": dialog + "ASSISTANT:\n"
})
new_messages = [{
"parts": parts,
"role": "user"
}]
return new_messages
def get_model_response(self, messages: List[Dict], model_name: str, **args) -> Tuple[bool, str]:
def get_gcloud_token():
def get_token():
try:
# Load the credentials from the key file
creds = service_account.Credentials.from_service_account_file(
GCLOUD_KEY_FILE_PATH,
# You can list multiple scopes if needed
scopes=['https://www.googleapis.com/auth/cloud-platform']
)
# Refresh the token (this is needed even for the first time)
creds.refresh(Request())
return creds.token
except Exception as e:
print(f"An error occurred while trying to fetch the gcloud token: {str(e)}")
return None
os.environ['HTTP_PROXY'] = PROXIES['http']
os.environ['HTTPS_PROXY'] = PROXIES['https']
fail_time = 0
while not api_key and fail_time < 10:
time.sleep(5)
api_key = get_token()
fail_time += 1
return api_key
def get_url(model_name: str) -> str:
region_code = GCLOUD_REGIONAL_CODE
model_id = f"{model_name}:generateContent"
with open(GCLOUD_KEY_FILE_PATH, "r") as f:
project_id = json.load(f)["project_id"]
return f"https://{region_code}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{region_code}/publishers/google/models/{model_id}"
url = get_url(model_name)
api_key = get_gcloud_token()
if not api_key:
return False, "Failed to fetch gcloud token."
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {api_key}"
}
proxies = PROXIES
safety_settings = [
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_ONLY_HIGH"
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_ONLY_HIGH"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_ONLY_HIGH"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_ONLY_HIGH"
}
]
generation_config = {
"temperature": args.get('temperature'),
"maxOutputTokens": args.get('max_tokens'),
"stopSequences": ["\n\n###\n\n"]
}
payload = {
"contents": messages,
"generationConfig": generation_config,
"safetySettings": safety_settings
}
try:
response = requests.post(url, headers=headers, json=payload, proxies=proxies, timeout=120)
response = response.json()
except Exception as e:
return False, str(e)
if "error" in response:
return False, response["error"]["message"]
if "content" not in response['candidates'][0]:
self.generation_config['maxOutputTokens'] *= 2
return False, "No content generated."
return True, response['candidates'][0]['content']['parts'][0]['text']
class APIModel(object):
def __init__(self):
super().__init__()
self.models = {
"openai": OpenAIModel(),
"gemini": VertexGeminiModel(),
"qwen": QwenModel(),
"finetuned": FinetuneModel(),
"claude": ClaudeModel()
}
def inference(self, model_id: str, messages: List[Dict], args: Dict) -> Tuple[bool, str]:
model_provider, model_name = model_id.split("_")[:2] if "_" in model_id else (model_id, model_id) # eg. "openai_gpt4o"
if model_provider not in self.models:
return False, f"Unsupported model: {model_provider} ({model_name})"
model = self.models[model_provider]
prompt = model.prompt_construct(messages)
resp = model.get_model_response(prompt, model_name, **args)
return resp
model = APIModel()
def generate_with_api(prompt: List[dict], model_id: str, args: Dict) -> str:
success, response = model.inference(model_id, prompt, args)
return response
if __name__ == "__main__":
path_to_image = "../../coco_images/000000000285.jpg"
from PIL import Image
image = Image.open(path_to_image)
img_str = BasicModel.pil_to_b64(image)
messages = [
{
"role": "system",
"content": "You are a helpful assistant. Please response concisely."
},
{
"role": "user",
"content": [{
"type": "text",
"text": "what's annotated in this image? Image: Omitted."
}]
},
{
"role": "assistant",
"content": "Only 5.cart is annotated in this image."
},
{
"role": "user",
"content": [{
"type": "text",
"text": "What can you see?"
},{
"type": "image_url",
"image_url": {
"url": img_str,
"detail": "high"
}
}]
}
]
response = generate_with_api(messages, "openai", {
"temperature": 0.5,
"max_tokens": 1024,
"top_p": 0.9,
"n": 1,
})

View File

@ -0,0 +1,649 @@
"""base class for evaluation"""
# answer string match
import importlib
import json
import re
import time
import urllib
from pathlib import Path
from typing import Any, Optional, Tuple, Union
from urllib.parse import urljoin
import evaluate # type: ignore[import]
import requests
from beartype import beartype
from beartype.door import is_bearable
from nltk.tokenize import word_tokenize # type: ignore
from PIL import Image
from playwright.sync_api import CDPSession, Page
from browser_env.actions import Action
from browser_env.utils import StateInfo
from evaluation_harness import image_utils
from evaluation_harness.helper_functions import (
PseudoPage,
get_query_text,
get_query_text_lowercase,
gitlab_get_project_memeber_role,
llm_fuzzy_match,
llm_ua_match,
reddit_get_latest_comment_content_by_username,
reddit_get_latest_comment_obj_by_username,
reddit_get_parent_comment_username_of_latest_comment_by_username,
reddit_get_post_url,
shopping_get_latest_order_url,
shopping_get_num_reviews,
shopping_get_order_product_name_list,
shopping_get_order_product_option,
shopping_get_order_product_quantity,
shopping_get_product_attributes,
shopping_get_product_price,
shopping_get_rating_as_percentage,
shopping_get_sku_latest_review_author,
shopping_get_sku_latest_review_rating,
shopping_get_sku_latest_review_text,
)
Trajectory = list[Union[Action, StateInfo]]
@beartype
class Evaluator(object):
def __init__(self, eval_tag: str = "") -> None:
self.eval_tag = eval_tag
def __call__(
self,
trajectory: Trajectory,
config_file: Path | str,
page: Page | PseudoPage
) -> float:
raise NotImplementedError
@staticmethod
def get_last_action(trajectory: Trajectory) -> Action:
try:
is_bearable(trajectory[-1], Action)
last_action = trajectory[-1]
except Exception:
raise ValueError(
"The last element of trajectory should be an action, add a fake stop action if needed"
)
return last_action # type: ignore[return-value]
@staticmethod
def get_last_state(trajectory: Trajectory) -> StateInfo:
try:
is_bearable(trajectory[-2], StateInfo)
last_state = trajectory[-2]
except Exception:
raise ValueError(
"The second last element of trajectory should be a state, add a fake stop action if needed"
)
return last_state # type: ignore[return-value]
@beartype
class NumericEvaluator(Evaluator):
"""Check if the numerical relationship is correct"""
@staticmethod
@beartype
def str_2_int(s: str) -> Optional[int]:
try:
s = s.strip()
if "," in s:
s = s.replace(",", "")
return int(s)
except ValueError:
# Return None if the string cannot be converted to int
print(f"[NumericEvaluator error]: Cannot convert {s} to int")
return None
@staticmethod
@beartype
def compare_inequality(
value: Union[int, float], inequality: str, tol: float = 1e-8
) -> bool:
"""
Compare a value (int or float) against an inequality string.
Args:
- value (int/float): The value to be compared.
- inequality (str): Inequality in the form of "< 700", ">= 300", etc.
- tol (float): Tolerance for floating point comparisons.
Returns:
- bool: True if the value satisfies the inequality, False otherwise.
"""
# Extract the operator and the number from the inequality string
ops = {
"<=": lambda x, y: x <= y + tol,
">=": lambda x, y: x >= y - tol,
"==": lambda x, y: abs(x - y) <= tol,
"<": lambda x, y: x < y + tol,
">": lambda x, y: x > y - tol,
}
for op, func in ops.items():
if op in inequality:
_, num = inequality.split(op)
return func(value, float(num.strip()))
raise ValueError(f"Invalid inequality string: {inequality}")
@beartype
class StringEvaluator(Evaluator):
"""Check whether the answer is correct with:
exact match: the answer is exactly the same as the reference answer
must include: each phrase in the reference answer must be included in the answer
fuzzy match: the answer is similar to the reference answer, using LLM judge
"""
@staticmethod
@beartype
def clean_answer(answer: str) -> str:
if answer.startswith("'") and answer.endswith("'"):
answer = answer[1:-1]
elif answer.startswith('"') and answer.endswith('"'):
answer = answer[1:-1]
return answer.lower()
@staticmethod
@beartype
def exact_match(ref: str, pred: Union[str, int]) -> float:
if isinstance(pred, int):
pred = str(pred)
return float(
StringEvaluator.clean_answer(pred)
== StringEvaluator.clean_answer(ref)
)
@staticmethod
@beartype
def must_include(ref: str, pred: str) -> float:
clean_ref = StringEvaluator.clean_answer(ref)
clean_pred = StringEvaluator.clean_answer(pred)
# tokenize the answer if the ref is a single word
# prevent false positive (e.g, 0)
if len(word_tokenize(clean_ref)) == 1:
tok_pred = word_tokenize(clean_pred)
return float(clean_ref in tok_pred)
else:
return float(clean_ref in clean_pred)
@staticmethod
@beartype
def must_exclude(ref: str, pred: str) -> float:
"""Returns 1 if pred is not in ref, and 0 otherwise"""
clean_ref = StringEvaluator.clean_answer(ref)
clean_pred = StringEvaluator.clean_answer(pred)
# tokenize the answer if the ref is a single word
# prevent false positive (e.g, 0)
if len(word_tokenize(clean_ref)) == 1:
tok_pred = word_tokenize(clean_pred)
return float(clean_ref not in tok_pred)
else:
return float(clean_ref not in clean_pred)
@staticmethod
@beartype
def fuzzy_match(ref: str, pred: str, intent: str) -> float:
return llm_fuzzy_match(pred, ref, intent)
@staticmethod
@beartype
def ua_match(ref: str, pred: str, intent: str) -> float:
return llm_ua_match(pred, ref, intent)
def __call__(
self,
trajectory: Trajectory,
config_file: Path | str,
page: Page | PseudoPage | None = None
) -> float:
with open(config_file, "r") as f:
configs = json.load(f)
last_action = self.get_last_action(trajectory)
pred = self.clean_answer(last_action["answer"])
score = 1.0
for approach, value in configs["eval"]["reference_answers"].items():
match approach:
case "exact_match":
score *= self.exact_match(ref=value, pred=pred)
case "required_values":
required_values = value
assert isinstance(required_values, list)
pred = NumericEvaluator.str_2_int(pred)
if pred is None:
score = 0.0
else:
for v in required_values:
value_or = v.split(" |OR| ")
score *= any(
[
NumericEvaluator.compare_inequality(
pred, value
)
for value in value_or
]
)
case "must_include":
assert isinstance(value, list)
for must_value in value:
value_or = must_value.split(" |OR| ")
score *= any([self.must_include(ref=v, pred=pred) for v in value_or])
case "must_exclude":
assert isinstance(value, list)
for must_excl_value in value:
score *= self.must_exclude(
ref=must_excl_value, pred=pred
)
case "one_of":
assert isinstance(value, list)
found = False
for one_of_value in value:
one_of_value = self.clean_answer(one_of_value)
if one_of_value in pred:
found = True
break
score = score * found
case "fuzzy_match":
intent = configs["intent"]
if value == "N/A":
# if the instruction only asks the model to generate N/A when encountering an unachievable task
# without more concrete reasons
score *= self.exact_match(ref=value, pred=pred)
# if the instruction also asks the model to generate the reason why the task is unachievable
# this should be the default as it will prevent false positive N/A`
if score != 1:
score = 1.0 * self.ua_match(
intent=configs["intent"],
ref=configs["eval"]["string_note"],
pred=pred,
)
else:
assert isinstance(value, list)
reference = ', '.join(value)
score *= self.fuzzy_match(
ref=reference, pred=pred, intent=intent
)
return score
@beartype
class StringSoftEvaluator(Evaluator):
"""Use text generation metrics such as BLEU, ROUGE, etc. to evaluate the answer"""
def __call__(
self,
trajectory: Trajectory,
config_file: Path | str,
page: Page | PseudoPage | None = None
) -> float:
with open(config_file, "r") as f:
configs = json.load(f)
last_action = self.get_last_action(trajectory)
pred = last_action["answer"]
ref = configs["eval"]["reference_answers"]
# rouge
m = evaluate.load("rouge")
rouge = m.compute(predictions=[pred], references=[ref])
return float(rouge["rouge1"])
@beartype
class URLExactEvaluator(Evaluator):
"""Check whether the URL is exactly the same as of the reference URLs"""
def __call__(
self,
trajectory: Trajectory,
config_file: Path | str,
page: Page | PseudoPage
) -> float:
with open(config_file, "r") as f:
configs = json.load(f)
def clean_url(url: str) -> str:
url = str(url)
# Replace http://localhost with http://127.0.0.1 to keep things consistent across evals.
url = url.replace("localhost", "127.0.0.1")
if url.endswith("/"):
url = url[:-1]
return url
pred = clean_url(page.url)
ref_urls = configs["eval"]["reference_url"].split(" |OR| ")
ref_urls = [clean_url(url) for url in ref_urls]
matching_rule = configs["eval"].get("url_note", "EXACT")
if matching_rule == "EXACT":
if pred in ref_urls:
return 1.0
else:
return 0.0
elif matching_rule == "GOLD in PRED":
if any([ref in pred for ref in ref_urls]):
return 1.0
else:
return 0.0
else:
raise ValueError(f"Unknown matching rule: {matching_rule}")
@beartype
class HTMLContentExactEvaluator(Evaluator):
"""Check whether the contents appear in the page"""
@staticmethod
@beartype
def fuzzy_match(ref: str, pred: str, intent: str) -> float:
return llm_fuzzy_match(pred, ref, intent)
def __call__(
self,
trajectory: Trajectory,
config_file: Path | str,
page: Page | PseudoPage
) -> float:
with open(config_file, "r") as f:
configs = json.load(f)
targets = configs["eval"]["program_html"]
score = 1.0
for target in targets:
target_url: str = target["url"] # which url to check
if target_url.startswith("func"):
func = target_url.split("func:")[1]
func = func.replace("__last_url__", page.url)
target_url = eval(func)
locator: str = target["locator"] # js element locator
# navigate to that url
if target_url != "last":
page.goto(target_url)
time.sleep(3) # TODO [shuyanzh]: fix this hard-coded sleep
# empty, use the full page
if not locator.strip():
selected_element = page.content()
# use JS to select the element
elif locator.startswith("document.") or locator.startswith(
"[...document."
):
if "prep_actions" in target:
try:
for prep_action in target["prep_actions"]:
page.evaluate(f"() => {prep_action}")
except Exception:
pass
try:
selected_element = str(page.evaluate(f"() => {locator}"))
if not selected_element:
selected_element = ""
except Exception:
# the page is wrong, return empty
selected_element = ""
elif locator.startswith("lambda:"):
try:
locator = locator.lstrip("lambda:")
selected_element = page.evaluate(locator)
if not selected_element:
selected_element = None
except Exception:
# the page is wrong, return empty
selected_element = None
# run program to call API
elif locator.startswith("func:"): # a helper function
func = locator.split("func:")[1]
func = func.replace("__page__", "page")
selected_element = eval(func)
else:
raise ValueError(f"Unknown locator: {locator}")
# If the selected element is None, then the page is wrong
if selected_element is None:
score = 0.0
break
if "exact_match" in target["required_contents"]:
required_contents = target["required_contents"]["exact_match"]
score *= StringEvaluator.exact_match(
ref=required_contents, pred=selected_element
)
elif "must_include" in target["required_contents"]:
required_contents = target["required_contents"]["must_include"]
assert isinstance(required_contents, list)
for content in required_contents:
content_or = content.split(" |OR| ")
score *= any(
[
StringEvaluator.must_include(
ref=content, pred=selected_element
)
for content in content_or
]
)
elif "must_exclude" in target["required_contents"]:
required_contents = target["required_contents"]["must_exclude"]
assert isinstance(required_contents, list)
for content in required_contents:
assert " |OR| " not in content
score *= StringEvaluator.must_exclude(
content, pred=selected_element
)
elif "required_values" in target["required_contents"]:
required_values = target["required_contents"][
"required_values"
]
assert isinstance(required_values, list)
if isinstance(selected_element, str):
selected_element = NumericEvaluator.str_2_int(
selected_element
)
if selected_element is None:
score = 0.0
else:
for value in required_values:
value_or = value.split(" |OR| ")
score *= any(
[
NumericEvaluator.compare_inequality(
selected_element, value
)
for value in value_or
]
)
elif "fuzzy_match" in target["required_contents"]:
required_contents = target["required_contents"]["fuzzy_match"]
intent = configs["intent"]
assert isinstance(required_contents, list)
reference = ', '.join(required_contents)
score *= self.fuzzy_match(
ref=reference, pred=selected_element, intent=intent
)
else:
raise ValueError(
f"Unknown required_contents: {target['required_contents'].keys()}"
)
return score
@beartype
class PageImageEvaluator(Evaluator):
"""Check whether the answer is correct by querying a vision model."""
def __init__(self, captioning_fn):
self.captioning_fn = captioning_fn
# Default to 0.8 as the threshold for similarity to account for compression, resizing, etc
# This might be too generous but we bias towards minimizing false negatives.
self.ssim_threshold = 0.8
def __call__(
self,
trajectory: Trajectory,
config_file: Path | str,
page: Page | PseudoPage | None = None
) -> float:
with open(config_file, "r") as f:
configs = json.load(f)
for query in configs["eval"]["page_image_query"]:
locator: str = query["eval_image_class"]
target_url: str = query["eval_image_url"]
if target_url.startswith("func"):
func = target_url.split("func:")[1]
func = func.replace("__last_url__", page.url)
target_url = eval(func)
# navigate to that url
if target_url != "last":
page.goto(target_url)
time.sleep(3) # TODO(jykoh): fix this hard-coded sleep
# empty, use the full page
if not locator.strip():
images = page.get_by_role("img").all()
# use JS to select the element
elif locator.startswith("."):
# Get all img children under the locator
elements = page.query_selector_all(locator)
images = []
for element in elements:
is_img = element.evaluate(
'element => element.tagName === "IMG"'
)
if is_img:
images.append(element)
else:
images.extend(element.query_selector_all("img"))
else:
raise ValueError(f"Unknown locator: {locator}")
if images == []:
return 0.0
all_image_pixels = []
for image in images:
try:
# Get image from URL.
image_url = image.get_attribute("src")
if not image_url.startswith(
("http://", "https://", "www.")
):
image_url = urljoin(page.url, image_url)
image = Image.open(
requests.get(image_url, stream=True).raw
)
all_image_pixels.append(image)
except Exception as e:
print("[WARNING]: ", e)
score = 1.0
if all_image_pixels == []:
return 0.0
else:
# Run the VQA eval on the image elements.
eval_vqas = query.get("eval_vqa", [])
assert (
len(eval_vqas) > 0 or "eval_fuzzy_image_match" in query
), "eval_vqa must have at least 2 questions or eval_fuzzy_image_match must be True"
for qa in eval_vqas:
question, answer = qa["question"], qa["answer"]
prompt = f"Q: {question} A:"
pred_ans = self.captioning_fn(
all_image_pixels, [prompt] * len(all_image_pixels)
)
score *= float(
any(
[answer.lower() in ans.lower() for ans in pred_ans]
)
)
if "eval_fuzzy_image_match" in query:
ssim_threshold = query.get(
"ssim_threshold", self.ssim_threshold
)
exact_match_imgs = query["eval_fuzzy_image_match"].split(
" |OR| "
)
all_exact_match_pixels = []
for exact_match_img in exact_match_imgs:
if exact_match_img.startswith("http"):
exact_match_pixels = Image.open(
requests.get(exact_match_img, stream=True).raw
)
else:
exact_match_pixels = Image.open(exact_match_img)
all_exact_match_pixels.append(exact_match_pixels)
# Check if any of the images on the page match
found_exact_match = False
for exact_match_pixels in all_exact_match_pixels:
for image_pixels in all_image_pixels:
ssim = image_utils.get_image_ssim(
image_pixels, exact_match_pixels
)
if ssim > ssim_threshold:
found_exact_match = True
break
score *= float(found_exact_match)
return score
class EvaluatorComb:
def __init__(self, evaluators: list[Evaluator]) -> None:
self.evaluators = evaluators
def __call__(
self,
trajectory: Trajectory,
config_file: Path | str,
page: Page | PseudoPage
) -> float:
score = 1.0
for evaluator in self.evaluators:
cur_score = evaluator(trajectory, config_file, page)
score *= cur_score
return score
@beartype
def evaluator_router(
config_file: Path | str, captioning_fn=None
) -> EvaluatorComb:
"""Router to get the evaluator class"""
with open(config_file, "r") as f:
configs = json.load(f)
eval_types = configs["eval"]["eval_types"]
evaluators: list[Evaluator | EvaluatorPartial] = []
for eval_type in eval_types:
match eval_type:
case "string_match":
evaluators.append(StringEvaluator())
case "url_match":
evaluators.append(URLExactEvaluator())
case "program_html":
evaluators.append(HTMLContentExactEvaluator())
case "page_image_query":
evaluators.append(PageImageEvaluator(captioning_fn))
case _:
raise ValueError(f"eval_type {eval_type} is not supported")
return EvaluatorComb(evaluators)

View File

@ -0,0 +1,63 @@
"""Replace the website placeholders with website domains from env_config
Generate the test data"""
import json
import os
from browser_env.env_config import *
def main() -> None:
DATASET = os.environ["DATASET"]
if DATASET == "webarena":
print("DATASET: webarena")
print(f"REDDIT: {REDDIT}")
print(f"SHOPPING: {SHOPPING}")
print(f"SHOPPING_ADMIN: {SHOPPING_ADMIN}")
print(f"GITLAB: {GITLAB}")
print(f"WIKIPEDIA: {WIKIPEDIA}")
print(f"MAP: {MAP}")
inp_paths = ["config_files/wa/test_webarena.raw.json", "config_files/wa/test_webarena_lite.raw.json"]
replace_map = {
"__REDDIT__": REDDIT,
"__SHOPPING__": SHOPPING,
"__SHOPPING_ADMIN__": SHOPPING_ADMIN,
"__GITLAB__": GITLAB,
"__WIKIPEDIA__": WIKIPEDIA,
"__MAP__": MAP,
}
elif DATASET == "visualwebarena":
print("DATASET: visualwebarena")
print(f"CLASSIFIEDS: {CLASSIFIEDS}")
print(f"REDDIT: {REDDIT}")
print(f"SHOPPING: {SHOPPING}")
inp_paths = [
"config_files/vwa/test_classifieds.raw.json", "config_files/vwa/test_shopping.raw.json", "config_files/vwa/test_reddit.raw.json",
]
replace_map = {
"__REDDIT__": REDDIT,
"__SHOPPING__": SHOPPING,
"__WIKIPEDIA__": WIKIPEDIA,
"__CLASSIFIEDS__": CLASSIFIEDS,
}
else:
raise ValueError(f"Dataset not implemented: {DATASET}")
for inp_path in inp_paths:
output_dir = inp_path.replace('.raw.json', '')
os.makedirs(output_dir, exist_ok=True)
with open(inp_path, "r") as f:
raw = f.read()
for k, v in replace_map.items():
raw = raw.replace(k, v)
with open(inp_path.replace(".raw", ""), "w") as f:
f.write(raw)
data = json.loads(raw)
for idx, item in enumerate(data):
with open(os.path.join(output_dir, f"{idx}.json"), "w") as f:
json.dump(item, f, indent=2)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,647 @@
"""Implements helper functions to assist evaluation cases where other evaluators are not suitable."""
import json
from datetime import datetime, timezone
from typing import Any, Union
from urllib.parse import urlparse
import requests
from beartype import beartype
from beartype.typing import Dict, List
from playwright.sync_api import CDPSession, Page
from browser_env.env_config import (
ACCOUNTS,
REDDIT,
SHOPPING,
WIKIPEDIA,
)
from llms.providers.openai_utils import (
generate_from_openai_chat_completion,
)
import logging
logger = logging.getLogger("logger")
class PseudoPage:
def __init__(self, original_page: Page, url: str):
self.url = url
self.original_page = original_page
def __getattr__(self, attr: str) -> Any:
# Delegate attribute access to the original page object
if attr not in ["url"]:
return getattr(self.original_page, attr)
else:
return getattr(self, attr)
@beartype
def shopping_get_auth_token() -> str:
response = requests.post(
url=f"{SHOPPING}/rest/default/V1/integration/admin/token",
headers={"content-type": "application/json"},
data=json.dumps(
{
"username": ACCOUNTS["shopping_site_admin"]["username"],
"password": ACCOUNTS["shopping_site_admin"]["password"],
}
),
)
token: str = response.json()
return token
@beartype
def shopping_get_latest_order_url() -> str:
"""Get the latest order url from the shopping website."""
header = {
"Authorization": f"Bearer {shopping_get_auth_token()}",
"Content-Type": "application/json",
}
params = {
"searchCriteria[sortOrders][0][field]": "created_at",
"searchCriteria[sortOrders][0][direction]": "DESC",
"searchCriteria[pageSize]": "1",
}
response = requests.get(
f"{SHOPPING}/rest/V1/orders", params=params, headers=header
)
assert response.status_code == 200
response_obj = response.json()["items"][0]
order_id = int(response_obj["increment_id"])
order_url = f"{SHOPPING}/sales/order/view/order_id/{order_id}/"
return order_url
@beartype
def shopping_get_sku_latest_review_author(sku: str) -> str:
"""Get the latest review for shopping admin."""
header = {
"Authorization": f"Bearer {shopping_get_auth_token()}",
"Content-Type": "application/json",
}
response = requests.get(
f"{SHOPPING}/rest/V1/products/{sku}/reviews", headers=header
)
assert response.status_code == 200
response_obj = response.json()
if len(response_obj) == 0:
return ""
author: str = response_obj[-1]["nickname"]
return author
@beartype
def shopping_get_sku_latest_review_rating(sku: str) -> str:
"""Get the latest review for shopping admin."""
header = {
"Authorization": f"Bearer {shopping_get_auth_token()}",
"Content-Type": "application/json",
}
response = requests.get(
f"{SHOPPING}/rest/V1/products/{sku}/reviews", headers=header
)
assert response.status_code == 200
response_obj = response.json()
if len(response_obj) == 0:
return ""
assert response_obj[0]["ratings"][0]["rating_name"] == "Rating"
rating: str = str(response_obj[-1]["ratings"][0]["percent"])
return rating
@beartype
def shopping_get_sku_latest_review_text(sku: str) -> str:
"""Get the latest review text for shopping admin."""
header = {
"Authorization": f"Bearer {shopping_get_auth_token()}",
"Content-Type": "application/json",
}
response = requests.get(
f"{SHOPPING}/rest/V1/products/{sku}/reviews", headers=header
)
assert response.status_code == 200
response_obj = response.json()
if len(response_obj) == 0:
return ""
text: str = response_obj[-1]["detail"]
return text
@beartype
def shopping_get_sku_latest_review_title(sku: str) -> str:
"""Get the latest review title for shopping admin."""
header = {
"Authorization": f"Bearer {shopping_get_auth_token()}",
"Content-Type": "application/json",
}
response = requests.get(
f"{SHOPPING}/rest/V1/products/{sku}/reviews", headers=header
)
assert response.status_code == 200
response_obj = response.json()
if len(response_obj) == 0:
return ""
title: str = response_obj[-1]["title"]
return title
@beartype
def shopping_get_sku_product_page_url(sku: str) -> str:
"""Get product page url from sku"""
header = {
"Authorization": f"Bearer {shopping_get_auth_token()}",
"Content-Type": "application/json",
}
response = requests.get(
f"{SHOPPING}/rest/V1/products/{sku}", headers=header
)
assert response.status_code == 200
response_obj = response.json()
if len(response_obj) == 0:
return ""
for custom_attributes in response_obj["custom_attributes"]:
if custom_attributes["attribute_code"] == "url_key":
return f"{SHOPPING}/{custom_attributes['value']}.html"
return ""
@beartype
def shopping_get_all_product_order(
page: Page | PseudoPage,
) -> List[Dict[str, str]]:
"""
Get info of all product in a given order page.
Example output:
[
{
"name": "Kellogg's Special K Protein Bars, Meal Replacement, Protein Snacks, Value Size, Strawberry, 19oz Box (12 Bars)\nSize\n12 Count (Pack of 1)",
"options": {
"Size": "12 Count (Pack of 1)"
},
"sku": "B00MXUFL0E",
"price": "$24.50",
"qty": "Ordered2",
"subtotal": "$49.00"
},
{
"name": "Kellogg's Special K Protein Bars, Meal Replacement, Protein Snacks, Value Size, Chocolatey Chip Cookie Dough, 19oz Box (12 Bars)",
"sku": "B07ZD2PB9F",
"price": "$42.30",
"qty": "Ordered2",
"subtotal": "$84.60"
}
]
"""
try:
result = page.evaluate(
f"""
(() => {{
try {{
const products = [...document.querySelector("#my-orders-table").getElementsByTagName('tbody')].map(
(x) => {{
return [...x.getElementsByTagName('td')].reduce(function(obj, y) {{
const key = y.className.split(' ')[1];
obj[key] = y.outerText;
// check if options exist
if (key === 'name' && y.querySelector('dl')) {{
var option_dict = {{}}
const options = [...y.querySelector('dl').children];
for (let i = 0; i < options.length; i += 2) {{
option_dict[options[i].outerText] = options[i+1].outerText;
}}
obj['options'] = option_dict;
}}
return obj;
}}, {{}})
}}
);
return products;
}} catch (e) {{
// If any errors are caught, return an empty string
return e;
return [];
}}
}})();
"""
)
return result
except Exception as e:
result = []
return result
@beartype
def shopping_get_order_product_name_list(page: Page | PseudoPage) -> str:
try:
products = shopping_get_all_product_order(page)
return " |OR| ".join([p["name"] for p in products])
except Exception:
return ""
@beartype
def shopping_get_order_product_quantity(
page: Page | PseudoPage, sku: str
) -> int:
try:
if "|OR|" in sku:
skus = sku.split(" |OR| ")
else:
skus = [sku]
products = shopping_get_all_product_order(page)
for product in products:
if product["sku"].strip() in skus:
# Ordered{qty}
return int(product["qty"][7:])
return 0
except Exception:
return 0
@beartype
def shopping_get_order_product_option(
page: Page | PseudoPage, sku: str, option_name: str
) -> str:
try:
products = shopping_get_all_product_order(page)
for product in products:
if product["sku"].strip() == sku:
# Ordered{qty}
return product["options"][option_name]
return ""
except Exception as e:
return ""
@beartype
def shopping_get_product_attributes(
page: Page | PseudoPage, attribute: str
) -> str:
# Get the values of all cells in the table for the given attribute
try:
result = page.evaluate(
f"""
(() => {{
try {{
// Create an array of search terms, splitting the string by ' |OR| '
const searchTerms = '{attribute}'.toLowerCase().split(' |or| ');
// Convert the children of the tbody inside the element with the given ID into an array
return Array.from(
document.querySelector('#productDetails_detailBullets_sections1 > tbody').children
)
// Filter the array to only include elements where the first child's text includes any of the search terms
.filter(x =>
searchTerms.some(term => x.children[0].outerText.toLowerCase().includes(term))
)
// Map over the filtered elements to get the outerText of their second child
.map(x => x.children[1].outerText)
// Join all the resulting strings with a comma and a space
.join(', ')
}} catch (e) {{
// If any errors are caught, return an empty string
return ''
}}
}})();
"""
)
except Exception:
result = ""
return result
@beartype
def shopping_get_product_price(page: Page | PseudoPage) -> Union[float, int]:
"""Get the price of the product on the shopping website."""
try:
result = page.evaluate(
"""
(() => {{
res = parseFloat(document.querySelector(\"#maincontent > div.columns > div > div.product-info-main > div.product-info-price > div.price-box.price-final_price > span > span\")
.outerText.substr(1));
return res ? res : 0;
}})();
"""
)
except Exception:
result = 0
return result
@beartype
def shopping_get_num_reviews(page: Page | PseudoPage) -> int:
"""Get the price of the product on the shopping website."""
try:
result = page.evaluate(
"""
(() => {{
res = parseInt(document.querySelector(\"#tab-label-reviews-title\")
.outerText.split(' ')[1]);
return res ? res : 0; }}
)();
"""
)
except Exception:
result = 0
return result
@beartype
def shopping_get_rating_as_percentage(page: Page | PseudoPage) -> int:
"""Get the rating of the product on the shopping website as a percentage out of 100."""
try:
rating = page.evaluate(
"""
(() => {{
ratingPercentage = parseFloat(document.querySelector('.rating-result').title.replace('%', ''));
return ratingPercentage ? ratingPercentage : 0;
}})();
"""
)
except Exception:
rating = 0
return rating
@beartype
def get_query_text(page: Page | PseudoPage, selector: str) -> str:
"""Get the text content of the element matching the given selector.
Note that this function DOES NOT perform downcasing.
"""
try:
result = page.evaluate(
f"""
(() => {{
try {{
return document.querySelector('{selector}').textContent;
}} catch (e) {{
return '';
}}
}})();
"""
)
except Exception:
result = ""
return result
@beartype
def get_query_text_lowercase(page: Page | PseudoPage, selector: str) -> str:
"""Get the lowercase text content of the element matching the given selector."""
return get_query_text(page, selector).lower()
@beartype
def reddit_get_post_url(url: str) -> str:
"""Get the post url"""
# Url is http://domain/f/subreddit/post_id/...
# get domain, subreddit, post_id
domain = urlparse(url).netloc
tok_url = urlparse(url).path.split("/")
# not a valid post/comment url, return the url as is
if len(tok_url) < 4:
return url
if tok_url[1] != "f":
return url
subreddit = urlparse(url).path.split("/")[2]
post_id = urlparse(url).path.split("/")[3]
scheme = urlparse(url).scheme
post_url = f"{scheme}://{domain}/f/{subreddit}/{post_id}/"
return post_url
@beartype
def reddit_get_post_comment_tree(page: Page | PseudoPage) -> Dict[str, Any]:
try:
comment_tree = page.evaluate(
f"""(function buildCommentTree(node, data_level) {{
let tree = {{
"username": node.querySelector(".fg-inherit").outerText,
"net_score": parseInt(node.querySelector(".vote__net-score").outerText),
"content": node.querySelector(".comment__content").outerText,
"time": new Date(node.querySelector('.comment__main > header > h1 > span > time').dateTime),
"children": []
}};
node.querySelectorAll(".comment").forEach((child) => {{
if (parseInt(child.getAttribute('data-level')) === data_level+1) {{
tree['children'].push(buildCommentTree(child, data_level+1));
}}
}})
return tree;
}})(document.querySelector("#main"), 0)"""
)
except Exception:
comment_tree = {}
return comment_tree
@beartype
def reddit_get_latest_comment_obj_by_username(
page: Page | PseudoPage, username: str
) -> Dict[str, Any]:
try:
comment_tree = reddit_get_post_comment_tree(page)
latest_time = datetime.min.replace(tzinfo=timezone.utc)
comment = {}
def dfs(node):
nonlocal latest_time
nonlocal comment
if node["username"] == username:
if node["time"] > latest_time:
comment = {
"username": node["username"],
"net_score": node["net_score"],
"content": node["content"],
"time": node["time"],
}
latest_time = node["time"]
for child in node["children"]:
dfs(child)
dfs(comment_tree)
except Exception as e:
comment = {}
return comment
@beartype
def reddit_get_latest_comment_content_by_username(
page: Page | PseudoPage, username: str
) -> str:
try:
comment = reddit_get_latest_comment_obj_by_username(page, username)
content = comment["content"]
except Exception:
content = ""
return content
@beartype
def reddit_get_parent_comment_obj_of_latest_comment_by_username(
page: Page | PseudoPage, username: str
) -> Dict[str, Any]:
try:
comment_tree = reddit_get_post_comment_tree(page)
latest_time = datetime.min.replace(tzinfo=timezone.utc)
comment = {}
def dfs(node):
nonlocal latest_time
nonlocal comment
for child in node["children"]:
if child["username"] == username:
if child["time"] > latest_time:
comment = {
"username": node["username"],
"net_score": node["net_score"],
"content": node["content"],
"time": node["time"],
}
latest_time = child["time"]
else:
dfs(child)
dfs(comment_tree)
except Exception:
comment = {}
return comment
@beartype
def reddit_get_parent_comment_username_of_latest_comment_by_username(
page: Page | PseudoPage, username: str
) -> str:
try:
comment = reddit_get_parent_comment_obj_of_latest_comment_by_username(
page, username
)
username = comment["username"]
except Exception:
username = ""
return username
@beartype
def gitlab_get_project_memeber_role(
page: Page | PseudoPage, account_name: str
) -> str:
# get the account index
try:
account_idx = page.evaluate(
f"""(() => {{
const elements = document.querySelectorAll("td[data-label='Account'] span.gl-avatar-labeled-sublabel");
let index = -1; // Default value if not found
for(let i = 0; i < elements.length; i++) {{
if(elements[i].outerText === '@{account_name}') {{
index = i;
break;
}}
}}
return index;
}})()"""
)
# get the role
role: str = page.evaluate(
f"""(() => {{
return document.querySelectorAll("td.col-max-role span")[{account_idx}].outerText;
}})()"""
)
except Exception:
role = ""
return role
@beartype
def llm_fuzzy_match(pred: str, reference: str, question: str) -> float:
"""Check whether the prediction matches the reference with GPT-4-turbo"""
messages: list[dict[str, Any]] = []
# construct the question to ask
message = "Help a teacher to grade the answer of a student given a question. Keep in mind that the student may use different phrasing or wording to answer the question. The goal is to evaluate whether the answer is semantically equivalent to the reference answer.\n"
message += f"question: {question}\n"
message += f"reference answer: {reference}\n"
message += "all the string 'N/A' that you see is a special sequence that means 'not achievable'\n"
message += f"student answer: {pred}\n"
message += "Conclude the judgement by 'correct', 'incorrect', or 'partially correct'. Only output one of these options, and nothing else."
messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": message},
]
logger.info(f'[R] {reference}')
logger.info(f'[P] {pred}')
response = generate_from_openai_chat_completion(
model="gpt-4-1106-preview",
messages=messages,
temperature=0,
max_tokens=768,
top_p=1.0,
context_length=0,
).lower()
if "partially correct" in response or "incorrect" in response:
return 0.0
else:
assert "correct" in response, response
return 1.0
def llm_ua_match(pred: str, reference: str, question: str) -> float:
"""Check whether the prediction matches the reference with GPT-4-turbo"""
messages: list[dict[str, Any]] = []
# construct the question to ask
message = ""
message += f"task: {question}\n"
message += f"actual unachievable reason: {reference}\n"
message += f"reported unachievable reason: {pred}\n"
message += (
"The task described above is inherently unachievable due to the reason specified under 'actual unachievable reason'. "
"An individual previously attempted this task and was unable to complete it. They provided a reason for their failure, "
"which is listed under 'reported unachievable reason'. Your role is to review both the actual and reported reasons. "
"Determine if the reported reason aligns with the actual reason, even if implicitly. "
"If the stated reason is in line with the actual reason, respond with 'same'. Otherwise, respond with 'different'."
)
messages = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": message},
]
response = generate_from_openai_chat_completion(
model="gpt-4-1106-preview",
messages=messages,
temperature=0,
max_tokens=768,
top_p=1.0,
context_length=0,
).lower()
if "different" in response:
return 0.0
else:
assert "same" in response
return 1.0

View File

@ -0,0 +1,23 @@
"""This module is adapt from https://github.com/zeno-ml/zeno-build"""
try:
from .providers.gemini_utils import generate_from_gemini_completion
except:
print('Google Cloud not set up, skipping import of providers.gemini_utils.generate_from_gemini_completion')
from .providers.hf_utils import generate_from_huggingface_completion
from .providers.openai_utils import (
generate_from_openai_chat_completion,
generate_from_openai_completion,
)
from .providers.api_utils import (
generate_with_api,
)
from .utils import call_llm
__all__ = [
"generate_from_openai_completion",
"generate_from_openai_chat_completion",
"generate_from_huggingface_completion",
"generate_from_gemini_completion",
"call_llm",
]

View File

@ -0,0 +1,57 @@
"""Config for language models."""
from __future__ import annotations
import argparse
import dataclasses
from dataclasses import dataclass
from typing import Any
@dataclass(frozen=True)
class LMConfig:
"""A config for a language model.
Attributes:
provider: The name of the API provider.
model: The name of the model.
model_cls: The Python class corresponding to the model, mostly for
Hugging Face transformers.
tokenizer_cls: The Python class corresponding to the tokenizer, mostly
for Hugging Face transformers.
mode: The mode of the API calls, e.g., "chat" or "generation".
"""
provider: str
model: str
model_cls: type | None = None
tokenizer_cls: type | None = None
mode: str | None = None
gen_config: dict[str, Any] = dataclasses.field(default_factory=dict)
def construct_llm_config(args: argparse.Namespace) -> LMConfig:
llm_config = LMConfig(
provider=args.provider, model=args.model, mode=args.mode
)
if args.provider in ["openai", "google", "api", "finetune"]:
llm_config.gen_config["temperature"] = args.temperature
llm_config.gen_config["top_p"] = args.top_p
llm_config.gen_config["context_length"] = args.context_length
llm_config.gen_config["max_tokens"] = args.max_tokens
llm_config.gen_config["stop_token"] = args.stop_token
llm_config.gen_config["max_obs_length"] = args.max_obs_length
llm_config.gen_config["max_retry"] = args.max_retry
elif args.provider == "huggingface":
llm_config.gen_config["temperature"] = args.temperature
llm_config.gen_config["top_p"] = args.top_p
llm_config.gen_config["max_new_tokens"] = args.max_tokens
llm_config.gen_config["stop_sequences"] = (
[args.stop_token] if args.stop_token else None
)
llm_config.gen_config["max_obs_length"] = args.max_obs_length
llm_config.gen_config["model_endpoint"] = args.model_endpoint
llm_config.gen_config["max_retry"] = args.max_retry
else:
raise NotImplementedError(f"provider {args.provider} not implemented")
return llm_config

View File

@ -0,0 +1,286 @@
"""Tools to generate from OpenAI prompts.
Adopted from https://github.com/zeno-ml/zeno-build/"""
import asyncio
import logging
import os
import random
import time
from typing import Any
import aiolimiter
import openai
from openai import AsyncOpenAI, OpenAI
base_url = os.environ.get("OPENAI_API_URL")
client = OpenAI(api_key=os.environ["OPENAI_API_KEY"], base_url=base_url)
aclient = AsyncOpenAI(api_key=os.environ["OPENAI_API_KEY"], base_url=base_url)
from tqdm.asyncio import tqdm_asyncio
def retry_with_exponential_backoff( # type: ignore
func,
initial_delay: float = 1,
exponential_base: float = 2,
jitter: bool = True,
max_retries: int = 3,
errors: tuple[Any] = (
openai.RateLimitError,
openai.BadRequestError,
openai.InternalServerError,
),
):
"""Retry a function with exponential backoff."""
def wrapper(*args, **kwargs): # type: ignore
# Initialize variables
num_retries = 0
delay = initial_delay
# Loop until a successful response or max_retries is hit or an exception is raised
while True:
try:
return func(*args, **kwargs)
# Retry on specified errors
except errors as e:
# Increment retries
num_retries += 1
# Check if max retries has been reached
if num_retries > max_retries:
raise Exception(
f"Maximum number of retries ({max_retries}) exceeded."
)
# Increment the delay
delay *= exponential_base * (1 + jitter * random.random())
# Sleep for the delay
time.sleep(delay)
# Raise exceptions for any errors not specified
except Exception as e:
raise e
return wrapper
async def _throttled_openai_completion_acreate(
engine: str,
prompt: str,
temperature: float,
max_tokens: int,
top_p: float,
limiter: aiolimiter.AsyncLimiter,
) -> dict[str, Any]:
async with limiter:
for _ in range(3):
try:
return await aclient.completions.create(
engine=engine,
prompt=prompt,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
)
except openai.RateLimitError:
logging.warning(
"OpenAI API rate limit exceeded. Sleeping for 10 seconds."
)
await asyncio.sleep(10)
except openai.APIError as e:
logging.warning(f"OpenAI API error: {e}")
break
return {"choices": [{"message": {"content": ""}}]}
async def agenerate_from_openai_completion(
prompts: list[str],
engine: str,
temperature: float,
max_tokens: int,
top_p: float,
context_length: int,
requests_per_minute: int = 300,
) -> list[str]:
"""Generate from OpenAI Completion API.
Args:
prompts: list of prompts
temperature: Temperature to use.
max_tokens: Maximum number of tokens to generate.
top_p: Top p to use.
context_length: Length of context to use.
requests_per_minute: Number of requests per minute to allow.
Returns:
List of generated responses.
"""
if "OPENAI_API_KEY" not in os.environ:
raise ValueError(
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
)
limiter = aiolimiter.AsyncLimiter(requests_per_minute)
async_responses = [
_throttled_openai_completion_acreate(
engine=engine,
prompt=prompt,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
limiter=limiter,
)
for prompt in prompts
]
responses = await tqdm_asyncio.gather(*async_responses)
return [x["choices"][0]["text"] for x in responses]
@retry_with_exponential_backoff
def generate_from_openai_completion(
prompt: str,
engine: str,
temperature: float,
max_tokens: int,
top_p: float,
context_length: int,
stop_token: str | None = None,
) -> str:
if "OPENAI_API_KEY" not in os.environ:
raise ValueError(
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
)
response = client.completions.create(
prompt=prompt,
engine=engine,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
stop=[stop_token],
)
answer: str = response["choices"][0]["text"]
return answer
async def _throttled_openai_chat_completion_acreate(
model: str,
messages: list[dict[str, str]],
temperature: float,
max_tokens: int,
top_p: float,
limiter: aiolimiter.AsyncLimiter,
) -> dict[str, Any]:
async with limiter:
for _ in range(3):
try:
return await aclient.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
)
except openai.RateLimitError:
logging.warning(
"OpenAI API rate limit exceeded. Sleeping for 10 seconds."
)
await asyncio.sleep(10)
except asyncio.exceptions.TimeoutError:
logging.warning("OpenAI API timeout. Sleeping for 10 seconds.")
await asyncio.sleep(10)
except openai.APIError as e:
logging.warning(f"OpenAI API error: {e}")
break
return {"choices": [{"message": {"content": ""}}]}
async def agenerate_from_openai_chat_completion(
messages_list: list[list[dict[str, str]]],
engine: str,
temperature: float,
max_tokens: int,
top_p: float,
context_length: int,
requests_per_minute: int = 300,
) -> list[str]:
"""Generate from OpenAI Chat Completion API.
Args:
messages_list: list of message list
temperature: Temperature to use.
max_tokens: Maximum number of tokens to generate.
top_p: Top p to use.
context_length: Length of context to use.
requests_per_minute: Number of requests per minute to allow.
Returns:
List of generated responses.
"""
if "OPENAI_API_KEY" not in os.environ:
raise ValueError(
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
)
limiter = aiolimiter.AsyncLimiter(requests_per_minute)
async_responses = [
_throttled_openai_chat_completion_acreate(
model=engine,
messages=message,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
limiter=limiter,
)
for message in messages_list
]
responses = await tqdm_asyncio.gather(*async_responses)
return [x["choices"][0]["message"]["content"] for x in responses]
@retry_with_exponential_backoff
def generate_from_openai_chat_completion(
messages: list[dict[str, str]],
model: str,
temperature: float,
max_tokens: int,
top_p: float,
context_length: int,
stop_token: str | None = None,
) -> str:
if "OPENAI_API_KEY" not in os.environ:
raise ValueError(
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
)
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
)
answer: str = response.choices[0].message.content
return answer
@retry_with_exponential_backoff
# debug only
def fake_generate_from_openai_chat_completion(
messages: list[dict[str, str]],
model: str,
temperature: float,
max_tokens: int,
top_p: float,
context_length: int,
stop_token: str | None = None,
) -> str:
if "OPENAI_API_KEY" not in os.environ:
raise ValueError(
"OPENAI_API_KEY environment variable must be set when using OpenAI API."
)
answer = "Let's think step-by-step. This page shows a list of links and buttons. There is a search box with the label 'Search query'. I will click on the search box to type the query. So the action I will perform is \"click [60]\"."
return answer

View File

@ -0,0 +1,501 @@
import json
import re
from pathlib import Path
from typing import Any, TypedDict
from PIL import Image
from browser_env import Action, ActionParsingError, Trajectory
from browser_env.env_config import URL_MAPPINGS
from browser_env.utils import StateInfo, pil_to_b64, pil_to_vertex
from llms import lm_config
from llms.tokenizers import Tokenizer
from llms.utils import APIInput
class Instruction(TypedDict):
"""Instruction for constructing prompt"""
intro: str
examples: list[tuple[str, str]]
template: str
meta_data: dict[str, Any]
class PromptConstructor(object):
def __init__(
self,
instruction_path: str | Path,
lm_config: lm_config.LMConfig,
tokenizer: Tokenizer,
):
self.instruction_path = Path(instruction_path)
self.obs_modality = "text"
self.lm_config = lm_config
instruction = json.load(open(self.instruction_path))
instruction["examples"] = [tuple(e) for e in instruction["examples"]]
self.instruction: Instruction = instruction
self.tokenizer = tokenizer
def get_lm_api_input(
self, intro: str, examples: list[tuple[str, str]], current: str
) -> APIInput:
"""Return the require format for an API"""
message: list[dict[str, str]] | str
if "openai" in self.lm_config.provider:
if self.lm_config.mode == "chat":
message = [{"role": "system", "content": intro}]
for (x, y) in examples:
message.append(
{
"role": "system",
"name": "example_user",
"content": x,
}
)
message.append(
{
"role": "system",
"name": "example_assistant",
"content": y,
}
)
message.append({"role": "user", "content": current})
return message
elif self.lm_config.mode == "completion":
message = f"{intro}\n\n"
message += "Here are a few examples:\n"
for example in examples:
message += f"Observation\n:{example[0]}\n\n"
message += f"Action: {example[1]}\n\n"
message += "Now make prediction given the observation\n\n"
message += f"Observation\n:{current}\n\n"
message += "Action:"
return message
else:
raise ValueError(
f"OpenAI models do not support mode {self.lm_config.mode}"
)
elif "huggingface" in self.lm_config.provider:
# https://huggingface.co/blog/llama2#how-to-prompt-llama-2
# https://github.com/facebookresearch/llama/blob/main/llama/generation.py#L320
if "Llama-2" in self.lm_config.model:
if self.lm_config.mode == "chat":
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
BOS, EOS = "<s>", "</s>"
# adding the system message to be the starting of the first example
examples = [
(
B_SYS + intro + E_SYS + examples[0][0],
examples[0][1],
)
] + examples[1:]
message = "".join(
[
f"{BOS}{B_INST} {x.strip()} {E_INST} {y.strip()} {EOS}"
for (x, y) in examples
]
)
# add the current observation
message += f"{BOS}{B_INST} {current.strip()} {E_INST} {self.instruction['meta_data'].get('force_prefix', '')}"
return message
else:
raise ValueError("Only chat mode is supported for Llama-2")
else:
raise ValueError(
f"Huggingface models do not support model_tag {self.lm_config.gen_config['model_tag']}"
)
else:
raise NotImplementedError(
f"Provider {self.lm_config.provider} not implemented"
)
def construct(
self,
trajectory: Trajectory,
intent: str,
meta_data: dict[str, Any] = {},
) -> APIInput:
raise NotImplementedError
def map_url_to_real(self, url: str) -> str:
"""Map the urls to their real world counterparts"""
for i, j in URL_MAPPINGS.items():
if i in url:
url = url.replace(i, j)
return url
def map_url_to_local(self, url: str) -> str:
"""Map the urls to their local counterparts"""
for i, j in URL_MAPPINGS.items():
if j in url:
url = url.replace(j, i)
# https
if j.replace("http", "https") in url:
url = url.replace(j.replace("http", "https"), i)
return url
def _extract_action(self, response: str) -> str:
raise NotImplementedError
def extract_action(self, response: str) -> str:
response = self._extract_action(response)
response = self.map_url_to_local(response)
return response
class DirectPromptConstructor(PromptConstructor):
"""The agent will direct predict the action"""
def __init__(
self,
instruction_path: str | Path,
lm_config: lm_config.LMConfig,
tokenizer: Tokenizer,
):
super().__init__(instruction_path, lm_config, tokenizer)
def construct(
self,
trajectory: Trajectory,
intent: str,
meta_data: dict[str, Any] = {},
) -> APIInput:
"""Construct prompt given the trajectory"""
intro = self.instruction["intro"]
examples = self.instruction["examples"]
template = self.instruction["template"]
keywords = self.instruction["meta_data"]["keywords"]
state_info: StateInfo = trajectory[-1] # type: ignore[assignment]
obs = state_info["observation"][self.obs_modality]
max_obs_length = self.lm_config.gen_config["max_obs_length"]
if max_obs_length:
if self.lm_config.provider == "google":
print("NOTE: This is a Gemini model, so we use characters instead of tokens for max_obs_length.")
obs = obs[:max_obs_length]
else:
obs = self.tokenizer.decode(self.tokenizer.encode(obs)[:max_obs_length]) # type: ignore[arg-type]
page = state_info["info"]["page"]
url = page.url
previous_action_str = meta_data["action_history"][-1]
# input x
current = template.format(
objective=intent,
url=self.map_url_to_real(url),
observation=obs,
previous_action=previous_action_str,
)
# make sure all keywords are replaced
assert all([f"{{k}}" not in current for k in keywords])
prompt = self.get_lm_api_input(intro, examples, current)
return prompt
def _extract_action(self, response: str) -> str:
action_splitter = self.instruction["meta_data"]["action_splitter"]
pattern = rf"{action_splitter}((.|\n)*?){action_splitter}"
match = re.search(pattern, response)
if match:
return match.group(1).strip()
else:
raise ActionParsingError(
f"Cannot parse action from response {response}"
)
class CoTPromptConstructor(PromptConstructor):
"""The agent will perform step-by-step reasoning before the answer"""
def __init__(
self,
instruction_path: str | Path,
lm_config: lm_config.LMConfig,
tokenizer: Tokenizer,
):
super().__init__(instruction_path, lm_config, tokenizer)
self.answer_phrase = self.instruction["meta_data"]["answer_phrase"]
def construct(
self,
trajectory: Trajectory,
intent: str,
meta_data: dict[str, Any] = {},
) -> APIInput:
intro = self.instruction["intro"]
examples = self.instruction["examples"]
template = self.instruction["template"]
keywords = self.instruction["meta_data"]["keywords"]
state_info: StateInfo = trajectory[-1] # type: ignore[assignment]
obs = state_info["observation"][self.obs_modality]
max_obs_length = self.lm_config.gen_config["max_obs_length"]
if max_obs_length:
if self.lm_config.provider == "google":
print("NOTE: This is a Gemini model, so we use characters instead of tokens for max_obs_length.")
obs = obs[:max_obs_length]
else:
obs = self.tokenizer.decode(self.tokenizer.encode(obs)[:max_obs_length]) # type: ignore[arg-type]
page = state_info["info"]["page"]
url = page.url
previous_action_str = meta_data["action_history"][-1]
current = template.format(
objective=intent,
url=self.map_url_to_real(url),
observation=obs,
previous_action=previous_action_str,
)
assert all([f"{{k}}" not in current for k in keywords])
prompt = self.get_lm_api_input(intro, examples, current)
return prompt
def _extract_action(self, response: str) -> str:
# find the first occurence of action
action_splitter = self.instruction["meta_data"]["action_splitter"]
pattern = rf"{action_splitter}((.|\n)*?){action_splitter}"
match = re.search(pattern, response)
if match:
return match.group(1).strip()
else:
raise ActionParsingError(
f'Cannot find the answer phrase "{self.answer_phrase}" in "{response}"'
)
class MultimodalCoTPromptConstructor(CoTPromptConstructor):
"""The agent will perform step-by-step reasoning before the answer"""
def __init__(
self,
instruction_path: str | Path,
lm_config: lm_config.LMConfig,
tokenizer: Tokenizer,
):
super().__init__(instruction_path, lm_config, tokenizer)
self.answer_phrase = self.instruction["meta_data"]["answer_phrase"]
def construct(
self,
trajectory: Trajectory,
intent: str,
page_screenshot_img: Image.Image,
images: list[Image.Image],
meta_data: dict[str, Any] = {},
) -> APIInput:
intro = self.instruction["intro"]
examples = self.instruction["examples"]
template = self.instruction["template"]
keywords = self.instruction["meta_data"]["keywords"]
state_info: StateInfo = trajectory[-1] # type: ignore[assignment]
obs = state_info["observation"][self.obs_modality]
max_obs_length = self.lm_config.gen_config["max_obs_length"]
if max_obs_length:
if self.lm_config.provider in ["google", "api", "finetune"]:
print("NOTE: This is a Gemini / API model, so we use characters instead of tokens for max_obs_length.")
obs = obs[:max_obs_length]
else:
obs = self.tokenizer.decode(self.tokenizer.encode(obs)[:max_obs_length]) # type: ignore[arg-type]
page = state_info["info"]["page"]
url = page.url
previous_action_str = meta_data["action_history"][-1]
current = template.format(
objective=intent,
url=self.map_url_to_real(url),
observation=obs,
previous_action=previous_action_str,
)
assert all([f"{{k}}" not in current for k in keywords])
# TODO: for your finetune model, you can config you prompt here
if self.lm_config.provider == "finetune":
current = ""
traj = trajectory[1::2]
for rnd, tra in enumerate(traj):
tar = '** screenshot **' if rnd > 0 else intent
raw = tra["raw_prediction"]
current += f"Round {rnd}\n\n<|user|>\n\n** node_info **\n\n{tar}\n\n<|assistant|>\n{raw}\n\n"""
current += f"Round {len(traj)}\n\n<|user|>\n\n{obs}\n\n{'** screenshot **' if len(traj) > 0 else intent}\n"
prompt = self.get_lm_api_input(
intro, examples, current, page_screenshot_img, images
)
return prompt
def get_lm_api_input(
self,
intro: str,
examples: list[tuple[str, str, str]],
current: str,
page_screenshot_img: Image.Image,
images: list[Image.Image],
) -> APIInput:
"""Return the require format for an API"""
message: list[dict[str, str]] | str | list[str | Image.Image]
if "openai" in self.lm_config.provider:
if self.lm_config.mode == "chat":
message = [
{
"role": "system",
"content": [{"type": "text", "text": intro}],
}
]
for (x, y, z) in examples:
example_img = Image.open(z)
message.append(
{
"role": "system",
"name": "example_user",
"content": [
{"type": "text", "text": x},
{
"type": "text",
"text": "IMAGES: (1) current page screenshot",
},
{
"type": "image_url",
"image_url": {
"url": pil_to_b64(example_img)
},
},
],
}
)
message.append(
{
"role": "system",
"name": "example_assistant",
"content": [{"type": "text", "text": y}],
}
)
# Encode images and page_screenshot_img as base64 strings.
current_prompt = current
content = [
{
"type": "text",
"text": "IMAGES: (1) current page screenshot",
},
{
"type": "image_url",
"image_url": {"url": pil_to_b64(page_screenshot_img)},
},
]
for image_i, image in enumerate(images):
content.extend(
[
{
"type": "text",
"text": f"({image_i+2}) input image {image_i+1}",
},
{
"type": "image_url",
"image_url": {"url": pil_to_b64(image)},
},
]
)
content = [{"type": "text", "text": current_prompt}] + content
message.append({"role": "user", "content": content})
return message
else:
raise ValueError(
f"GPT-4V models do not support mode {self.lm_config.mode}"
)
elif "google" in self.lm_config.provider:
if self.lm_config.mode == "completion":
message = [
intro,
"Here are a few examples:",
]
for (x, y, z) in examples:
example_img = Image.open(z)
message.append(f"Observation\n:{x}\n")
message.extend(
[
"IMAGES:",
"(1) current page screenshot:",
pil_to_vertex(example_img),
]
)
message.append(f"Action: {y}")
message.append("Now make prediction given the observation")
message.append(f"Observation\n:{current}\n")
message.extend(
[
"IMAGES:",
"(1) current page screenshot:",
pil_to_vertex(page_screenshot_img),
]
)
for image_i, image in enumerate(images):
message.extend(
[
f"({image_i+2}) input image {image_i+1}",
pil_to_vertex(image),
]
)
message.append("Action:")
return message
else:
raise ValueError(
f"Gemini models do not support mode {self.lm_config.mode}"
)
elif self.lm_config.provider in ["api", "finetune"]:
message = [
{
"role": "system",
"content": intro,
}
]
# we keep few-shot here, but remove the image corresponding to the current page.
for (x, y, _) in examples:
message.append({
"role": "user",
"content": [
{ "type": "text", "text": x },
{ "type": "text", "text": "IMAGES: (1) current page screenshot\n\n** Screenshot **\n" },
],
})
message.append({
"role": "assistant",
"content": y,
})
# TODO: Encode images and page_screenshot_img as base64 strings, we only keep screenshot of current page.
current_prompt = current
content = []
if self.lm_config.provider != "finetune":
content.append({
"type": "text",
"text": "IMAGES: (1) current page screenshot",
})
if "text" not in self.lm_config.model:
content.append({
"type": "image_url",
"image_url": {"url": pil_to_b64(page_screenshot_img)},
})
content = [{"type": "text", "text": current_prompt}] + content
message.append({"role": "user", "content": content})
return message
else:
raise NotImplementedError(
f"Provider {self.lm_config.provider} not implemented"
)

View File

@ -0,0 +1,562 @@
"""Script to run end-to-end evaluation on the benchmark.
Modified from https://github.com/web-arena-x/webarena/blob/main/run.py.
"""
import argparse
import glob
import json
import logging
import os
import random
import subprocess
import tempfile
import time
from pathlib import Path
from typing import List
import openai
import requests
import torch
from PIL import Image
from agent import (
PromptAgent,
construct_agent,
)
from agent.prompts import *
from browser_env import (
Action,
ActionTypes,
ScriptBrowserEnv,
StateInfo,
Trajectory,
create_stop_action,
)
from browser_env.actions import is_equivalent
from browser_env.auto_login import get_site_comb_from_filepath
from browser_env.helper_functions import (
RenderHelper,
get_action_description,
)
from evaluation_harness import evaluator_router, image_utils
DATASET = os.environ["DATASET"]
LOG_FOLDER = "log_files"
Path(LOG_FOLDER).mkdir(parents=True, exist_ok=True)
LOG_FILE_NAME = f"{LOG_FOLDER}/log_{time.strftime('%Y%m%d%H%M%S', time.localtime())}_{random.randint(0, 10000)}.log"
logger = logging.getLogger("logger")
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
logger.addHandler(console_handler)
file_handler = logging.FileHandler(LOG_FILE_NAME)
file_handler.setLevel(logging.DEBUG)
logger.addHandler(file_handler)
# Set the log format
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
console_handler.setFormatter(formatter)
file_handler.setFormatter(formatter)
def config() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run end-to-end evaluation on the benchmark"
)
parser.add_argument(
"--render", action="store_true", help="Render the browser"
)
parser.add_argument(
"--slow_mo",
type=int,
default=0,
help="Slow down the browser by the specified amount",
)
parser.add_argument(
"--action_set_tag", default="id_accessibility_tree", help="Action type"
)
parser.add_argument(
"--observation_type",
choices=[
"accessibility_tree",
"accessibility_tree_with_captioner",
"html",
"image",
"image_som",
],
default="accessibility_tree",
help="Observation type",
)
parser.add_argument(
"--current_viewport_only",
action="store_true",
help="Only use the current viewport for the observation",
)
parser.add_argument("--viewport_width", type=int, default=1280)
parser.add_argument("--viewport_height", type=int, default=2048)
parser.add_argument("--save_trace_enabled", action="store_true")
parser.add_argument("--sleep_after_execution", type=float, default=0.0)
parser.add_argument("--max_steps", type=int, default=30)
# agent config
parser.add_argument("--agent_type", type=str, default="prompt")
parser.add_argument(
"--instruction_path",
type=str,
default="agents/prompts/state_action_agent.json",
)
parser.add_argument(
"--parsing_failure_th",
help="When consecutive parsing failures exceed this threshold, the agent will terminate early.",
type=int,
default=3,
)
parser.add_argument(
"--repeating_action_failure_th",
help="When consecutive repeated actions exceed this threshold, the agent will terminate early.",
type=int,
default=5,
)
parser.add_argument("--test_config_base_dir", type=str)
parser.add_argument(
"--eval_captioning_model_device",
type=str,
default="cpu",
choices=["cpu", "cuda"],
help="Device to run eval captioning model on. By default, runs it on CPU.",
)
parser.add_argument(
"--eval_captioning_model",
type=str,
default="Salesforce/blip2-flan-t5-xl",
choices=["Salesforce/blip2-flan-t5-xl"],
help="Captioning backbone for VQA-type evals.",
)
parser.add_argument(
"--captioning_model",
type=str,
default="Salesforce/blip2-flan-t5-xl",
choices=["Salesforce/blip2-flan-t5-xl", "llava-hf/llava-1.5-7b-hf"],
help="Captioning backbone for accessibility tree alt text.",
)
# lm config
parser.add_argument("--provider", type=str, default="openai")
parser.add_argument("--model", type=str, default="gpt-3.5-turbo-0613")
parser.add_argument("--mode", type=str, default="chat")
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--top_p", type=float, default=0.9)
parser.add_argument("--context_length", type=int, default=0)
parser.add_argument("--max_tokens", type=int, default=384)
parser.add_argument("--stop_token", type=str, default=None)
parser.add_argument(
"--max_retry",
type=int,
help="max retry times to perform generations when parsing fails",
default=1,
)
parser.add_argument(
"--max_obs_length",
type=int,
help="when not zero, will truncate the observation to this length before feeding to the model",
default=3840,
)
# example config
parser.add_argument("--test_start_idx", type=int, default=0)
parser.add_argument("--test_end_idx", type=int, default=910)
# logging related
parser.add_argument("--result_dir", type=str, default="")
args = parser.parse_args()
# check the whether the action space is compatible with the observation space
if (
args.action_set_tag == "id_accessibility_tree"
and args.observation_type
not in [
"accessibility_tree",
"accessibility_tree_with_captioner",
"image_som",
]
):
raise ValueError(
f"Action type {args.action_set_tag} is incompatible with the observation type {args.observation_type}"
)
return args
def early_stop(
trajectory: Trajectory, max_steps: int, thresholds: dict[str, int]
) -> tuple[bool, str]:
"""Check whether need to stop early"""
# reach the max step
num_steps = (len(trajectory) - 1) / 2
if num_steps >= max_steps:
return True, f"Reach max steps {max_steps}"
last_k_actions: list[Action]
action_seq: list[Action]
# Case: parsing failure for k times
k = thresholds["parsing_failure"]
last_k_actions = trajectory[1::2][-k:] # type: ignore[assignment]
if len(last_k_actions) >= k:
if all(
[
action["action_type"] == ActionTypes.NONE
for action in last_k_actions
]
):
return True, f"Failed to parse actions for {k} times"
# Case: same action for k times
k = thresholds["repeating_action"]
last_k_actions = trajectory[1::2][-k:] # type: ignore[assignment]
action_seq = trajectory[1::2] # type: ignore[assignment]
if len(action_seq) == 0:
return False, ""
last_action: Action = action_seq[-1]
if last_action["action_type"] != ActionTypes.TYPE:
if len(last_k_actions) >= k:
if all(
[
is_equivalent(action, last_action)
for action in last_k_actions
]
):
return True, f"Same action for {k} times"
else:
# check the action sequence
if (
sum([is_equivalent(action, last_action) for action in action_seq])
>= k
):
return True, f"Same typing action for {k} times"
return False, ""
def update_action_history(path: str, task_id: int, actions: List[str], score: float=-0.1):
obj = {
"task_id": task_id,
"score": score,
"actions": actions
}
json.dump(obj, open(path, "w"), indent=4)
def test(
args: argparse.Namespace,
config_file_list: list[str]
) -> None:
scores = []
max_steps = args.max_steps
early_stop_thresholds = {
"parsing_failure": args.parsing_failure_th,
"repeating_action": args.repeating_action_failure_th,
}
if args.observation_type in [
"accessibility_tree_with_captioner",
# "image_som",
]:
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
caption_image_fn = image_utils.get_captioning_fn(
device, dtype, args.captioning_model
)
else:
caption_image_fn = None
# Load a (possibly different) captioning model for running VQA evals.
if DATASET == 'visualwebarena':
if (
caption_image_fn
and args.eval_captioning_model == args.captioning_model
):
eval_caption_image_fn = caption_image_fn
else:
eval_caption_image_fn = image_utils.get_captioning_fn(
args.eval_captioning_model_device,
torch.float16
if (
torch.cuda.is_available()
and args.eval_captioning_model_device == "cuda"
)
else torch.float32,
args.eval_captioning_model,
)
else:
caption_image_fn = None
eval_caption_image_fn = None
agent = construct_agent(
args,
captioning_fn=caption_image_fn
if args.observation_type == "accessibility_tree_with_captioner"
else None,
) # NOTE: captioning_fn here is used for captioning input images.
env = ScriptBrowserEnv(
headless=not args.render,
slow_mo=args.slow_mo,
observation_type=args.observation_type,
current_viewport_only=args.current_viewport_only,
viewport_size={
"width": args.viewport_width,
"height": args.viewport_height,
},
save_trace_enabled=args.save_trace_enabled,
sleep_after_execution=args.sleep_after_execution,
# NOTE: captioning_fn here is used for LLM + captioning baselines.
# This can be different from the captioning model used for evals.
captioning_fn=caption_image_fn,
)
for config_file in config_file_list:
try:
render_helper = RenderHelper(
config_file, args.result_dir, args.action_set_tag
)
# Load task.
with open(config_file) as f:
_c = json.load(f)
intent = _c["intent"]
task_id = _c["task_id"]
image_paths = _c.get("image", None)
images = []
# automatically login
if _c["storage_state"]:
cookie_file_name = os.path.basename(_c["storage_state"])
comb = get_site_comb_from_filepath(cookie_file_name)
temp_dir = tempfile.mkdtemp()
# subprocess to renew the cookie
subprocess.run(
[
"python",
"browser_env/auto_login.py",
"--auth_folder",
temp_dir,
"--site_list",
*comb,
]
)
_c["storage_state"] = f"{temp_dir}/{cookie_file_name}"
assert os.path.exists(_c["storage_state"])
# update the config file
config_file = f"{temp_dir}/{os.path.basename(config_file)}"
with open(config_file, "w") as f:
json.dump(_c, f)
# Load input images for the task, if any.
if image_paths is not None:
if isinstance(image_paths, str):
image_paths = [image_paths]
for image_path in image_paths:
# Load image either from the web or from a local path.
if image_path.startswith("http"):
input_image = Image.open(requests.get(image_path, stream=True).raw)
else:
input_image = Image.open(image_path)
images.append(input_image)
logger.info(f"[Config file]: {config_file}")
logger.info(f"[Intent]: {intent}")
agent.reset(config_file)
trajectory: Trajectory = []
obs, info = env.reset(options={"config_file": config_file})
state_info: StateInfo = {"observation": obs, "info": info}
trajectory.append(state_info)
meta_data = {"action_history": ["None"]}
out_path = os.path.join(args.result_dir, "actions", f"{task_id}.json")
actions = []
while True:
update_action_history(out_path, task_id, actions=actions)
early_stop_flag, stop_info = early_stop(
trajectory, max_steps, early_stop_thresholds
)
if early_stop_flag:
action = create_stop_action(f"Early stop: {stop_info}")
else:
try:
action = agent.next_action(
trajectory,
intent,
images=images,
meta_data=meta_data,
)
except ValueError as e:
# get the error message
action = create_stop_action(f"ERROR: {str(e)}")
trajectory.append(action)
action_str = get_action_description(
action,
state_info["info"]["observation_metadata"],
action_set_tag=args.action_set_tag,
prompt_constructor=agent.prompt_constructor
if isinstance(agent, PromptAgent)
else None,
)
render_helper.render(
action, state_info, meta_data, args.render_screenshot
)
meta_data["action_history"].append(action_str)
actions.append(action_str)
print(action_str)
if action["action_type"] == ActionTypes.STOP:
break
obs, _, terminated, _, info = env.step(action)
state_info = {"observation": obs, "info": info}
trajectory.append(state_info)
if terminated:
# add a action place holder
trajectory.append(create_stop_action(""))
break
# NOTE: eval_caption_image_fn is used for running eval_vqa functions.
evaluator = evaluator_router(
config_file, captioning_fn=eval_caption_image_fn
)
score = evaluator(
trajectory=trajectory,
config_file=config_file,
page=env.page
)
update_action_history(out_path, task_id, actions=actions, score=score)
scores.append(score)
if score == 1:
logger.info(f"[Result] (PASS) {config_file}")
else:
logger.info(f"[Result] (FAIL) {config_file}")
if args.save_trace_enabled:
env.save_trace(
Path(args.result_dir) / "traces" / f"{task_id}.zip"
)
except openai.OpenAIError as e:
logger.info(f"[OpenAI Error] {repr(e)}")
except Exception as e:
logger.info(f"[Unhandled Error] {repr(e)}]")
import traceback
# write to error file
with open(Path(args.result_dir) / "error.txt", "a") as f:
f.write(f"[Config file]: {config_file}\n")
f.write(f"[Unhandled Error] {repr(e)}\n")
f.write(traceback.format_exc()) # write stack trace to file
render_helper.close()
env.close()
if len(scores):
logger.info(f"Average score: {sum(scores) / len(scores)}")
def prepare(args: argparse.Namespace) -> None:
# convert prompt python files to json
from agent.prompts import to_json
to_json.run()
# prepare result dir
result_dir = args.result_dir
if not result_dir:
result_dir = (
f"cache/results_{time.strftime('%Y%m%d%H%M%S', time.localtime())}"
)
if not Path(result_dir).exists():
Path(result_dir).mkdir(parents=True, exist_ok=True)
args.result_dir = result_dir
logger.info(f"Create result dir: {result_dir}")
if not (Path(result_dir) / "traces").exists():
(Path(result_dir) / "traces").mkdir(parents=True)
os.makedirs(os.path.join(result_dir, "actions"), exist_ok=True)
# log the log file
with open(os.path.join(result_dir, "log_files.txt"), "a+") as f:
f.write(f"{LOG_FILE_NAME}\n")
def get_unfinished(config_files: list[str], result_dir: str) -> list[str]:
result_files = glob.glob(f"{result_dir}/*.html")
task_ids = [
os.path.basename(f).split(".")[0].split("_")[1] for f in result_files
]
unfinished_configs = []
for config_file in config_files:
task_id = os.path.basename(config_file).split(".")[0]
try:
with open(f"{result_dir}/actions/{task_id}.json", "r") as f:
jd = json.load(f)
except:
jd = {}
if task_id not in task_ids or jd.get('score', -1) < 0:
unfinished_configs.append(config_file)
return unfinished_configs
def dump_config(args: argparse.Namespace) -> None:
config_file = Path(args.result_dir) / "config.json"
if not config_file.exists():
with open(config_file, "w") as f:
json.dump(vars(args), f, indent=4)
logger.info(f"Dump config to {config_file}")
if __name__ == "__main__":
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = config()
args.sleep_after_execution = 2.5
prepare(args)
test_config_base_dir = args.test_config_base_dir
test_file_list = []
st_idx = args.test_start_idx
ed_idx = args.test_end_idx
for i in range(st_idx, ed_idx):
test_file_list.append(os.path.join(test_config_base_dir, f"{i}.json"))
test_file_list = get_unfinished(test_file_list, args.result_dir)
print(f"Total {len(test_file_list)} tasks left")
args.render = False
args.render_screenshot = True
args.save_trace_enabled = True
args.current_viewport_only = True
dump_config(args)
test(args, test_file_list)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,29 @@
from typing import Any
import tiktoken
from transformers import LlamaTokenizer # type: ignore
class Tokenizer(object):
def __init__(self, provider: str, model_name: str) -> None:
if provider == "openai":
self.tokenizer = tiktoken.encoding_for_model(model_name)
elif provider == "huggingface":
self.tokenizer = LlamaTokenizer.from_pretrained(model_name)
# turn off adding special tokens automatically
self.tokenizer.add_special_tokens = False # type: ignore[attr-defined]
self.tokenizer.add_bos_token = False # type: ignore[attr-defined]
self.tokenizer.add_eos_token = False # type: ignore[attr-defined]
elif provider in ["google", "api", "finetune"]:
self.tokenizer = None # Not used for input length computation, as Gemini is based on characters
else:
raise NotImplementedError
def encode(self, text: str) -> list[int]:
return self.tokenizer.encode(text)
def decode(self, ids: list[int]) -> str:
return self.tokenizer.decode(ids)
def __call__(self, text: str) -> list[int]:
return self.tokenizer.encode(text)

View File

@ -0,0 +1,87 @@
import argparse
from typing import Any
try:
from vertexai.preview.generative_models import Image
from llms import generate_from_gemini_completion
except:
print('Google Cloud not set up, skipping import of vertexai.preview.generative_models.Image and llms.generate_from_gemini_completion')
from llms import (
generate_from_huggingface_completion,
generate_from_openai_chat_completion,
generate_from_openai_completion,
generate_with_api,
lm_config,
)
APIInput = str | list[Any] | dict[str, Any]
def call_llm(
lm_config: lm_config.LMConfig,
prompt: APIInput,
) -> str:
response: str
if lm_config.provider == "openai":
if lm_config.mode == "chat":
assert isinstance(prompt, list)
response = generate_from_openai_chat_completion(
messages=prompt,
model=lm_config.model,
temperature=lm_config.gen_config["temperature"],
top_p=lm_config.gen_config["top_p"],
context_length=lm_config.gen_config["context_length"],
max_tokens=lm_config.gen_config["max_tokens"],
stop_token=None,
)
elif lm_config.mode == "completion":
assert isinstance(prompt, str)
response = generate_from_openai_completion(
prompt=prompt,
engine=lm_config.model,
temperature=lm_config.gen_config["temperature"],
max_tokens=lm_config.gen_config["max_tokens"],
top_p=lm_config.gen_config["top_p"],
stop_token=lm_config.gen_config["stop_token"],
)
else:
raise ValueError(
f"OpenAI models do not support mode {lm_config.mode}"
)
elif lm_config.provider == "huggingface":
assert isinstance(prompt, str)
response = generate_from_huggingface_completion(
prompt=prompt,
model_endpoint=lm_config.gen_config["model_endpoint"],
temperature=lm_config.gen_config["temperature"],
top_p=lm_config.gen_config["top_p"],
stop_sequences=lm_config.gen_config["stop_sequences"],
max_new_tokens=lm_config.gen_config["max_new_tokens"],
)
elif lm_config.provider == "google":
assert isinstance(prompt, list)
assert all(
[isinstance(p, str) or isinstance(p, Image) for p in prompt]
)
response = generate_from_gemini_completion(
prompt=prompt,
engine=lm_config.model,
temperature=lm_config.gen_config["temperature"],
max_tokens=lm_config.gen_config["max_tokens"],
top_p=lm_config.gen_config["top_p"],
)
elif lm_config.provider in ["api", "finetune"]:
args = {
"temperature": lm_config.gen_config["temperature"], # openai, gemini, claude
"max_tokens": lm_config.gen_config["max_tokens"], # openai, gemini, claude
"top_k": lm_config.gen_config["top_p"], # qwen
}
response = generate_with_api(prompt, lm_config.model, args)
else:
raise NotImplementedError(
f"Provider {lm_config.provider} not implemented"
)
return response

View File

@ -0,0 +1,78 @@
#!/bin/bash
DATASET='webarena' # webarena, visualwebarena
result_dir=''
provider=''
model=''
instruction_path='agent/prompts/jsons/p_som_cot_id_actree_3s.json' # e.g., agent/prompts/jsons/p_cot_id_actree_2s.json
test_config_base_dir='config_files/wa/test_webarena_lite' # e.g., config_files/wa/test_webarena_lite
temperature=0.0
SERVER='' # your server address
MAP_SERVER='' # the same as SERVER
OPENAI_API_KEY=''
OPENAI_ORGANIZATION=''
CONDA_ENV_NAME='' # the name of your conda environment
ENV_VARIABLES="export DATASET=${DATASET}; export SHOPPING='http://${SERVER}:7770';export SHOPPING_ADMIN='http://${SERVER}:7780/admin';export REDDIT='http://${SERVER}:9999';export GITLAB='http://${SERVER}:8023';export MAP='http://${MAP_SERVER}:3000';export WIKIPEDIA='http://${SERVER}:8888/wikipedia_en_all_maxi_2022-05/A/User:The_other_Kiwix_guy/Landing';export HOMEPAGE='http://${SERVER}:4399';export OPENAI_API_KEY=${OPENAI_API_KEY};export OPENAI_ORGANIZATION=${OPENAI_ORGANIZATION}"
# get the number of tmux panes
num_panes=$(tmux list-panes | wc -l)
# calculate how many panes need to be created
let "panes_to_create = 8 - num_panes"
# array of tmux commands to create each pane
tmux_commands=(
'tmux split-window -h'
'tmux split-window -v'
'tmux select-pane -t 0; tmux split-window -v'
'tmux select-pane -t 0; tmux split-window -v'
'tmux select-pane -t 2; tmux split-window -v'
'tmux select-pane -t 4; tmux split-window -v'
'tmux select-pane -t 6; tmux split-window -v'
)
# create panes up to 8
for ((i=0; i<$panes_to_create; i++)); do
eval ${tmux_commands[$i]}
done
#!/bin/bash
# Function to run a job
run_job() {
tmux select-pane -t $1
tmux send-keys "tmux set mouse on; conda activate ${CONDA_ENV_NAME}; ${ENV_VARIABLES}; until python run.py --viewport_width 1280 --viewport_height 720 --test_start_idx $2 --test_end_idx $3 --provider ${provider} --model ${model} --instruction_path ${instruction_path} --temperature ${temperature} --test_config_base_dir ${test_config_base_dir} --result_dir ${result_dir}; do echo 'crashed' >&2; sleep 1; done" C-m
sleep 3
}
TOLERANCE=2
run_batch() {
args=("$@") # save all arguments in an array
num_jobs=${#args[@]} # get number of arguments
for ((i=1; i<$num_jobs; i++)); do
run_job $i ${args[i-1]} ${args[i]}
done
# Wait for all jobs to finish
while tmux list-panes -F "#{pane_pid} #{pane_current_command}" | grep -q python; do
sleep 100 # wait for 10 seconds before checking again
done
# Run checker
while ! python scripts/check_error_runs.py ${result_dir} --delete_errors --tolerance ${TOLERANCE}; do
echo "Check failed, rerunning jobs..."
for ((i=1; i<$num_jobs; i++)); do
run_job $i ${args[i-1]} ${args[i]}
done
# Wait for all jobs to finish
while tmux list-panes -F "#{pane_pid} #{pane_current_command}" | grep -q python; do
sleep 100 # wait for 10 seconds before checking again
done
done
}
run_batch 0 24 48 72 96 120 143 165

View File

@ -0,0 +1,71 @@
#!/bin/bash
docker stop shopping
docker stop shopping_admin
docker stop forum
docker stop gitlab
docker rm shopping
docker rm shopping_admin
docker rm forum
docker rm gitlab
# One Stop Shop
# docker load --input shopping_final_0712.tar
docker run --name shopping -p 7770:80 -d shopping_final_0712
# CMS
# docker load --input shopping_admin_final_0719.tar
docker run --name shopping_admin -p 7780:80 -d shopping_admin_final_0719
# Reddit
# docker load --input postmill-populated-exposed-withimg.tar
docker run --name forum -p 9999:80 -d postmill-populated-exposed-withimg
# GitLab
# docker load --input gitlab-populated-final-port8023.tar
docker run --name gitlab --shm-size="10g" -d -p 8023:8023 gitlab-populated-final-port8023 /opt/gitlab/embedded/bin/runsvdir-start
sleep 60
# Define your actual server hostname
YOUR_ACTUAL_HOSTNAME="http://localhost"
# Remove trailing / if it exists
YOUR_ACTUAL_HOSTNAME=${YOUR_ACTUAL_HOSTNAME%/}
# OSS
docker exec shopping /var/www/magento2/bin/magento setup:store-config:set --base-url="${YOUR_ACTUAL_HOSTNAME}:7770"
docker exec shopping mysql -u magentouser -pMyPassword magentodb -e "UPDATE core_config_data SET value='${YOUR_ACTUAL_HOSTNAME}:7775/' WHERE path = 'web/secure/base_url';"
docker exec shopping /var/www/magento2/bin/magento cache:flush
# Disable re-indexing of products
docker exec shopping /var/www/magento2/bin/magento indexer:set-mode schedule catalogrule_product
docker exec shopping /var/www/magento2/bin/magento indexer:set-mode schedule catalogrule_rule
docker exec shopping /var/www/magento2/bin/magento indexer:set-mode schedule catalogsearch_fulltext
docker exec shopping /var/www/magento2/bin/magento indexer:set-mode schedule catalog_category_product
docker exec shopping /var/www/magento2/bin/magento indexer:set-mode schedule customer_grid
docker exec shopping /var/www/magento2/bin/magento indexer:set-mode schedule design_config_grid
docker exec shopping /var/www/magento2/bin/magento indexer:set-mode schedule inventory
docker exec shopping /var/www/magento2/bin/magento indexer:set-mode schedule catalog_product_category
docker exec shopping /var/www/magento2/bin/magento indexer:set-mode schedule catalog_product_attribute
docker exec shopping /var/www/magento2/bin/magento indexer:set-mode schedule catalog_product_price
docker exec shopping /var/www/magento2/bin/magento indexer:set-mode schedule cataloginventory_stock
# CMS
docker exec shopping_admin /var/www/magento2/bin/magento setup:store-config:set --base-url="${YOUR_ACTUAL_HOSTNAME}:7780"
docker exec shopping_admin mysql -u magentouser -pMyPassword magentodb -e "UPDATE core_config_data SET value='${YOUR_ACTUAL_HOSTNAME}:7780/' WHERE path = 'web/secure/base_url';"
docker exec shopping_admin /var/www/magento2/bin/magento cache:flush
# Forum
docker exec forum sed -i '/@RateLimit/,/)/d' /var/www/html/src/DataObject/CommentData.php
docker exec forum sed -i '/@RateLimit/,/)/d' /var/www/html/src/DataObject/SubmissionData.php
docker exec forum sed -i '/@RateLimit/,/)/d' /var/www/html/src/DataObject/UserData.php
docker exec forum bin/console cache:clear --env=prod
sleep 60
# Gitlab
docker exec gitlab sed -i "s|^external_url.*|external_url '${YOUR_ACTUAL_HOSTNAME}:8023'|" /etc/gitlab/gitlab.rb
docker exec gitlab sed -i "s/.*postgresql\['max_connections'.*/postgresql\['max_connections'\] = 2000/g" /etc/gitlab/gitlab.rb
docker exec gitlab gitlab-ctl reconfigure
docker exec gitlab gitlab-ctl restart

View File

@ -0,0 +1,37 @@
#!/bin/bash
# 1. reset the original environment
cd visualwebarena
rm -rf .git
cd ..
# 2. replace files
# webarena-lite
cp -f new/test_webarena_lite.raw.json visualwebarena/config_files/wa/test_webarena_lite.raw.json
cp -f new/generate_test_data.py visualwebarena/scripts/generate_test_data.py
# agent
cp -f new/run.py visualwebarena/run.py
cp -f new/agent.py visualwebarena/agent/agent.py
cp -f new/prompt_constructor.py visualwebarena/agent/prompts/prompt_constructor.py
# llms
cp -f new/utils.py visualwebarena/llms/utils.py
cp -f new/llms_init.py visualwebarena/llms/__init__.py
cp -f new/lm_config.py visualwebarena/llms/lm_config.py
cp -f new/tokenizers.py visualwebarena/llms/tokenizers.py
cp -f new/api_utils.py visualwebarena/llms/providers/api_utils.py
cp -f new/openai_utils.py visualwebarena/llms/providers/openai_utils.py
# eval
cp -f new/evaluators.py visualwebarena/evaluation_harness/evaluators.py
cp -f new/helper_functions.py visualwebarena/evaluation_harness/helper_functions.py
# misc
cp -f README.md visualwebarena/README.md
cp -f new/wa_parallel_run.sh visualwebarena/wa_parallel_run.sh
# 3. remove temporary files
mv visualwebarena/* .
rm -rf new
rm -rf visualwebarena

BIN
assets/framework.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 407 KiB

View File

@ -0,0 +1,49 @@
# Setup for VAB-Minecraft
## Installation
1. We have tested on Ubuntu. VAB-Minecraft requires at least 4 GB NVIDIA GPU and NVIDIA GPU driver version >= 530.30.02.
2. Besides [docker](https://www.docker.com/), install [NVIDIA container toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) on your machine.
3. Get pre-built docker image.
- If you have access to docker hub:
```bash
docker pull tianjiezhang/vab_minecraft:latest
```
- Or you can download from ModelScope.
1. Make sure `git-lfs` is installed.
2. Download from ModelScope:
```bash
git lfs install
git clone https://www.modelscope.cn/datasets/VisualAgentBench/VAB-Minecraft.git
```
3. Load the docker image from ModelScope dataset.
```bash
docker load -i VAB-Minecraft/vab_minecraft.tar
```
4. Download weights of Steve-1 to `data/minecraft`. Please make sure you have access to google drive.
```bash
python scripts/minecraft_download.py
```
## Get Started
According to your hardware equipment, fill `available_ports` and `available_devices` in the task configuration file `configs/tasks/minecraft.yaml`.
- `available_ports`: Please fill in available ports in your machine. Each concurrent docker container requires 1 port for communication with the task server. Ensure that you provide enough ports to accommodate the expected concurrency.
- `available_devices`: Please fill in GPU IDs and their corresponding capability of concurrency. Each concurrent docker container occupies about **3.3 GB** memory. Ensure that you provide enough GPU memory to accommodate the expected concurrency.
**Note: If you manually shut down the task server and assigner, please ensure you also stop the Minecraft containers to free up the ports!**

View File

@ -57,57 +57,3 @@
2. Load the new value: `sudo sysctl -p`.
**Note: If you manually shut down the task server and assigner, please ensure you also stop the OmniGibson containers to free up the ports!**
# Setup for VAB-Minecraft
## Installation
1. We have tested on Ubuntu. VAB-Minecraft requires at least 4 GB NVIDIA GPU and NVIDIA GPU driver version >= 530.30.02.
2. Besides [docker](https://www.docker.com/), install [NVIDIA container toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) on your machine.
3. Get pre-built docker image.
- If you have access to docker hub:
```bash
docker pull tianjiezhang/vab_minecraft:latest
```
- Or you can download from ModelScope.
1. Make sure `git-lfs` is installed.
2. Download from ModelScope:
```bash
git lfs install
git clone https://www.modelscope.cn/datasets/VisualAgentBench/VAB-Minecraft.git
```
3. Load the docker image from ModelScope dataset.
```bash
docker load -i VAB-Minecraft/vab_minecraft.tar
```
4. Download weights of Steve-1 to `data/minecraft`. Please make sure you have access to google drive.
```bash
python scripts/minecraft_download.py
```
## Get Started
According to your hardware equipment, fill `available_ports` and `available_devices` in the task configuration file `configs/tasks/minecraft.yaml`.
- `available_ports`: Please fill in available ports in your machine. Each concurrent docker container requires 1 port for communication with the task server. Ensure that you provide enough ports to accommodate the expected concurrency.
- `available_devices`: Please fill in GPU IDs and their corresponding capability of concurrency. Each concurrent docker container occupies about **3.3 GB** memory. Ensure that you provide enough GPU memory to accommodate the expected concurrency.
**Note: If you manually shut down the task server and assigner, please ensure you also stop the Minecraft containers to free up the ports!**
# Setup for VAB-CSS
TODO