webrl/VAB-WebArena-Lite/scripts/calc_breakdown_sr.py
2025-04-23 17:01:18 +08:00

150 lines
4.1 KiB
Python

"""Calculate the breakdown of success rate by different brekadown"""
import json
import os
import argparse
def config() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--log_file", type=str, help="Path to the log file")
parser.add_argument(
"--config_file",
type=str,
help="Path to the config file",
default="config_files/wa/test_webarena.json",
)
args = parser.parse_args()
return args
def calc_sr(id_to_success: dict[int, bool]) -> dict[str, float]:
sr = {
"total": len(id_to_success),
"success": sum(id_to_success.values()),
"fail": len(id_to_success) - sum(id_to_success.values()),
}
# keep 2 decimal places
sr["success_rate"] = round(sr["success"] / sr["total"], 4) * 100
return sr
def parse_result(log_file: str) -> dict[int, bool]:
id_to_success: dict[int, bool] = {}
# parse
with open(log_file, "r") as f:
for line in f:
if "[Config file]" in line:
id = os.path.basename(line.split()[-1]).split(".")[0]
elif "[Result]" in line:
if "(FAIL)" in line:
success = False
elif "(PASS)" in line:
success = True
else:
raise ValueError(f"Unknown result: {line}")
id_to_success[int(id)] = success
return id_to_success
def main(args: argparse.Namespace):
id_to_success = parse_result(args.log_file)
overall_sr = calc_sr(id_to_success)
with open(args.config_file, "r") as f:
config = json.load(f)
# get achievable and unachievable tasks
non_achievable = set()
for task in config:
task_id = int(task["task_id"])
if (
task["eval"]["eval_types"] == ["string_match"]
and task["eval"]["reference_answers"].get("fuzzy_match", "") == "N/A"
):
non_achievable.add(task_id)
achievable_sr = calc_sr(
{k: v for k, v in id_to_success.items() if k not in non_achievable}
)
unachievable_sr = calc_sr(
{k: v for k, v in id_to_success.items() if k in non_achievable}
)
# get sr per website
website_to_ids = {
website: set()
for website in [
"shopping",
"shopping_admin",
"gitlab",
"reddit",
"map",
"wikipedia",
]
}
for task in config:
task_id = int(task["task_id"])
cur_sites = task["sites"]
for site in cur_sites:
website_to_ids[site].add(task_id)
website_sr = {
website: calc_sr(
{k: v for k, v in id_to_success.items() if k in website_to_ids[website]}
)
for website in website_to_ids
}
# by task type
task_type_to_ids = {
"info_seeking": set(),
"site_nav": set(),
"content_config": set(),
}
for task in config:
task_id = int(task["task_id"])
if task["eval"]["eval_types"] == ["string_match"]:
task_type_to_ids["info_seeking"].add(task_id)
elif task["eval"]["eval_types"] == ["url_match"]:
task_type_to_ids["site_nav"].add(task_id)
else:
task_type_to_ids["content_config"].add(task_id)
task_type_sr = {
task_type: calc_sr(
{k: v for k, v in id_to_success.items() if k in task_type_to_ids[task_type]}
)
for task_type in task_type_to_ids
}
print("=====================================")
print("Overall SR:")
print(overall_sr)
print("=====================================")
print("Achievable SR:")
print(achievable_sr)
print("=====================================")
print("Unachievable SR:")
print(unachievable_sr)
print("=====================================")
print("Website SR:")
for website, sr in website_sr.items():
print(website)
print(sr)
print("=====================================")
print("Task type SR:")
for task_type, sr in task_type_sr.items():
print(task_type)
print(sr)
if __name__ == "__main__":
args = config()
main(args)