244 lines
8.8 KiB
Python
244 lines
8.8 KiB
Python
import json
|
||
import os
|
||
import re
|
||
import logging
|
||
from openai import OpenAI
|
||
from dotenv import load_dotenv
|
||
|
||
# 设置日志
|
||
# 创建logger
|
||
log = logging.getLogger(__name__)
|
||
log.setLevel(logging.INFO)
|
||
|
||
# 创建文件处理器
|
||
file_handler = logging.FileHandler('temp_analysis/test_run.log')
|
||
file_handler.setLevel(logging.INFO)
|
||
|
||
# 创建控制台处理器
|
||
console_handler = logging.StreamHandler()
|
||
console_handler.setLevel(logging.INFO)
|
||
|
||
# 创建格式器并添加到处理器
|
||
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||
file_handler.setFormatter(formatter)
|
||
console_handler.setFormatter(formatter)
|
||
|
||
# 将处理器添加到logger
|
||
log.addHandler(file_handler)
|
||
log.addHandler(console_handler)
|
||
|
||
# 加载环境变量
|
||
load_dotenv()
|
||
|
||
# 配置OpenAI客户端
|
||
client = OpenAI(
|
||
api_key=os.getenv("OPENAI_API_KEY"),
|
||
base_url=os.getenv("OPENAI_API_BASE_URL") # 如果使用其他兼容服务,可以设置基础URL
|
||
)
|
||
|
||
# modified version of the task proposer agent prompt from https://arxiv.org/pdf/2502.11357
|
||
|
||
# System prompt
|
||
cot_system_prompt = """
|
||
What does this webpage show? Imagine you are a real user on this webpage. Given the webpage
|
||
screenshot or ocr result and parsed HTML/accessibility tree and the task description, please provide
|
||
the first action towards completing that task.
|
||
|
||
Do the following step by step:
|
||
1. Given the webpage screenshot or ocr result and parsed HTML/accessibility tree, generate the first action
|
||
towards completing that task (in natural language form).
|
||
2. Given the webpage screenshot or ocr result, parsed HTML/accessibility tree, and the natural language
|
||
action, generate the grounded version of that action.
|
||
|
||
ACTION SPACE: Your action space is: ['click [element ID]', 'type [element ID] [content]',
|
||
'select [element ID] [content of option to select]', 'scroll [up]', 'scroll [down]', and 'stop'].
|
||
Action output should follow the syntax as given below:
|
||
click [element ID]: This action clicks on an element with a specific ID on the webpage.
|
||
type [element ID] [content]: Use this to type the content into the field with id. By default, the
|
||
"Enter" key is pressed after typing. Both the content and the ID should be within square braces
|
||
as per the syntax.
|
||
select [element ID] [content of option to select]: Select an option from a dropdown menu. The
|
||
content of the option to select should be within square braces. When you get (select an option)
|
||
tags from the accessibility tree, you need to select the serial number (element_id) corresponding
|
||
to the select tag, not the option, and select the most likely content corresponding to the option as
|
||
input.
|
||
scroll [down]: Scroll the page down.
|
||
scroll [up]: Scroll the page up.
|
||
|
||
IMPORTANT:
|
||
To be successful, it is important to STRICTLY follow the below rules:
|
||
Action generation rules:
|
||
1. You should generate a single atomic action at each step.
|
||
2. The action should be an atomic action from the given vocabulary - click, type, select, scroll
|
||
(up or down), or stop.
|
||
3. The arguments to each action should be within square braces. For example, "click [127]",
|
||
"type [43] [content to type]", "scroll [up]", "scroll [down]".
|
||
4. The natural language form of action (corresponding to the field "action_in_natural_language")
|
||
should be consistent with the grounded version of the action (corresponding to the field "grounded
|
||
_action"). Do NOT add any additional information in the grounded action. For example, if a
|
||
particular element ID is specified in the grounded action, a description of that element must be
|
||
present in the natural language action.
|
||
5. If the type action is selected, the natural language form of action ("action_in_natural_language") should always specify the actual text to be typed.
|
||
6. You should issue a "stop" action if the current webpage asks to log in or for credit card
|
||
information.
|
||
7. To input text, there is NO need to click the textbox first, directly type content. After typing,
|
||
the system automatically hits the 'ENTER' key.
|
||
8. STRICTLY Avoid repeating the same action (click/type) if the webpage remains unchanged.
|
||
You may have selected the wrong web element.
|
||
9. Do NOT use quotation marks in the action generation.
|
||
|
||
OUTPUT FORMAT:
|
||
Please give a short analysis of the screenshot, parsed
|
||
HTML/accessibility tree, then put your answer within ``` ```, for example,
|
||
"In summary, the proposed task and the corresponding action is: ```{
|
||
"action_in_natural_language":<ACTION_IN_NATURAL_LANGUAGE>:str,
|
||
"grounded_action": <ACTION>:str}"```
|
||
"""
|
||
|
||
# User prompt
|
||
cot_user_prompt = """
|
||
Website URL: {INIT_URL}
|
||
Parsed HTML/Accessibility Tree: {A11Y_TREE}
|
||
Screenshot ocr result: {SCREENSHOT}
|
||
Task description: {TASK_DESCRIPTION}
|
||
"""
|
||
|
||
def call_api(messages):
|
||
"""使用openai库调用API接口"""
|
||
try:
|
||
response = client.chat.completions.create(
|
||
messages=messages,
|
||
model="gpt-4o-mini"
|
||
)
|
||
|
||
return response
|
||
except Exception as e:
|
||
log.error(f"API调用出错: {e}")
|
||
return None
|
||
|
||
def extract_action(response_text):
|
||
"""从API响应中提取action_in_natural_language和grounded_action"""
|
||
# 使用正则表达式提取JSON部分
|
||
match = re.search(r'```\s*{(.+?)}```', response_text, re.DOTALL)
|
||
if match:
|
||
try:
|
||
# 构建完整的JSON字符串并解析
|
||
json_str = "{" + match.group(1) + "}"
|
||
action_data = json.loads(json_str)
|
||
return action_data.get("action_in_natural_language"), action_data.get("grounded_action")
|
||
except json.JSONDecodeError:
|
||
log.error(f"无法解析JSON: {match.group(1)}")
|
||
return None, None
|
||
|
||
def check_answer(grounded_action, answer):
|
||
"""检查grounded_action是否匹配answer"""
|
||
if not grounded_action or not grounded_action.startswith("click"):
|
||
return False
|
||
|
||
# 提取element ID
|
||
match = re.search(r'click \[(\d+)\]', grounded_action)
|
||
if not match:
|
||
return False
|
||
|
||
element_id = match.group(1)
|
||
# 将answer拆分为列表(逗号分隔)
|
||
answer_ids = [id.strip() for id in answer.split(",")]
|
||
|
||
# 检查element ID是否在answer列表中
|
||
return element_id in answer_ids
|
||
|
||
def main():
|
||
# 读取exam.json
|
||
with open("temp_analysis/exam.json", "r") as f:
|
||
exam_data = json.load(f)
|
||
|
||
results = []
|
||
|
||
# 遍历exam.json中的每个项目
|
||
for item in exam_data:
|
||
item_id = item["id"]
|
||
url = item["url"]
|
||
task_description = item["question"]
|
||
answer = item["answer"]
|
||
answer_text = item["answer_text"]
|
||
|
||
log.info(f"处理ID: {item_id}, URL: {url}")
|
||
|
||
# 读取a11y tree
|
||
try:
|
||
with open(f"axtrees/{item_id}.txt", "r") as f: # 修正文件路径
|
||
a11y_tree = f.read()
|
||
except FileNotFoundError:
|
||
log.error(f"找不到文件: axtrees/{item_id}.txt")
|
||
continue
|
||
|
||
# 构建API请求
|
||
messages = [
|
||
{"role": "system", "content": cot_system_prompt},
|
||
{"role": "user", "content": cot_user_prompt.format(
|
||
INIT_URL=url,
|
||
A11Y_TREE=a11y_tree,
|
||
SCREENSHOT="", # 这里没有提供截图OCR结果
|
||
TASK_DESCRIPTION=task_description
|
||
)}
|
||
]
|
||
|
||
log.info(f"task_description: {task_description}")
|
||
log.info(f"answer: {answer}, answer_text: {answer_text}")
|
||
|
||
|
||
# 调用API
|
||
response = call_api(messages)
|
||
|
||
if response and hasattr(response, 'choices') and len(response.choices) > 0:
|
||
content = response.choices[0].message.content
|
||
|
||
log.info(f"content: {content}")
|
||
|
||
# 提取action
|
||
action_nl, grounded_action = extract_action(content)
|
||
|
||
# 检查答案
|
||
is_correct = check_answer(grounded_action, answer)
|
||
|
||
# 记录结果
|
||
result = {
|
||
"id": item_id,
|
||
"url": url,
|
||
"task": task_description,
|
||
"action_nl": action_nl,
|
||
"grounded_action": grounded_action,
|
||
"expected_answer": answer,
|
||
"is_correct": is_correct
|
||
}
|
||
|
||
results.append(result)
|
||
|
||
log.info(f"ID: {item_id}")
|
||
log.info(f"任务: {task_description}")
|
||
log.info(f"动作: {grounded_action}")
|
||
log.info(f"是否正确: {is_correct}")
|
||
log.info("-" * 50)
|
||
else:
|
||
log.error(f"API调用失败: {response}")
|
||
|
||
# 保存结果
|
||
with open("temp_analysis/results.json", "w") as f:
|
||
json.dump(results, f, indent=2)
|
||
|
||
# 计算正确率
|
||
correct_count = sum(1 for r in results if r["is_correct"])
|
||
total_count = len(results)
|
||
accuracy = correct_count / total_count if total_count > 0 else 0
|
||
|
||
log.info(f"总计: {total_count}题,正确: {correct_count}题,正确率: {accuracy:.2%}")
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
|
||
|
||
|
||
|
||
|
||
|