105 lines
4.4 KiB
Python
105 lines
4.4 KiB
Python
from webagents_step.agents.agent import Agent
|
|
from webagents_step.utils.stack import Stack
|
|
from webagents_step.agents.prompt_agent import PromptAgent
|
|
|
|
from typing import List, Dict
|
|
import re
|
|
import time
|
|
|
|
class StepAgent(Agent):
|
|
def __init__(self, max_actions: int = 10, verbose: bool = False, logging: bool = False,
|
|
debug: bool = False,
|
|
root_action: str = None,
|
|
action_to_prompt_dict: Dict = None,
|
|
low_level_action_list: List = None,
|
|
model: str = "gpt-3.5-turbo",
|
|
prompt_mode: str = "chat", previous_actions: List = None):
|
|
super().__init__(max_actions=max_actions, verbose=verbose, logging=logging, previous_actions=previous_actions)
|
|
self.debug = debug
|
|
self.root_action = root_action
|
|
self.action_to_prompt_dict = {} if action_to_prompt_dict is None else action_to_prompt_dict
|
|
self.low_level_action_list = [] if low_level_action_list is None else low_level_action_list
|
|
self.model = model
|
|
self.prompt_mode = prompt_mode
|
|
self.stack = Stack()
|
|
|
|
def is_done(self, action):
|
|
if action and "stop" in action:
|
|
return True
|
|
return False
|
|
|
|
def is_low_level_action(self, action):
|
|
if not action:
|
|
return False
|
|
action_type = action.split()[0]
|
|
return (action_type in self.low_level_action_list)
|
|
|
|
def is_high_level_action(self, action):
|
|
if not action:
|
|
return False
|
|
action_type = action.split()[0]
|
|
return (action_type in self.action_to_prompt_dict)
|
|
|
|
def init_root_agent(self, objective):
|
|
root_prompt_template = self.action_to_prompt_dict[self.root_action]
|
|
agent = PromptAgent(
|
|
prompt_template=root_prompt_template,
|
|
model=self.model,
|
|
prompt_mode=self.prompt_mode,
|
|
max_actions=self.max_actions,
|
|
verbose=self.verbose,
|
|
logging=self.logging,
|
|
debug=self.debug,
|
|
previous_actions=[],
|
|
previous_reasons=[],
|
|
previous_responses=[]
|
|
)
|
|
return {'agent': agent, 'objective': objective}
|
|
|
|
def init_agent(self, action):
|
|
pattern = r'(\w+)\s+\[(.*?)\]'
|
|
matches = re.findall(pattern, action)
|
|
action_type, _ = matches[0]
|
|
objective = action
|
|
prompt_template = self.action_to_prompt_dict[action_type]
|
|
agent = PromptAgent(
|
|
prompt_template=prompt_template,
|
|
model=self.model,
|
|
prompt_mode=self.prompt_mode,
|
|
max_actions=self.max_actions,
|
|
verbose=self.verbose,
|
|
logging=self.logging,
|
|
debug=self.debug,
|
|
previous_actions=[],
|
|
previous_reasons=[],
|
|
previous_responses=[]
|
|
)
|
|
return {'agent': agent, 'objective': objective}
|
|
|
|
def predict_action(self, objective, observation, url=None):
|
|
observation = observation["text"] if isinstance(observation, dict) else observation
|
|
if self.stack.is_empty():
|
|
new_element = self.init_root_agent(objective=objective)
|
|
self.stack.push(new_element)
|
|
|
|
action, reason = None, None
|
|
while not self.stack.is_empty():
|
|
element = self.stack.peek()
|
|
action, reason = element['agent'].predict_action(objective=element['objective'], observation=observation, url=url)
|
|
if (not self.is_done(action)) and self.is_low_level_action(action):
|
|
element['agent'].receive_response("")
|
|
return action, reason
|
|
if (not self.is_done(action)) and self.is_high_level_action(action):
|
|
new_element = self.init_agent(action)
|
|
self.stack.push(new_element)
|
|
if self.logging:
|
|
self.log_step(objective=element['objective'], url=url, observation=observation, action=action, reason=reason, status={})
|
|
continue
|
|
if self.is_done(action):
|
|
self.stack.pop()
|
|
if not self.stack.is_empty():
|
|
self.stack.peek()['agent'].receive_response(re.search(r"\[(.*?)\]", action).group(1))
|
|
if self.logging:
|
|
self.log_step(objective=element['objective'], url=url, observation=observation, action=action, reason=reason, status={})
|
|
continue
|
|
return action, reason |