124 lines
3.6 KiB
Python
124 lines
3.6 KiB
Python
"""
|
||
读取generated_tasks.json文件格式如下输入,将每个问题一一转换按照输出格式的的json,其中question作为intent, answer对应到must_include的数组,sql对应填充到reference_answer_raw_annotation,task_id从0开始编号,其他字段保留原始内容,
|
||
输出文件保存为test_rlvr.raw.json
|
||
|
||
输入
|
||
[
|
||
{
|
||
"question": "Does the attribute ID 120 use page builder?",
|
||
"sql": "SELECT is_pagebuilder_enabled FROM catalog_eav_attribute WHERE attribute_id = 120;",
|
||
"answer": [
|
||
"No"
|
||
],
|
||
"sql_execute_result": [
|
||
[
|
||
0
|
||
]
|
||
]
|
||
},
|
||
...
|
||
]
|
||
|
||
|
||
输出
|
||
[
|
||
{
|
||
"sites": [
|
||
"shopping_admin"
|
||
],
|
||
"task_id": 0,
|
||
"require_login": true,
|
||
"storage_state": "./.auth/shopping_admin_state.json",
|
||
"start_url": "__SHOPPING_ADMIN__",
|
||
"geolocation": null,
|
||
"intent_template": "",
|
||
"instantiation_dict": {},
|
||
"intent": "What are the top-3 best-selling product in Jan 2023",
|
||
"require_reset": false,
|
||
"eval": {
|
||
"eval_types": [
|
||
"string_match"
|
||
],
|
||
"reference_answers": {
|
||
"must_include": [
|
||
"Impulse Duffle",
|
||
"Overnight Duffle",
|
||
"Hawkeye Yoga Short-32-Blue"
|
||
]
|
||
},
|
||
"reference_url": "",
|
||
"program_html": [],
|
||
"string_note": "",
|
||
"reference_answer_raw_annotation": ""
|
||
},
|
||
"intent_template_id": 0,
|
||
"old_task_id": 0
|
||
},
|
||
...
|
||
]
|
||
|
||
"""
|
||
|
||
import json
|
||
import os
|
||
|
||
def generate_config():
|
||
"""
|
||
Reads tasks from 'generated_tasks.json', converts them to the WebArena
|
||
format, and saves them to 'test_rlvr.raw.json'.
|
||
"""
|
||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||
input_path = os.path.join(script_dir, 'generated_tasks.json')
|
||
output_path = os.path.join(script_dir, 'test_rlvr.raw.json')
|
||
|
||
try:
|
||
with open(input_path, 'r', encoding='utf-8') as f:
|
||
input_data = json.load(f)
|
||
except FileNotFoundError:
|
||
print(f"Error: Input file '{input_path}' not found.")
|
||
return
|
||
except json.JSONDecodeError:
|
||
print(f"Error: Failed to decode JSON from '{input_path}'.")
|
||
return
|
||
|
||
output_data = []
|
||
for i, task in enumerate(input_data):
|
||
if not all(k in task for k in ['question', 'answer', 'sql']):
|
||
print(f"Warning: Skipping task at index {i} due to missing keys.")
|
||
continue
|
||
|
||
new_task = {
|
||
"sites": ["shopping_admin"],
|
||
"task_id": i,
|
||
"require_login": True,
|
||
"storage_state": "./.auth/shopping_admin_state.json",
|
||
"start_url": "__SHOPPING_ADMIN__",
|
||
"geolocation": None,
|
||
"intent_template": "",
|
||
"instantiation_dict": {},
|
||
"intent": task['question'],
|
||
"require_reset": False,
|
||
"eval": {
|
||
"eval_types": ["string_match"],
|
||
"reference_answers": {
|
||
"must_include": task['answer']
|
||
},
|
||
"reference_url": "",
|
||
"program_html": [],
|
||
"string_note": "",
|
||
"reference_answer_raw_annotation": task['sql']
|
||
},
|
||
"intent_template_id": 0,
|
||
"old_task_id": i
|
||
}
|
||
output_data.append(new_task)
|
||
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
json.dump(output_data, f, indent=4)
|
||
|
||
print(f"Successfully created '{output_path}' with {len(output_data)} tasks.")
|
||
|
||
if __name__ == '__main__':
|
||
generate_config()
|
||
|