temp save; temp hold on for rejection sample CoT; try direct RL with raw stepwise data first
This commit is contained in:
		
							parent
							
								
									7f4fc8b05b
								
							
						
					
					
						commit
						5e632c53ac
					
				
							
								
								
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -1,5 +1,7 @@
 | 
			
		||||
output/
 | 
			
		||||
result/
 | 
			
		||||
sample/sample_output/
 | 
			
		||||
sample/tmp_dataset_gen/
 | 
			
		||||
sample/output_data/
 | 
			
		||||
 | 
			
		||||
*__pycache__/
 | 
			
		||||
							
								
								
									
										6
									
								
								sample/combine_output_file_3.jsonl
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								sample/combine_output_file_3.jsonl
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										5
									
								
								sample/combine_output_file_5.jsonl
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								sample/combine_output_file_5.jsonl
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										263
									
								
								sample/dataset_gen.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										263
									
								
								sample/dataset_gen.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,263 @@
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
import re
 | 
			
		||||
import subprocess
 | 
			
		||||
import logging
 | 
			
		||||
import shutil
 | 
			
		||||
 | 
			
		||||
# --- Configuration ---
 | 
			
		||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 | 
			
		||||
 | 
			
		||||
# File and Directory Paths
 | 
			
		||||
CWD = os.path.dirname(os.path.abspath(__file__))
 | 
			
		||||
FULL_TRACE_PATH = os.path.join(CWD, 'full_trace.jsonl')
 | 
			
		||||
OUTPUT_DIR = os.path.join(CWD, 'output_data')
 | 
			
		||||
FINAL_OUTPUT_PATH = os.path.join(OUTPUT_DIR, 'trace_rl.jsonl')
 | 
			
		||||
PERSISTENT_STATE_FILE = os.path.join(OUTPUT_DIR, 'persistent_state.json')
 | 
			
		||||
TMP_DIR = os.path.join(CWD, 'tmp_dataset_gen')
 | 
			
		||||
 | 
			
		||||
# --- Swift CLI Configuration ---
 | 
			
		||||
# Make sure to set the OPENAI_API_KEY in your environment variables
 | 
			
		||||
MODEL = "qwen3-8b"
 | 
			
		||||
