temp save; temp hold on for rejection sample CoT; try direct RL with raw stepwise data first

This commit is contained in:
yuyr 2025-06-30 10:29:00 +00:00
parent 7f4fc8b05b
commit 5e632c53ac
10 changed files with 1119 additions and 4 deletions

2
.gitignore vendored
View File

@ -1,5 +1,7 @@
output/ output/
result/ result/
sample/sample_output/ sample/sample_output/
sample/tmp_dataset_gen/
sample/output_data/
*__pycache__/ *__pycache__/

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

263
sample/dataset_gen.py Normal file
View File

@ -0,0 +1,263 @@
import json
import os
import re
import subprocess
import logging
import shutil
# --- Configuration ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# File and Directory Paths
CWD = os.path.dirname(os.path.abspath(__file__))
FULL_TRACE_PATH = os.path.join(CWD, 'full_trace.jsonl')
OUTPUT_DIR = os.path.join(CWD, 'output_data')
FINAL_OUTPUT_PATH = os.path.join(OUTPUT_DIR, 'trace_rl.jsonl')
PERSISTENT_STATE_FILE = os.path.join(OUTPUT_DIR, 'persistent_state.json')
TMP_DIR = os.path.join(CWD, 'tmp_dataset_gen')
# --- Swift CLI Configuration ---
# Make sure to set the OPENAI_API_KEY in your environment variables
MODEL = "qwen3-8b"
SWIFT_COMMAND_TEMPLATE = (
'swift sample '
'--sampler_type distill '
'--sampler_engine client '
f'--model {MODEL} '
'--stream true '
'--orm_model external_web_acc '
'--dataset "{input_file}" '
'--num_return_sequences 1 '
'--temperature 0.8 '
'--top_p 0.95 '
'--external_plugins plugin.py '
'--system system_prompt.txt '
'--output_dir "{output_dir}" '
'--output_file "{output_file}" '
'--engine_kwargs \'{{"base_url":"http://192.168.16.116:18088/v1"}}\''
# '--engine_kwargs \'{{"base_url":"https://dashscope.aliyuncs.com/compatible-mode/v1"}}\''
)
# --- Constants ---
MAX_ITERATIONS = 100
EXIT_PATTERN = re.compile(r'\bexit\(.*\)', re.IGNORECASE)
def read_jsonl(file_path):
"""Reads a .jsonl file and yields each line as a parsed JSON object."""
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
yield json.loads(line)
def write_jsonl(file_path, data, mode='w'):
"""Writes a list of JSON objects to a .jsonl file."""
with open(file_path, mode, encoding='utf-8') as f:
for item in data:
f.write(json.dumps(item) + '\n')
def get_task_id(trace):
"""Extracts the task ID from the first user message of a trace."""
try:
match = re.search(r'\[#(\d+)\]', trace['messages'][0]['content'])
if match:
return match.group(1)
except (KeyError, IndexError, TypeError):
pass
logging.warning(f"Could not find task ID in trace: {trace}")
return None
def is_exit_action(trace):
"""Checks if the last assistant message is an exit action."""
try:
last_message = trace['messages'][-1]
if last_message['role'] == 'assistant':
return bool(EXIT_PATTERN.search(last_message['content']))
except (KeyError, IndexError, TypeError):
pass
return False
def save_state(b_set, original_traces_map):
"""Saves the current state (B set and original traces) to a file."""
state = {
'b_set': b_set,
'original_traces_map': original_traces_map
}
with open(PERSISTENT_STATE_FILE, 'w', encoding='utf-8') as f:
json.dump(state, f, indent=4)
logging.info(f"Successfully saved state to {PERSISTENT_STATE_FILE}")
def load_state():
"""Loads the state from a file if it exists."""
if os.path.exists(PERSISTENT_STATE_FILE):
logging.info(f"Found persistent state file. Resuming from {PERSISTENT_STATE_FILE}")
with open(PERSISTENT_STATE_FILE, 'r', encoding='utf-8') as f:
state = json.load(f)
return state['b_set'], state['original_traces_map']
return None, None
def main():
"""Main function to generate the dataset."""
if not os.path.exists(FULL_TRACE_PATH):
logging.error(f"Source file not found: {FULL_TRACE_PATH}")
return
# Create output and temporary directories
os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(TMP_DIR, exist_ok=True)
b_set_to_process, original_traces_map = load_state()
if b_set_to_process is None:
logging.info("No persistent state found. Starting from scratch.")
# Initial run: Load full traces and create initial sets A and B
original_traces_map = {}
initial_a_set = []
logging.info(f"Reading full traces from {FULL_TRACE_PATH}")
for trace in read_jsonl(FULL_TRACE_PATH):
task_id = get_task_id(trace)
if task_id:
original_traces_map[task_id] = trace['messages']
if len(trace['messages']) >= 3:
initial_sub_trace = {"messages": trace['messages'][:3]}
initial_a_set.append(initial_sub_trace)
# Write all of set A to the final RL training data file
logging.info(f"Adding {len(initial_a_set)} initial traces to {FINAL_OUTPUT_PATH}")
write_jsonl(FINAL_OUTPUT_PATH, initial_a_set, mode='a')
# Create initial B set from non-exit traces in A
b_set_to_process = [trace for trace in initial_a_set if not is_exit_action(trace)]
logging.info(f"Created initial B set with {len(b_set_to_process)} non-exit traces.")
# Initial save of state
save_state(b_set_to_process, original_traces_map)
for i in range(1, MAX_ITERATIONS + 1):
if not b_set_to_process:
logging.info("B set is empty. Process finished successfully.")
break
logging.info(f"\n--- Iteration {i}/{MAX_ITERATIONS} ---")
logging.info(f"Processing {len(b_set_to_process)} traces in B set.")
# Prepare input for swift CLI
iteration_dir = os.path.join(TMP_DIR, f'iter_{i}')
os.makedirs(iteration_dir, exist_ok=True)
iteration_input_path = os.path.join(iteration_dir, 'sample_input.jsonl')
swift_input_data = []
for trace in b_set_to_process:
# The swift tool is expected to take the full trace,
# use messages[:-1] as prompt, and messages[-1] as the ground truth for distillation.
swift_input_data.append({
"messages": trace['messages']
})
write_jsonl(iteration_input_path, swift_input_data)
# Run swift CLI
# Per swift's requirement, provide a filename prefix without a directory.
swift_output_filename = 'sample_output.jsonl'
iteration_output_path = os.path.join(iteration_dir, swift_output_filename)
command = SWIFT_COMMAND_TEMPLATE.format(
input_file=iteration_input_path,
output_dir=iteration_dir,
output_file=swift_output_filename
)
logging.info(f"Executing Swift CLI command...")
try:
# Ensure OPENAI_API_KEY is available to the subprocess
env = os.environ.copy()
if "OPENAI_API_KEY" not in env:
logging.warning("OPENAI_API_KEY not found in environment. The script might fail.")
subprocess.run(command, shell=True, check=True, capture_output=True, text=True, cwd=CWD)
logging.info(f"Swift CLI finished. Output at {iteration_output_path}")
except subprocess.CalledProcessError as e:
logging.error(f"Swift CLI failed with exit code {e.returncode}")
logging.error(f"STDOUT: {e.stdout}")
logging.error(f"STDERR: {e.stderr}")
# Decide if we should stop or continue. For now, stop.
break
if not os.path.exists(iteration_output_path):
logging.warning(f"Swift output file not found at {iteration_output_path}. Assuming no traces were successfully sampled. Ending loop.")
break
# Process swift output and prepare for next iteration
b_next_iteration = []
extended_traces_for_rl = []
# The output from swift contains the generated assistant message
generated_traces = list(read_jsonl(iteration_output_path))
logging.info(f"Swift generated {len(generated_traces)} successful samples.")
successful_task_ids = {get_task_id(t) for t in generated_traces}
successful_task_ids.discard(None)
for generated_trace in generated_traces:
task_id = get_task_id(generated_trace)
if not task_id or task_id not in original_traces_map:
continue
generated_messages = generated_trace['messages']
original_full_trace_messages = original_traces_map[task_id]
# Extend the generated trace with the next user-assistant pair from the original trace
current_len = len(generated_messages)
if current_len + 2 <= len(original_full_trace_messages):
next_user_msg = original_full_trace_messages[current_len]
next_assistant_gt_msg = original_full_trace_messages[current_len + 1]
extended_messages = generated_messages + [next_user_msg, next_assistant_gt_msg]
extended_trace = {"messages": extended_messages}
extended_traces_for_rl.append(extended_trace)
if not is_exit_action(extended_trace):
b_next_iteration.append(extended_trace)
if extended_traces_for_rl:
logging.info(f"Adding {len(extended_traces_for_rl)} extended traces to {FINAL_OUTPUT_PATH}")
write_jsonl(FINAL_OUTPUT_PATH, extended_traces_for_rl, mode='a')
# Add back rejected tasks to be retried in the next iteration
rejected_count = 0
for original_b_trace in b_set_to_process:
task_id = get_task_id(original_b_trace)
if task_id and task_id not in successful_task_ids:
b_next_iteration.append(original_b_trace)
rejected_count += 1
if rejected_count > 0:
logging.info(f"Added {rejected_count} rejected traces to the queue for retry.")
b_set_to_process = b_next_iteration
save_state(b_set_to_process, original_traces_map)
else: # This else belongs to the for loop, executes if loop finishes without break
if b_set_to_process:
logging.warning(f"Reached max iterations ({MAX_ITERATIONS}) but B set is not empty.")
logging.warning(f"{len(b_set_to_process)} traces remain.")
# Cleanup
if os.path.exists(PERSISTENT_STATE_FILE):
os.remove(PERSISTENT_STATE_FILE)
logging.info(f"Process finished. Removed persistent state file.")
# Optional: You may want to clean up the tmp directory
# shutil.rmtree(TMP_DIR)
logging.info(f"Temporary files are stored in {TMP_DIR}")
if __name__ == '__main__':
main()

