crawlee/misc/trajectory_test_v17.py
2025-04-23 12:14:50 +08:00

310 lines
9.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import re
import requests
import logging
import base64
from openai import OpenAI
from dotenv import load_dotenv
from PIL import Image
import random
import matplotlib.pyplot as plt
from collections import Counter
from datetime import datetime
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def load_json_data(file_path):
"""
从指定路径加载JSON数据
"""
try:
with open(file_path, 'r', encoding='utf-8') as file:
return json.load(file)
except Exception as e:
logger.error(f"加载JSON文件失败: {e}")
return None
def encode_image_to_base64(image_path):
"""
将图片编码为base64字符串
"""
try:
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode('utf-8')
except Exception as e:
logger.error(f"图片编码失败: {e}")
return None
def call_gpt4o_mini(image_path, title, messages=None):
"""
使用 OpenAI SDK 调用 GPT-4o-mini 模型,获取下一步应该点击的坐标
"""
# 加载环境变量
load_dotenv()
# 读取并编码图片
image_base64 = encode_image_to_base64(image_path)
if not image_base64:
return None
# 从环境变量获取API配置
api_base = os.getenv('OPENAI_API_BASE_URL')
api_key = os.getenv('OPENAI_API_KEY')
if not api_base or not api_key:
logger.error("未找到API配置环境变量")
return None
# 初始化 OpenAI 客户端
client = OpenAI(
api_key=api_key,
base_url=api_base
)
# 如果是新会话,初始化消息列表
if messages is None:
messages = [
{
"role": "system",
"content": """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
## Output Format
```
Thought: ...
Action: ...
```
## Action Space
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='') #If you want to submit your input, use \"\" at the end of `content`.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished()
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
## Note
- Use Chinese in `Thought` part.
- Summarize your next action (with its target element) in one sentence in `Thought` part."""
}
]
# 构建当前步骤的用户消息
user_message = {
"role": "user",
"content": [
{"type": "text", "text": f"你的任务是'{title}'"},
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{image_base64}"
}
}
]
}
# 添加用户消息到会话历史
messages.append(user_message)
try:
response = client.chat.completions.create(
model="UI-TARS-72B-DPO",
messages=messages
)
# 获取助手的回复
assistant_message = {
"role": "assistant",
"content": response.choices[0].message.content
}
# 添加助手的回复到会话历史
messages.append(assistant_message)
return response.choices[0].message.content, messages
except Exception as e:
logger.error(f"API调用失败: {e}")
return None, messages
def extract_coordinates(response_text, image_path):
"""
从API响应中提取坐标并将1000x1000的坐标映射到实际图片尺寸
"""
logger.info(f"API响应: {response_text}")
try:
# 使用正则表达式提取坐标
pattern = r'\((\d+),(\d+)\)' # 匹配格式 (x,y)
match = re.search(pattern, response_text)
if match:
# 获取1000x1000下的坐标
x_1000 = int(match.group(1))
y_1000 = int(match.group(2))
# 获取实际图片尺寸
with Image.open(image_path) as img:
actual_width, actual_height = img.size
# 映射坐标到实际尺寸
x = round(actual_width * x_1000 / 1000)
y = round(actual_height * y_1000 / 1000)
logger.info(f"坐标映射: 从 ({x_1000}, {y_1000}) 映射到 ({x}, {y})")
return {"x": x, "y": y}
return None
except Exception as e:
logger.error(f"坐标提取或映射失败: {e}")
return None
def is_point_in_box(point, box):
"""
判断点是否在边界框内
"""
return (box["x"] <= point["x"] <= box["x"] + box["width"] and
box["y"] <= point["y"] <= box["y"] + box["height"])
def test_path(url, path_data):
"""
测试单个路径
"""
logger.info(f"开始测试路径: {url}")
meta = path_data["shortestPathsMeta"][0]
title = meta["title"]
chain_ids = meta["chainIDs"]
child_nums = meta["chainChildNum"]
viewport_boxes = meta["chainViewportBoundingBoxes"]
total_steps = len(chain_ids) - 1
success_steps = 0
logger.info(f"任务标题: {title}")
logger.info(f"总步骤数: {total_steps}")
# 初始化会话历史
messages = None
# 遍历每个步骤
for i in range(total_steps):
current_id = chain_ids[i]
child_num = child_nums[i] if i < len(child_nums) else None
# 构建图片路径
if child_num is not None:
image_path = f"screenshots/{current_id}_{child_num}.png"
else:
image_path = f"screenshots/{current_id}.png"
logger.info(f"步骤 {i+1}/{total_steps}: 处理图片 {image_path}")
# 检查图片是否存在
if not os.path.exists(image_path):
logger.error(f"图片不存在: {image_path}")
break
# 调用模型获取点击坐标,传入会话历史
response, messages = call_gpt4o_mini(image_path, title, messages)
if not response:
logger.error("模型调用失败")
break
# 提取坐标
coordinates = extract_coordinates(response, image_path)
if not coordinates:
logger.error("无法从响应中提取坐标")
break
logger.info(f"模型返回坐标: {coordinates}")
# 判断坐标是否在边界框内
target_box = viewport_boxes[i] if i < len(viewport_boxes) else None
if target_box and is_point_in_box(coordinates, target_box):
logger.info("坐标在边界框内,步骤成功!")
success_steps += 1
else:
logger.warning(f"坐标不在边界框内,步骤失败。目标边界框: {target_box}")
break
# 总结测试结果
is_success = (success_steps == total_steps)
logger.info(f"路径测试完成: 总步骤 {total_steps}, 成功步骤 {success_steps}, 路径是否完全成功: {is_success}")
return {
"url": url,
"title": title,
"total_steps": total_steps,
"success_steps": success_steps,
"is_success": is_success
}
def main():
"""
主函数读取JSON数据并测试随机选择的10个路径
"""
# 加载JSON数据
json_path = "path/processed_3_with_analysis.json"
data = load_json_data(json_path)
if not data:
logger.error("数据加载失败,程序退出")
return
# 随机选择10个URL
urls = list(data.keys())
if len(urls) > 10:
selected_urls = random.sample(urls, 10)
else:
selected_urls = urls
# 保存所有测试结果
results = []
# 测试选中的URL
for url in selected_urls:
path_data = data[url]
result = test_path(url, path_data)
results.append(result)
# 输出总体统计
total_paths = len(results)
successful_paths = sum(1 for r in results if r["is_success"])
logger.info(f"测试完成: 总路径数 {total_paths}, 成功路径数 {successful_paths}")
# 输出详细结果
for result in results:
logger.info(f"URL: {result['url']}")
logger.info(f" 任务: {result['title']}")
logger.info(f" 总步骤: {result['total_steps']}, 成功步骤: {result['success_steps']}")
logger.info(f" 路径是否成功: {result['is_success']}")
# 统计成功步数的分布
success_steps_counts = Counter(result['success_steps'] for result in results)
# 绘制柱状图
plt.figure(figsize=(10, 6))
steps = sorted(success_steps_counts.keys())
counts = [success_steps_counts[step] for step in steps]
plt.bar(steps, counts)
plt.title('成功步数分布')
plt.xlabel('成功步数')
plt.ylabel('URL数量')
# 在每个柱子上方显示具体数值
for i, count in enumerate(counts):
plt.text(steps[i], count, str(count), ha='center', va='bottom')
# 保存图片
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_path = f'success_steps_distribution_{timestamp}.png'
plt.savefig(save_path)
plt.close()
logger.info(f"成功步数分布图已保存为 {save_path}")
if __name__ == "__main__":
main()