webrlvr/random_sample/generate_tasks copy 2.py
2025-06-11 17:30:06 +08:00

364 lines
15 KiB
Python

import os
import random
import json
import mysql.connector
from openai import OpenAI
from dotenv import load_dotenv
# --- Configuration ---
load_dotenv()
MYSQL_CONFIG = {
"host": "localhost",
"port": "23306",
"user": "mcpuser",
"password": "StrongPass123!",
"database": "magentodb"
}
OPENAI_CONFIG = {
"api_key": os.getenv("OPENAI_API_KEY"),
"base_url": os.getenv("OPENAI_BASE_URL"),
"model": "gpt-4o"
}
# --- Prompt Template ---
# This is a carefully engineered prompt to guide the LLM's output.
PROMPT_TEMPLATE = """
You are an expert database analyst and a creative test case designer for e-commerce web applications.
Your goal is to generate realistic administrative tasks that can be solved by a Web Agent navigating an admin panel.
I will provide you with the following context:
1. **Full Database Schema**: A list of `CREATE TABLE` statements for the core tables of a Magento e-commerce platform.
2. **Sampled Data**: A JSON object containing 5 random rows of data from 5 randomly selected core tables. This data is REAL and should be used to inspire specific, answerable questions.
## Your Task
Based on the provided schema and sample data, create a JSON object containing a single key, "questions", which holds an array of up to 10 unique task objects.
### Requirements for Each Question:
- **Web Agent Solvable**: The task must represent a realistic action an administrator would perform in a web UI (e.g., "Find all orders for customer X", "Update the stock for product Y", "Approve a pending review").
- **Grounded in Data**: The questions should be specific, using names, IDs, or values from the provided **Sampled Data** to make them concrete.
- **Utilize Schema**: You can formulate questions that require joining tables, even if not all tables were sampled. The full schema is your guide.
### Output Format
The final output MUST be a single, valid JSON object. Do not include any other text, explanations, or markdown formatting like ```json.
The JSON object must have one key: "questions", containing a JSON array of task objects.
Each object in the array must contain exactly three keys: `question`, `answer`, and `sql`.
- **`question`**: (string) A natural language description of the task for a web agent.
- **`answer`**: (string, integer, float, or list) The precise and concise answer to the question, derived by running the SQL query against the database.
- **`sql`**: (string) The exact, runnable MySQL query that was used to find the answer.
### Output Format Example
```json
{{
"questions": [
{{
"question": "What is the email address for customer with ID 5?",
"answer": "customer5@example.com",
"sql": "SELECT email FROM customer_entity WHERE entity_id = 5;"
}},
{{
"question": "Find the total quantity of item with SKU 'ABC-123' in the cart.",
"answer": 3,
"sql": "SELECT SUM(qty) FROM quote_item WHERE sku = 'ABC-123';"
}}
]
}}
```
---
### Full Database Schema
{schema_context}
---
### Sampled Data
Here is the sample data from randomly selected tables. Use this to make your questions specific.
{sampled_data_str}
---
Now, generate the JSON object based on these instructions.
"""
# This is a carefully engineered prompt to verify the LLM's own output.
SEMANTIC_VERIFICATION_PROMPT_TEMPLATE = """
You are a meticulous data verifier. Your task is to determine if a given "answer" is semantically correct and accurately supported by the "SQL query result".
I will provide you with a JSON object containing:
1. `question`: The original question asked.
2. `sql`: The SQL query used to find the answer.
3. `answer`: The answer generated by a previous AI.
4. `sql_result`: The actual data returned by executing the SQL query.
## Your Task
Carefully analyze the `sql_result` and compare it to the `answer`. The match should be semantic, not just a simple substring match. For example, if the question is "How many products are in stock?", an answer of "5" should be verifiable from the SQL result which might be `[(5,)]`.
### Requirements:
- Respond with a single JSON object.
- Do not include any other text, explanations, or markdown formatting.
- The JSON object must have exactly two keys:
- `is_match`: (boolean) `true` if the `answer` is fully and accurately supported by the `sql_result`, otherwise `false`.
- `reason`: (string) A brief explanation for your decision. If it's a mismatch, explain why (e.g., "The answer is 'John Doe' but the result contains 'Jane Doe'", "The answer is a count but the result is a list of names").
---
### Verification Data
{task_data_json}
---
Now, provide your verification as a JSON object.
"""
def get_db_connection():
"""Establishes a connection to the MySQL database."""
try:
conn = mysql.connector.connect(**MYSQL_CONFIG)
return conn
except mysql.connector.Error as err:
print(f"Error connecting to MySQL: {err}")
return None
def get_full_schema(cursor, tables):
"""Fetches the CREATE TABLE statements for all core tables."""
schema_parts = []
for table_name in tables:
try:
cursor.execute(f"SHOW CREATE TABLE `{table_name}`")
result = cursor.fetchone()
if result:
schema_parts.append(result[1]) # result[1] is the CREATE TABLE statement
except mysql.connector.Error as err:
print(f"Warning: Could not get schema for table {table_name}: {err}")
return "\n\n".join(schema_parts)
def get_random_tables_and_samples(cursor, tables, num_tables=5, num_samples=5):
"""Selects random tables and samples random rows from them."""
selected_tables = random.sample(tables, num_tables)
sampled_data = {}
for table_name in selected_tables:
try:
# Use ORDER BY RAND() for random sampling. Can be slow on very large tables.
query = f"SELECT * FROM `{table_name}` ORDER BY RAND() LIMIT {num_samples}"
cursor.execute(query)
rows = cursor.fetchall()
if not rows:
sampled_data[table_name] = []
continue
columns = [desc[0] for desc in cursor.description]
# Convert rows (tuples) to a list of dictionaries
sampled_rows = []
for row in rows:
row_dict = {}
for i, col_value in enumerate(row):
# Handle bytes by decoding, fall back to string representation
if isinstance(col_value, bytes):
try:
row_dict[columns[i]] = col_value.decode('utf-8')
except UnicodeDecodeError:
row_dict[columns[i]] = str(col_value)
else:
row_dict[columns[i]] = col_value
sampled_rows.append(row_dict)
sampled_data[table_name] = sampled_rows
except mysql.connector.Error as err:
print(f"Warning: Could not sample data from table {table_name}: {err}")
sampled_data[table_name] = f"Error: {err}"
return sampled_data
def generate_questions(client, schema_context, sampled_data):
"""Generates questions by calling the OpenAI API."""
if not client:
raise ValueError("OpenAI client not provided.")
sampled_data_str = json.dumps(sampled_data, indent=2, default=str)
prompt = PROMPT_TEMPLATE.format(
schema_context=schema_context,
sampled_data_str=sampled_data_str
)
try:
response = client.chat.completions.create(
model=OPENAI_CONFIG["model"],
messages=[
{"role": "system", "content": "You are a helpful assistant designed to output JSON."},
{"role": "user", "content": prompt}
],
temperature=0.7,
response_format={"type": "json_object"},
)
content = response.choices[0].message.content
data = json.loads(content)
# The prompt asks for {"questions": [...]}, so we extract the list.
if isinstance(data, dict) and "questions" in data and isinstance(data["questions"], list):
return data["questions"]
elif isinstance(data, list):
# Fallback in case the model returns a list directly
print("Warning: Model returned a raw list instead of an object with a 'questions' key.")
return data
else:
print(f"Warning: Failed to find a 'questions' list in the model's output. Got: {content}")
return None
except Exception as e:
print(f"Error calling OpenAI API or parsing JSON: {e}")
return None
def semantic_validate_tasks(tasks, client):
"""
Uses an LLM to semantically validate if the task's answer matches the SQL result.
"""
if not tasks:
return []
final_validated_tasks = []
print("\nPerforming semantic validation with GPT-4o...")
for task in tasks:
# Prepare data for the prompt, including the SQL result
task_data_for_prompt = {
"question": task["question"],
"sql": task["sql"],
"answer": task["answer"],
"sql_result": task["sql_result"]
}
task_data_json = json.dumps(task_data_for_prompt, indent=2, default=str)
prompt = SEMANTIC_VERIFICATION_PROMPT_TEMPLATE.format(task_data_json=task_data_json)
try:
print(f" - Verifying question: \"{task['question'][:80]}...\"")
response = client.chat.completions.create(
model=OPENAI_CONFIG["model"],
messages=[
{"role": "system", "content": "You are a helpful assistant designed to output JSON."},
{"role": "user", "content": prompt}
],
temperature=0.0, # We want deterministic validation
response_format={"type": "json_object"},
)
content = response.choices[0].message.content
verification_result = json.loads(content)
if verification_result.get("is_match") is True:
# Task is valid. Rename sql_result for the final output.
print(f" - Validation PASSED.")
task['sql_execute_result'] = task.pop('sql_result')
final_validated_tasks.append(task)
else:
reason = verification_result.get('reason', 'No reason provided.')
print(f" - Validation FAILED. Filtering task.")
print(f" - Reason: {reason}")
print(f" - Question: {task['question']}")
print(f" - Expected Answer: {json.dumps(task['answer'], default=str)}")
print(f" - SQL: {task['sql']}")
sql_result_str = json.dumps(task['sql_result'], indent=2, default=str)
print(f" - SQL Result: {sql_result_str}")
except Exception as e:
print(f" - An error occurred during semantic validation for task, filtering it out: {e}")
print(f" - Question: {task.get('question', 'N/A')}")
print(f" - SQL: {task.get('sql', 'N/A')}")
return final_validated_tasks
def main():
"""Main function to run the script."""
# 1. Load the list of core tables
try:
with open('core_tables.json', 'r') as f:
core_tables = json.load(f)
except FileNotFoundError:
print("Error: core_tables.json not found. Please create it.")
return
# 2. Connect to the database
conn = get_db_connection()
if not conn:
return
cursor = conn.cursor()
# 3. Setup OpenAI Client
if not OPENAI_CONFIG["api_key"]:
print("Error: OPENAI_API_KEY environment variable not set.")
return
client = OpenAI(api_key=OPENAI_CONFIG["api_key"], base_url=OPENAI_CONFIG["base_url"])
try:
# 4. Get full schema context
print("Fetching full database schema...")
schema_context = get_full_schema(cursor, core_tables)
# 5. Get random samples and print them
print("Sampling data from 5 random tables...")
sampled_data = get_random_tables_and_samples(cursor, core_tables, num_tables=5, num_samples=5)
print(f"Sampled from tables: {list(sampled_data.keys())}")
print("\n--- Sampled Data ---")
print(json.dumps(sampled_data, indent=2, default=str))
print("---------------------\n")
# 6. Generate questions using the LLM
print("Generating questions with GPT-4o...")
generated_tasks = generate_questions(client, schema_context, sampled_data)
# 7. Initial validation (SQL execution and substring check)
pre_validated_tasks = []
if generated_tasks:
print("\nPerforming initial validation (SQL execution and substring match)...")
for task in generated_tasks:
if not isinstance(task, dict) or not all(k in task for k in ['sql', 'answer', 'question']):
print(f"Filtering task due to malformed structure or missing keys: {task}")
continue
try:
cursor.execute(task['sql'])
sql_result = cursor.fetchall()
answer_str = str(task['answer'])
result_str = str(sql_result)
if answer_str in result_str:
task['sql_result'] = sql_result # Attach result for the next validation step
pre_validated_tasks.append(task)
else:
print(f"Filtering task: Answer '{answer_str}' not found in SQL result.")
print(f" - Question: {task['question']}")
print(f" - SQL: {task['sql']}")
print(f" - Result: {result_str[:250]}...")
except mysql.connector.Error as err:
print(f"Filtering task due to SQL error: {err}")
print(f" - Question: {task['question']}")
print(f" - SQL: {task['sql']}")
except Exception as e:
print(f"An unexpected error occurred during initial validation for task {task}: {e}")
# 8. Semantic validation using LLM
validated_tasks = semantic_validate_tasks(pre_validated_tasks, client)
# 9. Print the final JSON output
if validated_tasks:
print("\n--- Final Validated Tasks ---")
print(json.dumps(validated_tasks, indent=2, default=str))
else:
print("Failed to generate any valid tasks after all validation steps.")
finally:
# 10. Close the database connection
if conn.is_connected():
cursor.close()
conn.close()
print("\nDatabase connection closed.")
if __name__ == "__main__":
main()