AgentOccam/browser_env/envs.py
yuyr bb7393bcb3 增加环境变量指向g14部署webarena环境;
map网站使用socks5连接到官网openstreetmap;
模型改成gpt-4o; llm fuzzy math也使用4o;
增加generate_test_data.py脚本
2025-04-15 18:16:52 +08:00

344 lines
12 KiB
Python

import json
import re
import time
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Union
import numpy as np
import numpy.typing as npt
from beartype import beartype
from beartype.door import is_bearable
from gymnasium import Env
from gymnasium.spaces import Box, Text
from playwright.sync_api import (
CDPSession,
Page,
Playwright,
ViewportSize,
expect,
sync_playwright,
)
from .actions import Action, execute_action, get_action_space
from .processors import ObservationHandler, ObservationMetadata
from .utils import (
AccessibilityTree,
DetachedPage,
Observation,
png_bytes_to_numpy,
)
import base64
from .scripts import *
@dataclass
class PlaywrightScript:
function: str # goto, get_by_role
destination: str # https://www.google.com/, combobox
name: str | None = None # Search, Avatar 2009
operation: str | None = None # click, fill, press
value: str | None = None # avatar movie, Enter
def parse_action(action: str) -> PlaywrightScript:
splitted = action.strip().split(" ")
assert len(splitted) >= 2
match splitted[:2]:
case ["goto", url]:
assert len(splitted) == 2
return PlaywrightScript("goto", url)
case ["get_by_role", destination]:
assert len(splitted) >= 4
match splitted[2:]:
case [name, operation]:
return PlaywrightScript(
"get_by_role", destination, name, operation
)
case [name, operation, value]:
return PlaywrightScript(
"get_by_role", destination, name, operation, value
)
case _:
raise ValueError("Invalid action")
case _:
raise ValueError(f"Invalid action {action}")
class ScriptBrowserEnv(Env[dict[str, Observation], Action]):
"""
The goal of this environment is to produce a prototype of a browser environment.
In the end, we want to support a fully configurable browser environment with wide
range of action spaces and observation spaces, both structured and unstructured.
But in this prototype, we just support action space specified by Playwright script,
and observation space is the html content of the page.
"""
@beartype
def __init__(
self,
max_page_length: int = 8192,
headless: bool = True,
slow_mo: int = 0,
observation_type: str = "html",
current_viewport_only: bool = False,
viewport_size: ViewportSize = {"width": 1280, "height": 720},
save_trace_enabled: bool = False,
sleep_after_execution: float = 5.0,
global_config = None,
proxy_url: str = "",
):
# TODO: make Space[Action] = ActionSpace
self.action_space = get_action_space() # type: ignore[assignment]
self.headless = headless
self.slow_mo = slow_mo
self.current_viewport_only = current_viewport_only
self.reset_finished = False
self.viewport_size = viewport_size
self.save_trace_enabled = save_trace_enabled
self.sleep_after_execution = sleep_after_execution
self.global_config = global_config
self.proxy_url = proxy_url
print(f"ScriptBrowserEnv proxy_url: {self.proxy_url}")
match observation_type:
case "html" | "accessibility_tree":
self.text_observation_type = observation_type
self.image_observation_type = ""
self.main_observation_type = "text"
case "image":
self.image_observation_type = observation_type
self.text_observation_type = "" # type: ignore[assignment]
self.main_observation_type = "image"
case _:
raise ValueError(
f"Unsupported observation type: {observation_type}"
)
self.observation_handler = ObservationHandler(
self.main_observation_type,
self.text_observation_type,
self.image_observation_type,
self.current_viewport_only,
self.viewport_size,
)
self.observation_space = (
self.observation_handler.get_observation_space()
)
@beartype
def setup(self, config_file: Path | None = None) -> None:
def handle_dialog(dialog):
self.page.dialog_message = dialog.message
dialog.dismiss()
self.context_manager = sync_playwright()
self.playwright = self.context_manager.__enter__()
self.browser = self.playwright.chromium.launch(
headless=self.headless, slow_mo=self.slow_mo
)
if config_file:
with open(config_file, "r") as f:
instance_config = json.load(f)
else:
instance_config = {}
storage_state = instance_config.get("storage_state", None)
start_url = instance_config.get("start_url", None)
geolocation = instance_config.get("geolocation", None)
self.context = self.browser.new_context(
viewport=self.viewport_size,
storage_state=storage_state,
geolocation=geolocation,
device_scale_factor=1,
proxy={
"server": self.proxy_url,
"bypass": "127.0.0.1,localhost",
}
if self.proxy_url
else None,
)
if self.save_trace_enabled:
self.context.tracing.start(screenshots=True, snapshots=True)
if start_url:
start_urls = start_url.split(" |AND| ")
for url in start_urls:
page = self.context.new_page()
page.on("dialog", handle_dialog)
client = page.context.new_cdp_session(
page
) # talk to chrome devtools
if self.text_observation_type == "accessibility_tree":
client.send("Accessibility.enable")
page.client = client # type: ignore # TODO[shuyanzh], fix this hackey client
page.goto(url, timeout=10000)
# set the first page as the current page
self.page = self.context.pages[0]
self.page.bring_to_front()
else:
self.page = self.context.new_page()
page.on("dialog", handle_dialog)
client = self.page.context.new_cdp_session(self.page)
if self.text_observation_type == "accessibility_tree":
client.send("Accessibility.enable")
self.page.client = client # type: ignore
def get_page_client(self, page: Page) -> CDPSession:
return page.client # type: ignore
def _get_obs(self) -> dict[str, Observation]:
obs = self.observation_handler.get_observation(
self.page, self.get_page_client(self.page)
)
return obs
def _get_obs_metadata(self) -> dict[str, ObservationMetadata]:
metadata = self.observation_handler.get_observation_metadata()
return metadata
@beartype
def reset(
self,
*,
seed: int | None = None,
options: dict[str, str] | None = None,
) -> tuple[dict[str, Observation], dict[str, Any]]:
"""
Reset the environment.
:param options: options for the environment. The current supported options are:
- "storage_state": the storage state of the browser. It is a file path to a json file.
"""
super().reset(seed=seed, options=options)
if self.reset_finished:
self.context_manager.__exit__()
if options is not None and "config_file" in options:
config_file = Path(options["config_file"])
if config_file.exists():
self.setup(config_file=config_file)
else:
raise ValueError(f"Config file {config_file} does not exist.")
else:
self.setup()
self.reset_finished = True
if self.sleep_after_execution > 0:
time.sleep(self.sleep_after_execution)
images = self.modify_page()
observation = self._get_obs()
observation_metadata = self._get_obs_metadata()
info = {
"page": DetachedPage(self.page.url, ""),
"fail_error": "",
"observation_metadata": observation_metadata,
"images": images,
}
return (observation, info)
def save_trace(self, trace_path: str | Path) -> None:
if self.save_trace_enabled:
self.context.tracing.stop(path=trace_path)
def close(self) -> None:
if self.reset_finished:
self.context_manager.__exit__()
def step(
self, action: Action
) -> tuple[dict[str, Observation], float, bool, bool, dict[str, Any]]:
if not self.reset_finished:
raise RuntimeError("Call reset first before calling step.")
success = False
fail_error = ""
try:
self.page = execute_action(
action,
self.page,
self.context,
self.observation_handler.action_processor,
)
success = True
except Exception as e:
fail_error = str(e)
raise e
# hard sleep TODO[shuyanzh] suboptimal, may need to check network
if self.sleep_after_execution > 0:
time.sleep(self.sleep_after_execution)
images = self.modify_page()
observation = self._get_obs()
observation_metadata = self._get_obs_metadata()
info = {
"page": DetachedPage(self.page.url, self.page.content()),
"fail_error": fail_error,
"observation_metadata": observation_metadata,
"images": images,
}
msg = (
observation,
float(success), # reward
False, # terminated
False, # truncated
info,
)
return msg
def modify_page(self):
self.page.wait_for_timeout(500)
try:
self.page.evaluate(remove_id_script)
except:
pass
suffix = getattr(self.global_config, "logname", "")
if suffix:
img_bytes = self.page.screenshot(path=f"output/screenshot-{suffix}.png", full_page=True)
else:
img_bytes = self.page.screenshot(path="output/screenshot_raw.png")
raw_image = base64.b64encode(img_bytes).decode()
self.page.evaluate(mix_marker_script)
self.page.wait_for_timeout(100)
# get all clickable elements
start_id = 0
elem_items, start_id = self.page.evaluate(get_rect_script, {
"selector": ".possible-clickable-element",
"startIndex": start_id
})
# get ocr items
ocr_items = []
# ocr_items = page.evaluate(canva_handler_script)
# svg_items, _ = page.evaluate(get_rect_script, {"selector": "svg", "startIndex": -1})
# ocr_items = ocr_items + svg_items
# ocr_items, start_id = get_canva_images(ocr_items, img_bytes, start_id)
items = elem_items + ocr_items
# mark our own labels and get the images
items = self.page.evaluate(label_marker_script, items)
if suffix:
img_bytes = self.page.screenshot(path=f"output/marked-{suffix}.png", full_page=True)
else:
img_bytes = self.page.screenshot(path="output/marked.png")
marked_image = base64.b64encode(img_bytes).decode()
self.page.evaluate(remove_label_mark_script)
return {
"raw_image": raw_image,
"marked_image": marked_image,
}