808
sample/full_trace.jsonl Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -7,10 +7,11 @@ swift sample \
--model $MODEL \ --model $MODEL \
--stream true \ --stream true \
--orm_model external_web_acc \ --orm_model external_web_acc \
--dataset combine_output_file_4.jsonl \ --dataset sample_input_step2.jsonl \
--num_return_sequences 1 \ --num_return_sequences 1 \
--temperature 0.6 \ --temperature 0.6 \
--top_p 0.95 \ --top_p 0.95 \
--external_plugins plugin.py \ --external_plugins plugin.py \
--system system_prompt.txt \ --system system_prompt.txt \
--output_file sample_output_step2.jsonl \
--engine_kwargs '{"base_url":"https://dashscope.aliyuncs.com/compatible-mode/v1"}' --engine_kwargs '{"base_url":"https://dashscope.aliyuncs.com/compatible-mode/v1"}'

View File

@ -18,6 +18,24 @@ def do(action, argument, element):
:param element: optional. Only for "Click", "Right Click", "Type", "Search", "Select Dropdown Option", and "Hover". Should be specific element id in the html. :param element: optional. Only for "Click", "Right Click", "Type", "Search", "Select Dropdown Option", and "Hover". Should be specific element id in the html.
Returns: Returns:
None. The webpage will be updated after executing the action. None. The webpage will be updated after executing the action.
IMPORTANT Notes:
**1. Task Classification:**
- **Execution Task:** The instruction asks to perform an action, like "delete an item", "fill out a form", "navigate to a page".
- **Query Task:** The instruction asks to find information, like "how many items are there?", "what is the price?", "find all products".
**2. Answer Rules:**
**If the task is 'Execution':**
- If the task was completed successfully, the final answer should be **DONE**.
- If the task failed or could not be completed, the final answer should be **INFEASIBLE**.
**If the task is 'Query':**
- **Not Found:** If the answer is "N/A" or indicates the information could not be found, the final answer should be **N/A**.
- **Single Answer:** If the result is a single piece of information (e.g., a number, a name, a date), the final answer should be the most concise answer. For example, if the question is "How many products?" and the answer is "There are 5 products", the final answer should be just "5".
- **Multiple Answers (List):** If the result is a list of items, the final answer should be a single string with items separated by a comma. For example: "item1, item2, item3".
- **Multiple Answers (Key-Value):** If the result is a set of key-value pairs, the final answer should be a JSON string. For example: `{"k1": "v1", "k2": "v2"}`.
""" """
def exit(message): def exit(message):
@ -45,7 +63,8 @@ do(action="Select Dropdown Option", argument="Month", element="20")
- # Element: the 'From' date picker input field, middle center - # Element: the 'From' date picker input field, middle center
do(action="Type", argument="01/01/2023", element="22") do(action="Type", argument="01/01/2023", element="22")
- do(action="Scroll Down") - do(action="Scroll Down")
- exit(message="The top-3 best-selling products in January 2023 are: 1") - # Note: The top-3 best-selling products in January 2023 are: 1
exit(message="1")
- # Element: The search bar - # Element: The search bar
do(action="Search", argument="international airport near Carnegie Mellon University within a driving distance of 50 km", element="13") do(action="Search", argument="international airport near Carnegie Mellon University within a driving distance of 50 km", element="13")
- # Note: Pittsburgh International Airport, Southern Beltway, Findlay Township, Allegheny County, 15231, United States - # Note: Pittsburgh International Airport, Southern Beltway, Findlay Township, Allegheny County, 15231, United States
@ -65,7 +84,7 @@ Key guidelines you MUST follow:
Your reply should strictly follow the format: Your reply should strictly follow the format:
Thought: Your reasoning trace. A good practice is to summarize information on the current web page that are relevant to the task goal, then generate a high-level plan that contains the sequence of actions you probably need to take Thought: Your reasoning trace. A good practice is to summarize information on the current web page that are relevant to the task goal, then generate a high-level plan that contains the sequence of actions you probably need to take. Also, use this section to record **important information** that may be useful later. For example, at each step, you might only gather a **fragment** of the overall task. Keeping track of these pieces can help you synthesize them into a complete conclusion when enough context is available.
Action: Based on this reasoning, identify the single most optimal action. You should output it in the format specified above (under "STRICTLY follow the format") Action: Based on this reasoning, identify the single most optimal action. You should output it in the format specified above (under "STRICTLY follow the format")
After each action, you'll receive a new observation. Proceed until task completion. Website Note "- Always save progress through appropriate buttons (Save, Submit, Post, etc.) After each action, you'll receive a new observation. Proceed until task completion. Website Note "- Always save progress through appropriate buttons (Save, Submit, Post, etc.)