458 lines
15 KiB
Python
458 lines
15 KiB
Python
import argparse
|
|
import datetime
|
|
import json
|
|
import os
|
|
import re
|
|
import time
|
|
from collections import OrderedDict
|
|
|
|
import yaml
|
|
|
|
from .configs import ConfigLoader
|
|
from .utils import ColorMessage
|
|
|
|
MODEL_MAP = {
|
|
"gpt-4o-2024-05-13": "gpt-4o-2024-05-13",
|
|
"gpt-4": "gpt-4",
|
|
"gpt-3.5-turbo-0613": "gpt-3.5-turbo",
|
|
"llama-2-13b": "llama2-13b",
|
|
"llama-2-7b": "llama2-7b",
|
|
"chatglm-6b": "chatglm-6b",
|
|
"wizard-30b": "wizardlm-30b",
|
|
"vicuna-33b": "vicuna-33b",
|
|
"oasst-12b": "oasst-12b",
|
|
"guanaco-65b": "guanaco-65b",
|
|
"koala-13b": "koala-13b",
|
|
"text-davinci-003": "text-davinci-003",
|
|
"wizard-13b": "wizardlm-13b",
|
|
"guanaco-33b": "guanaco-33b",
|
|
"text-davinci-002": "text-davinci-002",
|
|
"llama2-70b": "llama2-70b",
|
|
"codellama": "codellama-34b",
|
|
"openchat": "openchat-13b",
|
|
"claude-ins": "claude-instant",
|
|
"claude-v1.3": "claude",
|
|
"claude-2": "claude-2",
|
|
"codellama-13b": "codellama-13b",
|
|
"codellama-7b": "codellama-7b",
|
|
"codegeex2-6b": "codegeex2-6b",
|
|
"dolly": "dolly-12b",
|
|
"vicuna-7b": "vicuna-7b",
|
|
"vicuna-13b": "vicuna-13b",
|
|
"chat-bison": "chat-bison-001",
|
|
}
|
|
|
|
VALIDATION_MAP_FUNC = {
|
|
"Completed": lambda x: x["COMPLETED"],
|
|
"Context Limit Exceeded": lambda x: x["AGENT_CONTEXT_LIMIT"],
|
|
"Invalid Format": lambda x: x["AGENT_VALIDATION_FAILED"],
|
|
"Invalid Action": lambda x: x["AGENT_INVALID_ACTION"],
|
|
# Not in above list
|
|
"Task Limit Exceeded": lambda x: sum(
|
|
[x[t] for t in x if t in ["UNKNOWN", "TASK_ERROR", "TASK_LIMIT_REACHED"]]
|
|
),
|
|
}
|
|
|
|
|
|
def analyze_output(config: str, output: str, since_timestamp: float):
|
|
"""
|
|
Walk through the output folder (including sub-dir) and analyze the overall.json file
|
|
Rule:
|
|
- valid overall file: **/{agent}/{task}/overall.json
|
|
- if a same (agent, task) pair, select the latest one
|
|
"""
|
|
loader = ConfigLoader()
|
|
config: dict = loader.load_from(config)
|
|
assert "definition" in config, "definition not found in config"
|
|
assert "agent" in config["definition"], "agent not found in config.definition"
|
|
assert "task" in config["definition"], "task not found in config.definition"
|
|
agents = set(config["definition"]["agent"].keys()).intersection(
|
|
set(MODEL_MAP.keys())
|
|
)
|
|
tasks = list(config["definition"]["task"].keys())
|
|
|
|
print(
|
|
ColorMessage.cyan(
|
|
f"Available Agents ({len(agents)}):\n "
|
|
+ "\n ".join(agents)
|
|
+ "\n\n"
|
|
+ f"Available Tasks ({len(tasks)}):\n "
|
|
+ "\n ".join(tasks)
|
|
+ "\n"
|
|
)
|
|
)
|
|
|
|
overall_dict = OrderedDict() # agent -> task -> {file: str, time: float}
|
|
for root, dirs, files in os.walk(output):
|
|
if "overall.json" in files:
|
|
# get full path of root
|
|
root = os.path.abspath(root)
|
|
# get agent and task name
|
|
pattern = root.split("/")
|
|
if len(pattern) < 2:
|
|
continue
|
|
agent = pattern[-2]
|
|
task = pattern[-1]
|
|
ct = os.path.getmtime(os.path.join(root, "overall.json"))
|
|
if agent not in agents:
|
|
continue
|
|
elif task not in tasks:
|
|
continue
|
|
elif ct < since_timestamp:
|
|
continue
|
|
agent = MODEL_MAP[agent]
|
|
if agent in overall_dict and task in overall_dict[agent]:
|
|
# get time
|
|
if ct < overall_dict[agent][task]["time"]:
|
|
continue
|
|
overall_dict.setdefault(agent, OrderedDict())
|
|
overall_dict[agent][task] = {
|
|
"file": os.path.join(root, "overall.json"),
|
|
"time": os.path.getmtime(os.path.join(root, "overall.json")),
|
|
}
|
|
|
|
# agent -> task -> {file: str, time: str(YYYY-MM-DD HH:MM:SS), overall: dict}
|
|
|
|
agent_names = []
|
|
task_names = []
|
|
validation_names = []
|
|
|
|
for agent in overall_dict:
|
|
if agent not in agent_names:
|
|
agent_names.append(agent)
|
|
for task in overall_dict[agent]:
|
|
if task not in task_names:
|
|
task_names.append(task)
|
|
overall_dict[agent][task]["time"] = datetime.datetime.fromtimestamp(
|
|
overall_dict[agent][task]["time"]
|
|
).strftime("%Y-%m-%d %H:%M:%S")
|
|
with open(overall_dict[agent][task]["file"], "r", encoding="utf-8") as f:
|
|
overall_dict[agent][task]["overall"] = json.load(f)
|
|
if "validation" in overall_dict[agent][task]["overall"]:
|
|
overall_dict[agent][task]["overall"]["validation"] = {
|
|
validation: VALIDATION_MAP_FUNC[validation](
|
|
overall_dict[agent][task]["overall"]["validation"]
|
|
)
|
|
for validation in VALIDATION_MAP_FUNC
|
|
}
|
|
for validation in overall_dict[agent][task]["overall"]["validation"]:
|
|
if validation not in validation_names:
|
|
validation_names.append(validation)
|
|
|
|
return agent_names, task_names, validation_names, overall_dict
|
|
|
|
|
|
class TaskHandler:
|
|
def match(self, task_name) -> bool:
|
|
raise NotImplementedError()
|
|
|
|
def get_main_metric(self, overall_result):
|
|
raise NotImplementedError()
|
|
|
|
def get_order_priority(self):
|
|
return 100000
|
|
|
|
@staticmethod
|
|
def get_handler(task_name) -> "TaskHandler":
|
|
handlers = [DCG(), HH(), OS(), DB(), KG(), LTP(), WB(), WS()]
|
|
for handler in handlers:
|
|
if handler.match(task_name):
|
|
return handler
|
|
raise ValueError(f"Unknown task: {task_name}")
|
|
|
|
|
|
class DCG(TaskHandler):
|
|
def match(self, task_name) -> bool:
|
|
task_name = task_name.lower()
|
|
return (
|
|
"card" in task_name
|
|
or task_name.startswith("cg")
|
|
or task_name.startswith("dcg")
|
|
)
|
|
|
|
def get_main_metric(self, overall_result):
|
|
try:
|
|
return overall_result["custom"]["score"]
|
|
except:
|
|
return {"win_rate(legacy)": overall_result["custom"]["win_rate"]}
|
|
|
|
def get_order_priority(self):
|
|
return 4
|
|
|
|
|
|
class HH(TaskHandler):
|
|
def match(self, task_name) -> bool:
|
|
task_name = task_name.lower()
|
|
return task_name.startswith("alf")
|
|
|
|
def get_main_metric(self, overall_result):
|
|
return overall_result["custom"]["overall"]["success_rate"]
|
|
|
|
def get_order_priority(self):
|
|
return 6
|
|
|
|
|
|
class OS(TaskHandler):
|
|
def match(self, task_name) -> bool:
|
|
task_name = task_name.lower()
|
|
return task_name.startswith("os") or task_name.startswith("operating")
|
|
|
|
def get_main_metric(self, overall_result):
|
|
return overall_result["custom"]["overall"]["acc"]
|
|
|
|
def get_order_priority(self):
|
|
return 1
|
|
|
|
|
|
class DB(TaskHandler):
|
|
def match(self, task_name) -> bool:
|
|
task_name = task_name.lower()
|
|
return task_name.startswith("db") or task_name.startswith("database")
|
|
|
|
def get_main_metric(self, overall_result):
|
|
return overall_result["custom"]["overall_cat_accuracy"]
|
|
|
|
def get_order_priority(self):
|
|
return 2
|
|
|
|
|
|
class KG(TaskHandler):
|
|
def match(self, task_name) -> bool:
|
|
task_name = task_name.lower()
|
|
return task_name.startswith("kg") or task_name.startswith("knowledge")
|
|
|
|
def get_main_metric(self, overall_result):
|
|
return overall_result["custom"]["main"]
|
|
|
|
def get_order_priority(self):
|
|
return 3
|
|
|
|
|
|
class LTP(TaskHandler):
|
|
def match(self, task_name) -> bool:
|
|
task_name = task_name.lower()
|
|
return task_name.startswith("ltp") or task_name.startswith("literal")
|
|
|
|
def get_main_metric(self, overall_result):
|
|
return overall_result["custom"]["main"]
|
|
|
|
def get_order_priority(self):
|
|
return 5
|
|
|
|
|
|
class WB(TaskHandler):
|
|
def match(self, task_name) -> bool:
|
|
task_name = task_name.lower()
|
|
return task_name.startswith("m2w") or task_name.startswith("mind2web")
|
|
|
|
def get_main_metric(self, overall_result):
|
|
return overall_result["custom"]["step_sr"] / 100
|
|
|
|
def get_order_priority(self):
|
|
return 8
|
|
|
|
|
|
class WS(TaskHandler):
|
|
def match(self, task_name) -> bool:
|
|
task_name = task_name.lower()
|
|
return task_name.startswith("ws") or task_name.startswith("webshop")
|
|
|
|
def get_main_metric(self, overall_result):
|
|
return overall_result["custom"]["reward"]
|
|
|
|
def get_order_priority(self):
|
|
return 7
|
|
|
|
|
|
def parse_timestamp(time_str: str) -> float:
|
|
# is a int or float
|
|
try:
|
|
return float(time_str)
|
|
except:
|
|
pass
|
|
# is a datetime
|
|
try:
|
|
return datetime.datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S").timestamp()
|
|
except:
|
|
pass
|
|
try:
|
|
return datetime.datetime.strptime(time_str, "%Y-%m-%d").timestamp()
|
|
except:
|
|
pass
|
|
try:
|
|
return datetime.datetime.strptime(time_str, "%Y-%m").timestamp()
|
|
except:
|
|
pass
|
|
# is a time delta (e.g. 1d, 1h, 1m, 1s)
|
|
num = float(re.findall(r"[\d\.]+", time_str)[0])
|
|
unit = re.findall(r"[a-zA-Z]+", time_str)[0]
|
|
if unit == "d":
|
|
delta = num * 24 * 60 * 60
|
|
elif unit == "h":
|
|
delta = num * 60 * 60
|
|
elif unit == "m":
|
|
delta = num * 60
|
|
elif unit == "s":
|
|
delta = num
|
|
else:
|
|
raise Exception("Unknown time unit")
|
|
return time.time() - delta
|
|
|
|
|
|
def main(args):
|
|
agent_names, task_names, validation_names, details = analyze_output(
|
|
args.config, args.output, parse_timestamp(args.time)
|
|
)
|
|
task_names.sort(key=lambda x: TaskHandler.get_handler(x).get_order_priority())
|
|
summary = OrderedDict()
|
|
for agent in details:
|
|
summary[agent] = OrderedDict()
|
|
for task in details[agent]:
|
|
handler = TaskHandler.get_handler(task)
|
|
if handler is not None:
|
|
summary[agent][task] = handler.get_main_metric(
|
|
details[agent][task]["overall"]
|
|
)
|
|
else:
|
|
summary[agent][task] = details[agent][task]["overall"]
|
|
|
|
for agent in details:
|
|
for task in details[agent]:
|
|
print(
|
|
ColorMessage.cyan(
|
|
f"Agent: {agent:20} Task: {task:20} Path: {details[agent][task]['file']}"
|
|
)
|
|
)
|
|
|
|
final_result = {
|
|
"summary": summary,
|
|
"details": details,
|
|
}
|
|
|
|
os.makedirs(args.save, exist_ok=True)
|
|
|
|
# Overall Calculation
|
|
|
|
with open(os.path.join(args.save, "result.json"), "w", encoding="utf-8") as f:
|
|
json.dump(final_result, f, indent=4, ensure_ascii=False, sort_keys=True)
|
|
with open(os.path.join(args.save, "result.yaml"), "w", encoding="utf-8") as f:
|
|
yaml.dump(final_result, f, indent=4, allow_unicode=True, sort_keys=True)
|
|
with open(os.path.join(args.save, "summary.csv"), "w", encoding="utf-8") as f:
|
|
"""
|
|
Format:
|
|
Agent\\Task, Task1, Task2, ...
|
|
Agent1, MainMetric(Agent1,Task1), MainMetric(Agent1,Task2), ...
|
|
......
|
|
"""
|
|
f.write("Agent\\Task," + ",".join(task_names) + "\n")
|
|
for agent in summary:
|
|
f.write(
|
|
agent
|
|
+ ","
|
|
+ ",".join(
|
|
[
|
|
(str(summary[agent][task]) if task in summary[agent] else "")
|
|
for task in task_names
|
|
]
|
|
)
|
|
+ "\n"
|
|
)
|
|
|
|
# Validation Analysis
|
|
agent_validations = {
|
|
agent: {validation: [] for validation in validation_names}
|
|
for agent in agent_names
|
|
}
|
|
task_validations = {
|
|
task: {validation: [] for validation in validation_names} for task in task_names
|
|
}
|
|
|
|
for agent in summary:
|
|
for task in summary[agent]:
|
|
if "validation" in details[agent][task]["overall"]:
|
|
for validation in details[agent][task]["overall"]["validation"]:
|
|
agent_validations[agent][validation].append(
|
|
details[agent][task]["overall"]["validation"][validation]
|
|
)
|
|
task_validations[task][validation].append(
|
|
details[agent][task]["overall"]["validation"][validation]
|
|
)
|
|
|
|
# Agent-Centric Validation Analysis
|
|
with open(
|
|
os.path.join(args.save, "agent_validation.csv"), "w", encoding="utf-8"
|
|
) as f:
|
|
"""
|
|
Format:
|
|
Agent\\Validation, Validation1, Validation2, ...
|
|
Agent1, Avg(Agent1,Validation1), Avg(Agent1,Validation2), ...
|
|
......
|
|
"""
|
|
f.write("Agent\\Validation," + ",".join(validation_names) + "\n")
|
|
for agent in agent_validations:
|
|
f.write(
|
|
agent
|
|
+ ","
|
|
+ ",".join(
|
|
[
|
|
(
|
|
str(
|
|
sum(agent_validations[agent][validation])
|
|
/ len(agent_validations[agent][validation])
|
|
)
|
|
if validation in agent_validations[agent]
|
|
and len(agent_validations[agent][validation]) > 0
|
|
else "--"
|
|
)
|
|
for validation in validation_names
|
|
]
|
|
)
|
|
+ "\n"
|
|
)
|
|
|
|
# Task-Centric Validation Analysis
|
|
with open(
|
|
os.path.join(args.save, "task_validation.csv"), "w", encoding="utf-8"
|
|
) as f:
|
|
"""
|
|
Format:
|
|
Task\\Validation, Validation1, Validation2, ...
|
|
Task1, Avg(Task1,Validation1), Avg(Task1,Validation2), ...
|
|
......
|
|
"""
|
|
f.write("Task\\Validation," + ",".join(validation_names) + "\n")
|
|
for task in task_validations:
|
|
f.write(
|
|
task
|
|
+ ","
|
|
+ ",".join(
|
|
[
|
|
(
|
|
str(
|
|
sum(task_validations[task][validation])
|
|
/ len(task_validations[task][validation])
|
|
)
|
|
if validation in task_validations[task]
|
|
and len(task_validations[task][validation]) > 0
|
|
else "--"
|
|
)
|
|
for validation in validation_names
|
|
]
|
|
)
|
|
+ "\n"
|
|
)
|
|
|
|
print(ColorMessage.green(f"Analysis result saved to {os.path.abspath(args.save)}"))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
arg_parser = argparse.ArgumentParser()
|
|
arg_parser.add_argument(
|
|
"-c", "--config", type=str, default="configs/assignments/definition.yaml"
|
|
)
|
|
arg_parser.add_argument("-o", "--output", type=str, default="outputs")
|
|
arg_parser.add_argument("-s", "--save", type=str, default="analysis")
|
|
arg_parser.add_argument("-t", "--time", type=str, default="0")
|
|
args = arg_parser.parse_args()
|
|
main(args)
|