365 lines
14 KiB
Python
365 lines
14 KiB
Python
import json
|
||
import os
|
||
import re
|
||
import logging
|
||
import concurrent.futures
|
||
import argparse
|
||
from openai import OpenAI
|
||
from dotenv import load_dotenv
|
||
|
||
# 设置日志
|
||
# 创建logger
|
||
log = logging.getLogger(__name__)
|
||
log.setLevel(logging.INFO)
|
||
|
||
# 创建文件处理器
|
||
file_handler = logging.FileHandler('temp_analysis/test_run.log')
|
||
file_handler.setLevel(logging.INFO)
|
||
|
||
# 创建控制台处理器
|
||
console_handler = logging.StreamHandler()
|
||
console_handler.setLevel(logging.INFO)
|
||
|
||
# 创建格式器并添加到处理器
|
||
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||
file_handler.setFormatter(formatter)
|
||
console_handler.setFormatter(formatter)
|
||
|
||
# 将处理器添加到logger
|
||
log.addHandler(file_handler)
|
||
log.addHandler(console_handler)
|
||
|
||
# 加载环境变量
|
||
load_dotenv()
|
||
|
||
# 配置OpenAI客户端
|
||
client = OpenAI(
|
||
api_key=os.getenv("OPENAI_API_KEY"),
|
||
base_url=os.getenv("OPENAI_API_BASE_URL") # 如果使用其他兼容服务,可以设置基础URL
|
||
)
|
||
|
||
# modified version of the task proposer agent prompt from https://arxiv.org/pdf/2502.11357
|
||
|
||
# System prompt
|
||
cot_system_prompt = """
|
||
What does this webpage show? Imagine you are a real user on this webpage. Given the webpage
|
||
screenshot or ocr result and parsed HTML/accessibility tree and the task description, please provide
|
||
the first action towards completing that task.
|
||
|
||
Do the following step by step:
|
||
1. Given the webpage screenshot or ocr result and parsed HTML/accessibility tree, generate the first action
|
||
towards completing that task (in natural language form).
|
||
2. Given the webpage screenshot or ocr result, parsed HTML/accessibility tree, and the natural language
|
||
action, generate the grounded version of that action.
|
||
|
||
ACTION SPACE: Your action space is: ['click [element ID]', 'type [element ID] [content]',
|
||
'select [element ID] [content of option to select]', 'scroll [up]', 'scroll [down]', and 'stop'].
|
||
Action output should follow the syntax as given below:
|
||
click [element ID]: This action clicks on an element with a specific ID on the webpage.
|
||
type [element ID] [content]: Use this to type the content into the field with id. By default, the
|
||
"Enter" key is pressed after typing. Both the content and the ID should be within square braces
|
||
as per the syntax.
|
||
select [element ID] [content of option to select]: Select an option from a dropdown menu. The
|
||
content of the option to select should be within square braces. When you get (select an option)
|
||
tags from the accessibility tree, you need to select the serial number (element_id) corresponding
|
||
to the select tag, not the option, and select the most likely content corresponding to the option as
|
||
input.
|
||
scroll [down]: Scroll the page down.
|
||
scroll [up]: Scroll the page up.
|
||
|
||
IMPORTANT:
|
||
To be successful, it is important to STRICTLY follow the below rules:
|
||
Action generation rules:
|
||
1. You should generate a single atomic action at each step.
|
||
2. The action should be an atomic action from the given vocabulary - click, type, select, scroll
|
||
(up or down), or stop.
|
||
3. The arguments to each action should be within square braces. For example, "click [127]",
|
||
"type [43] [content to type]", "scroll [up]", "scroll [down]".
|
||
4. The natural language form of action (corresponding to the field "action_in_natural_language")
|
||
should be consistent with the grounded version of the action (corresponding to the field "grounded
|
||
_action"). Do NOT add any additional information in the grounded action. For example, if a
|
||
particular element ID is specified in the grounded action, a description of that element must be
|
||
present in the natural language action.
|
||
5. If the type action is selected, the natural language form of action ("action_in_natural_language") should always specify the actual text to be typed.
|
||
6. You should issue a "stop" action if the current webpage asks to log in or for credit card
|
||
information.
|
||
7. To input text, there is NO need to click the textbox first, directly type content. After typing,
|
||
the system automatically hits the 'ENTER' key.
|
||
8. STRICTLY Avoid repeating the same action (click/type) if the webpage remains unchanged.
|
||
You may have selected the wrong web element.
|
||
9. Do NOT use quotation marks in the action generation.
|
||
|
||
OUTPUT FORMAT:
|
||
Please give a short analysis of the screenshot, parsed
|
||
HTML/accessibility tree, then put your answer within ``` ```, for example,
|
||
"In summary, the proposed task and the corresponding action is: ```{
|
||
"action_in_natural_language":<ACTION_IN_NATURAL_LANGUAGE>:str,
|
||
"grounded_action": <ACTION>:str}"```
|
||
"""
|
||
|
||
# User prompt
|
||
cot_user_prompt = """
|
||
Website URL: {INIT_URL}
|
||
Parsed HTML/Accessibility Tree: {A11Y_TREE}
|
||
Screenshot ocr result: {SCREENSHOT}
|
||
Task description: {TASK_DESCRIPTION}
|
||
"""
|
||
|
||
def call_api(messages):
|
||
"""使用openai库调用API接口"""
|
||
try:
|
||
response = client.chat.completions.create(
|
||
messages=messages,
|
||
# model="gpt-4o-mini"
|
||
model="aiproxy/deepseek-reasoner"
|
||
)
|
||
|
||
return response
|
||
except Exception as e:
|
||
log.error(f"API调用出错: {e}")
|
||
return None
|
||
|
||
def extract_action(response_text):
|
||
"""从API响应中提取action_in_natural_language和grounded_action"""
|
||
# 使用正则表达式提取JSON部分
|
||
match = re.search(r'```\s*{(.+?)}```', response_text, re.DOTALL)
|
||
if match:
|
||
try:
|
||
# 构建完整的JSON字符串并解析
|
||
json_str = "{" + match.group(1) + "}"
|
||
action_data = json.loads(json_str)
|
||
return action_data.get("action_in_natural_language"), action_data.get("grounded_action")
|
||
except json.JSONDecodeError:
|
||
log.error(f"无法解析JSON: {match.group(1)}")
|
||
return None, None
|
||
|
||
def check_answer(grounded_action, answer):
|
||
"""检查grounded_action是否匹配answer"""
|
||
if not grounded_action or not grounded_action.startswith("click"):
|
||
return False
|
||
|
||
# 提取element ID
|
||
match = re.search(r'click \[(\d+)\]', grounded_action)
|
||
if not match:
|
||
return False
|
||
|
||
element_id = match.group(1)
|
||
# 将answer拆分为列表(逗号分隔)
|
||
answer_ids = [id.strip() for id in answer.split(",")]
|
||
|
||
# 检查element ID是否在answer列表中
|
||
return element_id in answer_ids
|
||
|
||
def process_item(item):
|
||
"""处理单个测试项目"""
|
||
item_id = item["id"]
|
||
url = item["url"]
|
||
task_description = item["question"]
|
||
answer = item["answer"]
|
||
answer_text = item["answer_text"]
|
||
|
||
log.info(f"处理ID: {item_id}, URL: {url}")
|
||
|
||
# 读取a11y tree
|
||
try:
|
||
with open(f"axtrees/{item_id}.txt", "r") as f:
|
||
a11y_tree = f.read()
|
||
except FileNotFoundError:
|
||
log.error(f"找不到文件: axtrees/{item_id}.txt")
|
||
return None
|
||
|
||
# 构建API请求
|
||
messages = [
|
||
{"role": "system", "content": cot_system_prompt},
|
||
{"role": "user", "content": cot_user_prompt.format(
|
||
INIT_URL=url,
|
||
A11Y_TREE=a11y_tree,
|
||
SCREENSHOT="", # 这里没有提供截图OCR结果
|
||
TASK_DESCRIPTION=task_description
|
||
)}
|
||
]
|
||
|
||
log.info(f"task_description: {task_description}")
|
||
log.info(f"answer: {answer}, answer_text: {answer_text}")
|
||
|
||
# 调用API
|
||
response = call_api(messages)
|
||
|
||
if response and hasattr(response, 'choices') and len(response.choices) > 0:
|
||
content = response.choices[0].message.content
|
||
reasoning_content = response.choices[0].message.reasoning_content
|
||
|
||
log.info(f"reasoning_content: {reasoning_content}")
|
||
log.info(f"content: {content}")
|
||
|
||
# 提取action
|
||
action_nl, grounded_action = extract_action(content)
|
||
|
||
# 检查答案
|
||
is_correct = check_answer(grounded_action, answer)
|
||
|
||
# 记录结果
|
||
result = {
|
||
"id": item_id,
|
||
"url": url,
|
||
"task": task_description,
|
||
"expected_answer": answer,
|
||
"thinking": reasoning_content,
|
||
"action_nl": action_nl,
|
||
"grounded_action": grounded_action,
|
||
"is_correct": is_correct
|
||
}
|
||
|
||
log.info(f"ID: {item_id}")
|
||
log.info(f"任务: {task_description}")
|
||
log.info(f"动作: {grounded_action}")
|
||
log.info(f"是否正确: {is_correct}")
|
||
log.info("-" * 50)
|
||
|
||
return result
|
||
else:
|
||
log.error(f"API调用失败: {response}")
|
||
return None
|
||
|
||
def main():
|
||
# 解析命令行参数
|
||
parser = argparse.ArgumentParser(description='运行测试或重跑失败的测试')
|
||
parser.add_argument('--rerun_failure', action='store_true', help='重跑失败的测试用例')
|
||
args = parser.parse_args()
|
||
|
||
if args.rerun_failure:
|
||
# 重跑失败的测试
|
||
try:
|
||
with open("temp_analysis/results.json", "r") as f:
|
||
results = json.load(f)
|
||
|
||
# 筛选出失败的测试用例
|
||
failed_results = [r for r in results if not r.get("is_correct", False)]
|
||
|
||
if not failed_results:
|
||
log.info("没有找到失败的测试用例,无需重跑")
|
||
return
|
||
|
||
# 读取exam.json以获取完整的测试数据
|
||
with open("temp_analysis/exam.json", "r") as f:
|
||
exam_data = json.load(f)
|
||
|
||
# 创建ID到测试项的映射
|
||
exam_map = {item["id"]: item for item in exam_data}
|
||
|
||
# 准备重跑的测试项
|
||
rerun_items = []
|
||
for result in failed_results:
|
||
item_id = result["id"]
|
||
if item_id in exam_map:
|
||
rerun_items.append(exam_map[item_id])
|
||
|
||
total_items = len(rerun_items)
|
||
log.info(f"开始重跑失败的测试,总共 {total_items} 个任务")
|
||
|
||
# 创建ID到结果索引的映射,用于更新结果
|
||
result_indices = {r["id"]: i for i, r in enumerate(results)}
|
||
|
||
completed = 0
|
||
success_count = 0
|
||
fail_count = 0
|
||
|
||
# 使用线程池并发处理
|
||
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
|
||
# 提交所有任务
|
||
future_to_item = {executor.submit(process_item, item): item for item in rerun_items}
|
||
|
||
# 收集结果
|
||
for future in concurrent.futures.as_completed(future_to_item):
|
||
item = future_to_item[future]
|
||
result = future.result()
|
||
completed += 1
|
||
|
||
if result:
|
||
# 更新结果
|
||
if result["id"] in result_indices:
|
||
results[result_indices[result["id"]]] = result
|
||
else:
|
||
results.append(result)
|
||
|
||
if result["is_correct"]:
|
||
success_count += 1
|
||
else:
|
||
fail_count += 1
|
||
else:
|
||
fail_count += 1
|
||
|
||
# 打印当前进度
|
||
progress = (completed / total_items) * 100
|
||
log.info(f"重跑进度: {progress:.2f}% ({completed}/{total_items}) - 成功: {success_count}, 失败: {fail_count}")
|
||
|
||
# 保存更新后的结果
|
||
with open("temp_analysis/results.json", "w") as f:
|
||
json.dump(results, f, indent=2)
|
||
|
||
# 计算新的正确率
|
||
all_correct = sum(1 for r in results if r.get("is_correct", False))
|
||
all_total = len(results)
|
||
accuracy = all_correct / all_total if all_total > 0 else 0
|
||
|
||
log.info(f"重跑完成! 重跑: {total_items}题,成功: {success_count}题,失败: {fail_count}题")
|
||
log.info(f"总计: {all_total}题,正确: {all_correct}题,正确率: {accuracy:.2%}")
|
||
|
||
except FileNotFoundError:
|
||
log.error("找不到results.json文件,请先运行完整测试")
|
||
return
|
||
else:
|
||
# 运行完整测试
|
||
# 读取exam.json
|
||
with open("temp_analysis/exam.json", "r") as f:
|
||
exam_data = json.load(f)
|
||
|
||
results = []
|
||
total_items = len(exam_data)
|
||
completed = 0
|
||
success_count = 0
|
||
fail_count = 0
|
||
|
||
log.info(f"开始测试,总共 {total_items} 个任务")
|
||
|
||
# 使用线程池并发处理
|
||
with concurrent.futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="test_run") as executor:
|
||
# 提交所有任务
|
||
future_to_item = {executor.submit(process_item, item): item for item in exam_data}
|
||
|
||
# 收集结果
|
||
for future in concurrent.futures.as_completed(future_to_item):
|
||
result = future.result()
|
||
completed += 1
|
||
|
||
if result:
|
||
results.append(result)
|
||
if result["is_correct"]:
|
||
success_count += 1
|
||
else:
|
||
fail_count += 1
|
||
else:
|
||
fail_count += 1
|
||
|
||
# 打印当前进度
|
||
progress = (completed / total_items) * 100
|
||
log.info(f"进度: {progress:.2f}% ({completed}/{total_items}) - 成功: {success_count}, 失败: {fail_count}")
|
||
|
||
# 保存结果
|
||
with open("temp_analysis/results.json", "w") as f:
|
||
json.dump(results, f, indent=2)
|
||
|
||
# 计算正确率
|
||
accuracy = success_count / total_items if total_items > 0 else 0
|
||
|
||
log.info(f"测试完成! 总计: {total_items}题,正确: {success_count}题,错误: {fail_count}题,正确率: {accuracy:.2%}")
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
|
||
|
||
|
||
|
||
|
||
|