import json import os import re import subprocess import logging import shutil # --- Configuration --- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # File and Directory Paths CWD = os.path.dirname(os.path.abspath(__file__)) FULL_TRACE_PATH = os.path.join(CWD, 'full_trace.jsonl') OUTPUT_DIR = os.path.join(CWD, 'output_data') FINAL_OUTPUT_PATH = os.path.join(OUTPUT_DIR, 'trace_rl.jsonl') PERSISTENT_STATE_FILE = os.path.join(OUTPUT_DIR, 'persistent_state.json') TMP_DIR = os.path.join(CWD, 'tmp_dataset_gen') # --- Swift CLI Configuration --- # Make sure to set the OPENAI_API_KEY in your environment variables MODEL = "qwen3-8b" SWIFT_COMMAND_TEMPLATE = ( 'swift sample ' '--sampler_type distill ' '--sampler_engine client ' f'--model {MODEL} ' '--stream true ' '--orm_model external_web_acc ' '--dataset "{input_file}" ' '--num_return_sequences 1 ' '--temperature 0.8 ' '--top_p 0.95 ' '--external_plugins plugin.py ' '--system system_prompt.txt ' '--output_dir "{output_dir}" ' '--output_file "{output_file}" ' '--engine_kwargs \'{{"base_url":"http://192.168.16.116:18088/v1"}}\'' # '--engine_kwargs \'{{"base_url":"https://dashscope.aliyuncs.com/compatible-mode/v1"}}\'' ) # --- Constants --- MAX_ITERATIONS = 100 EXIT_PATTERN = re.compile(r'\bexit\(.*\)', re.IGNORECASE) def read_jsonl(file_path): """Reads a .jsonl file and yields each line as a parsed JSON object.""" with open(file_path, 'r', encoding='utf-8') as f: for line in f: if line.strip(): yield json.loads(line) def write_jsonl(file_path, data, mode='w'): """Writes a list of JSON objects to a .jsonl file.""" with open(file_path, mode, encoding='utf-8') as f: for item in data: f.write(json.dumps(item) + '\n') def get_task_id(trace): """Extracts the task ID from the first user message of a trace.""" try: match = re.search(r'\[#(\d+)\]', trace['messages'][0]['content']) if match: return match.group(1) except (KeyError, IndexError, TypeError): pass logging.warning(f"Could not find task ID in trace: {trace}") return None def is_exit_action(trace): """Checks if the last assistant message is an exit action.""" try: last_message = trace['messages'][-1] if last_message['role'] == 'assistant': return bool(EXIT_PATTERN.search(last_message['content'])) except (KeyError, IndexError, TypeError): pass return False def save_state(b_set, original_traces_map): """Saves the current state (B set and original traces) to a file.""" state = { 'b_set': b_set, 'original_traces_map': original_traces_map } with open(PERSISTENT_STATE_FILE, 'w', encoding='utf-8') as f: json.dump(state, f, indent=4) logging.info(f"Successfully saved state to {PERSISTENT_STATE_FILE}") def load_state(): """Loads the state from a file if it exists.""" if os.path.exists(PERSISTENT_STATE_FILE): logging.info(f"Found persistent state file. Resuming from {PERSISTENT_STATE_FILE}") with open(PERSISTENT_STATE_FILE, 'r', encoding='utf-8') as f: state = json.load(f) return state['b_set'], state['original_traces_map'] return None, None def main(): """Main function to generate the dataset.""" if not os.path.exists(FULL_TRACE_PATH): logging.error(f"Source file not found: {FULL_TRACE_PATH}") return # Create output and temporary directories os.makedirs(OUTPUT_DIR, exist_ok=True) os.makedirs(TMP_DIR, exist_ok=True) b_set_to_process, original_traces_map = load_state() if b_set_to_process is None: logging.info("No persistent state found. Starting from scratch.") # Initial run: Load full traces and create initial sets A and B original_traces_map = {} initial_a_set = [] logging.info(f"Reading full traces from {FULL_TRACE_PATH}") for trace in read_jsonl(FULL_TRACE_PATH): task_id = get_task_id(trace) if task_id: original_traces_map[task_id] = trace['messages'] if len(trace['messages']) >= 3: initial_sub_trace = {"messages": trace['messages'][:3]} initial_a_set.append(initial_sub_trace) # Write all of set A to the final RL training data file logging.info(f"Adding {len(initial_a_set)} initial traces to {FINAL_OUTPUT_PATH}") write_jsonl(FINAL_OUTPUT_PATH, initial_a_set, mode='a') # Create initial B set from non-exit traces in A b_set_to_process = [trace for trace in initial_a_set if not is_exit_action(trace)] logging.info(f"Created initial B set with {len(b_set_to_process)} non-exit traces.") # Initial save of state save_state(b_set_to_process, original_traces_map) for i in range(1, MAX_ITERATIONS + 1): if not b_set_to_process: logging.info("B set is empty. Process finished successfully.") break logging.info(f"\n--- Iteration {i}/{MAX_ITERATIONS} ---") logging.info(f"Processing {len(b_set_to_process)} traces in B set.") # Prepare input for swift CLI iteration_dir = os.path.join(TMP_DIR, f'iter_{i}') os.makedirs(iteration_dir, exist_ok=True) iteration_input_path = os.path.join(iteration_dir, 'sample_input.jsonl') swift_input_data = [] for trace in b_set_to_process: # The swift tool is expected to take the full trace, # use messages[:-1] as prompt, and messages[-1] as the ground truth for distillation. swift_input_data.append({ "messages": trace['messages'] }) write_jsonl(iteration_input_path, swift_input_data) # Run swift CLI # Per swift's requirement, provide a filename prefix without a directory. swift_output_filename = 'sample_output.jsonl' iteration_output_path = os.path.join(iteration_dir, swift_output_filename) command = SWIFT_COMMAND_TEMPLATE.format( input_file=iteration_input_path, output_dir=iteration_dir, output_file=swift_output_filename ) logging.info(f"Executing Swift CLI command...") try: # Ensure OPENAI_API_KEY is available to the subprocess env = os.environ.copy() if "OPENAI_API_KEY" not in env: logging.warning("OPENAI_API_KEY not found in environment. The script might fail.") subprocess.run(command, shell=True, check=True, capture_output=True, text=True, cwd=CWD) logging.info(f"Swift CLI finished. Output at {iteration_output_path}") except subprocess.CalledProcessError as e: logging.error(f"Swift CLI failed with exit code {e.returncode}") logging.error(f"STDOUT: {e.stdout}") logging.error(f"STDERR: {e.stderr}") # Decide if we should stop or continue. For now, stop. break if not os.path.exists(iteration_output_path): logging.warning(f"Swift output file not found at {iteration_output_path}. Assuming no traces were successfully sampled. Ending loop.") break # Process swift output and prepare for next iteration b_next_iteration = [] extended_traces_for_rl = [] # The output from swift contains the generated assistant message generated_traces = list(read_jsonl(iteration_output_path)) logging.info(f"Swift generated {len(generated_traces)} successful samples.") successful_task_ids = {get_task_id(t) for t in generated_traces} successful_task_ids.discard(None) for generated_trace in generated_traces: task_id = get_task_id(generated_trace) if not task_id or task_id not in original_traces_map: continue generated_messages = generated_trace['messages'] original_full_trace_messages = original_traces_map[task_id] # Extend the generated trace with the next user-assistant pair from the original trace current_len = len(generated_messages) if current_len + 2 <= len(original_full_trace_messages): next_user_msg = original_full_trace_messages[current_len] next_assistant_gt_msg = original_full_trace_messages[current_len + 1] extended_messages = generated_messages + [next_user_msg, next_assistant_gt_msg] extended_trace = {"messages": extended_messages} extended_traces_for_rl.append(extended_trace) if not is_exit_action(extended_trace): b_next_iteration.append(extended_trace) if extended_traces_for_rl: logging.info(f"Adding {len(extended_traces_for_rl)} extended traces to {FINAL_OUTPUT_PATH}") write_jsonl(FINAL_OUTPUT_PATH, extended_traces_for_rl, mode='a') # Add back rejected tasks to be retried in the next iteration rejected_count = 0 for original_b_trace in b_set_to_process: task_id = get_task_id(original_b_trace) if task_id and task_id not in successful_task_ids: b_next_iteration.append(original_b_trace) rejected_count += 1 if rejected_count > 0: logging.info(f"Added {rejected_count} rejected traces to the queue for retry.") b_set_to_process = b_next_iteration save_state(b_set_to_process, original_traces_map) else: # This else belongs to the for loop, executes if loop finishes without break if b_set_to_process: logging.warning(f"Reached max iterations ({MAX_ITERATIONS}) but B set is not empty.") logging.warning(f"{len(b_set_to_process)} traces remain.") # Cleanup if os.path.exists(PERSISTENT_STATE_FILE): os.remove(PERSISTENT_STATE_FILE) logging.info(f"Process finished. Removed persistent state file.") # Optional: You may want to clean up the tmp directory # shutil.rmtree(TMP_DIR) logging.info(f"Temporary files are stored in {TMP_DIR}") if __name__ == '__main__': main()