107 lines
2.5 KiB
Python
107 lines
2.5 KiB
Python
import base64
|
|
from dataclasses import dataclass
|
|
from io import BytesIO
|
|
from typing import Any, Dict, TypedDict, Union
|
|
|
|
import numpy as np
|
|
import numpy.typing as npt
|
|
from beartype import beartype
|
|
from PIL import Image
|
|
|
|
try:
|
|
from vertexai.preview.generative_models import Image as VertexImage
|
|
except:
|
|
print('Google Cloud not set up, skipping import of vertexai.preview.generative_models.Image')
|
|
|
|
|
|
@dataclass
|
|
class DetachedPage:
|
|
url: str
|
|
content: str # html
|
|
|
|
|
|
@beartype
|
|
def png_bytes_to_numpy(png: bytes) -> npt.NDArray[np.uint8]:
|
|
"""Convert png bytes to numpy array
|
|
|
|
Example:
|
|
|
|
>>> fig = go.Figure(go.Scatter(x=[1], y=[1]))
|
|
>>> plt.imshow(png_bytes_to_numpy(fig.to_image('png')))
|
|
"""
|
|
return np.array(Image.open(BytesIO(png)))
|
|
|
|
|
|
def pil_to_b64(img: Image.Image) -> str:
|
|
with BytesIO() as image_buffer:
|
|
img.save(image_buffer, format="PNG")
|
|
byte_data = image_buffer.getvalue()
|
|
img_b64 = base64.b64encode(byte_data).decode("utf-8")
|
|
img_b64 = "data:image/png;base64," + img_b64
|
|
return img_b64
|
|
|
|
|
|
def pil_to_vertex(img: Image.Image) -> str:
|
|
with BytesIO() as image_buffer:
|
|
img.save(image_buffer, format="PNG")
|
|
byte_data = image_buffer.getvalue()
|
|
img_vertex = VertexImage.from_bytes(byte_data)
|
|
return img_vertex
|
|
|
|
|
|
class DOMNode(TypedDict):
|
|
nodeId: str
|
|
nodeType: str
|
|
nodeName: str
|
|
nodeValue: str
|
|
attributes: str
|
|
backendNodeId: str
|
|
parentId: str
|
|
childIds: list[str]
|
|
cursor: int
|
|
union_bound: list[float] | None
|
|
center: list[float] | None
|
|
|
|
|
|
class AccessibilityTreeNode(TypedDict):
|
|
nodeId: str
|
|
ignored: bool
|
|
role: dict[str, Any]
|
|
chromeRole: dict[str, Any]
|
|
name: dict[str, Any]
|
|
properties: list[dict[str, Any]]
|
|
childIds: list[str]
|
|
parentId: str
|
|
backendDOMNodeId: int
|
|
frameId: str
|
|
bound: list[float] | None
|
|
union_bound: list[float] | None
|
|
offsetrect_bound: list[float] | None
|
|
center: list[float] | None
|
|
|
|
|
|
class BrowserConfig(TypedDict):
|
|
win_upper_bound: float
|
|
win_left_bound: float
|
|
win_width: float
|
|
win_height: float
|
|
win_right_bound: float
|
|
win_lower_bound: float
|
|
device_pixel_ratio: float
|
|
|
|
|
|
class BrowserInfo(TypedDict):
|
|
DOMTree: dict[str, Any]
|
|
config: BrowserConfig
|
|
|
|
|
|
AccessibilityTree = list[AccessibilityTreeNode]
|
|
DOMTree = list[DOMNode]
|
|
|
|
Observation = str | npt.NDArray[np.uint8]
|
|
|
|
|
|
class StateInfo(TypedDict):
|
|
observation: dict[str, Observation]
|
|
info: Dict[str, Any]
|