150 lines
4.1 KiB
Python
150 lines
4.1 KiB
Python
"""Calculate the breakdown of success rate by different brekadown"""
|
|
|
|
import json
|
|
import os
|
|
import argparse
|
|
|
|
|
|
def config() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--log_file", type=str, help="Path to the log file")
|
|
parser.add_argument(
|
|
"--config_file",
|
|
type=str,
|
|
help="Path to the config file",
|
|
default="config_files/wa/test_webarena.json",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def calc_sr(id_to_success: dict[int, bool]) -> dict[str, float]:
|
|
sr = {
|
|
"total": len(id_to_success),
|
|
"success": sum(id_to_success.values()),
|
|
"fail": len(id_to_success) - sum(id_to_success.values()),
|
|
}
|
|
# keep 2 decimal places
|
|
sr["success_rate"] = round(sr["success"] / sr["total"], 4) * 100
|
|
|
|
return sr
|
|
|
|
|
|
def parse_result(log_file: str) -> dict[int, bool]:
|
|
id_to_success: dict[int, bool] = {}
|
|
|
|
# parse
|
|
with open(log_file, "r") as f:
|
|
for line in f:
|
|
if "[Config file]" in line:
|
|
id = os.path.basename(line.split()[-1]).split(".")[0]
|
|
elif "[Result]" in line:
|
|
if "(FAIL)" in line:
|
|
success = False
|
|
elif "(PASS)" in line:
|
|
success = True
|
|
else:
|
|
raise ValueError(f"Unknown result: {line}")
|
|
id_to_success[int(id)] = success
|
|
|
|
return id_to_success
|
|
|
|
|
|
def main(args: argparse.Namespace):
|
|
id_to_success = parse_result(args.log_file)
|
|
|
|
overall_sr = calc_sr(id_to_success)
|
|
|
|
with open(args.config_file, "r") as f:
|
|
config = json.load(f)
|
|
|
|
# get achievable and unachievable tasks
|
|
non_achievable = set()
|
|
for task in config:
|
|
task_id = int(task["task_id"])
|
|
if (
|
|
task["eval"]["eval_types"] == ["string_match"]
|
|
and task["eval"]["reference_answers"].get("fuzzy_match", "") == "N/A"
|
|
):
|
|
non_achievable.add(task_id)
|
|
|
|
achievable_sr = calc_sr(
|
|
{k: v for k, v in id_to_success.items() if k not in non_achievable}
|
|
)
|
|
unachievable_sr = calc_sr(
|
|
{k: v for k, v in id_to_success.items() if k in non_achievable}
|
|
)
|
|
|
|
# get sr per website
|
|
website_to_ids = {
|
|
website: set()
|
|
for website in [
|
|
"shopping",
|
|
"shopping_admin",
|
|
"gitlab",
|
|
"reddit",
|
|
"map",
|
|
"wikipedia",
|
|
]
|
|
}
|
|
for task in config:
|
|
task_id = int(task["task_id"])
|
|
cur_sites = task["sites"]
|
|
for site in cur_sites:
|
|
website_to_ids[site].add(task_id)
|
|
|
|
website_sr = {
|
|
website: calc_sr(
|
|
{k: v for k, v in id_to_success.items() if k in website_to_ids[website]}
|
|
)
|
|
for website in website_to_ids
|
|
}
|
|
|
|
# by task type
|
|
task_type_to_ids = {
|
|
"info_seeking": set(),
|
|
"site_nav": set(),
|
|
"content_config": set(),
|
|
}
|
|
for task in config:
|
|
task_id = int(task["task_id"])
|
|
if task["eval"]["eval_types"] == ["string_match"]:
|
|
task_type_to_ids["info_seeking"].add(task_id)
|
|
elif task["eval"]["eval_types"] == ["url_match"]:
|
|
task_type_to_ids["site_nav"].add(task_id)
|
|
else:
|
|
task_type_to_ids["content_config"].add(task_id)
|
|
|
|
task_type_sr = {
|
|
task_type: calc_sr(
|
|
{k: v for k, v in id_to_success.items() if k in task_type_to_ids[task_type]}
|
|
)
|
|
for task_type in task_type_to_ids
|
|
}
|
|
|
|
print("=====================================")
|
|
print("Overall SR:")
|
|
print(overall_sr)
|
|
print("=====================================")
|
|
print("Achievable SR:")
|
|
print(achievable_sr)
|
|
print("=====================================")
|
|
print("Unachievable SR:")
|
|
print(unachievable_sr)
|
|
print("=====================================")
|
|
print("Website SR:")
|
|
for website, sr in website_sr.items():
|
|
print(website)
|
|
print(sr)
|
|
print("=====================================")
|
|
print("Task type SR:")
|
|
for task_type, sr in task_type_sr.items():
|
|
print(task_type)
|
|
print(sr)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = config()
|
|
main(args)
|