import json import re from pathlib import Path from typing import Any, TypedDict from PIL import Image from browser_env import Action, ActionParsingError, Trajectory from browser_env.env_config import URL_MAPPINGS from browser_env.utils import StateInfo, pil_to_b64, pil_to_vertex from llms import lm_config from llms.tokenizers import Tokenizer from llms.utils import APIInput class Instruction(TypedDict): """Instruction for constructing prompt""" intro: str examples: list[tuple[str, str]] template: str meta_data: dict[str, Any] class PromptConstructor(object): def __init__( self, instruction_path: str | Path, lm_config: lm_config.LMConfig, tokenizer: Tokenizer, ): self.instruction_path = Path(instruction_path) self.obs_modality = "text" self.lm_config = lm_config instruction = json.load(open(self.instruction_path)) instruction["examples"] = [tuple(e) for e in instruction["examples"]] self.instruction: Instruction = instruction self.tokenizer = tokenizer def get_lm_api_input( self, intro: str, examples: list[tuple[str, str]], current: str ) -> APIInput: """Return the require format for an API""" message: list[dict[str, str]] | str if "openai" in self.lm_config.provider: if self.lm_config.mode == "chat": message = [{"role": "system", "content": intro}] for (x, y) in examples: message.append( { "role": "system", "name": "example_user", "content": x, } ) message.append( { "role": "system", "name": "example_assistant", "content": y, } ) message.append({"role": "user", "content": current}) return message elif self.lm_config.mode == "completion": message = f"{intro}\n\n" message += "Here are a few examples:\n" for example in examples: message += f"Observation\n:{example[0]}\n\n" message += f"Action: {example[1]}\n\n" message += "Now make prediction given the observation\n\n" message += f"Observation\n:{current}\n\n" message += "Action:" return message else: raise ValueError( f"OpenAI models do not support mode {self.lm_config.mode}" ) elif "huggingface" in self.lm_config.provider: # https://huggingface.co/blog/llama2#how-to-prompt-llama-2 # https://github.com/facebookresearch/llama/blob/main/llama/generation.py#L320 if "Llama-2" in self.lm_config.model: if self.lm_config.mode == "chat": B_INST, E_INST = "[INST]", "[/INST]" B_SYS, E_SYS = "<>\n", "\n<>\n\n" BOS, EOS = "", "" # adding the system message to be the starting of the first example examples = [ ( B_SYS + intro + E_SYS + examples[0][0], examples[0][1], ) ] + examples[1:] message = "".join( [ f"{BOS}{B_INST} {x.strip()} {E_INST} {y.strip()} {EOS}" for (x, y) in examples ] ) # add the current observation message += f"{BOS}{B_INST} {current.strip()} {E_INST} {self.instruction['meta_data'].get('force_prefix', '')}" return message else: raise ValueError("Only chat mode is supported for Llama-2") else: raise ValueError( f"Huggingface models do not support model_tag {self.lm_config.gen_config['model_tag']}" ) else: raise NotImplementedError( f"Provider {self.lm_config.provider} not implemented" ) def construct( self, trajectory: Trajectory, intent: str, meta_data: dict[str, Any] = {}, ) -> APIInput: raise NotImplementedError def map_url_to_real(self, url: str) -> str: """Map the urls to their real world counterparts""" for i, j in URL_MAPPINGS.items(): if i in url: url = url.replace(i, j) return url def map_url_to_local(self, url: str) -> str: """Map the urls to their local counterparts""" for i, j in URL_MAPPINGS.items(): if j in url: url = url.replace(j, i) # https if j.replace("http", "https") in url: url = url.replace(j.replace("http", "https"), i) return url def _extract_action(self, response: str) -> str: raise NotImplementedError def extract_action(self, response: str) -> str: response = self._extract_action(response) response = self.map_url_to_local(response) return response class DirectPromptConstructor(PromptConstructor): """The agent will direct predict the action""" def __init__( self, instruction_path: str | Path, lm_config: lm_config.LMConfig, tokenizer: Tokenizer, ): super().__init__(instruction_path, lm_config, tokenizer) def construct( self, trajectory: Trajectory, intent: str, meta_data: dict[str, Any] = {}, ) -> APIInput: """Construct prompt given the trajectory""" intro = self.instruction["intro"] examples = self.instruction["examples"] template = self.instruction["template"] keywords = self.instruction["meta_data"]["keywords"] state_info: StateInfo = trajectory[-1] # type: ignore[assignment] obs = state_info["observation"][self.obs_modality] max_obs_length = self.lm_config.gen_config["max_obs_length"] if max_obs_length: if self.lm_config.provider == "google": print("NOTE: This is a Gemini model, so we use characters instead of tokens for max_obs_length.") obs = obs[:max_obs_length] else: obs = self.tokenizer.decode(self.tokenizer.encode(obs)[:max_obs_length]) # type: ignore[arg-type] page = state_info["info"]["page"] url = page.url previous_action_str = meta_data["action_history"][-1] # input x current = template.format( objective=intent, url=self.map_url_to_real(url), observation=obs, previous_action=previous_action_str, ) # make sure all keywords are replaced assert all([f"{{k}}" not in current for k in keywords]) prompt = self.get_lm_api_input(intro, examples, current) return prompt def _extract_action(self, response: str) -> str: action_splitter = self.instruction["meta_data"]["action_splitter"] pattern = rf"{action_splitter}((.|\n)*?){action_splitter}" match = re.search(pattern, response) if match: return match.group(1).strip() else: raise ActionParsingError( f"Cannot parse action from response {response}" ) class CoTPromptConstructor(PromptConstructor): """The agent will perform step-by-step reasoning before the answer""" def __init__( self, instruction_path: str | Path, lm_config: lm_config.LMConfig, tokenizer: Tokenizer, ): super().__init__(instruction_path, lm_config, tokenizer) self.answer_phrase = self.instruction["meta_data"]["answer_phrase"] def construct( self, trajectory: Trajectory, intent: str, meta_data: dict[str, Any] = {}, ) -> APIInput: intro = self.instruction["intro"] examples = self.instruction["examples"] template = self.instruction["template"] keywords = self.instruction["meta_data"]["keywords"] state_info: StateInfo = trajectory[-1] # type: ignore[assignment] obs = state_info["observation"][self.obs_modality] max_obs_length = self.lm_config.gen_config["max_obs_length"] if max_obs_length: if self.lm_config.provider == "google": print("NOTE: This is a Gemini model, so we use characters instead of tokens for max_obs_length.") obs = obs[:max_obs_length] else: obs = self.tokenizer.decode(self.tokenizer.encode(obs)[:max_obs_length]) # type: ignore[arg-type] page = state_info["info"]["page"] url = page.url previous_action_str = meta_data["action_history"][-1] current = template.format( objective=intent, url=self.map_url_to_real(url), observation=obs, previous_action=previous_action_str, ) assert all([f"{{k}}" not in current for k in keywords]) prompt = self.get_lm_api_input(intro, examples, current) return prompt def _extract_action(self, response: str) -> str: # find the first occurence of action action_splitter = self.instruction["meta_data"]["action_splitter"] pattern = rf"{action_splitter}((.|\n)*?){action_splitter}" match = re.search(pattern, response) if match: return match.group(1).strip() else: raise ActionParsingError( f'Cannot find the answer phrase "{self.answer_phrase}" in "{response}"' ) class MultimodalCoTPromptConstructor(CoTPromptConstructor): """The agent will perform step-by-step reasoning before the answer""" def __init__( self, instruction_path: str | Path, lm_config: lm_config.LMConfig, tokenizer: Tokenizer, ): super().__init__(instruction_path, lm_config, tokenizer) self.answer_phrase = self.instruction["meta_data"]["answer_phrase"] def construct( self, trajectory: Trajectory, intent: str, page_screenshot_img: Image.Image, images: list[Image.Image], meta_data: dict[str, Any] = {}, ) -> APIInput: intro = self.instruction["intro"] examples = self.instruction["examples"] template = self.instruction["template"] keywords = self.instruction["meta_data"]["keywords"] state_info: StateInfo = trajectory[-1] # type: ignore[assignment] obs = state_info["observation"][self.obs_modality] max_obs_length = self.lm_config.gen_config["max_obs_length"] if max_obs_length: if self.lm_config.provider in ["google", "api", "finetune"]: print("NOTE: This is a Gemini / API model, so we use characters instead of tokens for max_obs_length.") obs = obs[:max_obs_length] else: obs = self.tokenizer.decode(self.tokenizer.encode(obs)[:max_obs_length]) # type: ignore[arg-type] page = state_info["info"]["page"] url = page.url previous_action_str = meta_data["action_history"][-1] current = template.format( objective=intent, url=self.map_url_to_real(url), observation=obs, previous_action=previous_action_str, ) assert all([f"{{k}}" not in current for k in keywords]) # TODO: for your finetune model, you can config you prompt here if self.lm_config.provider == "finetune": current = "" traj = trajectory[1::2] for rnd, tra in enumerate(traj): tar = '** screenshot **' if rnd > 0 else intent raw = tra["raw_prediction"] current += f"Round {rnd}\n\n<|user|>\n\n** node_info **\n\n{tar}\n\n<|assistant|>\n{raw}\n\n""" current += f"Round {len(traj)}\n\n<|user|>\n\n{obs}\n\n{'** screenshot **' if len(traj) > 0 else intent}\n" prompt = self.get_lm_api_input( intro, examples, current, page_screenshot_img, images ) return prompt def get_lm_api_input( self, intro: str, examples: list[tuple[str, str, str]], current: str, page_screenshot_img: Image.Image, images: list[Image.Image], ) -> APIInput: """Return the require format for an API""" message: list[dict[str, str]] | str | list[str | Image.Image] if "openai" in self.lm_config.provider: if self.lm_config.mode == "chat": message = [ { "role": "system", "content": [{"type": "text", "text": intro}], } ] for (x, y, z) in examples: example_img = Image.open(z) message.append( { "role": "system", "name": "example_user", "content": [ {"type": "text", "text": x}, { "type": "text", "text": "IMAGES: (1) current page screenshot", }, { "type": "image_url", "image_url": { "url": pil_to_b64(example_img) }, }, ], } ) message.append( { "role": "system", "name": "example_assistant", "content": [{"type": "text", "text": y}], } ) # Encode images and page_screenshot_img as base64 strings. current_prompt = current content = [ { "type": "text", "text": "IMAGES: (1) current page screenshot", }, { "type": "image_url", "image_url": {"url": pil_to_b64(page_screenshot_img)}, }, ] for image_i, image in enumerate(images): content.extend( [ { "type": "text", "text": f"({image_i+2}) input image {image_i+1}", }, { "type": "image_url", "image_url": {"url": pil_to_b64(image)}, }, ] ) content = [{"type": "text", "text": current_prompt}] + content message.append({"role": "user", "content": content}) return message else: raise ValueError( f"GPT-4V models do not support mode {self.lm_config.mode}" ) elif "google" in self.lm_config.provider: if self.lm_config.mode == "completion": message = [ intro, "Here are a few examples:", ] for (x, y, z) in examples: example_img = Image.open(z) message.append(f"Observation\n:{x}\n") message.extend( [ "IMAGES:", "(1) current page screenshot:", pil_to_vertex(example_img), ] ) message.append(f"Action: {y}") message.append("Now make prediction given the observation") message.append(f"Observation\n:{current}\n") message.extend( [ "IMAGES:", "(1) current page screenshot:", pil_to_vertex(page_screenshot_img), ] ) for image_i, image in enumerate(images): message.extend( [ f"({image_i+2}) input image {image_i+1}", pil_to_vertex(image), ] ) message.append("Action:") return message else: raise ValueError( f"Gemini models do not support mode {self.lm_config.mode}" ) elif self.lm_config.provider in ["api", "finetune"]: message = [ { "role": "system", "content": intro, } ] # we keep few-shot here, but remove the image corresponding to the current page. for (x, y, _) in examples: message.append({ "role": "user", "content": [ { "type": "text", "text": x }, { "type": "text", "text": "IMAGES: (1) current page screenshot\n\n** Screenshot **\n" }, ], }) message.append({ "role": "assistant", "content": y, }) # TODO: Encode images and page_screenshot_img as base64 strings, we only keep screenshot of current page. current_prompt = current content = [] if self.lm_config.provider != "finetune": content.append({ "type": "text", "text": "IMAGES: (1) current page screenshot", }) if "text" not in self.lm_config.model: content.append({ "type": "image_url", "image_url": {"url": pil_to_b64(page_screenshot_img)}, }) content = [{"type": "text", "text": current_prompt}] + content message.append({"role": "user", "content": content}) return message else: raise NotImplementedError( f"Provider {self.lm_config.provider} not implemented" )