334 lines
11 KiB
Python
334 lines
11 KiB
Python
import json
|
|
import re
|
|
import time
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any, Union
|
|
|
|
import numpy as np
|
|
import numpy.typing as npt
|
|
from beartype import beartype
|
|
from beartype.door import is_bearable
|
|
from gymnasium import Env
|
|
from gymnasium.spaces import Box, Text
|
|
from playwright.sync_api import (
|
|
CDPSession,
|
|
Page,
|
|
Playwright,
|
|
ViewportSize,
|
|
expect,
|
|
sync_playwright,
|
|
)
|
|
|
|
from .actions import Action, execute_action, get_action_space
|
|
from .processors import ObservationHandler, ObservationMetadata
|
|
from .utils import (
|
|
AccessibilityTree,
|
|
DetachedPage,
|
|
Observation,
|
|
png_bytes_to_numpy,
|
|
)
|
|
|
|
import base64
|
|
from .scripts import *
|
|
|
|
@dataclass
|
|
class PlaywrightScript:
|
|
function: str # goto, get_by_role
|
|
destination: str # https://www.google.com/, combobox
|
|
name: str | None = None # Search, Avatar 2009
|
|
operation: str | None = None # click, fill, press
|
|
value: str | None = None # avatar movie, Enter
|
|
|
|
|
|
def parse_action(action: str) -> PlaywrightScript:
|
|
splitted = action.strip().split(" ")
|
|
assert len(splitted) >= 2
|
|
match splitted[:2]:
|
|
case ["goto", url]:
|
|
assert len(splitted) == 2
|
|
return PlaywrightScript("goto", url)
|
|
case ["get_by_role", destination]:
|
|
assert len(splitted) >= 4
|
|
match splitted[2:]:
|
|
case [name, operation]:
|
|
return PlaywrightScript(
|
|
"get_by_role", destination, name, operation
|
|
)
|
|
case [name, operation, value]:
|
|
return PlaywrightScript(
|
|
"get_by_role", destination, name, operation, value
|
|
)
|
|
case _:
|
|
raise ValueError("Invalid action")
|
|
case _:
|
|
raise ValueError(f"Invalid action {action}")
|
|
|
|
|
|
class ScriptBrowserEnv(Env[dict[str, Observation], Action]):
|
|
"""
|
|
The goal of this environment is to produce a prototype of a browser environment.
|
|
In the end, we want to support a fully configurable browser environment with wide
|
|
range of action spaces and observation spaces, both structured and unstructured.
|
|
But in this prototype, we just support action space specified by Playwright script,
|
|
and observation space is the html content of the page.
|
|
"""
|
|
|
|
@beartype
|
|
def __init__(
|
|
self,
|
|
max_page_length: int = 8192,
|
|
headless: bool = True,
|
|
slow_mo: int = 0,
|
|
observation_type: str = "html",
|
|
current_viewport_only: bool = False,
|
|
viewport_size: ViewportSize = {"width": 1280, "height": 720},
|
|
save_trace_enabled: bool = False,
|
|
sleep_after_execution: float = 5.0,
|
|
global_config = None,
|
|
):
|
|
# TODO: make Space[Action] = ActionSpace
|
|
self.action_space = get_action_space() # type: ignore[assignment]
|
|
self.headless = headless
|
|
self.slow_mo = slow_mo
|
|
self.current_viewport_only = current_viewport_only
|
|
self.reset_finished = False
|
|
self.viewport_size = viewport_size
|
|
self.save_trace_enabled = save_trace_enabled
|
|
self.sleep_after_execution = sleep_after_execution
|
|
self.global_config = global_config
|
|
|
|
match observation_type:
|
|
case "html" | "accessibility_tree":
|
|
self.text_observation_type = observation_type
|
|
self.image_observation_type = ""
|
|
self.main_observation_type = "text"
|
|
case "image":
|
|
self.image_observation_type = observation_type
|
|
self.text_observation_type = "" # type: ignore[assignment]
|
|
self.main_observation_type = "image"
|
|
case _:
|
|
raise ValueError(
|
|
f"Unsupported observation type: {observation_type}"
|
|
)
|
|
|
|
self.observation_handler = ObservationHandler(
|
|
self.main_observation_type,
|
|
self.text_observation_type,
|
|
self.image_observation_type,
|
|
self.current_viewport_only,
|
|
self.viewport_size,
|
|
)
|
|
|
|
self.observation_space = (
|
|
self.observation_handler.get_observation_space()
|
|
)
|
|
|
|
@beartype
|
|
def setup(self, config_file: Path | None = None) -> None:
|
|
def handle_dialog(dialog):
|
|
self.page.dialog_message = dialog.message
|
|
dialog.dismiss()
|
|
self.context_manager = sync_playwright()
|
|
self.playwright = self.context_manager.__enter__()
|
|
self.browser = self.playwright.chromium.launch(
|
|
headless=self.headless, slow_mo=self.slow_mo
|
|
)
|
|
|
|
if config_file:
|
|
with open(config_file, "r") as f:
|
|
instance_config = json.load(f)
|
|
else:
|
|
instance_config = {}
|
|
|
|
storage_state = instance_config.get("storage_state", None)
|
|
start_url = instance_config.get("start_url", None)
|
|
geolocation = instance_config.get("geolocation", None)
|
|
|
|
self.context = self.browser.new_context(
|
|
viewport=self.viewport_size,
|
|
storage_state=storage_state,
|
|
geolocation=geolocation,
|
|
device_scale_factor=1,
|
|
)
|
|
if self.save_trace_enabled:
|
|
self.context.tracing.start(screenshots=True, snapshots=True)
|
|
if start_url:
|
|
start_urls = start_url.split(" |AND| ")
|
|
for url in start_urls:
|
|
page = self.context.new_page()
|
|
page.on("dialog", handle_dialog)
|
|
client = page.context.new_cdp_session(
|
|
page
|
|
) # talk to chrome devtools
|
|
if self.text_observation_type == "accessibility_tree":
|
|
client.send("Accessibility.enable")
|
|
page.client = client # type: ignore # TODO[shuyanzh], fix this hackey client
|
|
page.goto(url)
|
|
# set the first page as the current page
|
|
self.page = self.context.pages[0]
|
|
self.page.bring_to_front()
|
|
else:
|
|
self.page = self.context.new_page()
|
|
page.on("dialog", handle_dialog)
|
|
client = self.page.context.new_cdp_session(self.page)
|
|
if self.text_observation_type == "accessibility_tree":
|
|
client.send("Accessibility.enable")
|
|
self.page.client = client # type: ignore
|
|
|
|
def get_page_client(self, page: Page) -> CDPSession:
|
|
return page.client # type: ignore
|
|
|
|
def _get_obs(self) -> dict[str, Observation]:
|
|
obs = self.observation_handler.get_observation(
|
|
self.page, self.get_page_client(self.page)
|
|
)
|
|
return obs
|
|
|
|
def _get_obs_metadata(self) -> dict[str, ObservationMetadata]:
|
|
metadata = self.observation_handler.get_observation_metadata()
|
|
return metadata
|
|
|
|
@beartype
|
|
def reset(
|
|
self,
|
|
*,
|
|
seed: int | None = None,
|
|
options: dict[str, str] | None = None,
|
|
) -> tuple[dict[str, Observation], dict[str, Any]]:
|
|
"""
|
|
Reset the environment.
|
|
:param options: options for the environment. The current supported options are:
|
|
- "storage_state": the storage state of the browser. It is a file path to a json file.
|
|
"""
|
|
super().reset(seed=seed, options=options)
|
|
if self.reset_finished:
|
|
self.context_manager.__exit__()
|
|
|
|
if options is not None and "config_file" in options:
|
|
config_file = Path(options["config_file"])
|
|
if config_file.exists():
|
|
self.setup(config_file=config_file)
|
|
else:
|
|
raise ValueError(f"Config file {config_file} does not exist.")
|
|
else:
|
|
self.setup()
|
|
self.reset_finished = True
|
|
|
|
if self.sleep_after_execution > 0:
|
|
time.sleep(self.sleep_after_execution)
|
|
|
|
images = self.modify_page()
|
|
|
|
observation = self._get_obs()
|
|
observation_metadata = self._get_obs_metadata()
|
|
info = {
|
|
"page": DetachedPage(self.page.url, ""),
|
|
"fail_error": "",
|
|
"observation_metadata": observation_metadata,
|
|
"images": images,
|
|
}
|
|
|
|
return (observation, info)
|
|
|
|
def save_trace(self, trace_path: str | Path) -> None:
|
|
if self.save_trace_enabled:
|
|
self.context.tracing.stop(path=trace_path)
|
|
|
|
def close(self) -> None:
|
|
if self.reset_finished:
|
|
self.context_manager.__exit__()
|
|
|
|
def step(
|
|
self, action: Action
|
|
) -> tuple[dict[str, Observation], float, bool, bool, dict[str, Any]]:
|
|
if not self.reset_finished:
|
|
raise RuntimeError("Call reset first before calling step.")
|
|
|
|
success = False
|
|
fail_error = ""
|
|
try:
|
|
self.page = execute_action(
|
|
action,
|
|
self.page,
|
|
self.context,
|
|
self.observation_handler.action_processor,
|
|
)
|
|
success = True
|
|
except Exception as e:
|
|
fail_error = str(e)
|
|
raise e
|
|
|
|
# hard sleep TODO[shuyanzh] suboptimal, may need to check network
|
|
if self.sleep_after_execution > 0:
|
|
time.sleep(self.sleep_after_execution)
|
|
|
|
images = self.modify_page()
|
|
|
|
observation = self._get_obs()
|
|
observation_metadata = self._get_obs_metadata()
|
|
|
|
info = {
|
|
"page": DetachedPage(self.page.url, self.page.content()),
|
|
"fail_error": fail_error,
|
|
"observation_metadata": observation_metadata,
|
|
"images": images,
|
|
}
|
|
|
|
msg = (
|
|
observation,
|
|
float(success), # reward
|
|
False, # terminated
|
|
False, # truncated
|
|
info,
|
|
)
|
|
return msg
|
|
|
|
def modify_page(self):
|
|
self.page.wait_for_timeout(500)
|
|
try:
|
|
self.page.evaluate(remove_id_script)
|
|
except:
|
|
pass
|
|
|
|
suffix = getattr(self.global_config, "logname", "")
|
|
if suffix:
|
|
img_bytes = self.page.screenshot(path=f"output/screenshot-{suffix}.png", full_page=True)
|
|
else:
|
|
img_bytes = self.page.screenshot(path="output/screenshot_raw.png")
|
|
raw_image = base64.b64encode(img_bytes).decode()
|
|
|
|
self.page.evaluate(mix_marker_script)
|
|
self.page.wait_for_timeout(100)
|
|
|
|
# get all clickable elements
|
|
start_id = 0
|
|
elem_items, start_id = self.page.evaluate(get_rect_script, {
|
|
"selector": ".possible-clickable-element",
|
|
"startIndex": start_id
|
|
})
|
|
|
|
# get ocr items
|
|
ocr_items = []
|
|
# ocr_items = page.evaluate(canva_handler_script)
|
|
# svg_items, _ = page.evaluate(get_rect_script, {"selector": "svg", "startIndex": -1})
|
|
# ocr_items = ocr_items + svg_items
|
|
# ocr_items, start_id = get_canva_images(ocr_items, img_bytes, start_id)
|
|
|
|
items = elem_items + ocr_items
|
|
|
|
# mark our own labels and get the images
|
|
items = self.page.evaluate(label_marker_script, items)
|
|
if suffix:
|
|
img_bytes = self.page.screenshot(path=f"output/marked-{suffix}.png", full_page=True)
|
|
else:
|
|
img_bytes = self.page.screenshot(path="output/marked.png")
|
|
marked_image = base64.b64encode(img_bytes).decode()
|
|
|
|
self.page.evaluate(remove_label_mark_script)
|
|
|
|
return {
|
|
"raw_image": raw_image,
|
|
"marked_image": marked_image,
|
|
} |