310 lines
9.6 KiB
Python
310 lines
9.6 KiB
Python
import json
|
||
import os
|
||
import re
|
||
import requests
|
||
import logging
|
||
import base64
|
||
from openai import OpenAI
|
||
from dotenv import load_dotenv
|
||
from PIL import Image
|
||
import random
|
||
import matplotlib.pyplot as plt
|
||
from collections import Counter
|
||
from datetime import datetime
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
def load_json_data(file_path):
|
||
"""
|
||
从指定路径加载JSON数据
|
||
"""
|
||
try:
|
||
with open(file_path, 'r', encoding='utf-8') as file:
|
||
return json.load(file)
|
||
except Exception as e:
|
||
logger.error(f"加载JSON文件失败: {e}")
|
||
return None
|
||
|
||
def encode_image_to_base64(image_path):
|
||
"""
|
||
将图片编码为base64字符串
|
||
"""
|
||
try:
|
||
with open(image_path, "rb") as image_file:
|
||
return base64.b64encode(image_file.read()).decode('utf-8')
|
||
except Exception as e:
|
||
logger.error(f"图片编码失败: {e}")
|
||
return None
|
||
|
||
def call_gpt4o_mini(image_path, title, messages=None):
|
||
"""
|
||
使用 OpenAI SDK 调用 GPT-4o-mini 模型,获取下一步应该点击的坐标
|
||
"""
|
||
# 加载环境变量
|
||
load_dotenv()
|
||
|
||
# 读取并编码图片
|
||
image_base64 = encode_image_to_base64(image_path)
|
||
if not image_base64:
|
||
return None
|
||
|
||
# 从环境变量获取API配置
|
||
api_base = os.getenv('OPENAI_API_BASE_URL')
|
||
api_key = os.getenv('OPENAI_API_KEY')
|
||
|
||
if not api_base or not api_key:
|
||
logger.error("未找到API配置环境变量")
|
||
return None
|
||
|
||
# 初始化 OpenAI 客户端
|
||
client = OpenAI(
|
||
api_key=api_key,
|
||
base_url=api_base
|
||
)
|
||
|
||
# 如果是新会话,初始化消息列表
|
||
if messages is None:
|
||
messages = [
|
||
{
|
||
"role": "system",
|
||
"content": """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||
|
||
## Output Format
|
||
```
|
||
Thought: ...
|
||
Action: ...
|
||
```
|
||
|
||
## Action Space
|
||
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
|
||
hotkey(key='')
|
||
type(content='') #If you want to submit your input, use \"\" at the end of `content`.
|
||
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
||
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||
finished()
|
||
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
|
||
|
||
## Note
|
||
- Use Chinese in `Thought` part.
|
||
- Summarize your next action (with its target element) in one sentence in `Thought` part."""
|
||
}
|
||
]
|
||
|
||
# 构建当前步骤的用户消息
|
||
user_message = {
|
||
"role": "user",
|
||
"content": [
|
||
{"type": "text", "text": f"你的任务是'{title}'。"},
|
||
{
|
||
"type": "image_url",
|
||
"image_url": {
|
||
"url": f"data:image/png;base64,{image_base64}"
|
||
}
|
||
}
|
||
]
|
||
}
|
||
|
||
# 添加用户消息到会话历史
|
||
messages.append(user_message)
|
||
|
||
try:
|
||
response = client.chat.completions.create(
|
||
model="UI-TARS-72B-DPO",
|
||
messages=messages
|
||
)
|
||
|
||
# 获取助手的回复
|
||
assistant_message = {
|
||
"role": "assistant",
|
||
"content": response.choices[0].message.content
|
||
}
|
||
|
||
# 添加助手的回复到会话历史
|
||
messages.append(assistant_message)
|
||
|
||
return response.choices[0].message.content, messages
|
||
except Exception as e:
|
||
logger.error(f"API调用失败: {e}")
|
||
return None, messages
|
||
|
||
def extract_coordinates(response_text, image_path):
|
||
"""
|
||
从API响应中提取坐标,并将1000x1000的坐标映射到实际图片尺寸
|
||
"""
|
||
logger.info(f"API响应: {response_text}")
|
||
try:
|
||
# 使用正则表达式提取坐标
|
||
pattern = r'\((\d+),(\d+)\)' # 匹配格式 (x,y)
|
||
match = re.search(pattern, response_text)
|
||
if match:
|
||
# 获取1000x1000下的坐标
|
||
x_1000 = int(match.group(1))
|
||
y_1000 = int(match.group(2))
|
||
|
||
# 获取实际图片尺寸
|
||
with Image.open(image_path) as img:
|
||
actual_width, actual_height = img.size
|
||
|
||
# 映射坐标到实际尺寸
|
||
x = round(actual_width * x_1000 / 1000)
|
||
y = round(actual_height * y_1000 / 1000)
|
||
|
||
logger.info(f"坐标映射: 从 ({x_1000}, {y_1000}) 映射到 ({x}, {y})")
|
||
return {"x": x, "y": y}
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"坐标提取或映射失败: {e}")
|
||
return None
|
||
|
||
def is_point_in_box(point, box):
|
||
"""
|
||
判断点是否在边界框内
|
||
"""
|
||
return (box["x"] <= point["x"] <= box["x"] + box["width"] and
|
||
box["y"] <= point["y"] <= box["y"] + box["height"])
|
||
|
||
def test_path(url, path_data):
|
||
"""
|
||
测试单个路径
|
||
"""
|
||
logger.info(f"开始测试路径: {url}")
|
||
meta = path_data["shortestPathsMeta"][0]
|
||
title = meta["title"]
|
||
chain_ids = meta["chainIDs"]
|
||
child_nums = meta["chainChildNum"]
|
||
viewport_boxes = meta["chainViewportBoundingBoxes"]
|
||
|
||
total_steps = len(chain_ids) - 1
|
||
success_steps = 0
|
||
|
||
logger.info(f"任务标题: {title}")
|
||
logger.info(f"总步骤数: {total_steps}")
|
||
|
||
# 初始化会话历史
|
||
messages = None
|
||
|
||
# 遍历每个步骤
|
||
for i in range(total_steps):
|
||
current_id = chain_ids[i]
|
||
child_num = child_nums[i] if i < len(child_nums) else None
|
||
|
||
# 构建图片路径
|
||
if child_num is not None:
|
||
image_path = f"screenshots/{current_id}_{child_num}.png"
|
||
else:
|
||
image_path = f"screenshots/{current_id}.png"
|
||
|
||
logger.info(f"步骤 {i+1}/{total_steps}: 处理图片 {image_path}")
|
||
|
||
# 检查图片是否存在
|
||
if not os.path.exists(image_path):
|
||
logger.error(f"图片不存在: {image_path}")
|
||
break
|
||
|
||
# 调用模型获取点击坐标,传入会话历史
|
||
response, messages = call_gpt4o_mini(image_path, title, messages)
|
||
if not response:
|
||
logger.error("模型调用失败")
|
||
break
|
||
|
||
# 提取坐标
|
||
coordinates = extract_coordinates(response, image_path)
|
||
if not coordinates:
|
||
logger.error("无法从响应中提取坐标")
|
||
break
|
||
|
||
logger.info(f"模型返回坐标: {coordinates}")
|
||
|
||
# 判断坐标是否在边界框内
|
||
target_box = viewport_boxes[i] if i < len(viewport_boxes) else None
|
||
if target_box and is_point_in_box(coordinates, target_box):
|
||
logger.info("坐标在边界框内,步骤成功!")
|
||
success_steps += 1
|
||
else:
|
||
logger.warning(f"坐标不在边界框内,步骤失败。目标边界框: {target_box}")
|
||
break
|
||
|
||
# 总结测试结果
|
||
is_success = (success_steps == total_steps)
|
||
logger.info(f"路径测试完成: 总步骤 {total_steps}, 成功步骤 {success_steps}, 路径是否完全成功: {is_success}")
|
||
|
||
return {
|
||
"url": url,
|
||
"title": title,
|
||
"total_steps": total_steps,
|
||
"success_steps": success_steps,
|
||
"is_success": is_success
|
||
}
|
||
|
||
def main():
|
||
"""
|
||
主函数,读取JSON数据并测试随机选择的10个路径
|
||
"""
|
||
# 加载JSON数据
|
||
json_path = "path/processed_3_with_analysis.json"
|
||
data = load_json_data(json_path)
|
||
if not data:
|
||
logger.error("数据加载失败,程序退出")
|
||
return
|
||
|
||
# 随机选择10个URL
|
||
urls = list(data.keys())
|
||
if len(urls) > 10:
|
||
selected_urls = random.sample(urls, 10)
|
||
else:
|
||
selected_urls = urls
|
||
|
||
# 保存所有测试结果
|
||
results = []
|
||
|
||
# 测试选中的URL
|
||
for url in selected_urls:
|
||
path_data = data[url]
|
||
result = test_path(url, path_data)
|
||
results.append(result)
|
||
|
||
# 输出总体统计
|
||
total_paths = len(results)
|
||
successful_paths = sum(1 for r in results if r["is_success"])
|
||
|
||
logger.info(f"测试完成: 总路径数 {total_paths}, 成功路径数 {successful_paths}")
|
||
|
||
# 输出详细结果
|
||
for result in results:
|
||
logger.info(f"URL: {result['url']}")
|
||
logger.info(f" 任务: {result['title']}")
|
||
logger.info(f" 总步骤: {result['total_steps']}, 成功步骤: {result['success_steps']}")
|
||
logger.info(f" 路径是否成功: {result['is_success']}")
|
||
|
||
# 统计成功步数的分布
|
||
success_steps_counts = Counter(result['success_steps'] for result in results)
|
||
|
||
# 绘制柱状图
|
||
plt.figure(figsize=(10, 6))
|
||
steps = sorted(success_steps_counts.keys())
|
||
counts = [success_steps_counts[step] for step in steps]
|
||
|
||
plt.bar(steps, counts)
|
||
plt.title('成功步数分布')
|
||
plt.xlabel('成功步数')
|
||
plt.ylabel('URL数量')
|
||
|
||
# 在每个柱子上方显示具体数值
|
||
for i, count in enumerate(counts):
|
||
plt.text(steps[i], count, str(count), ha='center', va='bottom')
|
||
|
||
# 保存图片
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
save_path = f'success_steps_distribution_{timestamp}.png'
|
||
plt.savefig(save_path)
|
||
plt.close()
|
||
|
||
logger.info(f"成功步数分布图已保存为 {save_path}")
|
||
|
||
if __name__ == "__main__":
|
||
main() |