AgentOccam/AgentOccam/AgentOccam.py
2025-01-22 11:32:35 -08:00

1448 lines
76 KiB
Python

from AgentOccam.obs_opt import parse_node_descendants, parse_node_ancestors, parse_node_siblings, action_set_invisible, action_set_visible, action_set_visible_if_with_name, translate_node_to_str, construct_new_DOM_with_visible_nodes
from AgentOccam.llms.claude import call_claude, call_claude_with_messages, arrange_message_for_claude
from AgentOccam.llms.mistral import call_mistral, call_mistral_with_messages, arrange_message_for_mistral
from AgentOccam.llms.cohere import call_cohere, call_cohere_with_messages, arrange_message_for_cohere
from AgentOccam.llms.llama import call_llama, call_llama_with_messages, arrange_message_for_llama
from AgentOccam.llms.titan import call_titan, call_titan_with_messages, arrange_message_for_titan
from AgentOccam.llms.gpt import call_gpt, call_gpt_with_messages, arrange_message_for_gpt
from AgentOccam.llms.gemini import call_gemini, call_gemini_with_messages, arrange_message_for_gemini
from AgentOccam.utils import CURRENT_DIR, HOMEPAGE_URL
from typing import Dict
import re
import copy
import os
from functools import partial
import random
import json
import warnings
warnings.filterwarnings("ignore")
DEFAULT_DOCUMENTED_INTERACTION_ELEMENTS = ["observation", "action"]
DEFAULT_ONLINE_INTERACTION_ELEMENTS = ["url", "observation"]
MODEL_FAMILIES = ["claude", "mistral", "cohere", "llama", "titan", "gpt", "gemini"]
CALL_MODEL_MAP = {
"claude": call_claude,
"mistral": call_mistral,
"cohere": call_cohere,
"llama": call_llama,
"titan": call_titan,
"gpt": call_gpt,
"gemini": call_gemini,
}
CALL_MODEL_WITH_MESSAGES_FUNCTION_MAP = {
"claude": call_claude_with_messages,
"mistral": call_mistral_with_messages,
"cohere": call_cohere_with_messages,
"llama": call_llama_with_messages,
"titan": call_titan_with_messages,
"gpt": call_gpt_with_messages,
"gemini": call_gemini_with_messages,
}
ARRANGE_MESSAGE_FOR_MODEL_MAP = {
"claude": arrange_message_for_claude,
"mistral": arrange_message_for_mistral,
"cohere": arrange_message_for_cohere,
"llama": arrange_message_for_llama,
"titan": arrange_message_for_titan,
"gpt": arrange_message_for_gpt,
"gemini": arrange_message_for_gemini,
}
class Agent:
def __init__(self, config, objective, prompt_template):
self.config = config
self.objective = objective
self.prompt_template = prompt_template
if hasattr(self.config, "documented_interaction_elements"):
self.previous_interactions = {k: [] for k in set(DEFAULT_DOCUMENTED_INTERACTION_ELEMENTS+self.config.documented_interaction_elements)}
else:
self.previous_interactions = {k: [] for k in DEFAULT_DOCUMENTED_INTERACTION_ELEMENTS}
if hasattr(self.config, "online_interaction_elements"):
self.online_interaction = {k: None for k in set(DEFAULT_ONLINE_INTERACTION_ELEMENTS+self.config.online_interaction_elements)}
else:
self.online_interaction = {k: None for k in DEFAULT_ONLINE_INTERACTION_ELEMENTS}
self.model_family = [model_family for model_family in MODEL_FAMILIES if model_family in self.config.model][0]
self.call_model = partial(CALL_MODEL_MAP[self.model_family], model_id=self.config.model)
self.call_model_with_message = partial(CALL_MODEL_WITH_MESSAGES_FUNCTION_MAP[self.model_family], model_id=self.config.model)
self.arrange_message_for_model = ARRANGE_MESSAGE_FOR_MODEL_MAP[self.model_family]
def shift_model(self, model_id):
self.model_family = [model_family for model_family in MODEL_FAMILIES if model_family in model_id][0]
self.call_model = partial(CALL_MODEL_MAP[self.model_family], model_id=model_id)
self.call_model_with_message = partial(CALL_MODEL_WITH_MESSAGES_FUNCTION_MAP[self.model_family], model_id=model_id)
self.arrange_message_for_model = ARRANGE_MESSAGE_FOR_MODEL_MAP[self.model_family]
def prune_message_list(self, message_list):
return self.merge_adjacent_text([m for m in message_list if not (m[0]=="text" and len(m[1])==0)])
def merge_adjacent_text(self, message_list):
merged_list = []
current_tuple = None
for tup in message_list:
if tup[0] == "text":
if current_tuple:
current_tuple = (current_tuple[0], current_tuple[1] + tup[1])
else:
current_tuple = tup
else:
if current_tuple:
merged_list.append(current_tuple)
current_tuple = None
merged_list.append(tup)
if current_tuple:
merged_list.append(current_tuple)
return merged_list
def get_step(self):
return len(self.previous_interactions["action"])
def update_objective(self, objective):
self.objective = objective
def update_online_state(self, **online_states):
for k in online_states.keys():
if k in self.online_interaction.keys():
self.online_interaction[k] = online_states[k]
def update_history(self, **interaction_dict):
for k in interaction_dict.keys():
if k in self.previous_interactions.keys():
self.previous_interactions[k].append(interaction_dict[k])
def equal_history_length(self):
lengths = [len(self.previous_interactions[k]) for k in self.previous_interactions.keys()]
return (len(set(lengths)) == 1)
def parse_elements(self, text, key_list):
element_dict = {}
for k in key_list:
# _match = re.search(rf'{k.upper()}:\s*(.*?)\s*(?=\n[A-Z\d\s\W]*: *\n|$)', text, re.DOTALL)
_match = re.search(rf'{k.upper()}:\s*(.*?)\s*(?=\n[A-Z\s]*:|$)', text, re.DOTALL)
element_dict[k] = _match.group(1).strip() if _match else ""
return element_dict
def get_output_specifications(self):
output_specifications = "\n".join([f"{o.upper()}:\n" + "".join(open(os.path.join(CURRENT_DIR, "AgentOccam", "prompts", "output_specifications", "{}.txt".format(o.replace(" ", "_"))), "r").readlines()) for o in self.config.output])
return output_specifications
def parse_stipulated_action_list(self, text: str, action: str, actions: list) -> str:
pattern = rf'({re.escape(action)}\s*(.*?))(?=\n(?:{"|".join(map(re.escape, actions))})|$)'
return [match[0].strip() for match in re.findall(pattern, text, re.DOTALL)]
def parse_str_to_action_list(self, text:str, actions: list):
remain_text = copy.deepcopy(text)
action_list = []
while remain_text:
find_action = False
for action in actions:
if remain_text.startswith(action):
match = re.search(rf'({re.escape(action)}\s*(.*?))(?=\n(?:{"|".join(map(re.escape, actions))})|$)', remain_text, re.DOTALL)
action_list.append(match[0])
remain_text = remain_text[len(match[0]):].strip()
find_action = True
if not find_action:
break
return action_list
def get_observation_text(self, idx=None):
if isinstance(self.online_interaction["observation"], dict):
if idx:
return self.previous_interactions["observation"][idx]["text"]
return self.online_interaction["observation"]["text"]
elif isinstance(self.online_interaction["observation"], str):
if idx:
return self.previous_interactions["observation"][idx]
return self.online_interaction["observation"]
def get_observation_image(self, idx=None):
if isinstance(self.online_interaction["observation"], dict):
if idx:
return self.previous_interactions["observation"][idx]["image"]
return self.online_interaction["observation"]["image"]
elif isinstance(self.online_interaction["observation"], str):
return None
def get_observation_node(self, idx=None):
if isinstance(self.online_interaction["observation"], dict):
if idx != None:
return self.previous_interactions["observation"][idx]["node"]
return self.online_interaction["observation"]["node"]
elif isinstance(self.online_interaction["observation"], str):
return None
def get_observation_node_str(self, idx=None):
if isinstance(self.online_interaction["observation"], dict):
if idx != None:
return self.previous_interactions["observation"][idx]["node_str"]
return translate_node_to_str(self.online_interaction["observation"]["node"], mode="name_only")
elif isinstance(self.online_interaction["observation"], str):
return None
def del_observation_node(self):
if isinstance(self.online_interaction["observation"], str):
return
if isinstance(self.online_interaction["observation"], dict):
for idx in range(len(self.previous_interactions["observation"])):
if "node" in self.previous_interactions["observation"][idx].keys() and self.previous_interactions["observation"][idx]["node"]:
node_str = translate_node_to_str(self.previous_interactions["observation"][idx]["node"], mode="name_only")
self.previous_interactions["observation"][idx]["node_str"] = node_str
self.previous_interactions["observation"][idx]["node"].delete_tree()
self.previous_interactions["observation"][idx]["node"] = None
class PlanTreeNode:
def __init__(self, id, type, text, level, url, step):
self.visible = True
self.id = id
self.type = type
self.text = text
self.level = level
self.url = url
self.step = step
self.children = []
self.parent = None
self.note = []
self.hint = []
self.resume_reason = []
self.steps_taken = []
def reset(self):
self.visible = True
self.note = []
self.hint = []
self.steps_taken = []
def add_child(self, child):
child.parent = self
self.children.append(child)
def search_node_by_id(self, target_id):
if self.visible and self.id == target_id:
return self
for child in self.children:
result = child.search_node_by_id(target_id)
if result:
return result
return None
def traverse(self, action=None, tree_buffer=[]):
res_action = action(self)
if res_action:
if isinstance(res_action, list):
tree_buffer.extend(res_action)
else:
tree_buffer.append(res_action)
for child in self.children:
child.traverse(action, tree_buffer=tree_buffer)
class QAActor(Agent):
def __init__(self, config, objective, prompt_template):
super().__init__(config, objective, prompt_template)
def get_instruction(self):
return self.prompt_template["instruction_template"]
def get_online_input(self):
return [("text", self.prompt_template["input_template"].replace("{current_observation}", self.get_observation_text()).replace("{objective}", self.objective))]
def get_action(self, instruction, online_input):
model_response = self.call_model_with_message(system_prompt=instruction, messages=self.arrange_message_for_model(online_input))
action_elements = self.parse_elements(text=model_response, key_list=self.config.output)
action = action_elements["response"]
action_elements["action"] = f"note [{action}]"
action_elements["instruction"] = instruction
action_elements["input"] = online_input
return model_response, action_elements
class PlanningActor(Agent):
def __init__(self, config, objective, prompt_template):
super().__init__(config, objective, prompt_template)
self.instruction = None
def get_planning_specifications(self):
return "\n".join(["- " + "".join(open(os.path.join(CURRENT_DIR, "AgentOccam", "prompts", "planning_specifications", f"{p}.txt"), "r").readlines()) for p in self.config.planning_command])
def get_instruction(self):
if self.instruction:
return self.instruction
output_specifications = self.get_output_specifications()
self.instruction = self.prompt_template["instruction_template"].replace("{output_specifications}", output_specifications).replace("{planning_specifications}", self.get_planning_specifications())
return self.instruction
def get_online_input(self):
return None
def get_action(self, instruction, online_input):
model_response = self.call_model_with_message(system_prompt=instruction, messages=self.arrange_message_for_model(online_input))
action_elements = self.parse_elements(text=model_response, key_list=self.config.output)
action_elements["action"] = copy.deepcopy(action_elements["plan"])
del action_elements["plan"]
action_elements["reason"] = "N/A"
action_elements["instruction"] = instruction
action_elements["input"] = online_input
return model_response, action_elements
class ReflectionActor(Agent):
def __init__(self, config, objective, prompt_template):
super().__init__(config, objective, prompt_template)
self.instruction = None
def get_planning_specifications(self):
return "\n".join(["- " + "".join(open(os.path.join(CURRENT_DIR, "AgentOccam", "prompts", "planning_specifications", f"{p}.txt"), "r").readlines()) for p in self.config.planning_command])
def get_navigation_specifications(self):
return "\n".join(["- " + "".join(open(os.path.join(CURRENT_DIR, "AgentOccam", "prompts", "navigation_specifications", f"{n}.txt"), "r").readlines()) for n in self.config.navigation_command])
def get_instruction(self):
if self.instruction:
return self.instruction
output_specifications = self.get_output_specifications()
planning_specifications = self.get_planning_specifications()
navigation_specifications = self.get_navigation_specifications()
instruction = self.prompt_template["instruction_template"]
instruction = instruction.replace("{output_specifications}", output_specifications)
instruction = instruction.replace("{planning_specifications}", planning_specifications)
instruction = instruction.replace("{navigation_specifications}", navigation_specifications)
self.instruction = instruction
return self.instruction
def get_online_input(self):
return None
def get_action(self, instruction, online_input):
model_response = self.call_model_with_message(system_prompt=instruction, messages=self.arrange_message_for_model(online_input))
action_elements = self.parse_elements(text=model_response, key_list=self.config.output)
action_elements["instruction"] = instruction
action_elements["input"] = online_input
return model_response, action_elements
IDENTITY_CLASS_MAP = {
"QA": QAActor,
"planning": PlanningActor,
"reflection": ReflectionActor,
}
class Actor(Agent):
def __init__(self, config, objective, prompt_template, plan_tree_node):
super().__init__(config, objective, prompt_template)
self.plan_tree_root = plan_tree_node
self.active_node = plan_tree_node
self.output_specifications = None
self.planning_specifications = None
self.navigation_specifications = None
self.criticism_element_list = None
self.output_play_path = os.path.join(CURRENT_DIR, f"play-{self.config.others.logname}.txt") if getattr(self.config.others, "logname", "") != "" else os.path.join(CURRENT_DIR, f"play.txt")
self.output_trash_path = os.path.join(CURRENT_DIR, f"trash-{self.config.others.logname}.txt") if getattr(self.config.others, "logname", "") != "" else os.path.join(CURRENT_DIR, f"trash.txt")
self.identities = []
if hasattr(self.config, "identities"):
i = 0
while hasattr(self.config.identities, f"identity_{i}"):
identity_config = getattr(self.config.identities, f"identity_{i}")
self.identities.append(IDENTITY_CLASS_MAP[identity_config.name](identity_config, objective=objective, prompt_template=prompt_template[identity_config.name]))
i += 1
def update_online_state(self, **online_states):
super().update_online_state(**online_states)
for identity in self.identities:
identity.update_online_state(**online_states)
def is_planning(self, action):
for c in self.config.planning_command:
if action.startswith(c):
return c
return False
def is_navigation(self, action):
action_without_note = re.sub(rf'(note\s*(.*?))(?=\n(?:{"|".join(map(re.escape, self.config.navigation_command))})|$)', "", action).strip()
for c in self.config.navigation_command:
if action_without_note.startswith(c):
return c
return False
def is_valid_action(self, action_str):
action = (
action_str.split("[")[0].strip()
if "[" in action_str
else action_str.split()[0].strip()
)
match action:
case "click":
match = re.search(r"click ?\[(\d+)\]", action_str)
if not match:
return False
element_id = match.group(1)
if element_id in self.get_observation_text():
return True
return False
case "type":
if not (action_str.endswith("[0]") or action_str.endswith("[1]")):
action_str += " [1]"
match = re.search(
r"type ?\[(\d+)\] ?\[(.*)\] ?\[(\d+)\]", action_str, re.DOTALL
)
if not match:
return False
element_id, text, enter_flag = (
match.group(1),
match.group(2),
match.group(3),
)
enter_flag = True if enter_flag == "1" else False
if enter_flag:
text += "\n"
if element_id in self.get_observation_text():
return True
case "go_back":
return True
case "go_home":
return True
case "note":
return True
case "stop":
return True
case "branch":
return True
case "prune":
return True
case "goto":
return True
case "scroll":
return True
def are_valid_actions(self, actions):
action_list = self.parse_str_to_action_list(actions, self.config.planning_command+self.config.navigation_command+["goto"])
if not action_list:
return False
for action in action_list:
if not self.is_valid_action(action):
return False
return True
def get_previous_plans(self, verbose=False):
def action_return_visible_node(node, verbose=False):
if node.id == self.active_node.id:
basic = "\t" * node.level + f"[{node.id}] (Active Plan) {node.text}" if node.visible else None
else:
basic = "\t" * node.level + f"[{node.id}] {node.text}" if node.visible else None
if basic and len(node.resume_reason) > 0:
basic += f" # Was resumed to this step {len(node.resume_reason)} time(s) for:"
for i, reason in enumerate(node.resume_reason):
basic += f" {i}. {reason}"
if verbose and basic and len(node.note) > 0:
for i, note in enumerate(node.note):
basic += "\n" + "\t" * node.level + f"Note {i}. {note}"
return basic
plan_tree_buffer = []
parse_node_descendants(self.plan_tree_root, partial(action_return_visible_node, verbose=verbose), tree_buffer=plan_tree_buffer)
return "\n".join(plan_tree_buffer)
def get_active_plan(self):
return f"[{self.active_node.id}] {self.active_node.text}"
def get_interaction_history(self, interaction_history_config=False, mode="highlight"):
interaction_history_config = interaction_history_config if interaction_history_config else self.config.interaction_history
previous_observation = []
for i in self.active_node.steps_taken:
if self.get_observation_node_str() and self.get_observation_node_str(i) and not self.get_observation_node_str() == self.get_observation_node_str(i):
if self.previous_interactions["observation highlight"][i] and mode == "highlight" and len(translate_node_to_str(self.previous_interactions["observation highlight"][i], mode="name_only", retained_ids=self.previous_interactions["retained element ids"][i]).split()) < 200:
try:
previous_observation.append({"text": translate_node_to_str(self.previous_interactions["observation highlight"][i], mode="name_only", retained_ids=self.previous_interactions["retained element ids"][i]), "image": self.get_observation_image(i)})
except:
print(i, self.previous_interactions["observation"][i]["text"])
raise ValueError("Cannot translate highlight node to text.")
else:
previous_observation.append({"text": self.previous_interactions["observation summary"][i], "image": self.get_observation_image(i)})
elif not self.get_observation_node() or mode == "full":
if len(self.get_observation_text(i).split()) < 200:
previous_observation.append({"text": self.get_observation_text(i), "image": self.get_observation_image(i)})
else:
previous_observation.append({"text": self.previous_interactions["observation summary"][i], "image": self.get_observation_image(i)})
else:
previous_observation.append({"text": "The same as the CURRENT OBSERVATION (see below CURRENT OBSERVATION section).", "image": self.get_observation_image(i)})
previous_observation_summary = [self.previous_interactions["observation summary"][i] for i in self.active_node.steps_taken]
def get_text(obs):
if isinstance(obs, dict):
return obs["text"]
elif isinstance(obs, str):
return obs
def get_image(obs):
if isinstance(obs, dict):
return obs["image"]
elif isinstance(obs, str):
return obs
if interaction_history_config.step_num == "all":
textual_observations = [get_text(obs) for obs in previous_observation] if interaction_history_config.verbose else previous_observation_summary
visual_observations = [get_image(obs) for obs in previous_observation]
else:
textual_observations = previous_observation_summary[:-interaction_history_config.step_num]
visual_observations = [None] * len(previous_observation_summary[:-interaction_history_config.step_num])
textual_observations += [get_text(obs) for obs in previous_observation][-interaction_history_config.step_num:] if interaction_history_config.verbose else previous_observation_summary[-interaction_history_config.step_num:]
visual_observations += [get_image(obs) for obs in previous_observation][-interaction_history_config.step_num:]
plans = [self.previous_interactions["plan"][i] for i in self.active_node.steps_taken]
reasons = [self.previous_interactions["reason"][i] for i in self.active_node.steps_taken]
actions = [self.previous_interactions["action"][i] for i in self.active_node.steps_taken]
if "image" in interaction_history_config.type:
message_list = []
for step, (obs, vi_obs, plan, reason, action) in enumerate(zip(textual_observations, visual_observations, plans, reasons, actions)):
message_list.append(("text", f"<step_{step}_interaction>\n"))
if vi_obs:
message_list.append(("text", "VISUAL OBSERVATION:\n"))
message_list.append(("image", vi_obs))
if self.active_node.id != 0:
message_list.append(("text", f"TEXTUAL OBSERVATION:\n{obs}\nACTIVE PLAN:\n{plan}\nREASON FOR ACTION:\n{reason}\nACTION:\n{action}\n</step_{step}_interaction>\n"))
else:
message_list.append(("text", f"TEXTUAL OBSERVATION:\n{obs}\nREASON FOR ACTION:\n{reason}\nACTION:\n{action}\n</step_{step}_interaction>\n"))
return self.prune_message_list(message_list=message_list)
else:
message = ""
for step, (obs, plan, reason, action) in enumerate(zip(textual_observations, plans, reasons, actions)):
if self.active_node.id != 0:
message += f"<step_{step}_interaction>\nOBSERVATION:\n{obs}\nACTIVE PLAN:\n{plan}\nREASON FOR ACTION:\n{reason}\nACTION:\n{action}\n</step_{step}_interaction>\n" # f"<step_{step}_interaction>\nOBSERVATION:\n{obs}\nACTIVE PLAN:\n{plan}\nREASON FOR ACTION:\n{reason}\nACTION:\n{action}\n</step_{step}_interaction>\n"
else:
message += f"<step_{step}_interaction>\nOBSERVATION:\n{obs}\nREASON FOR ACTION:\n{reason}\nACTION:\n{action}\n</step_{step}_interaction>\n" # f"<step_{step}_interaction>\nOBSERVATION:\n{obs}\nREASON FOR ACTION:\n{reason}\nACTION:\n{action}\n</step_{step}_interaction>\n"
return self.prune_message_list(message_list=[("text", message)])
def pre_process_atomic_actions(self, atomic_action_list=["combobox"]):
if self.get_observation_node() and "combobox" in atomic_action_list:
self.online_interaction["observation"]["text"] = translate_node_to_str(self.get_observation_node(), mode="concise", hidden_roles=["menu", "combobox", "listbox"])
def get_online_input(self, criticism_elements):
input_template = self.prompt_template["input_template"]
input_prefix, input_suffix = input_template.split("{input}")
INPUT_TYPE_TO_CONTENT_MAP = {
"step": self.get_step(),
"objective": self.objective,
"previous plans": self.get_previous_plans(verbose=True),
"interaction history": self.get_interaction_history(),
"current observation": self.get_observation_text(),
"current visual observation": self.get_observation_image()
}
input_list = []
for input_type in self.config.input:
input_content = None
if input_type == "current visual observation":
continue
elif input_type in INPUT_TYPE_TO_CONTENT_MAP.keys():
input_content = INPUT_TYPE_TO_CONTENT_MAP[input_type]
elif input_type.startswith("critic: ") and criticism_elements and input_type[len("critic: "):] in criticism_elements.keys() and criticism_elements[input_type[len("critic: "):]]:
input_type = input_type[len("critic: "):]
input_content = criticism_elements[input_type]
input_type = "FROM USER: " + input_type
if input_content and isinstance(input_content, str):
input_list.append(("text", f"{input_type.upper()}:\n{input_content}\n"))
elif input_content and isinstance(input_content, list):
input_list.append(("text", f"{input_type.upper()}:\n"))
input_list += input_content if len(input_content) > 0 else ["N/A"]
if "image" in self.config.current_observation.type:
input_type = "current visual observation"
input_list.append(("text", f"{input_type.upper()}:\n"))
input_list.append(("image", INPUT_TYPE_TO_CONTENT_MAP["current visual observation"]))
return self.prune_message_list(message_list=[("text", input_prefix)] + input_list + [("text", input_suffix)])
def get_planning_specifications(self):
if self.planning_specifications:
return self.planning_specifications
self.planning_specifications = "\n".join(["- " + "".join(open(os.path.join(CURRENT_DIR, "AgentOccam", "prompts", "planning_specifications", f"{p}.txt"), "r").readlines()) for p in self.config.planning_command])
return self.planning_specifications
def get_navigation_specifications(self):
if self.navigation_specifications:
return self.navigation_specifications
self.navigation_specifications = "\n".join(["- " + "".join(open(os.path.join(CURRENT_DIR, "AgentOccam", "prompts", "navigation_specifications", f"{n}.txt"), "r").readlines()) for n in self.config.navigation_command])
return self.navigation_specifications
def get_actor_instruction(self, examples=None):
if self.config.planning_command:
instruction = self.prompt_template["instruction_template"]["with_planning"]
else:
instruction = self.prompt_template["instruction_template"]["without_planning"]
output_specifications = self.get_output_specifications()
planning_specifications = self.get_planning_specifications()
navigation_specifications = self.get_navigation_specifications()
instruction = instruction.replace("{output_specifications}", output_specifications)
instruction = instruction.replace("{planning_specifications}", planning_specifications)
instruction = instruction.replace("{navigation_specifications}", navigation_specifications)
example_source = examples if examples is not None else self.prompt_template.get("examples", [])
if len(example_source) > 0:
instruction += f"\n\n## Here are a few examples:"
for i, example in enumerate(example_source):
example_input = example["input"]
example_output = example["output"]
if "example_template" in self.prompt_template.keys():
instruction += "\n\n"
instruction += self.prompt_template.get("example_template", "| Example {i}\n### Input:\n{example_input}\n### Response: Let's think step by step.\n{example_response}").replace("{i}", i).replace("{example_input}", example_input).replace("{example_output}", example_output)
else:
instruction += f"\n\n| Example {i}\n\n### Input:\n{example_input}\n\n### Response: Let's think step by step.\n{example_output}"
if self.get_step() == self.config.others.max_steps - 1:
instruction += f"\n\nWARNING: You have a {self.config.others.max_steps}-step budget, and this would be your FINAL STEP. Wrap up your observations and return your answer with `stop [answer]` to maximize the reward."
# else:
# instruction += f"\n\nWARNING: You have a {self.config.others.max_steps}-step budget, and there are {self.config.others.max_steps-self.get_step()} remaining attempts."
return instruction
def verbose(self, instruction, online_input, model_response_list, action_element_list):
action_element_keys = [k for k in self.config.play if k in action_element_list[0].keys()]
other_play_keys = [k for k in self.config.play if k not in action_element_list[0].keys()]
VERBOSE_TO_CONTENT_MAP = {
"step": self.get_step(),
"objective": self.objective,
"previous plans": self.get_previous_plans(verbose=True),
"url": self.online_interaction["url"],
"observation": self.get_observation_text(),
"response": "\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n".join([f"|\tAgent {i}:\n{model_response}" for i, model_response in enumerate(model_response_list[:self.config.number])]) if self.config.number > 1 else model_response_list[0],
"instruction": instruction,
"online input": "\n".join([i[1] for i in online_input if i[0]=="text"]),
"alter ego response": "\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n".join(["|\tAgent {}:\n{}".format(identity.config.name, response) for identity, response in zip(self.identities, model_response_list[self.config.number:])])
}
if self.config.others.verbose > 0 and self.config.verbose > 0:
with open(self.output_trash_path, "a") as af:
af.write("-"*32+"ACTOR"+"-"*32+"\n")
for t in self.config.trash:
content = VERBOSE_TO_CONTENT_MAP.get(t, "")
with open(self.output_trash_path, "a") as af:
af.write(f"{t.upper()}:\n{content}\n\n")
with open(self.output_play_path, "w") as _:
pass
for p in other_play_keys:
content = VERBOSE_TO_CONTENT_MAP.get(p, "")
with open(self.output_play_path, "a") as af:
af.write(f"{p.upper()}:\n{content}\n\n")
for i, action_elements in enumerate(action_element_list):
if len(action_element_list) > 1:
with open(self.output_play_path, "a") as af:
af.write("-"*32+f"AGENT {i}"+"-"*32+"\n")
for action_element_key in action_element_keys:
content = action_elements.get(action_element_key, "N/A")
with open(self.output_play_path, "a") as af:
af.write(f"{action_element_key.upper()}:\n{content}\n\n")
def parse_plan(self, planning):
planning_type = self.is_planning(action=planning)
match = re.search(
rf"{planning_type} ?\[(\d+)\] ?\[(.+)\]", planning, re.DOTALL
)
if not match:
raise ValueError("Invalid planning command.")
node_id, planning_content = (
int(match.group(1)),
match.group(2)
)
return planning_type, node_id, planning_content
def prune_planning(self, node:PlanTreeNode, planning_content):
def set_invisible(node:PlanTreeNode):
node.visible = False
def return_steps_taken(node:PlanTreeNode):
return [node.step] + node.steps_taken
after_node = False
if node.id > 0:
for child in node.parent.children:
if not after_node and child != node:
continue
elif child == node:
after_node = True
continue
child.visible = False
node.traverse(set_invisible)
node.reset()
steps_taken = []
node.traverse(action=return_steps_taken, tree_buffer=steps_taken)
node.steps_taken = sorted(list(set(steps_taken)), reverse=False)
node.resume_reason.append(planning_content)
navigation = f"goto [{node.url}] [1]"
self.active_node = node
return navigation
def branch_planning(self, node, planning_content):
new_node = PlanTreeNode(id=self.active_node.id+1, type=type, text=planning_content, level=node.level+1, url=self.online_interaction["url"], step=self.get_step())
self.active_node = new_node
node.add_child(new_node)
def planning(self, action):
if action and self.is_planning(action):
try:
planning_type, node_id, planning_content = self.parse_plan(planning=action)
node = self.plan_tree_root.search_node_by_id(node_id)
if not node:
raise ValueError(f"Invalid node id {node_id}: {action}.")
if planning_type == "prune":
navigation_action = self.prune_planning(node=node, planning_content=planning_content)
return navigation_action
elif planning_type == "branch":
self.branch_planning(node=node, planning_content=planning_content)
else:
raise ValueError(f"Invalid planning operation {planning_type}: {action}.")
except Exception as e:
print("Invalid plan node:", str(e))
flaw_node = self.active_node
flaw_node.note.append(f"You previously generate plan \"{action}\", which has INVALID syntax. User planning command like `branch [parent_plan_id] [new_subplan_intent]` or `prune [resume_plan_id] [reason]`.")
else:
self.active_node.steps_taken.append(self.get_step())
return None
def go_home(self, action):
if "go_home" in action:
return f"goto [{HOMEPAGE_URL}] [1]"
return None
def parse_action(self, action_str):
try:
DOM_root_node = self.get_observation_node()
action_str = action_str.strip()
action = (
action_str.split("[")[0].strip()
if "[" in action_str
else action_str.split()[0].strip()
)
match action:
case "click":
match = re.search(r"click ?\[(\d+)\]", action_str)
if not match:
raise ValueError(f"Invalid click action {action_str}")
element_id = match.group(1)
node = DOM_root_node.search_node_by_id(element_id)
return f"click [{element_id}] ({node.role} {node.name})"
case "hover":
match = re.search(r"hover ?\[(\d+)\]", action_str)
if not match:
raise ValueError(f"Invalid hover action {action_str}")
element_id = match.group(1)
node = DOM_root_node.search_node_by_id(element_id)
return f"hover [{element_id}] ({node.role} {node.name})"
case "type":
if not (action_str.endswith("[0]") or action_str.endswith("[1]")):
action_str += " [1]"
match = re.search(
r"type ?\[(\d+)\] ?\[(.+)\] ?\[(\d+)\]", action_str
)
if not match:
raise ValueError(f"Invalid type action {action_str}")
element_id, text, enter_flag = (
match.group(1),
match.group(2),
match.group(3),
)
enter_flag = True if enter_flag == "1" else False
if enter_flag:
text += "\n"
node = DOM_root_node.search_node_by_id(element_id)
return action + f" ({node.name})"
case "scroll":
return action_str
case "goto":
return action
case "new_tab":
return action
case "go_back":
return action
case "go_forward":
return action
case "stop":
return action
return False
except:
return False
def parse_actions_to_element_ids(self, actions):
action_str_list = []
for a in self.config.navigation_command:
action_str_list += self.parse_stipulated_action_list(text=actions, action=a, actions=self.config.planning_command+self.config.navigation_command+["goto"])
retained_element_ids = []
for action_str in action_str_list:
try:
action_str = action_str.strip()
action = (
action_str.split("[")[0].strip()
if "[" in action_str
else action_str.split()[0].strip()
)
match action:
case "click":
match = re.search(r"click ?\[(\d+)\]", action_str)
if not match:
raise ValueError(f"Invalid click action {action_str}")
element_id = match.group(1)
element_id = int(element_id)
retained_element_ids.append(element_id)
case "hover":
match = re.search(r"hover ?\[(\d+)\]", action_str)
if not match:
raise ValueError(f"Invalid hover action {action_str}")
element_id = match.group(1)
element_id = int(element_id)
retained_element_ids.append(element_id)
case "type":
if not (action_str.endswith("[0]") or action_str.endswith("[1]")):
action_str += " [1]"
match = re.search(
r"type ?\[(\d+)\] ?\[(.+)\] ?\[(\d+)\]", action_str
)
if not match:
raise ValueError(f"Invalid type action {action_str}")
element_id, text, enter_flag = (
match.group(1),
match.group(2),
match.group(3),
)
element_id = int(element_id)
retained_element_ids.append(element_id)
case "scroll":
pass
case "goto":
pass
case "new_tab":
pass
case "go_back":
pass
case "go_forward":
pass
case "stop":
pass
case "note":
pass
return retained_element_ids
except:
continue
return retained_element_ids
def take_note(self, action, note_as_action=True):
if action and "note [" in action:
none_note_action_list = []
action_list = self.parse_str_to_action_list(action, actions=self.config.planning_command+self.config.navigation_command+["goto"])
for a in action_list:
if "note [" in a:
note = re.search(r"note ?\[?(.+)", a, re.DOTALL).group(1)
if note.endswith("]"):
note = note[:-1]
self.active_node.note.append(f"STEP {self.get_step()}: {note}")
self.note_buffer = note
else:
none_note_action_list.append(a)
if note_as_action:
return action
return "\n".join(none_note_action_list)
# action_note = self.parse_action(action)
# if action_note:
# self.active_node.note.append(f"STEP {self.get_step()} ACTION: {action_note}")
return action
def get_observation_highlight(self, action_elements:dict):
action_elements["observation highlight idxs"] = copy.deepcopy(action_elements.get("observation highlight", ""))
DOM_root_node = self.get_observation_node()
if not DOM_root_node:
action_elements["observation highlight"] = None
return
observation_highlight_idxs = [int(idx.strip()) for idx in action_elements.get("observation highlight", "").split(",") if idx.strip().isdigit()]
if observation_highlight_idxs:
parse_node_descendants(node=DOM_root_node, action=action_set_invisible)
for idx in observation_highlight_idxs:
try:
node = DOM_root_node.search_node_by_id(idx)
parse_node_descendants(node=node, action=action_set_visible)
parse_node_ancestors(node=node, action=action_set_visible)
parse_node_siblings(node=node, action=action_set_visible_if_with_name)
except:
pass
try:
assert DOM_root_node.get_visible_node_number() < 30 and construct_new_DOM_with_visible_nodes(DOM_root=DOM_root_node)
action_elements["observation highlight"] = construct_new_DOM_with_visible_nodes(DOM_root=DOM_root_node)
parse_node_descendants(node=DOM_root_node, action=action_set_visible)
except:
parse_node_descendants(node=DOM_root_node, action=action_set_visible)
action_elements["observation highlight"] = None
action_elements["retained element ids"] = self.parse_actions_to_element_ids(action_elements["action"])
def parse_action_from_action_candidates(self, action_elements):
if "action" in action_elements.keys():
return action_elements
assert any("action candidates" in k for k in action_elements.keys())
action_candidates_key = [k for k in action_elements.keys() if "action candidates" in k][0]
def parse_reasons_and_actions(input_string):
pattern = r'- reason: \[(.*?)\]\s*(?:- action: \[(.*?)\])?\s*(?:\n|\Z)'
matches = re.findall(pattern, input_string, re.DOTALL)
parsed_data = []
for match in matches:
reason = match[0].strip()
action = match[1].strip()
if reason and action:
parsed_data.append({'reason': reason, 'action': action})
return parsed_data
action_elements[action_candidates_key] = parse_reasons_and_actions(action_elements[action_candidates_key])
return action_elements
def predict_action(self, criticism_elements):
if self.config.debug > 1:
action_elements = {k: "" for k in self.config.output}
human_input = input("ACTION: ")
action_elements["action"] = human_input
return [action_elements]
self.pre_process_atomic_actions()
instruction = self.get_actor_instruction()
online_input = self.get_online_input(criticism_elements=criticism_elements)
model_response_list = []
action_element_list = []
for _ in range(self.config.number):
get_valid_actions = False
repetitive_note = False
invalid_actions = False
while not get_valid_actions:
if repetitive_note:
model_response = self.call_model_with_message(system_prompt=instruction+"\nGenerating the command `note [{}]` will be severely punished! Don't generate repetitive notes!".format(getattr(self, "note_buffer", "")), messages=self.arrange_message_for_model(online_input))
elif invalid_actions:
model_response = self.call_model_with_message(system_prompt=instruction+"\nGenerating the command `{}` will be severely punished! Don't generate invalid actions! We don't have that element id in the current observation!".format(invalid_action_str), messages=self.arrange_message_for_model(online_input))
else:
model_response = self.call_model_with_message(system_prompt=instruction, messages=self.arrange_message_for_model(online_input))
action_elements = self.parse_elements(text=model_response, key_list=self.config.output)
action_elements = self.parse_action_from_action_candidates(action_elements=action_elements)
assert not ("action" in action_elements.keys() and any("action candidates" in k for k in action_elements.keys()))
if "action" in action_elements.keys():
if self.are_valid_actions(action_elements["action"]):
note_buffer = getattr(self, "note_buffer", "")
if note_buffer and f"note [{note_buffer}" in action_elements["action"]:
print(f"Repetitive note: {note_buffer}")
repetitive_note = True
continue
get_valid_actions = True
action_elements["input"] = online_input
model_response_list.append(model_response)
action_element_list.append(action_elements)
else:
invalid_action_str = action_elements["action"]
print(f"Invalid actions: {invalid_action_str}")
invalid_actions = True
elif any("action candidates" in k for k in action_elements.keys()):
action_candidates_key = [k for k in action_elements.keys() if "action candidates" in k][0]
if isinstance(action_elements[action_candidates_key], str):
continue
filtered_action_candidates = []
note_buffer = getattr(self, "note_buffer", "")
for action_reason_pair in action_elements[action_candidates_key]:
action = action_reason_pair["action"]
reason = action_reason_pair["reason"]
if self.are_valid_actions(action):
if note_buffer and f"note [{note_buffer}" in action:
print(f"Repetitive note: {note_buffer}")
repetitive_note = True
continue
filtered_action_candidates.append({'reason': reason, 'action': action})
else:
invalid_action_str = action
print(f"Invalid actions: {invalid_action_str}")
invalid_actions = True
if filtered_action_candidates:
action_elements[action_candidates_key] = filtered_action_candidates
get_valid_actions = True
action_elements["input"] = online_input
model_response_list.append(model_response)
action_element_list.append(action_elements)
else:
raise NotImplementedError("You have to generate either action or action candidates.")
# if self.config.number != 1:
if True:
for identity in self.identities:
identity_instruction = identity.get_instruction() if identity.get_instruction() else instruction
identity_online_input = identity.get_online_input() if identity.get_online_input() else online_input
get_valid_actions = False
invalid_actions = False
while not get_valid_actions:
if invalid_actions:
model_response, action_elements = identity.get_action(identity_instruction+"\nGenerating the command `{}` will be severely punished! Don't generate invalid actions! We don't have that element id in the current observation!".format(invalid_action_str), identity_online_input)
else:
model_response, action_elements = identity.get_action(identity_instruction, identity_online_input)
if self.are_valid_actions(action_elements["action"]):
get_valid_actions = True
model_response_list.append(model_response)
action_element_list.append(action_elements)
else:
invalid_action_str = action_elements["action"]
print(f"Invalid actions: {invalid_action_str}")
invalid_actions = True
self.verbose(instruction=instruction, online_input=online_input, model_response_list=model_response_list, action_element_list=action_element_list)
if self.config.others.debug or self.config.debug:
for i in range(len(action_element_list)):
human_input = input(f"ACTION {i}: ")
if human_input != "":
action_element_list[i]["action"] = human_input
return action_element_list
def finalize_action(self, action_elements):
self.get_observation_highlight(action_elements=action_elements)
action = action_elements["action"]
navigation_action = self.planning(action=action)
if navigation_action:
action_elements["navigation action"] = navigation_action
action = self.take_note(action)
action_elements["action"] = action
navigation_action = self.go_home(action=action)
if navigation_action:
action_elements["navigation action"] = navigation_action
return action_elements
class Critic(Agent):
def __init__(self, config, objective, prompt_template):
super().__init__(config, objective, prompt_template)
self.instruction = None
self.actor_basic_info_dict = None
self.output_play_path = os.path.join(CURRENT_DIR, f"play-{self.config.others.logname}.txt") if getattr(self.config.others, "logname", "") != "" else os.path.join(CURRENT_DIR, f"play.txt")
self.output_trash_path = os.path.join(CURRENT_DIR, f"trash-{self.config.others.logname}.txt") if getattr(self.config.others, "logname", "") != "" else os.path.join(CURRENT_DIR, f"trash.txt")
def verbose(self, instruction, online_input, model_response):
VERBOSE_TO_CONTENT_MAP = {
"url": self.online_interaction["url"],
"objective": self.objective,
"instruction": instruction,
"online input": "\n".join([i[1] for i in online_input if i[0]=="text"]),
"response": model_response
}
if self.config.others.verbose > 0 and self.config.verbose > 0:
with open(self.output_trash_path, "a") as af:
af.write("-"*32+"CRITIC"+"-"*32+"\n")
for t in self.config.trash:
content = VERBOSE_TO_CONTENT_MAP[t]
with open(self.output_trash_path, "a") as af:
af.write(f"{t.upper()}:\n{content}\n\n")
def update_actor_basic_info(self, **actor_basic_info_dict):
self.actor_basic_info_dict = actor_basic_info_dict
def get_output_specifications(self):
output_specification_filepath_list = []
for o in self.config.output:
if os.path.exists(os.path.join(CURRENT_DIR, "AgentOccam", "prompts", "output_specifications", "{}_{}.txt".format(o.replace(" ", "_"), self.config.character))):
output_specification_filepath_list.append(os.path.join(CURRENT_DIR, "AgentOccam", "prompts", "output_specifications", "{}_{}.txt".format(o.replace(" ", "_"), self.config.character)))
else:
output_specification_filepath_list.append(os.path.join(CURRENT_DIR, "AgentOccam", "prompts", "output_specifications", "{}.txt".format(o.replace(" ", "_"))))
output_specifications = "\n".join([f"{o.upper()}:\n" + "".join(open(filepath, "r").readlines()) for o, filepath in zip(self.config.output, output_specification_filepath_list)])
return output_specifications
def get_critic_instruction(self):
if self.instruction:
return self.instruction
instruction = self.prompt_template["instruction_template"]
output_specifications = self.get_output_specifications()
instruction = instruction.replace("{output_specifications}", output_specifications)
instruction = instruction.replace("{planning_specifications}", self.actor_basic_info_dict["planning_specifications"])
instruction = instruction.replace("{navigation_specifications}", self.actor_basic_info_dict["navigation_specifications"])
self.instruction = instruction
return self.instruction
def get_online_input(self):
input_template = self.prompt_template["input_template"]
input_prefix, input_suffix = input_template.split("{input}")
# ["objective", "previous plans", "interaction history", "step", "current observation"]
INPUT_TYPE_TO_CONTENT_MAP = {
"step": self.actor_basic_info_dict["step"],
"objective": self.objective,
"previous plans": self.actor_basic_info_dict["previous_plans"],
"interaction history": self.actor_basic_info_dict["interaction_history"],
"current observation": self.get_observation_text(),
"current visual observation": self.get_observation_image()
}
input_list = []
for input_type in self.config.input:
input_content = None
if input_type == "current visual observation":
continue
elif input_type in INPUT_TYPE_TO_CONTENT_MAP.keys():
input_content = INPUT_TYPE_TO_CONTENT_MAP[input_type]
if input_content and isinstance(input_content, str):
input_list.append(("text", f"{input_type.upper()}:\n{input_content}\n"))
elif input_content and isinstance(input_content, list):
input_list.append(("text", f"{input_type.upper()}:\n"))
input_list += input_content if len(input_content) > 0 else ["N/A"]
if "image" in self.config.current_observation.type:
input_type = "current visual observation"
input_list.append(("text", f"{input_type.upper()}:\n"))
input_list.append(("image", INPUT_TYPE_TO_CONTENT_MAP["current visual observation"]))
return self.prune_message_list(message_list=[("text", input_prefix)] + input_list + [("text", input_suffix)])
def get_criticism_elements(self):
if not self.config.mode:
return {}
if self.config.debug > 1:
criticism_elements = {k: random.choice(["I don't think the task is finished. Don't issue identical actions like taking the same notes. It's annoying. Continue.", "You have make a reasoning mistake. Continue.", "You have missed important details on this page. Continue.", "You don't follow the task requirements. Continue.", "The task assigner might just want to challenge you to answer no and there might be no answer for this brain teaser question. Who knows?", "You should break down the task by using the planning commands.", "You have not gone over all the relevant pages. Continue."]) for k in self.config.output}
# criticism_elements = {k: input(f"{k.upper()}: ") for k in self.config.output}
return criticism_elements
instruction = self.get_critic_instruction()
online_input = self.get_online_input()
model_response = self.call_model_with_message(system_prompt=instruction, messages=self.arrange_message_for_model(online_input))
self.verbose(instruction=instruction, online_input=online_input, model_response=model_response)
criticism_elements = self.parse_elements(text=model_response, key_list=self.config.output) # key_list=self.config.output)
criticism_elements["input"] = online_input
if self.config.others.debug or self.config.debug:
for k in self.config.output:
human_input = input(f"{k.upper()}: ")
if not human_input == "":
criticism_elements[k] = human_input
return criticism_elements
class Judge(Agent):
def __init__(self, config, objective, prompt_template):
super().__init__(config, objective, prompt_template)
self.instruction = None
self.actor_basic_info_dict = None
self.output_play_path = os.path.join(CURRENT_DIR, f"play-{self.config.others.logname}.txt") if getattr(self.config.others, "logname", "") != "" else os.path.join(CURRENT_DIR, f"play.txt")
self.output_trash_path = os.path.join(CURRENT_DIR, f"trash-{self.config.others.logname}.txt") if getattr(self.config.others, "logname", "") != "" else os.path.join(CURRENT_DIR, f"trash.txt")
def update_actor_basic_info(self, **actor_basic_info_dict):
self.actor_basic_info_dict = actor_basic_info_dict
def get_judge_instruction(self):
if self.instruction:
return self.instruction
instruction = self.prompt_template["instruction_template"]
output_specifications = self.get_output_specifications()
instruction = instruction.replace("{output_specifications}", output_specifications)
instruction = instruction.replace("{planning_specifications}", self.actor_basic_info_dict["planning_specifications"])
instruction = instruction.replace("{navigation_specifications}", self.actor_basic_info_dict["navigation_specifications"])
self.instruction = instruction
return self.instruction
def get_online_input(self, action_element_list):
input_template = self.prompt_template["input_template"]
input_prefix, input_suffix = input_template.split("{input}")
INPUT_TYPE_TO_CONTENT_MAP = {
"step": self.actor_basic_info_dict["step"],
"objective": self.objective,
"previous plans": self.actor_basic_info_dict["previous_plans"],
"interaction history": self.actor_basic_info_dict["interaction_history"],
"current observation": self.get_observation_text(),
"current visual observation": self.get_observation_image(),
"action choices": "\n\n".join(["|\taction [{}]:\n{}\n|\treason for action [{}]:\n{}".format(i, action_element["action"], i, action_element.get("reason", "N/A")) for i, action_element in enumerate(action_element_list)])
}
input_list = []
for input_type in self.config.input:
input_content = None
if input_type == "current visual observation":
continue
elif input_type in INPUT_TYPE_TO_CONTENT_MAP.keys():
input_content = INPUT_TYPE_TO_CONTENT_MAP[input_type]
if input_content and isinstance(input_content, str):
input_list.append(("text", f"{input_type.upper()}:\n{input_content}\n"))
elif input_content and isinstance(input_content, list):
input_list.append(("text", f"{input_type.upper()}:\n"))
input_list += input_content if len(input_content) > 0 else ["N/A"]
if "image" in self.config.current_observation.type:
input_type = "current visual observation"
input_list.append(("text", f"{input_type.upper()}:\n"))
input_list.append(("image", INPUT_TYPE_TO_CONTENT_MAP["current visual observation"]))
return self.prune_message_list(message_list=[("text", input_prefix)] + input_list + [("text", input_suffix)])
def verbose(self, instruction, online_input, model_response):
VERBOSE_TO_CONTENT_MAP = {
"url": self.online_interaction["url"],
"objective": self.objective,
"instruction": instruction,
"online input": "\n".join([i[1] for i in online_input if i[0]=="text"]),
"response": model_response
}
if self.config.others.verbose > 0 and self.config.verbose > 0:
with open(self.output_trash_path, "a") as af:
af.write("-"*32+"JUDGE"+"-"*32+"\n")
for t in self.config.trash:
content = VERBOSE_TO_CONTENT_MAP[t]
with open(self.output_trash_path, "a") as af:
af.write(f"{t.upper()}:\n{content}\n\n")
def flatten_action_element_list(self, action_element_list):
new_action_element_list = []
for action_element in action_element_list:
if any("action candidates" in k for k in action_element.keys()):
action_candidates_key = [k for k in action_element.keys() if "action candidates" in k][0]
new_action_element = copy.deepcopy(action_element)
for action_reason_pair in action_element[action_candidates_key]:
new_action_element["action"] = action_reason_pair["action"]
new_action_element["reason"] = action_reason_pair["reason"]
new_action_element_list.append(copy.deepcopy(new_action_element))
else:
new_action_element_list.append(action_element)
random.shuffle(new_action_element_list)
return new_action_element_list
def judge(self, action_element_list):
action_element_list = self.flatten_action_element_list(action_element_list)
if not self.config.mode or self.config.debug > 1:
return action_element_list[0], {}
if all(action_elements["action"]==action_element_list[0]["action"] for action_elements in action_element_list):
return action_element_list[0], {}
def deduplicate_action_element_list_strict(lst): # deduplicate, remove action_elements with only note or stop command
seen = set()
note_list = []
stop_list = []
deduplicated_list = []
for i, item in enumerate(lst):
item = copy.deepcopy(item)
action_list = self.parse_str_to_action_list(item["action"], self.actor_basic_info_dict["planning_command"]+self.actor_basic_info_dict["navigation_command"])
note_list.append([])
none_note_stop_action_list = []
for a in action_list:
if a.startswith("stop ["):
stop_list.append((a, i))
elif a.startswith("note ["):
note_list[-1].append(a)
else:
none_note_stop_action_list.append(a)
item["action"] = "\n".join(none_note_stop_action_list)
if item["action"] and item["action"] not in seen:
seen.add(item["action"])
deduplicated_list.append(item)
note_list = [("\n".join(notes), i) for i, notes in enumerate(note_list)]
return note_list, stop_list, deduplicated_list
def deduplicate_action_element_list(lst): # deduplicate, remove action_elements with only note or stop command
seen = set()
deduplicated_list = []
for item in lst:
item = copy.deepcopy(item)
if item["action"] and item["action"] not in seen:
seen.add(item["action"])
deduplicated_list.append(item)
return deduplicated_list
if hasattr(self.config, "strict") and self.config.strict:
note_list, stop_list, deduplicated_action_element_list = deduplicate_action_element_list_strict(action_element_list)
if len(stop_list) >= 0.6 * len(action_element_list):
stop_action_choice = max([s[0] for s in stop_list], key=len)
stop_action_id = [s[1] for s in stop_list if s[0]==stop_action_choice][0]
return action_element_list[stop_action_id], {}
if not deduplicated_action_element_list:
note_action_choice = max([n[0] for n in note_list], key=len)
note_action_id = [n[1] for n in note_list if n[0]==note_action_choice][0]
action_elements = action_element_list[note_action_id]
action_elements["action"] = note_action_choice
return action_elements, {}
elif len(deduplicated_action_element_list) == 1:
action_elements = deduplicated_action_element_list[0]
note_action_choice = max([n[0] for n in note_list], key=len)
action_elements["action"] = note_action_choice + "\n" + action_elements["action"]
return action_elements, {}
else:
deduplicated_action_element_list = deduplicate_action_element_list(action_element_list)
instruction = self.get_judge_instruction()
online_input = self.get_online_input(deduplicated_action_element_list)
model_response = self.call_model_with_message(system_prompt=instruction, messages=self.arrange_message_for_model(online_input))
self.verbose(instruction=instruction, online_input=online_input, model_response=model_response)
judgement_elements = self.parse_elements(text=model_response, key_list=self.config.output) # key_list=self.config.output)
judgement_elements["input"] = online_input
if self.config.others.debug or self.config.debug:
for k in self.config.output:
human_input = input(f"{k.upper()}: ")
if not human_input == "":
judgement_elements[k] = human_input
try:
action_selection = int(re.search(r'\d+', judgement_elements["action selection"]).group())
selected_action_elements = deduplicated_action_element_list[action_selection]
if hasattr(self.config, "strict") and self.config.strict:
note_action_choice = max([n[0] for n in note_list], key=len)
if note_action_choice:
selected_action_elements["action"] = note_action_choice + "\n" + selected_action_elements["action"]
return selected_action_elements, judgement_elements
except:
return action_element_list[0], judgement_elements
class AgentOccam:
def __init__(self,
config = None,
prompt_dict: Dict = None,
):
self.config = config
self.prompt_dict = {} if prompt_dict is None else prompt_dict
self.objective = None
self.online_observation = None
self.online_url = None
self.actor = None
self.critic = None
self.trajectory = []
def get_refined_objective(self):
model_response = call_claude(self.root_prompt_template["objective_rephrasing_query"].replace("{objective}", self.objective))
objective_match = re.search(r'REFINED OBJECTIVE:\s*(.*?)\s*(?=\n[A-Z]|$)', model_response, re.DOTALL)
self.objective_refined = objective_match.group(1) if objective_match else None
def get_observation_text(self):
if isinstance(self.online_observation, dict):
return self.online_observation["text"]
else:
return self.online_observation
def init_actor(self):
self.config.actor.others = self.config.others
if len(self.sites) > 1:
self.config.actor.navigation_command += ["go_home"]
self.actor = Actor(
config=self.config.actor,
objective=self.objective,
prompt_template=self.prompt_dict["actor"],
plan_tree_node=PlanTreeNode(id=0, type="branch", text=f"Find the solution to \"{self.objective}\"", level=0, url=self.online_url, step=0)
)
with open(self.actor.output_trash_path, "w") as _:
pass
def init_critic(self):
self.config.critic.others = self.config.others
self.critic = Critic(
config=self.config.critic,
objective=self.objective,
prompt_template=self.prompt_dict["critic"][self.config.critic.character],
)
def init_judge(self):
self.config.judge.others = self.config.others
self.judge = Judge(
config=self.config.judge,
objective=self.objective,
prompt_template=self.prompt_dict["judge"],
)
def predict_action(self):
self.critic.update_actor_basic_info(step=self.get_step(), planning_specifications=self.actor.get_planning_specifications(), navigation_specifications=self.actor.get_navigation_specifications(), interaction_history=self.actor.get_interaction_history(interaction_history_config=self.critic.config.interaction_history), previous_plans=self.actor.get_previous_plans(verbose=True))
criticism_elements = self.critic.get_criticism_elements() if not self.get_step()==0 else {}
action_element_list = self.actor.predict_action(criticism_elements=criticism_elements)
self.judge.update_actor_basic_info(step=self.get_step(), planning_specifications=self.actor.get_planning_specifications(), navigation_specifications=self.actor.get_navigation_specifications(), interaction_history=self.actor.get_interaction_history(interaction_history_config=self.judge.config.interaction_history), previous_plans=self.actor.get_previous_plans(verbose=True), planning_command=self.actor.config.planning_command, navigation_command=self.actor.config.navigation_command)
selected_action_elements, judgement_elements = self.judge.judge(action_element_list)
selected_action_elements = self.actor.finalize_action(selected_action_elements)
return {**selected_action_elements, **{"critic:"+k: criticism_elements[k] for k in criticism_elements.keys()}, **{"judge:"+k: judgement_elements[k] for k in judgement_elements.keys()}}, action_element_list
def update_online_state(self, url, observation):
self.online_url = url
self.online_observation = observation
def get_step(self):
return self.actor.get_step()
def is_navigation(self, action):
return self.actor.is_navigation(action=action)
def get_actor_active_plan(self):
return self.actor.get_active_plan()
def get_trajectory(self):
return self.trajectory
def act(self, objective, env):
self.objective = objective
self.sites = env.get_sites()
observation = env.observation()
url = env.get_url()
self.update_online_state(url=url, observation=observation)
self.init_actor()
self.init_critic()
self.init_judge()
while not env.done():
observation = env.observation()
url = env.get_url()
self.update_online_state(url=url, observation=observation)
self.actor.update_online_state(url=url, observation=observation)
self.critic.update_online_state(url=url, observation=observation)
self.judge.update_online_state(url=url, observation=observation)
action_elements, action_element_list = self.predict_action()
action = action_elements["action"]
navigation_action = action_elements["action"] if not action_elements.get("navigation action", "") else action_elements.get("navigation action", "")
status = env.step(navigation_action)
if navigation_action and self.is_navigation(action=navigation_action) and status == False: # means invalid action
flaw_node = self.actor.active_node
flaw_node.note.append(f"STEP {self.get_step()}: You generate action \"{action}\", which has INVALID syntax. Strictly follow the action specifications.")
DOCUMENTED_INTERACTION_ELEMENT_KEY_TO_CONTENT_MAP = {
"observation": observation,
"action": action,
"url": url,
"plan": self.get_actor_active_plan(),
"reason": action_elements.get("reason", ""),
"observation highlight": action_elements.get("observation highlight", ""),
"retained element ids": action_elements.get("retained element ids", []),
"observation summary": action_elements.get("observation description", "")
}
self.actor.update_history(**DOCUMENTED_INTERACTION_ELEMENT_KEY_TO_CONTENT_MAP)
self.actor.del_observation_node()
assert self.actor.equal_history_length()
if len(action_element_list) > 1:
if self.config.others.logging:
self.log_step(
status=status if "status" in locals() and isinstance(status, dict) else env.status(),
plan=self.get_actor_active_plan(),
**action_elements,
**{f"actor {i}:{k}": _action_elements[k] for i, _action_elements in enumerate(action_element_list) for k in _action_elements.keys() if k != "input" and k != "instruction"}
)
else:
if self.config.others.logging:
self.log_step(
status=status if "status" in locals() and isinstance(status, dict) else env.status(),
plan=self.get_actor_active_plan(),
**action_elements,
)
return status if "status" in locals() and isinstance(status, dict) else env.status()
def log_step(self, status, **kwargs):
def serialize_message_list(message_list):
if not isinstance(message_list, list):
return message_list
return "".join([m[1] for m in message_list if m[0]=="text"])
data_to_log = {}
data_to_log['objective'] = self.objective
data_to_log['url'] = self.online_url
data_to_log['observation'] = self.get_observation_text()
for (k, v) in status.items():
data_to_log[k] = v
for k in kwargs.keys():
try:
json.dumps(kwargs[k])
data_to_log[k.replace(" ", "_")] = kwargs[k] if not "input" in k else serialize_message_list(kwargs[k])
except:
pass
self.trajectory.append(data_to_log)