SWIFT_COMMAND_TEMPLATE = (
 | 
			
		||||
    'swift sample '
 | 
			
		||||
    '--sampler_type distill '
 | 
			
		||||
    '--sampler_engine client '
 | 
			
		||||
    f'--model {MODEL} '
 | 
			
		||||
    '--stream true '
 | 
			
		||||
    '--orm_model external_web_acc '
 | 
			
		||||
    '--dataset "{input_file}" '
 | 
			
		||||
    '--num_return_sequences 1 '
 | 
			
		||||
    '--temperature 0.8 '
 | 
			
		||||
    '--top_p 0.95 '
 | 
			
		||||
    '--external_plugins plugin.py '
 | 
			
		||||
    '--system system_prompt.txt '
 | 
			
		||||
    '--output_dir "{output_dir}" '
 | 
			
		||||
    '--output_file "{output_file}" '
 | 
			
		||||
    '--engine_kwargs \'{{"base_url":"http://192.168.16.116:18088/v1"}}\''
 | 
			
		||||
    # '--engine_kwargs \'{{"base_url":"https://dashscope.aliyuncs.com/compatible-mode/v1"}}\''
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# --- Constants ---
 | 
			
		||||
MAX_ITERATIONS = 100
 | 
			
		||||
EXIT_PATTERN = re.compile(r'\bexit\(.*\)', re.IGNORECASE)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def read_jsonl(file_path):
 | 
			
		||||
    """Reads a .jsonl file and yields each line as a parsed JSON object."""
 | 
			
		||||
    with open(file_path, 'r', encoding='utf-8') as f:
 | 
			
		||||
        for line in f:
 | 
			
		||||
            if line.strip():
 | 
			
		||||
                yield json.loads(line)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def write_jsonl(file_path, data, mode='w'):
 | 
			
		||||
    """Writes a list of JSON objects to a .jsonl file."""
 | 
			
		||||
    with open(file_path, mode, encoding='utf-8') as f:
 | 
			
		||||
        for item in data:
 | 
			
		||||
            f.write(json.dumps(item) + '\n')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_task_id(trace):
 | 
			
		||||
    """Extracts the task ID from the first user message of a trace."""
 | 
			
		||||
    try:
 | 
			
		||||
        match = re.search(r'\[#(\d+)\]', trace['messages'][0]['content'])
 | 
			
		||||
        if match:
 | 
			
		||||
            return match.group(1)
 | 
			
		||||
    except (KeyError, IndexError, TypeError):
 | 
			
		||||
        pass
 | 
			
		||||
    logging.warning(f"Could not find task ID in trace: {trace}")
 | 
			
		||||
    return None
 | 
			
		||||
 | 
			
		||||
def is_exit_action(trace):
 | 
			
		||||
    """Checks if the last assistant message is an exit action."""
 | 
			
		||||
    try:
 | 
			
		||||
        last_message = trace['messages'][-1]
 | 
			
		||||
        if last_message['role'] == 'assistant':
 | 
			
		||||
            return bool(EXIT_PATTERN.search(last_message['content']))
 | 
			
		||||
    except (KeyError, IndexError, TypeError):
 | 
			
		||||
        pass
 | 
			
		||||
    return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def save_state(b_set, original_traces_map):
 | 
			
		||||
    """Saves the current state (B set and original traces) to a file."""
 | 
			
		||||
    state = {
 | 
			
		||||
        'b_set': b_set,
 | 
			
		||||
        'original_traces_map': original_traces_map
 | 
			
		||||
    }
 | 
			
		||||
    with open(PERSISTENT_STATE_FILE, 'w', encoding='utf-8') as f:
 | 
			
		||||
        json.dump(state, f, indent=4)
 | 
			
		||||
    logging.info(f"Successfully saved state to {PERSISTENT_STATE_FILE}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_state():
 | 
			
		||||
    """Loads the state from a file if it exists."""
 | 
			
		||||
    if os.path.exists(PERSISTENT_STATE_FILE):
 | 
			
		||||
        logging.info(f"Found persistent state file. Resuming from {PERSISTENT_STATE_FILE}")
 | 
			
		||||
        with open(PERSISTENT_STATE_FILE, 'r', encoding='utf-8') as f:
 | 
			
		||||
            state = json.load(f)
 | 
			
		||||
            return state['b_set'], state['original_traces_map']
 | 
			
		||||
    return None, None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    """Main function to generate the dataset."""
 | 
			
		||||
    if not os.path.exists(FULL_TRACE_PATH):
 | 
			
		||||
        logging.error(f"Source file not found: {FULL_TRACE_PATH}")
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    # Create output and temporary directories
 | 
			
		||||
    os.makedirs(OUTPUT_DIR, exist_ok=True)
 | 
			
		||||
    os.makedirs(TMP_DIR, exist_ok=True)
 | 
			
		||||
 | 
			
		||||
    b_set_to_process, original_traces_map = load_state()
 | 
			
		||||
 | 
			
		||||
    if b_set_to_process is None:
 | 
			
		||||
        logging.info("No persistent state found. Starting from scratch.")
 | 
			
		||||
        # Initial run: Load full traces and create initial sets A and B
 | 
			
		||||
        original_traces_map = {}
 | 
			
		||||
        initial_a_set = []
 | 
			
		||||
 | 
			
		||||
        logging.info(f"Reading full traces from {FULL_TRACE_PATH}")
 | 
			
		||||
        for trace in read_jsonl(FULL_TRACE_PATH):
 | 
			
		||||
            task_id = get_task_id(trace)
 | 
			
		||||
            if task_id:
 | 
			
		||||
                original_traces_map[task_id] = trace['messages']
 | 
			
		||||
                if len(trace['messages']) >= 3:
 | 
			
		||||
                    initial_sub_trace = {"messages": trace['messages'][:3]}
 | 
			
		||||
                    initial_a_set.append(initial_sub_trace)
 | 
			
		||||
 | 
			
		||||
        # Write all of set A to the final RL training data file
 | 
			
		||||
        logging.info(f"Adding {len(initial_a_set)} initial traces to {FINAL_OUTPUT_PATH}")
 | 
			
		||||
        write_jsonl(FINAL_OUTPUT_PATH, initial_a_set, mode='a')
 | 
			
		||||
 | 
			
		||||
        # Create initial B set from non-exit traces in A
 | 
			
		||||
        b_set_to_process = [trace for trace in initial_a_set if not is_exit_action(trace)]
 | 
			
		||||
        logging.info(f"Created initial B set with {len(b_set_to_process)} non-exit traces.")
 | 
			
		||||
        
 | 
			
		||||
        # Initial save of state
 | 
			
		||||
        save_state(b_set_to_process, original_traces_map)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    for i in range(1, MAX_ITERATIONS + 1):
 | 
			
		||||
        if not b_set_to_process:
 | 
			
		||||
            logging.info("B set is empty. Process finished successfully.")
 | 
			
		||||
            break
 | 
			
		||||
 | 
			
		||||
        logging.info(f"\n--- Iteration {i}/{MAX_ITERATIONS} ---")
 | 
			
		||||
        logging.info(f"Processing {len(b_set_to_process)} traces in B set.")
 | 
			
		||||
 | 
			
		||||
        # Prepare input for swift CLI
 | 
			
		||||
        iteration_dir = os.path.join(TMP_DIR, f'iter_{i}')
 | 
			
		||||
        os.makedirs(iteration_dir, exist_ok=True)
 | 
			
		||||
        iteration_input_path = os.path.join(iteration_dir, 'sample_input.jsonl')
 | 
			
		||||
        swift_input_data = []
 | 
			
		||||
        for trace in b_set_to_process:
 | 
			
		||||
            # The swift tool is expected to take the full trace,
 | 
			
		||||
            # use messages[:-1] as prompt, and messages[-1] as the ground truth for distillation.
 | 
			
		||||
            swift_input_data.append({
 | 
			
		||||
                "messages": trace['messages']
 | 
			
		||||
            })
 | 
			
		||||
        write_jsonl(iteration_input_path, swift_input_data)
 | 
			
		||||
        
 | 
			
		||||
        # Run swift CLI
 | 
			
		||||
        # Per swift's requirement, provide a filename prefix without a directory.
 | 
			
		||||
        swift_output_filename = 'sample_output.jsonl'
 | 
			
		||||
        iteration_output_path = os.path.join(iteration_dir, swift_output_filename)
 | 
			
		||||
 | 
			
		||||
        command = SWIFT_COMMAND_TEMPLATE.format(
 | 
			
		||||
            input_file=iteration_input_path,
 | 
			
		||||
            output_dir=iteration_dir,
 | 
			
		||||
            output_file=swift_output_filename
 | 
			
		||||
        )
 | 
			
		||||
        logging.info(f"Executing Swift CLI command...")
 | 
			
		||||
        
 | 
			
		||||
        try:
 | 
			
		||||
            # Ensure OPENAI_API_KEY is available to the subprocess
 | 
			
		||||
            env = os.environ.copy()
 | 
			
		||||
            if "OPENAI_API_KEY" not in env:
 | 
			
		||||
                logging.warning("OPENAI_API_KEY not found in environment. The script might fail.")
 | 
			
		||||
                
 | 
			
		||||
            subprocess.run(command, shell=True, check=True, capture_output=True, text=True, cwd=CWD)
 | 
			
		||||
            logging.info(f"Swift CLI finished. Output at {iteration_output_path}")
 | 
			
		||||
        except subprocess.CalledProcessError as e:
 | 
			
		||||
            logging.error(f"Swift CLI failed with exit code {e.returncode}")
 | 
			
		||||
            logging.error(f"STDOUT: {e.stdout}")
 | 
			
		||||
            logging.error(f"STDERR: {e.stderr}")
 | 
			
		||||
            # Decide if we should stop or continue. For now, stop.
 | 
			
		||||
            break
 | 
			
		||||
 | 
			
		||||
        if not os.path.exists(iteration_output_path):
 | 
			
		||||
            logging.warning(f"Swift output file not found at {iteration_output_path}. Assuming no traces were successfully sampled. Ending loop.")
 | 
			
		||||
            break
 | 
			
		||||
 | 
			
		||||
        # Process swift output and prepare for next iteration
 | 
			
		||||
        b_next_iteration = []
 | 
			
		||||
        extended_traces_for_rl = []
 | 
			
		||||
        
 | 
			
		||||
        # The output from swift contains the generated assistant message
 | 
			
		||||
        generated_traces = list(read_jsonl(iteration_output_path))
 | 
			
		||||
        logging.info(f"Swift generated {len(generated_traces)} successful samples.")
 | 
			
		||||
 | 
			
		||||
        successful_task_ids = {get_task_id(t) for t in generated_traces}
 | 
			
		||||
        successful_task_ids.discard(None)
 | 
			
		||||
 | 
			
		||||
        for generated_trace in generated_traces:
 | 
			
		||||
            task_id = get_task_id(generated_trace)
 | 
			
		||||
            if not task_id or task_id not in original_traces_map:
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            generated_messages = generated_trace['messages']
 | 
			
		||||
            original_full_trace_messages = original_traces_map[task_id]
 | 
			
		||||
            
 | 
			
		||||
            # Extend the generated trace with the next user-assistant pair from the original trace
 | 
			
		||||
            current_len = len(generated_messages)
 | 
			
		||||
            if current_len + 2 <= len(original_full_trace_messages):
 | 
			
		||||
                next_user_msg = original_full_trace_messages[current_len]
 | 
			
		||||
                next_assistant_gt_msg = original_full_trace_messages[current_len + 1]
 | 
			
		||||
 | 
			
		||||
                extended_messages = generated_messages + [next_user_msg, next_assistant_gt_msg]
 | 
			
		||||
                extended_trace = {"messages": extended_messages}
 | 
			
		||||
                
 | 
			
		||||
                extended_traces_for_rl.append(extended_trace)
 | 
			
		||||
                
 | 
			
		||||
                if not is_exit_action(extended_trace):
 | 
			
		||||
                    b_next_iteration.append(extended_trace)
 | 
			
		||||
 | 
			
		||||
        if extended_traces_for_rl:
 | 
			
		||||
            logging.info(f"Adding {len(extended_traces_for_rl)} extended traces to {FINAL_OUTPUT_PATH}")
 | 
			
		||||
            write_jsonl(FINAL_OUTPUT_PATH, extended_traces_for_rl, mode='a')
 | 
			
		||||
        
 | 
			
		||||
        # Add back rejected tasks to be retried in the next iteration
 | 
			
		||||
        rejected_count = 0
 | 
			
		||||
        for original_b_trace in b_set_to_process:
 | 
			
		||||
            task_id = get_task_id(original_b_trace)
 | 
			
		||||
            if task_id and task_id not in successful_task_ids:
 | 
			
		||||
                b_next_iteration.append(original_b_trace)
 | 
			
		||||
                rejected_count += 1
 | 
			
		||||
        
 | 
			
		||||
        if rejected_count > 0:
 | 
			
		||||
            logging.info(f"Added {rejected_count} rejected traces to the queue for retry.")
 | 
			
		||||
 | 
			
		||||
        b_set_to_process = b_next_iteration
 | 
			
		||||
        save_state(b_set_to_process, original_traces_map)
 | 
			
		||||
 | 
			
		||||
    else: # This else belongs to the for loop, executes if loop finishes without break
 | 
			
		||||
        if b_set_to_process:
 | 
			
		||||
            logging.warning(f"Reached max iterations ({MAX_ITERATIONS}) but B set is not empty.")
 | 
			
		||||
            logging.warning(f"{len(b_set_to_process)} traces remain.")
 | 
			
		||||
 | 
			
		||||
    # Cleanup
 | 
			
		||||
    if os.path.exists(PERSISTENT_STATE_FILE):
 | 
			
		||||
        os.remove(PERSISTENT_STATE_FILE)
 | 
			
		||||
        logging.info(f"Process finished. Removed persistent state file.")
 | 
			
		||||
    
 | 
			
		||||
    # Optional: You may want to clean up the tmp directory
 | 
			
		||||
    # shutil.rmtree(TMP_DIR)
 | 
			
		||||
    logging.info(f"Temporary files are stored in {TMP_DIR}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    main()
 | 
			
		||||
							
								
								
									
										808
									
								
								sample/full_trace.jsonl
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										808
									
								
								sample/full_trace.jsonl
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										10
									
								
								sample/full_trace_10.jsonl
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								sample/full_trace_10.jsonl
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										2
									
								
								sample/sample_input_step2.jsonl
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								sample/sample_input_step2.jsonl
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							@ -7,10 +7,11 @@ swift sample \
 | 
			
		||||
    --model $MODEL \
 | 
			
		||||
    --stream true \
 | 
			
		||||
    --orm_model external_web_acc \
 | 
			
		||||
    --dataset combine_output_file_4.jsonl \
 | 
			
		||||
    --dataset sample_input_step2.jsonl \
 | 
			
		||||
    --num_return_sequences 1 \
 | 
			
		||||
    --temperature 0.6 \
 | 
			
		||||
    --top_p 0.95 \
 | 
			
		||||
    --external_plugins plugin.py \
 | 
			
		||||
    --system system_prompt.txt \
 | 
			
		||||
    --output_file sample_output_step2.jsonl \
 | 
			
		||||
    --engine_kwargs '{"base_url":"https://dashscope.aliyuncs.com/compatible-mode/v1"}'
 | 
			
		||||
@ -18,6 +18,24 @@ def do(action, argument, element):
 | 
			
		||||
                :param element: optional. Only for "Click", "Right Click", "Type", "Search", "Select Dropdown Option", and "Hover". Should be specific element id in the html.
 | 
			
		||||
        Returns:
 | 
			
		||||
                None. The webpage will be updated after executing the action.
 | 
			
		||||
            
 | 
			
		||||
        IMPORTANT Notes: 
 | 
			
		||||
        **1. Task Classification:**
 | 
			
		||||
        - **Execution Task:** The instruction asks to perform an action, like "delete an item", "fill out a form", "navigate to a page".
 | 
			
		||||
        - **Query Task:** The instruction asks to find information, like "how many items are there?", "what is the price?", "find all products".
 | 
			
		||||
 | 
			
		||||
        **2. Answer Rules:**
 | 
			
		||||
 | 
			
		||||
        **If the task is 'Execution':**
 | 
			
		||||
        - If the task was completed successfully, the final answer should be **DONE**.
 | 
			
		||||
        - If the task failed or could not be completed, the final answer should be **INFEASIBLE**.
 | 
			
		||||
 | 
			
		||||
        **If the task is 'Query':**
 | 
			
		||||
        - **Not Found:** If the answer is "N/A" or indicates the information could not be found, the final answer should be **N/A**.
 | 
			
		||||
        - **Single Answer:** If the result is a single piece of information (e.g., a number, a name, a date), the final answer should be the most concise answer. For example, if the question is "How many products?" and the answer is "There are 5 products", the final answer should be just "5".
 | 
			
		||||
        - **Multiple Answers (List):** If the result is a list of items, the final answer should be a single string with items separated by a comma. For example: "item1, item2, item3".
 | 
			
		||||
        - **Multiple Answers (Key-Value):** If the result is a set of key-value pairs, the final answer should be a JSON string. For example: `{"k1": "v1", "k2": "v2"}`.
 | 
			
		||||
 | 
			
		||||
        """
 | 
			
		||||
 | 
			
		||||
def exit(message):
 | 
			
		||||
@ -45,7 +63,8 @@ do(action="Select Dropdown Option", argument="Month", element="20")
 | 
			
		||||
- # Element: the 'From' date picker input field, middle center
 | 
			
		||||
do(action="Type", argument="01/01/2023", element="22")
 | 
			
		||||
- do(action="Scroll Down")
 | 
			
		||||
- exit(message="The top-3 best-selling products in January 2023 are: 1")
 | 
			
		||||
- # Note: The top-3 best-selling products in January 2023 are: 1
 | 
			
		||||
exit(message="1")
 | 
			
		||||
- # Element: The search bar
 | 
			
		||||
do(action="Search", argument="international airport near Carnegie Mellon University within a driving distance of 50 km", element="13")
 | 
			
		||||
- # Note: Pittsburgh International Airport, Southern Beltway, Findlay Township, Allegheny County, 15231, United States
 | 
			
		||||
@ -65,7 +84,7 @@ Key guidelines you MUST follow:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Your reply should strictly follow the format:
 | 
			
		||||
Thought: Your reasoning trace. A good practice is to summarize information on the current web page that are relevant to the task goal, then generate a high-level plan that contains the sequence of actions you probably need to take
 | 
			
		||||
Thought: Your reasoning trace. A good practice is to summarize information on the current web page that are relevant to the task goal, then generate a high-level plan that contains the sequence of actions you probably need to take. Also, use this section to record **important information** that may be useful later. For example, at each step, you might only gather a **fragment** of the overall task. Keeping track of these pieces can help you synthesize them into a complete conclusion when enough context is available.
 | 
			
		||||
Action: Based on this reasoning, identify the single most optimal action. You should output it in the format specified above (under "STRICTLY follow the format")
 | 
			
		||||
 | 
			
		||||
After each action, you'll receive a new observation. Proceed until task completion. Website Note "- Always save progress through appropriate buttons (Save, Submit, Post, etc.)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user