webrl/VAB-WebArena-Lite/new/envs.py
2024-11-14 15:51:41 +08:00

320 lines
11 KiB
Python

import json
import os
import re
import subprocess
import time
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Union
import numpy as np
import numpy.typing as npt
import requests
from beartype import beartype
from gymnasium import Env
from gymnasium.spaces import Box, Text
from playwright.sync_api import (
CDPSession,
Page,
Playwright,
ViewportSize,
expect,
sync_playwright,
)
DATASET = os.environ["DATASET"]
if DATASET == "visualwebarena":
from browser_env.env_config import (
CLASSIFIEDS,
CLASSIFIEDS_RESET_TOKEN,
)
from .actions import Action, execute_action, get_action_space, execute_action_webrl
from .processors import ObservationHandler, ObservationMetadata
from .utils import (
AccessibilityTree,
DetachedPage,
Observation,
png_bytes_to_numpy,
)
@dataclass
class PlaywrightScript:
function: str # goto, get_by_role
destination: str # https://www.google.com/, combobox
name: str | None = None # Search, Avatar 2009
operation: str | None = None # click, fill, press
value: str | None = None # avatar movie, Enter
def parse_action(action: str) -> PlaywrightScript:
splitted = action.strip().split(" ")
assert len(splitted) >= 2
match splitted[:2]:
case ["goto", url]:
assert len(splitted) == 2
return PlaywrightScript("goto", url)
case ["get_by_role", destination]:
assert len(splitted) >= 4
match splitted[2:]:
case [name, operation]:
return PlaywrightScript(
"get_by_role", destination, name, operation
)
case [name, operation, value]:
return PlaywrightScript(
"get_by_role", destination, name, operation, value
)
case _:
raise ValueError("Invalid action")
case _:
raise ValueError(f"Invalid action {action}")
class ScriptBrowserEnv(Env[dict[str, Observation], Action]):
"""
The goal of this environment is to produce a prototype of a browser environment.
In the end, we want to support a fully configurable browser environment with wide
range of action spaces and observation spaces, both structured and unstructured.
But in this prototype, we just support action space specified by Playwright script,
and observation space is the html content of the page.
"""
@beartype
def __init__(
self,
max_page_length: int = 8192,
headless: bool = True,
slow_mo: int = 0,
observation_type: str = "html",
current_viewport_only: bool = False,
viewport_size: ViewportSize = {"width": 1280, "height": 720},
save_trace_enabled: bool = False,
sleep_after_execution: float = 0.0,
captioning_fn=None,
):
# TODO: make Space[Action] = ActionSpace
self.action_space = get_action_space() # type: ignore[assignment]
self.headless = headless
self.slow_mo = slow_mo
self.current_viewport_only = current_viewport_only
self.reset_finished = False
self.viewport_size = viewport_size
self.save_trace_enabled = save_trace_enabled
self.sleep_after_execution = sleep_after_execution
match observation_type:
case "html" | "accessibility_tree" | "accessibility_tree_with_captioner" | "webrl":
self.text_observation_type = observation_type
self.image_observation_type = ""
self.main_observation_type = "text"
case "image":
self.image_observation_type = observation_type
self.text_observation_type = "" # type: ignore[assignment]
self.main_observation_type = "image"
case "image_som":
self.image_observation_type = observation_type
self.text_observation_type = observation_type # type: ignore[assignment]
self.main_observation_type = "image"
case _:
raise ValueError(
f"Unsupported observation type: {observation_type}"
)
self.observation_handler = ObservationHandler(
self.main_observation_type,
self.text_observation_type,
self.image_observation_type,
self.current_viewport_only,
self.viewport_size,
captioning_fn,
)
self.observation_space = (
self.observation_handler.get_observation_space()
)
@beartype
def setup(self, config_file: Path | None = None) -> None:
self.context_manager = sync_playwright()
self.playwright = self.context_manager.__enter__()
self.browser = self.playwright.chromium.launch(
headless=self.headless, slow_mo=self.slow_mo
)
if config_file:
with open(config_file, "r") as f:
instance_config = json.load(f)
else:
instance_config = {}
# Reset site if needed. Currently only supported for Classifieds.
# TODO(jykoh): Add reset functionality for Shopping/Reddit.
if instance_config.get("require_reset", False):
if "classifieds" in instance_config["sites"]:
# Send POST request to __CLASSIFIEDS__/index.php?page=reset with token=CLASSIFIEDS_TOKEN
response = requests.post(
f"{CLASSIFIEDS}/index.php?page=reset",
data={"token": CLASSIFIEDS_RESET_TOKEN},
)
# Check if the request was successful
if response.status_code == 200:
print("Reset Classifieds site.")
else:
print(
"Failed to reset Classifieds site:",
response.status_code,
)
else:
print(
"WARNING: Reset is not supported for this site. Please manually reset the site."
)
storage_state = instance_config.get("storage_state", None)
start_url = instance_config.get("start_url", None)
geolocation = instance_config.get("geolocation", None)
# Use custom viewport size if specified in the config, otherwise use the default.
viewport_size = self.viewport_size.copy()
viewport_size.update(instance_config.get("viewport_size", {}))
self.observation_handler.viewport_size = viewport_size
self.context = self.browser.new_context(
viewport=viewport_size,
storage_state=storage_state,
geolocation=geolocation,
device_scale_factor=1,
)
if self.save_trace_enabled:
self.context.tracing.start(screenshots=True, snapshots=True)
if start_url:
start_urls = start_url.split(" |AND| ")
for url in start_urls:
page = self.context.new_page()
if self.text_observation_type in [
"accessibility_tree",
"accessibility_tree_with_captioner",
]:
client = page.context.new_cdp_session(page)
client.send("Accessibility.enable")
client.detach()
page.goto(url)
# set the first page as the current page
self.page = self.context.pages[0]
self.page.bring_to_front()
else:
self.page = self.context.new_page()
if self.text_observation_type in [
"accessibility_tree",
"accessibility_tree_with_captioner",
]:
client = self.page.context.new_cdp_session(self.page)
client.send("Accessibility.enable")
client.detach()
def _get_obs(self) -> dict[str, Observation]:
obs = self.observation_handler.get_observation(self.page)
return obs
def _get_obs_metadata(self) -> dict[str, ObservationMetadata]:
metadata = self.observation_handler.get_observation_metadata()
return metadata
@beartype
def reset(
self,
*,
seed: int | None = None,
options: dict[str, str] | None = None,
) -> tuple[dict[str, Observation], dict[str, Any]]:
"""
Reset the environment.
:param options: options for the environment. The current supported options are:
- "storage_state": the storage state of the browser. It is a file path to a json file.
"""
super().reset(seed=seed, options=options)
if self.reset_finished:
self.context_manager.__exit__()
if options is not None and "config_file" in options:
config_file = Path(options["config_file"])
if config_file.exists():
self.setup(config_file=config_file)
else:
raise ValueError(f"Config file {config_file} does not exist.")
else:
self.setup()
self.reset_finished = True
timeout_in_ms = 120000
self.page.set_default_timeout(timeout_in_ms)
self.page.set_default_navigation_timeout(timeout_in_ms)
self.page.wait_for_timeout(int(self.sleep_after_execution * 1000))
observation = self._get_obs()
observation_metadata = self._get_obs_metadata()
info = {
"page": DetachedPage(self.page.url, ""),
"fail_error": "",
"observation_metadata": observation_metadata,
}
return (observation, info)
def save_trace(self, trace_path: str | Path) -> None:
if self.save_trace_enabled:
self.context.tracing.stop(path=trace_path)
def close(self) -> None:
if self.reset_finished:
self.context_manager.__exit__()
def step(
self, action: Action
) -> tuple[dict[str, Observation], float, bool, bool, dict[str, Any]]:
if not self.reset_finished:
raise RuntimeError("Call reset first before calling step.")
success = False
fail_error = ""
try:
if self.text_observation_type == 'webrl':
self.page = execute_action_webrl(
action,
self.page,
self.context,
self.observation_handler.action_processor,
self.sleep_after_execution,
)
else:
self.page = execute_action(
action,
self.page,
self.context,
self.observation_handler.action_processor,
self.sleep_after_execution,
)
success = True
except Exception as e:
fail_error = str(e)
observation = self._get_obs()
observation_metadata = self._get_obs_metadata()
info = {
"page": DetachedPage(self.page.url, self.page.content()),
"fail_error": fail_error,
"observation_metadata": observation_metadata,
}
msg = (
observation,
float(success), # reward
False, # terminated
False, # truncated
info,
)
return msg