webrlvr/random_sample/make_config_files.py
2025-06-11 17:30:06 +08:00

124 lines
3.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
读取generated_tasks.json文件格式如下输入将每个问题一一转换按照输出格式的的json其中question作为intent, answer对应到must_include的数组sql对应填充到reference_answer_raw_annotationtask_id从0开始编号其他字段保留原始内容
输出文件保存为test_rlvr.raw.json
输入
[
{
"question": "Does the attribute ID 120 use page builder?",
"sql": "SELECT is_pagebuilder_enabled FROM catalog_eav_attribute WHERE attribute_id = 120;",
"answer": [
"No"
],
"sql_execute_result": [
[
0
]
]
},
...
]
输出
[
{
"sites": [
"shopping_admin"
],
"task_id": 0,
"require_login": true,
"storage_state": "./.auth/shopping_admin_state.json",
"start_url": "__SHOPPING_ADMIN__",
"geolocation": null,
"intent_template": "",
"instantiation_dict": {},
"intent": "What are the top-3 best-selling product in Jan 2023",
"require_reset": false,
"eval": {
"eval_types": [
"string_match"
],
"reference_answers": {
"must_include": [
"Impulse Duffle",
"Overnight Duffle",
"Hawkeye Yoga Short-32-Blue"
]
},
"reference_url": "",
"program_html": [],
"string_note": "",
"reference_answer_raw_annotation": ""
},
"intent_template_id": 0,
"old_task_id": 0
},
...
]
"""
import json
import os
def generate_config():
"""
Reads tasks from 'generated_tasks.json', converts them to the WebArena
format, and saves them to 'test_rlvr.raw.json'.
"""
script_dir = os.path.dirname(os.path.abspath(__file__))
input_path = os.path.join(script_dir, 'generated_tasks.json')
output_path = os.path.join(script_dir, 'test_rlvr.raw.json')
try:
with open(input_path, 'r', encoding='utf-8') as f:
input_data = json.load(f)
except FileNotFoundError:
print(f"Error: Input file '{input_path}' not found.")
return
except json.JSONDecodeError:
print(f"Error: Failed to decode JSON from '{input_path}'.")
return
output_data = []
for i, task in enumerate(input_data):
if not all(k in task for k in ['question', 'answer', 'sql']):
print(f"Warning: Skipping task at index {i} due to missing keys.")
continue
new_task = {
"sites": ["shopping_admin"],
"task_id": i,
"require_login": True,
"storage_state": "./.auth/shopping_admin_state.json",
"start_url": "__SHOPPING_ADMIN__",
"geolocation": None,
"intent_template": "",
"instantiation_dict": {},
"intent": task['question'],
"require_reset": False,
"eval": {
"eval_types": ["string_match"],
"reference_answers": {
"must_include": task['answer']
},
"reference_url": "",
"program_html": [],
"string_note": "",
"reference_answer_raw_annotation": task['sql']
},
"intent_template_id": 0,
"old_task_id": i
}
output_data.append(new_task)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(output_data, f, indent=4)
print(f"Successfully created '{output_path}' with {len(output_data)} tasks.")
if __name__ == '__main__':
generate_config()