416 lines
16 KiB
Python
416 lines
16 KiB
Python
import json
|
||
import os
|
||
import re
|
||
import logging
|
||
import concurrent.futures
|
||
import argparse
|
||
import threading
|
||
import datetime
|
||
from openai import OpenAI
|
||
from dotenv import load_dotenv
|
||
|
||
# 设置日志
|
||
# 创建logger
|
||
log = logging.getLogger(__name__)
|
||
log.setLevel(logging.INFO)
|
||
|
||
# 创建文件处理器
|
||
file_handler = logging.FileHandler('temp_analysis/test_run.log')
|
||
file_handler.setLevel(logging.INFO)
|
||
|
||
# 创建控制台处理器
|
||
console_handler = logging.StreamHandler()
|
||
console_handler.setLevel(logging.INFO)
|
||
|
||
# 创建格式器并添加到处理器
|
||
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||
file_handler.setFormatter(formatter)
|
||
console_handler.setFormatter(formatter)
|
||
|
||
# 将处理器添加到logger
|
||
log.addHandler(file_handler)
|
||
log.addHandler(console_handler)
|
||
|
||
# 加载环境变量
|
||
load_dotenv()
|
||
|
||
# 配置OpenAI客户端
|
||
client = OpenAI(
|
||
api_key=os.getenv("OPENAI_API_KEY"),
|
||
base_url=os.getenv("OPENAI_API_BASE_URL") # 如果使用其他兼容服务,可以设置基础URL
|
||
)
|
||
|
||
# modified version of the task proposer agent prompt from https://arxiv.org/pdf/2502.11357
|
||
|
||
# System prompt
|
||
cot_system_prompt = """
|
||
What does this webpage show? Imagine you are a real user on this webpage. Given the webpage
|
||
screenshot or ocr result and parsed HTML/accessibility tree and the task description, please provide
|
||
the first action towards completing that task.
|
||
|
||
Do the following step by step:
|
||
1. Given the webpage screenshot or ocr result and parsed HTML/accessibility tree, generate the first action
|
||
towards completing that task (in natural language form).
|
||
2. Given the webpage screenshot or ocr result, parsed HTML/accessibility tree, and the natural language
|
||
action, generate the grounded version of that action.
|
||
|
||
ACTION SPACE: Your action space is: ['click [element ID]', 'type [element ID] [content]',
|
||
'select [element ID] [content of option to select]', 'scroll [up]', 'scroll [down]', and 'stop'].
|
||
Action output should follow the syntax as given below:
|
||
click [element ID]: This action clicks on an element with a specific ID on the webpage.
|
||
type [element ID] [content]: Use this to type the content into the field with id. By default, the
|
||
"Enter" key is pressed after typing. Both the content and the ID should be within square braces
|
||
as per the syntax.
|
||
select [element ID] [content of option to select]: Select an option from a dropdown menu. The
|
||
content of the option to select should be within square braces. When you get (select an option)
|
||
tags from the accessibility tree, you need to select the serial number (element_id) corresponding
|
||
to the select tag, not the option, and select the most likely content corresponding to the option as
|
||
input.
|
||
scroll [down]: Scroll the page down.
|
||
scroll [up]: Scroll the page up.
|
||
|
||
IMPORTANT:
|
||
To be successful, it is important to STRICTLY follow the below rules:
|
||
Action generation rules:
|
||
1. You should generate a single atomic action at each step.
|
||
2. The action should be an atomic action from the given vocabulary - click, type, select, scroll
|
||
(up or down), or stop.
|
||
3. The arguments to each action should be within square braces. For example, "click [127]",
|
||
"type [43] [content to type]", "scroll [up]", "scroll [down]".
|
||
4. The natural language form of action (corresponding to the field "action_in_natural_language")
|
||
should be consistent with the grounded version of the action (corresponding to the field "grounded
|
||
_action"). Do NOT add any additional information in the grounded action. For example, if a
|
||
particular element ID is specified in the grounded action, a description of that element must be
|
||
present in the natural language action.
|
||
5. If the type action is selected, the natural language form of action ("action_in_natural_language") should always specify the actual text to be typed.
|
||
6. You should issue a "stop" action if the current webpage asks to log in or for credit card
|
||
information.
|
||
7. To input text, there is NO need to click the textbox first, directly type content. After typing,
|
||
the system automatically hits the 'ENTER' key.
|
||
8. STRICTLY Avoid repeating the same action (click/type) if the webpage remains unchanged.
|
||
You may have selected the wrong web element.
|
||
9. Do NOT use quotation marks in the action generation.
|
||
|
||
OUTPUT FORMAT:
|
||
Please give a short analysis of the screenshot, parsed
|
||
HTML/accessibility tree, then put your answer within ``` ```, for example,
|
||
"In summary, the proposed task and the corresponding action is: ```{
|
||
"action_in_natural_language":<ACTION_IN_NATURAL_LANGUAGE>:str,
|
||
"grounded_action": <ACTION>:str}"```
|
||
"""
|
||
|
||
# User prompt
|
||
cot_user_prompt = """
|
||
Website URL: {INIT_URL}
|
||
Parsed HTML/Accessibility Tree: {A11Y_TREE}
|
||
Screenshot ocr result: {SCREENSHOT}
|
||
Task description: {TASK_DESCRIPTION}
|
||
"""
|
||
|
||
# 创建文件锁
|
||
file_locks = {
|
||
"temp_analysis/results.json": threading.Lock(),
|
||
"temp_analysis/results_success.json": threading.Lock(),
|
||
"temp_analysis/results_failure.json": threading.Lock()
|
||
}
|
||
|
||
def call_api(messages):
|
||
"""使用openai库调用API接口"""
|
||
try:
|
||
response = client.chat.completions.create(
|
||
messages=messages,
|
||
# model="gpt-4o-mini"
|
||
model="aiproxy/deepseek-reasoner"
|
||
)
|
||
|
||
return response
|
||
except Exception as e:
|
||
log.error(f"API调用出错: {e}")
|
||
return None
|
||
|
||
def extract_action(response_text):
|
||
"""从API响应中提取action_in_natural_language和grounded_action"""
|
||
# 使用正则表达式提取JSON部分
|
||
match = re.search(r'```\s*{(.+?)}```', response_text, re.DOTALL)
|
||
if match:
|
||
try:
|
||
# 构建完整的JSON字符串并解析
|
||
json_str = "{" + match.group(1) + "}"
|
||
action_data = json.loads(json_str)
|
||
return action_data.get("action_in_natural_language"), action_data.get("grounded_action")
|
||
except json.JSONDecodeError:
|
||
log.error(f"无法解析JSON: {match.group(1)}")
|
||
return None, None
|
||
|
||
def check_answer(grounded_action, answer):
|
||
"""检查grounded_action是否匹配answer"""
|
||
if not grounded_action or not grounded_action.startswith("click"):
|
||
return False
|
||
|
||
# 提取element ID
|
||
match = re.search(r'click \[(\d+)\]', grounded_action)
|
||
if not match:
|
||
return False
|
||
|
||
element_id = match.group(1)
|
||
# 将answer拆分为列表(逗号分隔)
|
||
answer_ids = [id.strip() for id in answer.split(",")]
|
||
|
||
# 检查element ID是否在answer列表中
|
||
return element_id in answer_ids
|
||
|
||
def load_successful_items():
|
||
"""加载已经成功的测试项目ID"""
|
||
successful_ids = set()
|
||
try:
|
||
if os.path.exists("temp_analysis/results_success.json"):
|
||
with open("temp_analysis/results_success.json", "r") as f:
|
||
try:
|
||
success_data = json.load(f)
|
||
successful_ids = {item["id"] for item in success_data if "id" in item}
|
||
except json.JSONDecodeError:
|
||
log.warning("无法解析results_success.json,将视为空文件")
|
||
except Exception as e:
|
||
log.error(f"加载成功项目时出错: {e}")
|
||
|
||
return successful_ids
|
||
|
||
def process_item(item, successful_ids, max_retries=3):
|
||
"""处理单个测试项目,失败时最多重试3次"""
|
||
item_id = item["id"]
|
||
|
||
# 检查是否已经成功完成
|
||
if item_id in successful_ids:
|
||
log.info(f"ID: {item_id} 已经成功完成,跳过")
|
||
return None
|
||
|
||
url = item["url"]
|
||
task_description = item["question"]
|
||
answer = item["answer"]
|
||
answer_text = item["answer_text"]
|
||
|
||
log.info(f"处理ID: {item_id}, URL: {url}")
|
||
|
||
# 读取a11y tree
|
||
try:
|
||
with open(f"axtrees/{item_id}.txt", "r") as f:
|
||
a11y_tree = f.read()
|
||
except FileNotFoundError:
|
||
log.error(f"找不到文件: axtrees/{item_id}.txt")
|
||
return None
|
||
|
||
# 构建API请求
|
||
messages = [
|
||
{"role": "system", "content": cot_system_prompt},
|
||
{"role": "user", "content": cot_user_prompt.format(
|
||
INIT_URL=url,
|
||
A11Y_TREE=a11y_tree,
|
||
SCREENSHOT="", # 这里没有提供截图OCR结果
|
||
TASK_DESCRIPTION=task_description
|
||
)}
|
||
]
|
||
|
||
print(f"messages: #######\n {messages} \n######")
|
||
|
||
log.info(f"task_description: {task_description}")
|
||
log.info(f"answer: {answer}, answer_text: {answer_text}")
|
||
|
||
# 尝试调用API,最多重试max_retries次
|
||
attempts_made = 0
|
||
all_attempts = [] # 记录所有尝试的结果
|
||
|
||
for attempt in range(max_retries):
|
||
attempts_made = attempt + 1
|
||
timestamp = datetime.datetime.now().isoformat()
|
||
|
||
|
||
# 调用API
|
||
response = call_api(messages)
|
||
|
||
if response and hasattr(response, 'choices') and len(response.choices) > 0:
|
||
content = response.choices[0].message.content
|
||
reasoning_content = response.choices[0].message.reasoning_content if hasattr(response.choices[0].message, 'reasoning_content') else ""
|
||
|
||
log.info(f"reasoning_content: {reasoning_content}")
|
||
log.info(f"content: {content}")
|
||
|
||
# 提取action
|
||
action_nl, grounded_action = extract_action(content)
|
||
|
||
# 检查答案
|
||
is_correct = check_answer(grounded_action, answer)
|
||
|
||
# 记录本次尝试
|
||
attempt_result = {
|
||
"attempt": attempt + 1,
|
||
"timestamp": timestamp,
|
||
"action_nl": action_nl,
|
||
"grounded_action": grounded_action,
|
||
"is_correct": is_correct,
|
||
"thinking": reasoning_content
|
||
}
|
||
all_attempts.append(attempt_result)
|
||
|
||
# 如果成功,直接返回结果
|
||
if is_correct:
|
||
log.info(f"ID: {item_id} 测试成功!")
|
||
break
|
||
else:
|
||
log.info(f"ID: {item_id} 测试失败,尝试 {attempt+1}/{max_retries}")
|
||
# 如果不是最后一次尝试,继续重试
|
||
if attempt < max_retries - 1:
|
||
continue
|
||
else:
|
||
log.error(f"API调用失败: {response}")
|
||
# 记录失败的尝试
|
||
attempt_result = {
|
||
"attempt": attempt + 1,
|
||
"timestamp": timestamp,
|
||
"action_nl": None,
|
||
"grounded_action": None,
|
||
"is_correct": False,
|
||
"thinking": "API调用失败"
|
||
}
|
||
all_attempts.append(attempt_result)
|
||
|
||
# 如果不是最后一次尝试,继续重试
|
||
if attempt < max_retries - 1:
|
||
log.info(f"ID: {item_id} API调用失败,尝试 {attempt+1}/{max_retries}")
|
||
continue
|
||
else:
|
||
reasoning_content = ""
|
||
content = ""
|
||
action_nl = None
|
||
grounded_action = None
|
||
is_correct = False
|
||
|
||
# 获取最后一次尝试的结果(无论成功与否)
|
||
last_attempt = all_attempts[-1] if all_attempts else None
|
||
|
||
# 记录结果
|
||
result = {
|
||
"id": item_id,
|
||
"url": url,
|
||
"task": task_description,
|
||
"expected_answer": answer,
|
||
"answer_text": answer_text,
|
||
"thinking": last_attempt.get("thinking", "") if last_attempt else "",
|
||
"action_nl": last_attempt.get("action_nl", None) if last_attempt else None,
|
||
"grounded_action": last_attempt.get("grounded_action", None) if last_attempt else None,
|
||
"is_correct": last_attempt.get("is_correct", False) if last_attempt else False,
|
||
"attempts": attempts_made, # 添加重试次数记录
|
||
"timestamp": datetime.datetime.now().isoformat(), # 添加时间戳
|
||
"all_attempts": all_attempts # 记录所有尝试的详细信息
|
||
}
|
||
|
||
log.info(f"ID: {item_id}")
|
||
log.info(f"任务: {task_description}")
|
||
log.info(f"动作: {last_attempt.get('grounded_action', None) if last_attempt else None}")
|
||
log.info(f"是否正确: {last_attempt.get('is_correct', False) if last_attempt else False}")
|
||
log.info(f"尝试次数: {attempts_made}")
|
||
log.info("-" * 50)
|
||
|
||
return result
|
||
|
||
def append_to_result_file(result, filename):
|
||
"""将结果追加到指定的JSON文件中,使用锁防止并发冲突"""
|
||
try:
|
||
# 获取对应文件的锁
|
||
with file_locks.get(filename, threading.Lock()):
|
||
# 如果文件存在,读取现有内容
|
||
if os.path.exists(filename):
|
||
with open(filename, 'r') as f:
|
||
try:
|
||
data = json.load(f)
|
||
except json.JSONDecodeError:
|
||
# 如果文件为空或格式不正确,创建新列表
|
||
data = []
|
||
else:
|
||
data = []
|
||
|
||
# 追加新结果
|
||
data.append(result)
|
||
|
||
# 写回文件
|
||
with open(filename, 'w') as f:
|
||
json.dump(data, f, indent=2)
|
||
|
||
return True
|
||
except Exception as e:
|
||
log.error(f"写入结果到文件 {filename} 时出错: {e}")
|
||
return False
|
||
|
||
def main():
|
||
# 读取exam.json
|
||
with open("temp_analysis/exam.json", "r") as f:
|
||
exam_data = json.load(f)
|
||
|
||
# 加载已经成功的测试项目ID
|
||
successful_ids = load_successful_items()
|
||
log.info(f"已经成功完成的测试项目数: {len(successful_ids)}, 成功ID: {successful_ids}")
|
||
|
||
results = []
|
||
total_items = len(exam_data)
|
||
completed = 0
|
||
success_count = 0
|
||
fail_count = 0
|
||
skip_count = 0
|
||
|
||
log.info(f"开始测试,总共 {total_items} 个任务")
|
||
|
||
# 确保结果文件存在并初始化为空列表(如果不存在的话)
|
||
for filename in ["temp_analysis/results.json", "temp_analysis/results_failure.json"]:
|
||
if not os.path.exists(filename):
|
||
with open(filename, 'w') as f:
|
||
json.dump([], f)
|
||
|
||
# 使用线程池并发处理
|
||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||
# 提交所有任务
|
||
future_to_item = {executor.submit(process_item, item, successful_ids): item for item in exam_data}
|
||
|
||
# 收集结果
|
||
for future in concurrent.futures.as_completed(future_to_item):
|
||
result = future.result()
|
||
completed += 1
|
||
|
||
if result is None:
|
||
skip_count += 1
|
||
log.info(f"跳过一个已完成的测试项目")
|
||
elif result:
|
||
results.append(result)
|
||
|
||
# 立即追加到结果文件
|
||
append_to_result_file(result, "temp_analysis/results.json")
|
||
|
||
if result["is_correct"]:
|
||
success_count += 1
|
||
append_to_result_file(result, "temp_analysis/results_success.json")
|
||
else:
|
||
fail_count += 1
|
||
append_to_result_file(result, "temp_analysis/results_failure.json")
|
||
else:
|
||
fail_count += 1
|
||
|
||
# 打印当前进度
|
||
progress = (completed / total_items) * 100
|
||
log.info(f"进度: {progress:.2f}% ({completed}/{total_items}) - 成功: {success_count}, 失败: {fail_count}, 跳过: {skip_count}")
|
||
|
||
# 计算正确率
|
||
total_processed = success_count + fail_count
|
||
accuracy = success_count / total_processed if total_processed > 0 else 0
|
||
|
||
log.info(f"测试完成! 总计: {total_items}题,成功: {success_count}题,失败: {fail_count}题,跳过: {skip_count}题,正确率: {accuracy:.2%}")
|
||
log.info(f"成功结果已保存到 results_success.json")
|
||
log.info(f"失败结果已保存到 results_failure.json")
|
||
log.info(f"全部结果已保存到 results.json")
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
|
||
|
||
|
||
|
||
|
||
|