crawlee/misc/temp_analysis/test_run copy 4.py
2025-04-23 12:14:50 +08:00

365 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import re
import logging
import concurrent.futures
import argparse
from openai import OpenAI
from dotenv import load_dotenv
# 设置日志
# 创建logger
log = logging.getLogger(__name__)
log.setLevel(logging.INFO)
# 创建文件处理器
file_handler = logging.FileHandler('temp_analysis/test_run.log')
file_handler.setLevel(logging.INFO)
# 创建控制台处理器
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
# 创建格式器并添加到处理器
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter)
# 将处理器添加到logger
log.addHandler(file_handler)
log.addHandler(console_handler)
# 加载环境变量
load_dotenv()
# 配置OpenAI客户端
client = OpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
base_url=os.getenv("OPENAI_API_BASE_URL") # 如果使用其他兼容服务可以设置基础URL
)
# modified version of the task proposer agent prompt from https://arxiv.org/pdf/2502.11357
# System prompt
cot_system_prompt = """
What does this webpage show? Imagine you are a real user on this webpage. Given the webpage
screenshot or ocr result and parsed HTML/accessibility tree and the task description, please provide
the first action towards completing that task.
Do the following step by step:
1. Given the webpage screenshot or ocr result and parsed HTML/accessibility tree, generate the first action
towards completing that task (in natural language form).
2. Given the webpage screenshot or ocr result, parsed HTML/accessibility tree, and the natural language
action, generate the grounded version of that action.
ACTION SPACE: Your action space is: ['click [element ID]', 'type [element ID] [content]',
'select [element ID] [content of option to select]', 'scroll [up]', 'scroll [down]', and 'stop'].
Action output should follow the syntax as given below:
click [element ID]: This action clicks on an element with a specific ID on the webpage.
type [element ID] [content]: Use this to type the content into the field with id. By default, the
"Enter" key is pressed after typing. Both the content and the ID should be within square braces
as per the syntax.
select [element ID] [content of option to select]: Select an option from a dropdown menu. The
content of the option to select should be within square braces. When you get (select an option)
tags from the accessibility tree, you need to select the serial number (element_id) corresponding
to the select tag, not the option, and select the most likely content corresponding to the option as
input.
scroll [down]: Scroll the page down.
scroll [up]: Scroll the page up.
IMPORTANT:
To be successful, it is important to STRICTLY follow the below rules:
Action generation rules:
1. You should generate a single atomic action at each step.
2. The action should be an atomic action from the given vocabulary - click, type, select, scroll
(up or down), or stop.
3. The arguments to each action should be within square braces. For example, "click [127]",
"type [43] [content to type]", "scroll [up]", "scroll [down]".
4. The natural language form of action (corresponding to the field "action_in_natural_language")
should be consistent with the grounded version of the action (corresponding to the field "grounded
_action"). Do NOT add any additional information in the grounded action. For example, if a
particular element ID is specified in the grounded action, a description of that element must be
present in the natural language action.
5. If the type action is selected, the natural language form of action ("action_in_natural_language") should always specify the actual text to be typed.
6. You should issue a "stop" action if the current webpage asks to log in or for credit card
information.
7. To input text, there is NO need to click the textbox first, directly type content. After typing,
the system automatically hits the 'ENTER' key.
8. STRICTLY Avoid repeating the same action (click/type) if the webpage remains unchanged.
You may have selected the wrong web element.
9. Do NOT use quotation marks in the action generation.
OUTPUT FORMAT:
Please give a short analysis of the screenshot, parsed
HTML/accessibility tree, then put your answer within ``` ```, for example,
"In summary, the proposed task and the corresponding action is: ```{
"action_in_natural_language":<ACTION_IN_NATURAL_LANGUAGE>:str,
"grounded_action": <ACTION>:str}"```
"""
# User prompt
cot_user_prompt = """
Website URL: {INIT_URL}
Parsed HTML/Accessibility Tree: {A11Y_TREE}
Screenshot ocr result: {SCREENSHOT}
Task description: {TASK_DESCRIPTION}
"""
def call_api(messages):
"""使用openai库调用API接口"""
try:
response = client.chat.completions.create(
messages=messages,
# model="gpt-4o-mini"
model="aiproxy/deepseek-reasoner"
)
return response
except Exception as e:
log.error(f"API调用出错: {e}")
return None
def extract_action(response_text):
"""从API响应中提取action_in_natural_language和grounded_action"""
# 使用正则表达式提取JSON部分
match = re.search(r'```\s*{(.+?)}```', response_text, re.DOTALL)
if match:
try:
# 构建完整的JSON字符串并解析
json_str = "{" + match.group(1) + "}"
action_data = json.loads(json_str)
return action_data.get("action_in_natural_language"), action_data.get("grounded_action")
except json.JSONDecodeError:
log.error(f"无法解析JSON: {match.group(1)}")
return None, None
def check_answer(grounded_action, answer):
"""检查grounded_action是否匹配answer"""
if not grounded_action or not grounded_action.startswith("click"):
return False
# 提取element ID
match = re.search(r'click \[(\d+)\]', grounded_action)
if not match:
return False
element_id = match.group(1)
# 将answer拆分为列表逗号分隔
answer_ids = [id.strip() for id in answer.split(",")]
# 检查element ID是否在answer列表中
return element_id in answer_ids
def process_item(item):
"""处理单个测试项目"""
item_id = item["id"]
url = item["url"]
task_description = item["question"]
answer = item["answer"]
answer_text = item["answer_text"]
log.info(f"处理ID: {item_id}, URL: {url}")
# 读取a11y tree
try:
with open(f"axtrees/{item_id}.txt", "r") as f:
a11y_tree = f.read()
except FileNotFoundError:
log.error(f"找不到文件: axtrees/{item_id}.txt")
return None
# 构建API请求
messages = [
{"role": "system", "content": cot_system_prompt},
{"role": "user", "content": cot_user_prompt.format(
INIT_URL=url,
A11Y_TREE=a11y_tree,
SCREENSHOT="", # 这里没有提供截图OCR结果
TASK_DESCRIPTION=task_description
)}
]
log.info(f"task_description: {task_description}")
log.info(f"answer: {answer}, answer_text: {answer_text}")
# 调用API
response = call_api(messages)
if response and hasattr(response, 'choices') and len(response.choices) > 0:
content = response.choices[0].message.content
reasoning_content = response.choices[0].message.reasoning_content
log.info(f"reasoning_content: {reasoning_content}")
log.info(f"content: {content}")
# 提取action
action_nl, grounded_action = extract_action(content)
# 检查答案
is_correct = check_answer(grounded_action, answer)
# 记录结果
result = {
"id": item_id,
"url": url,
"task": task_description,
"expected_answer": answer,
"thinking": reasoning_content,
"action_nl": action_nl,
"grounded_action": grounded_action,
"is_correct": is_correct
}
log.info(f"ID: {item_id}")
log.info(f"任务: {task_description}")
log.info(f"动作: {grounded_action}")
log.info(f"是否正确: {is_correct}")
log.info("-" * 50)
return result
else:
log.error(f"API调用失败: {response}")
return None
def main():
# 解析命令行参数
parser = argparse.ArgumentParser(description='运行测试或重跑失败的测试')
parser.add_argument('--rerun_failure', action='store_true', help='重跑失败的测试用例')
args = parser.parse_args()
if args.rerun_failure:
# 重跑失败的测试
try:
with open("temp_analysis/results.json", "r") as f:
results = json.load(f)
# 筛选出失败的测试用例
failed_results = [r for r in results if not r.get("is_correct", False)]
if not failed_results:
log.info("没有找到失败的测试用例,无需重跑")
return
# 读取exam.json以获取完整的测试数据
with open("temp_analysis/exam.json", "r") as f:
exam_data = json.load(f)
# 创建ID到测试项的映射
exam_map = {item["id"]: item for item in exam_data}
# 准备重跑的测试项
rerun_items = []
for result in failed_results:
item_id = result["id"]
if item_id in exam_map:
rerun_items.append(exam_map[item_id])
total_items = len(rerun_items)
log.info(f"开始重跑失败的测试,总共 {total_items} 个任务")
# 创建ID到结果索引的映射用于更新结果
result_indices = {r["id"]: i for i, r in enumerate(results)}
completed = 0
success_count = 0
fail_count = 0
# 使用线程池并发处理
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
# 提交所有任务
future_to_item = {executor.submit(process_item, item): item for item in rerun_items}
# 收集结果
for future in concurrent.futures.as_completed(future_to_item):
item = future_to_item[future]
result = future.result()
completed += 1
if result:
# 更新结果
if result["id"] in result_indices:
results[result_indices[result["id"]]] = result
else:
results.append(result)
if result["is_correct"]:
success_count += 1
else:
fail_count += 1
else:
fail_count += 1
# 打印当前进度
progress = (completed / total_items) * 100
log.info(f"重跑进度: {progress:.2f}% ({completed}/{total_items}) - 成功: {success_count}, 失败: {fail_count}")
# 保存更新后的结果
with open("temp_analysis/results.json", "w") as f:
json.dump(results, f, indent=2)
# 计算新的正确率
all_correct = sum(1 for r in results if r.get("is_correct", False))
all_total = len(results)
accuracy = all_correct / all_total if all_total > 0 else 0
log.info(f"重跑完成! 重跑: {total_items}题,成功: {success_count}题,失败: {fail_count}")
log.info(f"总计: {all_total}题,正确: {all_correct}题,正确率: {accuracy:.2%}")
except FileNotFoundError:
log.error("找不到results.json文件请先运行完整测试")
return
else:
# 运行完整测试
# 读取exam.json
with open("temp_analysis/exam.json", "r") as f:
exam_data = json.load(f)
results = []
total_items = len(exam_data)
completed = 0
success_count = 0
fail_count = 0
log.info(f"开始测试,总共 {total_items} 个任务")
# 使用线程池并发处理
with concurrent.futures.ThreadPoolExecutor(max_workers=1, thread_name_prefix="test_run") as executor:
# 提交所有任务
future_to_item = {executor.submit(process_item, item): item for item in exam_data}
# 收集结果
for future in concurrent.futures.as_completed(future_to_item):
result = future.result()
completed += 1
if result:
results.append(result)
if result["is_correct"]:
success_count += 1
else:
fail_count += 1
else:
fail_count += 1
# 打印当前进度
progress = (completed / total_items) * 100
log.info(f"进度: {progress:.2f}% ({completed}/{total_items}) - 成功: {success_count}, 失败: {fail_count}")
# 保存结果
with open("temp_analysis/results.json", "w") as f:
json.dump(results, f, indent=2)
# 计算正确率
accuracy = success_count / total_items if total_items > 0 else 0
log.info(f"测试完成! 总计: {total_items}题,正确: {success_count}题,错误: {fail_count}题,正确率: {accuracy:.2%}")
if __name__ == "__main__":
main()