swift_test/sample/dataset_gen.py

263 lines
10 KiB
Python

import json
import os
import re
import subprocess
import logging
import shutil
# --- Configuration ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# File and Directory Paths
CWD = os.path.dirname(os.path.abspath(__file__))
FULL_TRACE_PATH = os.path.join(CWD, 'full_trace.jsonl')
OUTPUT_DIR = os.path.join(CWD, 'output_data')
FINAL_OUTPUT_PATH = os.path.join(OUTPUT_DIR, 'trace_rl.jsonl')
PERSISTENT_STATE_FILE = os.path.join(OUTPUT_DIR, 'persistent_state.json')
TMP_DIR = os.path.join(CWD, 'tmp_dataset_gen')
# --- Swift CLI Configuration ---
# Make sure to set the OPENAI_API_KEY in your environment variables
MODEL = "qwen3-8b"
SWIFT_COMMAND_TEMPLATE = (
'swift sample '
'--sampler_type distill '
'--sampler_engine client '
f'--model {MODEL} '
'--stream true '
'--orm_model external_web_acc '
'--dataset "{input_file}" '
'--num_return_sequences 1 '
'--temperature 0.8 '
'--top_p 0.95 '
'--external_plugins plugin.py '
'--system system_prompt.txt '
'--output_dir "{output_dir}" '
'--output_file "{output_file}" '
'--engine_kwargs \'{{"base_url":"http://192.168.16.116:18088/v1"}}\''
# '--engine_kwargs \'{{"base_url":"https://dashscope.aliyuncs.com/compatible-mode/v1"}}\''
)
# --- Constants ---
MAX_ITERATIONS = 100
EXIT_PATTERN = re.compile(r'\bexit\(.*\)', re.IGNORECASE)
def read_jsonl(file_path):
"""Reads a .jsonl file and yields each line as a parsed JSON object."""
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
yield json.loads(line)
def write_jsonl(file_path, data, mode='w'):
"""Writes a list of JSON objects to a .jsonl file."""
with open(file_path, mode, encoding='utf-8') as f:
for item in data:
f.write(json.dumps(item) + '\n')
def get_task_id(trace):
"""Extracts the task ID from the first user message of a trace."""
try:
match = re.search(r'\[#(\d+)\]', trace['messages'][0]['content'])
if match:
return match.group(1)
except (KeyError, IndexError, TypeError):
pass
logging.warning(f"Could not find task ID in trace: {trace}")
return None
def is_exit_action(trace):
"""Checks if the last assistant message is an exit action."""
try:
last_message = trace['messages'][-1]
if last_message['role'] == 'assistant':
return bool(EXIT_PATTERN.search(last_message['content']))
except (KeyError, IndexError, TypeError):
pass
return False
def save_state(b_set, original_traces_map):
"""Saves the current state (B set and original traces) to a file."""
state = {
'b_set': b_set,
'original_traces_map': original_traces_map
}
with open(PERSISTENT_STATE_FILE, 'w', encoding='utf-8') as f:
json.dump(state, f, indent=4)
logging.info(f"Successfully saved state to {PERSISTENT_STATE_FILE}")
def load_state():
"""Loads the state from a file if it exists."""
if os.path.exists(PERSISTENT_STATE_FILE):
logging.info(f"Found persistent state file. Resuming from {PERSISTENT_STATE_FILE}")
with open(PERSISTENT_STATE_FILE, 'r', encoding='utf-8') as f:
state = json.load(f)
return state['b_set'], state['original_traces_map']
return None, None
def main():
"""Main function to generate the dataset."""
if not os.path.exists(FULL_TRACE_PATH):
logging.error(f"Source file not found: {FULL_TRACE_PATH}")
return
# Create output and temporary directories
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(TMP_DIR, exist_ok=True)
b_set_to_process, original_traces_map = load_state()
if b_set_to_process is None:
logging.info("No persistent state found. Starting from scratch.")
# Initial run: Load full traces and create initial sets A and B
original_traces_map = {}
initial_a_set = []
logging.info(f"Reading full traces from {FULL_TRACE_PATH}")
for trace in read_jsonl(FULL_TRACE_PATH):
task_id = get_task_id(trace)
if task_id:
original_traces_map[task_id] = trace['messages']
if len(trace['messages']) >= 3:
initial_sub_trace = {"messages": trace['messages'][:3]}
initial_a_set.append(initial_sub_trace)
# Write all of set A to the final RL training data file
logging.info(f"Adding {len(initial_a_set)} initial traces to {FINAL_OUTPUT_PATH}")
write_jsonl(FINAL_OUTPUT_PATH, initial_a_set, mode='a')
# Create initial B set from non-exit traces in A
b_set_to_process = [trace for trace in initial_a_set if not is_exit_action(trace)]
logging.info(f"Created initial B set with {len(b_set_to_process)} non-exit traces.")
# Initial save of state
save_state(b_set_to_process, original_traces_map)
for i in range(1, MAX_ITERATIONS + 1):
if not b_set_to_process:
logging.info("B set is empty. Process finished successfully.")
break
logging.info(f"\n--- Iteration {i}/{MAX_ITERATIONS} ---")
logging.info(f"Processing {len(b_set_to_process)} traces in B set.")
# Prepare input for swift CLI
iteration_dir = os.path.join(TMP_DIR, f'iter_{i}')
os.makedirs(iteration_dir, exist_ok=True)
iteration_input_path = os.path.join(iteration_dir, 'sample_input.jsonl')
swift_input_data = []
for trace in b_set_to_process:
# The swift tool is expected to take the full trace,
# use messages[:-1] as prompt, and messages[-1] as the ground truth for distillation.
swift_input_data.append({
"messages": trace['messages']
})
write_jsonl(iteration_input_path, swift_input_data)
# Run swift CLI
# Per swift's requirement, provide a filename prefix without a directory.
swift_output_filename = 'sample_output.jsonl'
iteration_output_path = os.path.join(iteration_dir, swift_output_filename)
command = SWIFT_COMMAND_TEMPLATE.format(
input_file=iteration_input_path,
output_dir=iteration_dir,
output_file=swift_output_filename
)
logging.info(f"Executing Swift CLI command...")
try:
# Ensure OPENAI_API_KEY is available to the subprocess
env = os.environ.copy()
if "OPENAI_API_KEY" not in env:
logging.warning("OPENAI_API_KEY not found in environment. The script might fail.")
subprocess.run(command, shell=True, check=True, capture_output=True, text=True, cwd=CWD)
logging.info(f"Swift CLI finished. Output at {iteration_output_path}")
except subprocess.CalledProcessError as e:
logging.error(f"Swift CLI failed with exit code {e.returncode}")
logging.error(f"STDOUT: {e.stdout}")
logging.error(f"STDERR: {e.stderr}")
# Decide if we should stop or continue. For now, stop.
break
if not os.path.exists(iteration_output_path):
logging.warning(f"Swift output file not found at {iteration_output_path}. Assuming no traces were successfully sampled. Ending loop.")
break
# Process swift output and prepare for next iteration
b_next_iteration = []
extended_traces_for_rl = []
# The output from swift contains the generated assistant message
generated_traces = list(read_jsonl(iteration_output_path))
logging.info(f"Swift generated {len(generated_traces)} successful samples.")
successful_task_ids = {get_task_id(t) for t in generated_traces}
successful_task_ids.discard(None)
for generated_trace in generated_traces:
task_id = get_task_id(generated_trace)
if not task_id or task_id not in original_traces_map:
continue
generated_messages = generated_trace['messages']
original_full_trace_messages = original_traces_map[task_id]
# Extend the generated trace with the next user-assistant pair from the original trace
current_len = len(generated_messages)
if current_len + 2 <= len(original_full_trace_messages):
next_user_msg = original_full_trace_messages[current_len]
next_assistant_gt_msg = original_full_trace_messages[current_len + 1]
extended_messages = generated_messages + [next_user_msg, next_assistant_gt_msg]
extended_trace = {"messages": extended_messages}
extended_traces_for_rl.append(extended_trace)
if not is_exit_action(extended_trace):
b_next_iteration.append(extended_trace)
if extended_traces_for_rl:
logging.info(f"Adding {len(extended_traces_for_rl)} extended traces to {FINAL_OUTPUT_PATH}")
write_jsonl(FINAL_OUTPUT_PATH, extended_traces_for_rl, mode='a')
# Add back rejected tasks to be retried in the next iteration
rejected_count = 0
for original_b_trace in b_set_to_process:
task_id = get_task_id(original_b_trace)
if task_id and task_id not in successful_task_ids:
b_next_iteration.append(original_b_trace)
rejected_count += 1
if rejected_count > 0:
logging.info(f"Added {rejected_count} rejected traces to the queue for retry.")
b_set_to_process = b_next_iteration
save_state(b_set_to_process, original_traces_map)
else: # This else belongs to the for loop, executes if loop finishes without break
if b_set_to_process:
logging.warning(f"Reached max iterations ({MAX_ITERATIONS}) but B set is not empty.")
logging.warning(f"{len(b_set_to_process)} traces remain.")
# Cleanup
if os.path.exists(PERSISTENT_STATE_FILE):
os.remove(PERSISTENT_STATE_FILE)
logging.info(f"Process finished. Removed persistent state file.")
# Optional: You may want to clean up the tmp directory
# shutil.rmtree(TMP_DIR)
logging.info(f"Temporary files are stored in {TMP_DIR}")
if __name__ == '__main__':
main()