419 lines
15 KiB
Python
419 lines
15 KiB
Python
import json
|
||
import os
|
||
import re
|
||
import logging
|
||
import concurrent.futures
|
||
import argparse
|
||
from openai import OpenAI
|
||
from dotenv import load_dotenv
|
||
import datetime
|
||
import base64
|
||
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
||
# from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
|
||
from qwen_vl_utils import process_vision_info
|
||
|
||
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3"
|
||
|
||
# 设置日志
|
||
# 创建logger
|
||
log = logging.getLogger(__name__)
|
||
log.setLevel(logging.INFO)
|
||
|
||
# 创建文件处理器
|
||
file_handler = logging.FileHandler('temp_analysis/exam_run_vision.log')
|
||
file_handler.setLevel(logging.INFO)
|
||
|
||
# 创建控制台处理器
|
||
console_handler = logging.StreamHandler()
|
||
console_handler.setLevel(logging.INFO)
|
||
|
||
# 创建格式器并添加到处理器
|
||
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||
file_handler.setFormatter(formatter)
|
||
console_handler.setFormatter(formatter)
|
||
|
||
# 将处理器添加到logger
|
||
log.addHandler(file_handler)
|
||
log.addHandler(console_handler)
|
||
|
||
# 加载环境变量
|
||
load_dotenv()
|
||
|
||
PARALLEL_WORKERS=1
|
||
|
||
MODEL_NAME=os.getenv("MODEL_NAME")
|
||
print(f"MODEL_NAME: {MODEL_NAME}")
|
||
|
||
# 配置OpenAI客户端
|
||
# client = OpenAI()
|
||
client = OpenAI(
|
||
api_key=os.getenv("OPENAI_API_KEY"),
|
||
base_url=os.getenv("OPENAI_API_BASE_URL") # 如果使用其他兼容服务,可以设置基础URL
|
||
)
|
||
|
||
# modified version of the task proposer agent prompt from https://arxiv.org/pdf/2502.11357
|
||
|
||
# System prompt
|
||
cot_system_prompt = """
|
||
What does this webpage show? Imagine you are a real user on this webpage. Given the webpage
|
||
screenshot or ocr result and parsed HTML/accessibility tree and the task description, please provide
|
||
the first action towards completing that task.
|
||
|
||
Do the following step by step:
|
||
1. Given the webpage screenshot or ocr result and parsed HTML/accessibility tree, generate the first action
|
||
towards completing that task (in natural language form).
|
||
2. Given the webpage screenshot or ocr result, parsed HTML/accessibility tree, and the natural language
|
||
action, generate the grounded version of that action.
|
||
|
||
ACTION SPACE: Your action space is: ['click [element ID]', 'type [element ID] [content]',
|
||
'select [element ID] [content of option to select]', 'scroll [up]', 'scroll [down]', and 'stop'].
|
||
Action output should follow the syntax as given below:
|
||
click [element text]: This action clicks on an element with a specific text on the webpage.
|
||
type [element ID] [content]: Use this to type the content into the field with id. By default, the
|
||
"Enter" key is pressed after typing. Both the content and the ID should be within square braces
|
||
as per the syntax.
|
||
select [element ID] [content of option to select]: Select an option from a dropdown menu. The
|
||
content of the option to select should be within square braces. When you get (select an option)
|
||
tags from the accessibility tree, you need to select the serial number (element_id) corresponding
|
||
to the select tag, not the option, and select the most likely content corresponding to the option as
|
||
input.
|
||
scroll [down]: Scroll the page down.
|
||
scroll [up]: Scroll the page up.
|
||
|
||
IMPORTANT:
|
||
To be successful, it is important to STRICTLY follow the below rules:
|
||
Action generation rules:
|
||
1. You should generate a single atomic action at each step.
|
||
2. The action should be an atomic action from the given vocabulary - click, type, select, scroll
|
||
(up or down), or stop.
|
||
3. The arguments to each action should be within square braces. For example, "click [127]",
|
||
"type [43] [content to type]", "scroll [up]", "scroll [down]".
|
||
4. The natural language form of action (corresponding to the field "action_in_natural_language")
|
||
should be consistent with the grounded version of the action (corresponding to the field "grounded
|
||
_action"). Do NOT add any additional information in the grounded action. For example, if a
|
||
particular element ID is specified in the grounded action, a description of that element must be
|
||
present in the natural language action.
|
||
5. If the type action is selected, the natural language form of action ("action_in_natural_language") should always specify the actual text to be typed.
|
||
6. You should issue a "stop" action if the current webpage asks to log in or for credit card
|
||
information.
|
||
7. To input text, there is NO need to click the textbox first, directly type content. After typing,
|
||
the system automatically hits the 'ENTER' key.
|
||
8. STRICTLY Avoid repeating the same action (click/type) if the webpage remains unchanged.
|
||
You may have selected the wrong web element.
|
||
9. Do NOT use quotation marks in the action generation.
|
||
|
||
OUTPUT FORMAT:
|
||
Please give a short analysis of the screenshot, parsed
|
||
HTML/accessibility tree, then put your answer within ``` ```, for example,
|
||
"In summary, the proposed task and the corresponding action is: ```{
|
||
"action_in_natural_language": "<ACTION_IN_NATURAL_LANGUAGE>:str",
|
||
"grounded_action": "<ACTION>:str"}```
|
||
"""
|
||
|
||
# User prompt
|
||
cot_user_prompt = """
|
||
[
|
||
{{
|
||
"type": "text",
|
||
"text": "Website URL: {INIT_URL}\nParsed HTML/Accessibility Tree: {A11Y_TREE}\nTask description: {TASK_DESCRIPTION}"
|
||
}},
|
||
{{
|
||
"type": "image_url",
|
||
"image_url": {{
|
||
"url": f"file://{SCREENSHOT}"
|
||
}}
|
||
}}
|
||
]
|
||
"""
|
||
|
||
def load_qwen_model():
|
||
"""
|
||
加载Qwen2VL模型和处理器
|
||
"""
|
||
global qwen_model, processor
|
||
|
||
# 设置本地模型路径
|
||
model_path = "/data1/yuyr/Qwen2.5-VL-7B-Instruct"
|
||
# model_path = "/home/yuyr/.cache/huggingface/hub/models--Qwen--Qwen2.5-VL-3B-Instruct/snapshots/1b989f2c63999d7344135894d3cfa8f494116743/"
|
||
# model_path = "/data1/yuyr/models--bytedance-research--UI-TARS-7B-DPO/snapshots/727b0df39207dafc6cf211a61f29d84b7659c39c/"
|
||
|
||
try:
|
||
log.info("正在加载Qwen2VL模型和处理器...")
|
||
|
||
# 加载模型
|
||
qwen_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||
# qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||
model_path, torch_dtype="auto", device_map="cuda:1"
|
||
# model_path, torch_dtype="auto", device_map="auto"
|
||
)
|
||
log.info("Qwen2VL模型加载完成")
|
||
|
||
# 加载processor
|
||
processor = AutoProcessor.from_pretrained(model_path)
|
||
log.info("处理器加载完成")
|
||
|
||
log.info(f"Qwen2VL模型和处理器加载完成")
|
||
|
||
return True
|
||
except Exception as e:
|
||
log.error(f"Qwen2VL模型加载失败: {e}")
|
||
return False
|
||
|
||
|
||
def call_api(messages):
|
||
"""使用openai库调用API接口"""
|
||
try:
|
||
log.info(f"call llm messages ")
|
||
response = client.chat.completions.create(
|
||
messages=messages,
|
||
model=MODEL_NAME
|
||
)
|
||
log.info(f"call llm response: {response}")
|
||
return response
|
||
except Exception as e:
|
||
log.error(f"API调用出错: {e}")
|
||
return None
|
||
|
||
def call_qwen_model(messages):
|
||
"""
|
||
调用Qwen2VL模型
|
||
"""
|
||
# 准备推理输入
|
||
try:
|
||
text = processor.apply_chat_template(
|
||
messages, tokenize=False, add_generation_prompt=True
|
||
)
|
||
image_inputs, video_inputs = process_vision_info(messages)
|
||
inputs = processor(
|
||
text=[text],
|
||
images=image_inputs,
|
||
videos=video_inputs,
|
||
padding=True,
|
||
return_tensors="pt",
|
||
)
|
||
inputs = inputs.to(qwen_model.device)
|
||
|
||
# 推理:生成输出
|
||
log.info("正在生成模型回复...")
|
||
generated_ids = qwen_model.generate(**inputs, max_new_tokens=1024)
|
||
generated_ids_trimmed = [
|
||
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||
]
|
||
output_text = processor.batch_decode(
|
||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||
)[0]
|
||
|
||
log.info(f"模型输出: {output_text}")
|
||
return output_text
|
||
except Exception as e:
|
||
log.error(f"模型推理失败: {e}")
|
||
return None
|
||
|
||
|
||
def extract_action(response_text):
|
||
"""从API响应中提取action_in_natural_language和grounded_action"""
|
||
# 使用正则表达式提取JSON部分
|
||
match = re.search(r'```\s*{\s*(.+?)\s*}\s*```', response_text, re.DOTALL)
|
||
if match:
|
||
try:
|
||
# 构建完整的JSON字符串并解析
|
||
json_str = "{" + match.group(1) + "}"
|
||
action_data = json.loads(json_str)
|
||
return action_data.get("action_in_natural_language"), action_data.get("grounded_action")
|
||
except json.JSONDecodeError:
|
||
log.error(f"无法解析JSON: {match.group(1)}")
|
||
return None, None
|
||
|
||
def check_answer(grounded_action, answer):
|
||
try:
|
||
log.info(f"grounded_action: {grounded_action}, answer: {answer}")
|
||
"""检查grounded_action是否匹配answer"""
|
||
if not grounded_action or not grounded_action.startswith("click"):
|
||
return False
|
||
|
||
# 提取element ID
|
||
match = re.search(r'click \[(\d+)\]', grounded_action)
|
||
if not match:
|
||
return False
|
||
|
||
element_id = match.group(1)
|
||
# 将answer拆分为列表(逗号分隔)
|
||
answer_ids = [id.strip() for id in answer.split(",")]
|
||
|
||
# 检查element ID是否在answer列表中
|
||
return element_id in answer_ids
|
||
except Exception as e:
|
||
log.error(f"检查答案出错: {e}")
|
||
return False
|
||
|
||
def process_item(item):
|
||
"""处理单个测试项目"""
|
||
item_num = item["num"]
|
||
item_id = item["id"]
|
||
page_id = item["page_id"]
|
||
url = item["url"]
|
||
task_description = item["task"]
|
||
answer = item["expected_answer"]
|
||
answer_text = item["answer_text"]
|
||
trajectory_id = item["trajectory_id"]
|
||
trajectory_step_num = item["trajectory_step_num"]
|
||
page_child_num = item["page_child_num"]
|
||
|
||
log.info(f"处理ID: {item_id}, URL: {url}")
|
||
|
||
image_path = f"screenshots/{page_id}_{page_child_num}.png"
|
||
if not os.path.exists(image_path):
|
||
log.error(f"找不到文件: {image_path}")
|
||
return None
|
||
|
||
def encode_image_to_base64(image_path):
|
||
"""
|
||
将图片编码为base64字符串
|
||
"""
|
||
try:
|
||
with open(image_path, "rb") as image_file:
|
||
return base64.b64encode(image_file.read()).decode('utf-8')
|
||
except Exception as e:
|
||
log.error(f"图片编码失败: {e}")
|
||
return None
|
||
|
||
SCREENSHOT = image_path
|
||
|
||
# 构建API请求
|
||
messages = [
|
||
{"role": "system", "content": cot_system_prompt},
|
||
{"role": "user", "content": cot_user_prompt.format(
|
||
INIT_URL=url,
|
||
A11Y_TREE="",
|
||
SCREENSHOT=SCREENSHOT, # 这里没有提供截图OCR结果
|
||
TASK_DESCRIPTION=task_description
|
||
)}
|
||
]
|
||
|
||
log.info(f"task_description: {task_description}")
|
||
log.info(f"answer: {answer}, answer_text: {answer_text}")
|
||
|
||
# 调用API
|
||
# response = call_api(messages)
|
||
response = call_qwen_model(messages)
|
||
|
||
print(f"response: {response}")
|
||
|
||
if response:
|
||
content = response
|
||
# if MODEL_NAME == "qwen/qwq-32b:free":
|
||
# reasoning_content = response.choices[0].message.reasoning
|
||
# else:
|
||
# reasoning_content = response.choices[0].message.reasoning_content
|
||
|
||
# log.info(f"reasoning_content: {reasoning_content}")
|
||
log.info(f"content: {content}")
|
||
|
||
# 提取action
|
||
action_nl, grounded_action = extract_action(content)
|
||
|
||
log.info(f"action_nl: {action_nl}, grounded_action: {grounded_action}")
|
||
|
||
# 检查答案,使用文字检查
|
||
is_correct = check_answer(grounded_action, answer_text)
|
||
log.info(f"is_correct: {is_correct}")
|
||
|
||
# 记录结果
|
||
result = {
|
||
"num": item_num,
|
||
"id": item_id,
|
||
"trajectory_id": trajectory_id,
|
||
"trajectory_step_num": trajectory_step_num,
|
||
"page_id": page_id,
|
||
"url": url,
|
||
"task": task_description,
|
||
"expected_answer": answer,
|
||
"thinking": reasoning_content,
|
||
"raw_content": content,
|
||
"action_nl": action_nl,
|
||
"grounded_action": grounded_action,
|
||
"is_correct": is_correct,
|
||
"model": MODEL_NAME,
|
||
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
}
|
||
|
||
log.info(f"ID: {item_id}")
|
||
log.info(f"任务: {task_description}")
|
||
log.info(f"动作: {grounded_action}")
|
||
log.info(f"是否正确: {is_correct}")
|
||
log.info("-" * 50)
|
||
|
||
return result
|
||
else:
|
||
log.error(f"API调用失败: {response}")
|
||
return None
|
||
|
||
def main():
|
||
load_qwen_model()
|
||
|
||
# 运行完整测试
|
||
# 读取exam.json
|
||
with open("temp_analysis/test.json", "r") as f:
|
||
exam_data = json.load(f)
|
||
|
||
results = []
|
||
|
||
total_items = len(exam_data) + len(results)
|
||
completed = len(results)
|
||
success_count = len(results)
|
||
fail_count = 0
|
||
|
||
log.info(f"开始测试,需要执行 {total_items} 个任务, 已经成功 {success_count} 个任务")
|
||
|
||
# 使用线程池并发处理
|
||
with concurrent.futures.ThreadPoolExecutor(max_workers=PARALLEL_WORKERS) as executor:
|
||
# 提交所有任务
|
||
future_to_item = {executor.submit(process_item, item): item for item in exam_data}
|
||
|
||
# 收集结果
|
||
for future in concurrent.futures.as_completed(future_to_item):
|
||
result = future.result()
|
||
|
||
# 使用文件锁防止并发写入冲突
|
||
completed += 1
|
||
# 重新读取最新结果,避免覆盖其他线程的写入
|
||
try:
|
||
with open("temp_analysis/results_vl.json", "r") as f:
|
||
results = json.load(f)
|
||
except FileNotFoundError:
|
||
results = []
|
||
|
||
if result:
|
||
results.append(result)
|
||
if result["is_correct"]:
|
||
success_count += 1
|
||
else:
|
||
fail_count += 1
|
||
else:
|
||
fail_count += 1
|
||
|
||
# 打印当前进度
|
||
progress = (completed / total_items) * 100
|
||
log.info(f"进度: {progress:.2f}% ({completed}/{total_items}) - 成功: {success_count}, 失败: {fail_count}")
|
||
|
||
log.info(f"save results to temp_analysis/results_vl.json")
|
||
# 保存结果
|
||
with open("temp_analysis/results_vl.json", "w") as f:
|
||
json.dump(results, f, indent=2)
|
||
|
||
|
||
# 计算正确率
|
||
accuracy = success_count / total_items if total_items > 0 else 0
|
||
|
||
log.info(f"测试完成! 总计: {total_items}题,正确: {success_count}题,错误: {fail_count}题,正确率: {accuracy:.2%}")
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
|
||
|
||
|
||
|
||
|
||
|