61 lines
2.1 KiB
Python
61 lines
2.1 KiB
Python
import json
|
|
import random
|
|
from pathlib import Path
|
|
|
|
# 读取 results.json 文件
|
|
with open('temp_analysis/results.json', 'r', encoding='utf-8') as f:
|
|
results = json.load(f)
|
|
|
|
# 筛选 is_correct 为 true 的条目
|
|
correct_results = [item for item in results if item.get('is_correct') is True]
|
|
print(f"找到 {len(correct_results)} 个正确的条目")
|
|
|
|
# 读取 exam.json 文件
|
|
with open('temp_analysis/exam.json', 'r', encoding='utf-8') as f:
|
|
exam_data = json.load(f)
|
|
|
|
# 创建 id 到 exam 条目的映射
|
|
exam_map = {item.get('id'): item for item in exam_data}
|
|
|
|
# 为每个正确的结果添加 answer_text 字段
|
|
for result in correct_results:
|
|
result_id = result.get('id')
|
|
if result_id in exam_map:
|
|
result['answer_text'] = exam_map[result_id].get('answer_text', '')
|
|
else:
|
|
print(f"警告: ID {result_id} 在 exam.json 中未找到")
|
|
|
|
|
|
with open('temp_analysis/process_3.json', 'r', encoding='utf-8') as f:
|
|
trajectory_data = json.load(f)
|
|
|
|
trajectory_map = {item.get('title'): item for key, item in trajectory_data.items()}
|
|
for result in correct_results:
|
|
task = result.get('task')
|
|
if task in trajectory_map:
|
|
step = result.get('trajectory_step_num') - 1
|
|
result['page_child_num'] = trajectory_map[task]['shortestPathsMeta'][0]['chainChildNum'][step]
|
|
else:
|
|
print(f"警告: {task} 在 trajectory_map 中未找到")
|
|
|
|
# 随机打乱数据
|
|
random.seed(42) # 设置随机种子以确保结果可重现
|
|
random.shuffle(correct_results)
|
|
|
|
# 按照 7:3 比例划分训练集和测试集
|
|
split_index = int(len(correct_results) * 0.7)
|
|
train_data = correct_results[:split_index]
|
|
test_data = correct_results[split_index:]
|
|
|
|
print(f"训练集大小: {len(train_data)}")
|
|
print(f"测试集大小: {len(test_data)}")
|
|
|
|
# 保存训练集和测试集
|
|
with open('temp_analysis/train.json', 'w', encoding='utf-8') as f:
|
|
json.dump(train_data, f, ensure_ascii=False, indent=2)
|
|
|
|
with open('temp_analysis/test.json', 'w', encoding='utf-8') as f:
|
|
json.dump(test_data, f, ensure_ascii=False, indent=2)
|
|
|
|
print("处理完成! 数据已保存到 temp_analysis/train.json 和 temp_analysis/test.json")
|