513 lines
17 KiB
Python
513 lines
17 KiB
Python
import json
|
||
import os
|
||
import re
|
||
import requests
|
||
import logging
|
||
import base64
|
||
from openai import OpenAI
|
||
from dotenv import load_dotenv
|
||
from PIL import Image
|
||
import random
|
||
import matplotlib.pyplot as plt
|
||
from collections import Counter
|
||
from datetime import datetime
|
||
import shutil
|
||
import torch
|
||
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
|
||
import io
|
||
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
|
||
from qwen_vl_utils import process_vision_info
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
test_dir = 'test'
|
||
# 创建test目录
|
||
os.makedirs(test_dir, exist_ok=True)
|
||
|
||
# 在全局范围内定义模型和tokenizer变量
|
||
tokenizer = None
|
||
model = None
|
||
|
||
# 在全局范围内定义模型和处理器变量
|
||
qwen_model = None
|
||
processor = None
|
||
|
||
# 添加模型加载函数
|
||
def load_json_data(file_path):
|
||
"""
|
||
从指定路径加载JSON数据
|
||
"""
|
||
try:
|
||
with open(file_path, 'r', encoding='utf-8') as file:
|
||
return json.load(file)
|
||
except Exception as e:
|
||
logger.error(f"加载JSON文件失败: {e}")
|
||
return None
|
||
|
||
def encode_image_to_base64(image_path):
|
||
"""
|
||
将图片编码为base64字符串
|
||
"""
|
||
try:
|
||
with open(image_path, "rb") as image_file:
|
||
return base64.b64encode(image_file.read()).decode('utf-8')
|
||
except Exception as e:
|
||
logger.error(f"图片编码失败: {e}")
|
||
return None
|
||
|
||
def load_qwen_model():
|
||
"""
|
||
加载Qwen2VL模型和处理器
|
||
"""
|
||
global qwen_model, processor
|
||
|
||
# 设置本地模型路径
|
||
model_path = "/data1/yuyr/models--bytedance-research--UI-TARS-7B-DPO/snapshots/727b0df39207dafc6cf211a61f29d84b7659c39c/"
|
||
|
||
try:
|
||
logger.info("正在加载Qwen2VL模型和处理器...")
|
||
start_time = datetime.now()
|
||
|
||
# 加载模型
|
||
qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||
model_path, torch_dtype="auto", device_map="cuda:1"
|
||
)
|
||
logger.info("Qwen2VL模型加载完成")
|
||
|
||
# 加载processor
|
||
processor = AutoProcessor.from_pretrained(model_path)
|
||
logger.info("处理器加载完成")
|
||
|
||
end_time = datetime.now()
|
||
load_time = (end_time - start_time).total_seconds()
|
||
logger.info(f"Qwen2VL模型和处理器加载完成,耗时: {load_time:.2f}秒")
|
||
|
||
return True
|
||
except Exception as e:
|
||
logger.error(f"Qwen2VL模型加载失败: {e}")
|
||
return False
|
||
|
||
def call_gpt4o_mini(image_path, title, messages=None, max_retries=3):
|
||
"""
|
||
使用本地HuggingFace模型代替OpenAI API,获取下一步应该点击的坐标
|
||
"""
|
||
global qwen_model, processor
|
||
|
||
# 使用Qwen2VL模型进行推理
|
||
try:
|
||
# 检查模型是否已加载
|
||
if qwen_model is None or processor is None:
|
||
logger.warning("模型或处理器未加载,尝试加载...")
|
||
if not load_qwen_model():
|
||
raise ValueError("模型加载失败")
|
||
|
||
# 导入必要的库
|
||
from qwen_vl_utils import process_vision_info
|
||
|
||
# 如果是新会话,初始化消息列表
|
||
if messages is None:
|
||
messages = [
|
||
{
|
||
"role": "system",
|
||
"content": """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||
|
||
## Output Format
|
||
```
|
||
Thought: ...
|
||
Action: ...
|
||
```
|
||
|
||
## Action Space
|
||
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
|
||
hotkey(key='')
|
||
type(content='') #If you want to submit your input, use \"\" at the end of `content`.
|
||
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
||
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||
finished()
|
||
call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
|
||
|
||
## Note
|
||
- Use Chinese in `Thought` part.
|
||
- Summarize your next action (with its target element) in one sentence in `Thought` part."""
|
||
}
|
||
]
|
||
else:
|
||
# 把messages中的role为user的content中的image删掉
|
||
for message in messages:
|
||
if message["role"] == "user":
|
||
message["content"] = [content for content in message["content"] if content["type"] != "image"]
|
||
|
||
# 构建当前步骤的用户消息
|
||
user_message = {
|
||
"role": "user",
|
||
"content": [
|
||
{
|
||
"type": "image",
|
||
"image_url": f"file://{image_path}"
|
||
},
|
||
{"type": "text", "text": title},
|
||
],
|
||
}
|
||
|
||
# 添加用户消息到会话历史
|
||
messages.append(user_message)
|
||
|
||
# 准备推理输入
|
||
text = processor.apply_chat_template(
|
||
messages, tokenize=False, add_generation_prompt=True
|
||
)
|
||
image_inputs, video_inputs = process_vision_info(messages)
|
||
inputs = processor(
|
||
text=[text],
|
||
images=image_inputs,
|
||
videos=video_inputs,
|
||
padding=True,
|
||
return_tensors="pt",
|
||
)
|
||
inputs = inputs.to(qwen_model.device)
|
||
|
||
# 推理:生成输出
|
||
logger.info("正在生成模型回复...")
|
||
generated_ids = qwen_model.generate(**inputs, max_new_tokens=1024)
|
||
generated_ids_trimmed = [
|
||
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||
]
|
||
output_text = processor.batch_decode(
|
||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||
)[0]
|
||
|
||
logger.info(f"模型输出: {output_text}")
|
||
|
||
# 添加助手的回复到会话历史
|
||
assistant_message = {
|
||
"role": "assistant",
|
||
"content": output_text
|
||
}
|
||
messages.append(assistant_message)
|
||
|
||
return output_text, messages
|
||
except Exception as e:
|
||
logger.error(f"模型推理失败: {e}")
|
||
if max_retries > 0:
|
||
logger.warning(f"重试中,剩余次数: {max_retries}")
|
||
return call_gpt4o_mini(image_path, title, messages, max_retries - 1)
|
||
else:
|
||
logger.error("重试次数已用完,推理失败")
|
||
return None, messages
|
||
|
||
def extract_coordinates(response_text, image_path):
|
||
"""
|
||
从API响应中提取坐标,并将1000x1000的坐标映射到实际图片尺寸
|
||
"""
|
||
logger.info(f"API响应: {response_text}")
|
||
try:
|
||
# 使用正则表达式提取坐标
|
||
pattern = r'\((\d+),(\d+)\)' # 匹配格式 (x,y)
|
||
match = re.search(pattern, response_text)
|
||
if match:
|
||
# 获取1000x1000下的坐标
|
||
x_1000 = int(match.group(1))
|
||
y_1000 = int(match.group(2))
|
||
|
||
# 获取实际图片尺寸
|
||
with Image.open(image_path) as img:
|
||
actual_width, actual_height = img.size
|
||
|
||
# 映射坐标到实际尺寸
|
||
x = round(actual_width * x_1000 / 1000)
|
||
y = round(actual_height * y_1000 / 1000)
|
||
|
||
logger.info(f"坐标映射: 从 ({x_1000}, {y_1000}) 映射到 ({x}, {y})")
|
||
return {
|
||
"raw_x": x_1000,
|
||
"raw_y": y_1000,
|
||
"adjust_x": x,
|
||
"adjust_y": y
|
||
}
|
||
return None
|
||
except Exception as e:
|
||
logger.error(f"坐标提取或映射失败: {e}")
|
||
return None
|
||
|
||
def is_point_in_box(point, box):
|
||
"""
|
||
判断点是否在边界框内
|
||
"""
|
||
return (box["x"] <= point["adjust_x"] <= box["x"] + box["width"] and
|
||
box["y"] <= point["adjust_y"] <= box["y"] + box["height"])
|
||
|
||
def test_path(url, path_data, use_title=True):
|
||
"""
|
||
测试单个路径
|
||
|
||
Args:
|
||
url: 路径URL
|
||
path_data: 路径数据
|
||
use_title: 是否使用title作为任务描述,False则使用description
|
||
"""
|
||
logger.info(f"开始测试路径: {url}, 使用{'标题' if use_title else '描述'}")
|
||
meta = path_data["shortestPathsMeta"][0]
|
||
title = meta["task_summaries"][0]["question"] if use_title else meta["raw_result"]
|
||
task_type = "title" if use_title else "raw_result"
|
||
chain_ids = meta["chainIDs"]
|
||
child_nums = meta["chainChildNum"]
|
||
viewport_boxes = meta["chainViewportBoundingBoxes"]
|
||
|
||
total_steps = len(chain_ids) - 1
|
||
success_steps = 0
|
||
steps_data = []
|
||
|
||
logger.info(f"任务内容: {title}")
|
||
logger.info(f"总步骤数: {total_steps}")
|
||
|
||
# 初始化会话历史
|
||
messages = None
|
||
|
||
# 遍历每个步骤
|
||
for i in range(total_steps):
|
||
current_id = chain_ids[i]
|
||
child_num = child_nums[i] if i < len(child_nums) else None
|
||
|
||
# 构建图片路径
|
||
if child_num is not None:
|
||
image_path = f"screenshots/{current_id}_{child_num}.png"
|
||
else:
|
||
image_path = f"screenshots/{current_id}.png"
|
||
|
||
logger.info(f"步骤 {i+1}/{total_steps}: 处理图片 {image_path}")
|
||
|
||
# 检查图片是否存在
|
||
if not os.path.exists(image_path):
|
||
logger.error(f"图片不存在: {image_path}")
|
||
break
|
||
|
||
# 调用模型获取点击坐标,传入会话历史
|
||
response, messages = call_gpt4o_mini(image_path, title, messages)
|
||
if not response:
|
||
logger.error("模型调用失败")
|
||
break
|
||
|
||
# 提取坐标
|
||
coordinates = extract_coordinates(response, image_path)
|
||
if not coordinates:
|
||
logger.error("无法从响应中提取坐标")
|
||
break
|
||
|
||
logger.info(f"模型返回坐标: {coordinates}")
|
||
|
||
# 判断坐标是否在边界框内
|
||
target_box = viewport_boxes[i] if i < len(viewport_boxes) else None
|
||
step_success = False
|
||
|
||
if target_box and is_point_in_box(coordinates, target_box):
|
||
logger.info("坐标在边界框内,步骤成功!")
|
||
success_steps += 1
|
||
step_success = True
|
||
else:
|
||
logger.warning(f"坐标不在边界框内,步骤失败。目标边界框: {target_box}")
|
||
|
||
# 记录步骤数据
|
||
step_data = {
|
||
"model_output": response,
|
||
"raw_x": coordinates["raw_x"],
|
||
"raw_y": coordinates["raw_y"],
|
||
"adjust_x": coordinates["adjust_x"],
|
||
"adjust_y": coordinates["adjust_y"],
|
||
"bounding_box": target_box,
|
||
"is_success": step_success
|
||
}
|
||
steps_data.append(step_data)
|
||
|
||
# 如果步骤失败,终止后续步骤
|
||
if not step_success:
|
||
break
|
||
|
||
# 总结测试结果
|
||
is_success = (success_steps == total_steps)
|
||
logger.info(f"路径测试完成: 总步骤 {total_steps}, 成功步骤 {success_steps}, 路径是否完全成功: {is_success}")
|
||
|
||
return {
|
||
"steps_data": steps_data,
|
||
"success_steps": success_steps,
|
||
"is_success": is_success
|
||
}
|
||
|
||
def main():
|
||
"""
|
||
主函数,从path数据中随机选择10条路径进行测试
|
||
"""
|
||
# 加载Qwen2VL模型
|
||
logger.info("开始加载Qwen2VL模型...")
|
||
if not load_qwen_model():
|
||
logger.error("Qwen2VL模型加载失败,程序退出")
|
||
return
|
||
|
||
# 加载原始路径数据
|
||
path_json = "path/processed_3_with_analysis.json"
|
||
data = load_json_data(path_json)
|
||
if not data:
|
||
logger.error("数据加载失败,程序退出")
|
||
return
|
||
|
||
# 随机选择10条路径
|
||
urls = list(data.keys())
|
||
# if len(urls) > 10:
|
||
# selected_urls = random.sample(urls, 10)
|
||
# else:
|
||
# selected_urls = urls
|
||
selected_urls = urls
|
||
|
||
# 准备结果数据
|
||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
|
||
result_data = {
|
||
"timestamp": timestamp,
|
||
"num": len(selected_urls),
|
||
"list": []
|
||
}
|
||
|
||
# 测试每条路径
|
||
for url in selected_urls:
|
||
path_data = data[url]
|
||
meta = path_data["shortestPathsMeta"][0]
|
||
title = meta["task_summaries"][0]["question"]
|
||
description = meta["raw_result"]
|
||
total_steps = len(meta["chainIDs"]) - 1
|
||
|
||
# 记录路径长度,用于文件名
|
||
path_length = total_steps
|
||
|
||
logger.info(f"测试路径: {url}")
|
||
logger.info(f"标题: {title}")
|
||
logger.info(f"描述: {description}")
|
||
|
||
# 使用标题测试
|
||
title_result = test_path(url, path_data, use_title=True)
|
||
|
||
# 使用描述测试
|
||
desc_result = test_path(url, path_data, use_title=False)
|
||
|
||
# 记录结果
|
||
path_result = {
|
||
"url": url,
|
||
"title": title,
|
||
"description": description,
|
||
"total_steps": total_steps,
|
||
"title_steps": title_result["steps_data"],
|
||
"desc_steps": desc_result["steps_data"],
|
||
"title_success_steps": title_result["success_steps"],
|
||
"title_is_success": title_result["is_success"],
|
||
"desc_success_step": desc_result["success_steps"],
|
||
"desc_is_success": desc_result["is_success"]
|
||
}
|
||
|
||
result_data["list"].append(path_result)
|
||
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
# 保存结果
|
||
result_file = f"{test_dir}/sample_path_{path_length + 1}_{timestamp}.json"
|
||
with open(result_file, 'w', encoding='utf-8') as f:
|
||
json.dump(result_data, f, indent=2, ensure_ascii=False)
|
||
|
||
logger.info(f"结果已保存到 {result_file}")
|
||
|
||
# 统计成功步数的分布并生成图表
|
||
generate_statistics(result_data)
|
||
|
||
def generate_statistics(result_data):
|
||
"""
|
||
生成统计图表
|
||
"""
|
||
# 统计标题测试的成功步数分布
|
||
title_success_steps = [item["title_success_steps"] for item in result_data["list"]]
|
||
title_success_counts = Counter(title_success_steps)
|
||
|
||
# 统计描述测试的成功步数分布
|
||
desc_success_steps = [item["desc_success_step"] for item in result_data["list"]]
|
||
desc_success_counts = Counter(desc_success_steps)
|
||
|
||
# 绘制标题测试的成功步数分布图
|
||
plt.figure(figsize=(10, 6))
|
||
title_steps = sorted(title_success_counts.keys())
|
||
title_counts = [title_success_counts[step] for step in title_steps]
|
||
|
||
plt.bar(title_steps, title_counts, color='blue', alpha=0.7)
|
||
plt.title('标题测试成功步数分布')
|
||
plt.xlabel('成功步数')
|
||
plt.ylabel('URL数量')
|
||
|
||
# 在每个柱子上方显示具体数值
|
||
for i, count in enumerate(title_counts):
|
||
plt.text(title_steps[i], count, str(count), ha='center', va='bottom')
|
||
|
||
# 保存图片
|
||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
title_chart_path = f'{test_dir}/title_success_steps_distribution_{timestamp}.png'
|
||
plt.savefig(title_chart_path)
|
||
plt.close()
|
||
|
||
# 绘制描述测试的成功步数分布图
|
||
plt.figure(figsize=(10, 6))
|
||
desc_steps = sorted(desc_success_counts.keys())
|
||
desc_counts = [desc_success_counts[step] for step in desc_steps]
|
||
|
||
plt.bar(desc_steps, desc_counts, color='green', alpha=0.7)
|
||
plt.title('描述测试成功步数分布')
|
||
plt.xlabel('成功步数')
|
||
plt.ylabel('URL数量')
|
||
|
||
# 在每个柱子上方显示具体数值
|
||
for i, count in enumerate(desc_counts):
|
||
plt.text(desc_steps[i], count, str(count), ha='center', va='bottom')
|
||
|
||
# 保存图片
|
||
desc_chart_path = f'{test_dir}/desc_success_steps_distribution_{timestamp}.png'
|
||
plt.savefig(desc_chart_path)
|
||
plt.close()
|
||
|
||
# 绘制对比图
|
||
plt.figure(figsize=(12, 7))
|
||
|
||
# 合并所有可能的步数
|
||
all_steps = sorted(set(title_steps + desc_steps))
|
||
|
||
# 获取每个步数对应的计数
|
||
title_all_counts = [title_success_counts.get(step, 0) for step in all_steps]
|
||
desc_all_counts = [desc_success_counts.get(step, 0) for step in all_steps]
|
||
|
||
# 设置柱状图的位置
|
||
x = range(len(all_steps))
|
||
width = 0.35
|
||
|
||
# 绘制柱状图
|
||
plt.bar([i - width/2 for i in x], title_all_counts, width, label='brief', color='blue', alpha=0.7)
|
||
plt.bar([i + width/2 for i in x], desc_all_counts, width, label='detail', color='green', alpha=0.7)
|
||
|
||
# 设置图表标题和标签
|
||
plt.title('brief vs detail success steps distribution')
|
||
plt.xlabel('success steps')
|
||
plt.ylabel('URL count')
|
||
plt.xticks(x, all_steps)
|
||
plt.legend()
|
||
|
||
# 在每个柱子上方显示具体数值
|
||
for i, count in enumerate(title_all_counts):
|
||
if count > 0:
|
||
plt.text(i - width/2, count, str(count), ha='center', va='bottom')
|
||
|
||
for i, count in enumerate(desc_all_counts):
|
||
if count > 0:
|
||
plt.text(i + width/2, count, str(count), ha='center', va='bottom')
|
||
|
||
# 保存图片
|
||
compare_chart_path = f'{test_dir}/title_vs_desc_distribution_{timestamp}.png'
|
||
plt.savefig(compare_chart_path)
|
||
plt.close()
|
||
|
||
logger.info(f"统计图表已保存: {title_chart_path}, {desc_chart_path}, {compare_chart_path}")
|
||
|
||
if __name__ == "__main__":
|
||
main() |