474 lines
20 KiB
Python
474 lines
20 KiB
Python
import os
|
|
import random
|
|
import json
|
|
import mysql.connector
|
|
import argparse
|
|
from openai import OpenAI
|
|
from dotenv import load_dotenv
|
|
|
|
# --- Configuration ---
|
|
load_dotenv()
|
|
|
|
server_address = "localhost"
|
|
|
|
MYSQL_CONFIG = {
|
|
"host": server_address,
|
|
"port": "23306",
|
|
"user": "mcpuser",
|
|
"password": "StrongPass123!",
|
|
"database": "magentodb"
|
|
}
|
|
|
|
OPENAI_CONFIG = {
|
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
|
"base_url": os.getenv("OPENAI_BASE_URL"),
|
|
"model": "gpt-4o"
|
|
}
|
|
|
|
# --- Prompt Template ---
|
|
# This is a carefully engineered prompt to guide the LLM's output.
|
|
PROMPT_TEMPLATE = """
|
|
You are an expert database analyst and a creative test case designer for e-commerce web applications.
|
|
Your goal is to generate realistic administrative tasks that can be solved by a Web Agent navigating an admin panel.
|
|
|
|
I will provide you with the following context:
|
|
1. **Full Database Schema**: A list of `CREATE TABLE` statements for the core tables of a Magento e-commerce platform.
|
|
2. **Sampled Data**: A JSON object containing 5 random rows of data from 5 randomly selected core tables. This data is REAL and should be used to inspire specific, answerable questions.
|
|
|
|
## Your Task
|
|
|
|
Based on the provided schema and sample data, create a JSON object containing a single key, "questions", which holds an array of up to 10 unique task objects.
|
|
|
|
### Requirements for Each Question:
|
|
- **Web Agent Solvable**: The task must represent a realistic action an administrator would perform in a web UI (e.g., "Find all orders for customer X", "Update the stock for product Y", "Approve a pending review").
|
|
- **Grounded in Data**: The questions should be specific, using names, IDs, or values from the provided **Sampled Data** to make them concrete.
|
|
- **Utilize Schema**: You can formulate questions that require joining tables, even if not all tables were sampled. The full schema is your guide.
|
|
|
|
### Output Format
|
|
The final output MUST be a single, valid JSON object. Do not include any other text, explanations, or markdown formatting like ```json.
|
|
The JSON object must have one key: "questions", containing a JSON array of task objects.
|
|
|
|
Each object in the array must contain exactly three keys: `question`, `answer`, and `sql`.
|
|
|
|
- **`question`**: (string) A natural language description of the task for a web agent.
|
|
- **`answer`**: (string, integer, float, or list) The precise and concise answer to the question, derived by running the SQL query against the database.
|
|
- **`sql`**: (string) The exact, runnable MySQL query that was used to find the answer.
|
|
|
|
### Output Format Example
|
|
```json
|
|
{{
|
|
"questions": [
|
|
{{
|
|
"question": "What is the email address for customer with ID 5?",
|
|
"answer": "customer5@example.com",
|
|
"sql": "SELECT email FROM customer_entity WHERE entity_id = 5;"
|
|
}},
|
|
{{
|
|
"question": "Find the total quantity of item with SKU 'ABC-123' in the cart.",
|
|
"answer": 3,
|
|
"sql": "SELECT SUM(qty) FROM quote_item WHERE sku = 'ABC-123';"
|
|
}}
|
|
]
|
|
}}
|
|
```
|
|
|
|
---
|
|
### Full Database Schema
|
|
{schema_context}
|
|
|
|
---
|
|
### Sampled Data
|
|
Here is the sample data from randomly selected tables. Use this to make your questions specific.
|
|
|
|
{sampled_data_str}
|
|
|
|
---
|
|
Now, generate the JSON object based on these instructions.
|
|
"""
|
|
|
|
# This is a new prompt to evaluate results and generate a corrected answer.
|
|
SEMANTIC_EVALUATION_PROMPT_TEMPLATE = """
|
|
You are a precise data analyst. Your task is to evaluate if a SQL query's result adequately answers a given natural language question. You will then either refine the answer, or completely rephrase the question if the result set is large.
|
|
|
|
I will provide you with a JSON object containing:
|
|
1. `question`: The original question asked.
|
|
2. `sql`: The SQL query that was executed.
|
|
3. `sql_result`: The actual data returned by executing the SQL query.
|
|
4. `row_count`: The number of rows in `sql_result`.
|
|
|
|
## Your Task
|
|
Analyze the inputs and respond with a JSON object. You have three cases. The `new_answer` field MUST always be an array of strings.
|
|
|
|
### Data Analysis and Refinement Rules
|
|
1. **Analyze SQL and Question Intent**: Look at the SQL query (`SELECT`, `COUNT`, `DISTINCT`, etc.) and the natural language `question` to understand the user's goal. Is the goal to count things, list unique items, or retrieve specific related data points?
|
|
2. **Handle Duplicates and Merge Data**:
|
|
- **De-duplication**: If the `question` implies a list of unique items (e.g., "List the cities..." or "What are the unique order statuses?"), you MUST de-duplicate the values in `sql_result` to form the `new_answer`. For example, if `sql_result` is `[["pending"], ["shipped"], ["pending"]]`, the `new_answer` should be `["pending", "shipped"]`.
|
|
- **Data Merging**: If the `sql_result` contains multiple rows related to the same entity (e.g., different attributes of one product), combine the relevant information into a concise `new_answer`. For instance, if the question is "What are the name and price of product 'XYZ'?" and `sql_result` is `[["Product XYZ", 99.99]]`, the `new_answer` is `["Product XYZ", "99.99"]`. If the result was `[["Product XYZ", "Red"], ["Product XYZ", "Blue"]]` for a question about colors, `new_answer` could be `["Red", "Blue"]`. Extract only the information that directly answers the question.
|
|
|
|
After applying these rules, select one of the three cases below for your response format.
|
|
|
|
### Case 1: Large Result Set (Question Transformation)
|
|
If `row_count` is greater than 10 AND the original `question` does NOT already ask for a count (e.g., it is not phrased like "How many..."), you must transform the question.
|
|
Respond with:
|
|
```json
|
|
{{
|
|
"can_answer": true,
|
|
"new_question": "How many items were found?",
|
|
"new_answer": ["42"]
|
|
}}
|
|
```
|
|
- `can_answer`: (boolean) Must be `true`.
|
|
- `new_question`: (string) A rephrased question that asks for the quantity of items. For example, if the original question was "List all products", the new question should be "How many products were found?".
|
|
- `new_answer`: (array of strings) An array containing the `row_count` as a single string element.
|
|
|
|
### Case 2: Standard Answer (No Transformation)
|
|
If Case 1 does not apply, but the `sql_result` still provides a clear answer to the original `question` (after applying the refinement rules), respond with:
|
|
```json
|
|
{{
|
|
"can_answer": true,
|
|
"new_answer": ["value1", "value2", ...]
|
|
}}
|
|
```
|
|
- `can_answer`: (boolean) Must be `true`.
|
|
- `new_answer`: (array of strings) An array containing all the essential parts of the answer extracted and refined from `sql_result`. Every value from the result set that contributes to the answer should be included as a string in the array. This ensures answer completeness.
|
|
- **Example 1**: If `question` is "What is the status of order 123?" and `sql_result` is `[["processing"]]`, `new_answer` should be `["processing"]`.
|
|
- **Example 2**: If `question` is "List emails for pending customers" and `sql_result` is `[["test@a.com"], ["test@b.com"]]`, `new_answer` should be `["test@a.com", "test@b.com"]`.
|
|
- **Example 3**: If `question` is "Get product name and price for SKU 'XYZ'" and `sql_result` is `[["My Product", 19.99]]`, `new_answer` should be `["My Product", "19.99"]`.
|
|
|
|
### Case 3: The question CANNOT be answered
|
|
If the `sql_result` is empty, irrelevant, or insufficient to answer the question, respond with:
|
|
```json
|
|
{{
|
|
"can_answer": false,
|
|
"reason": "..."
|
|
}}
|
|
```
|
|
- `can_answer`: (boolean) Must be `false`.
|
|
- `reason`: (string) A brief explanation for why the question cannot be answered.
|
|
|
|
---
|
|
### Evaluation Data
|
|
{task_data_json}
|
|
---
|
|
|
|
Now, provide your evaluation as a JSON object.
|
|
"""
|
|
|
|
def get_db_connection():
|
|
"""Establishes a connection to the MySQL database."""
|
|
try:
|
|
conn = mysql.connector.connect(**MYSQL_CONFIG)
|
|
return conn
|
|
except mysql.connector.Error as err:
|
|
print(f"Error connecting to MySQL: {err}")
|
|
return None
|
|
|
|
def get_full_schema(cursor, tables):
|
|
"""Fetches the CREATE TABLE statements for all core tables."""
|
|
schema_parts = []
|
|
for table_name in tables:
|
|
try:
|
|
cursor.execute(f"SHOW CREATE TABLE `{table_name}`")
|
|
result = cursor.fetchone()
|
|
if result:
|
|
schema_parts.append(result[1]) # result[1] is the CREATE TABLE statement
|
|
except mysql.connector.Error as err:
|
|
print(f"Warning: Could not get schema for table {table_name}: {err}")
|
|
return "\n\n".join(schema_parts)
|
|
|
|
def get_random_tables_and_samples(cursor, tables, num_tables=5, num_samples=5):
|
|
"""Selects random tables and samples random rows from them."""
|
|
selected_tables = random.sample(tables, num_tables)
|
|
sampled_data = {}
|
|
|
|
for table_name in selected_tables:
|
|
try:
|
|
# Use ORDER BY RAND() for random sampling. Can be slow on very large tables.
|
|
query = f"SELECT * FROM `{table_name}` ORDER BY RAND() LIMIT {num_samples}"
|
|
cursor.execute(query)
|
|
|
|
rows = cursor.fetchall()
|
|
if not rows:
|
|
sampled_data[table_name] = []
|
|
continue
|
|
|
|
columns = [desc[0] for desc in cursor.description]
|
|
|
|
# Convert rows (tuples) to a list of dictionaries
|
|
sampled_rows = []
|
|
for row in rows:
|
|
row_dict = {}
|
|
for i, col_value in enumerate(row):
|
|
# Handle bytes by decoding, fall back to string representation
|
|
if isinstance(col_value, bytes):
|
|
try:
|
|
row_dict[columns[i]] = col_value.decode('utf-8')
|
|
except UnicodeDecodeError:
|
|
row_dict[columns[i]] = str(col_value)
|
|
else:
|
|
row_dict[columns[i]] = col_value
|
|
sampled_rows.append(row_dict)
|
|
|
|
sampled_data[table_name] = sampled_rows
|
|
|
|
except mysql.connector.Error as err:
|
|
print(f"Warning: Could not sample data from table {table_name}: {err}")
|
|
sampled_data[table_name] = f"Error: {err}"
|
|
|
|
return sampled_data
|
|
|
|
def generate_questions(client, schema_context, sampled_data, verbose=False):
|
|
"""Generates questions by calling the OpenAI API."""
|
|
if not client:
|
|
raise ValueError("OpenAI client not provided.")
|
|
|
|
sampled_data_str = json.dumps(sampled_data, indent=2, default=str)
|
|
|
|
prompt = PROMPT_TEMPLATE.format(
|
|
schema_context=schema_context,
|
|
sampled_data_str=sampled_data_str
|
|
)
|
|
|
|
if verbose:
|
|
print("\n--- Generation Prompt ---")
|
|
print(prompt)
|
|
print("-------------------------\n")
|
|
|
|
try:
|
|
response = client.chat.completions.create(
|
|
model=OPENAI_CONFIG["model"],
|
|
messages=[
|
|
{"role": "system", "content": "You are a helpful assistant designed to output JSON."},
|
|
{"role": "user", "content": prompt}
|
|
],
|
|
temperature=0.7,
|
|
response_format={"type": "json_object"},
|
|
)
|
|
content = response.choices[0].message.content
|
|
if verbose:
|
|
print("\n--- GPT-4o Raw Generation Response ---")
|
|
print(content)
|
|
print("--------------------------------------\n")
|
|
data = json.loads(content)
|
|
|
|
# The prompt asks for {"questions": [...]}, so we extract the list.
|
|
if isinstance(data, dict) and "questions" in data and isinstance(data["questions"], list):
|
|
return data["questions"]
|
|
elif isinstance(data, list):
|
|
# Fallback in case the model returns a list directly
|
|
print("Warning: Model returned a raw list instead of an object with a 'questions' key.")
|
|
return data
|
|
else:
|
|
print(f"Warning: Failed to find a 'questions' list in the model's output. Got: {content}")
|
|
return None
|
|
|
|
except Exception as e:
|
|
print(f"Error calling OpenAI API or parsing JSON: {e}")
|
|
return None
|
|
|
|
def load_existing_tasks(filepath):
|
|
"""Loads tasks from a JSON file if it exists."""
|
|
if not os.path.exists(filepath):
|
|
return []
|
|
try:
|
|
with open(filepath, 'r') as f:
|
|
content = f.read()
|
|
if not content: # Handle empty file
|
|
return []
|
|
return json.loads(content)
|
|
except (json.JSONDecodeError, FileNotFoundError):
|
|
print(f"Warning: Could not read or parse {filepath}. Starting with an empty list.")
|
|
return []
|
|
|
|
def evaluate_and_refine_tasks(tasks, client, verbose=False):
|
|
"""
|
|
Uses an LLM to evaluate if a SQL result answers the question and refines the answer.
|
|
"""
|
|
if not tasks:
|
|
return []
|
|
|
|
final_validated_tasks = []
|
|
print("\nPerforming semantic evaluation and answer refinement with GPT-4o...")
|
|
|
|
for task in tasks:
|
|
# Prepare data for the prompt, excluding the original 'answer'
|
|
task_data_for_prompt = {
|
|
"question": task["question"],
|
|
"sql": task["sql"],
|
|
"sql_result": task["sql_result"],
|
|
"row_count": task["row_count"]
|
|
}
|
|
task_data_json = json.dumps(task_data_for_prompt, indent=2, default=str)
|
|
|
|
prompt = SEMANTIC_EVALUATION_PROMPT_TEMPLATE.format(task_data_json=task_data_json)
|
|
|
|
if verbose:
|
|
print("\n--- Evaluation Prompt ---")
|
|
print(prompt)
|
|
print("-------------------------\n")
|
|
|
|
try:
|
|
print(f" - Evaluating question: \"{task['question'][:80]}...\"")
|
|
response = client.chat.completions.create(
|
|
model=OPENAI_CONFIG["model"],
|
|
messages=[
|
|
{"role": "system", "content": "You are a helpful assistant designed to output JSON."},
|
|
{"role": "user", "content": prompt}
|
|
],
|
|
temperature=0.0, # We want deterministic evaluation
|
|
response_format={"type": "json_object"},
|
|
)
|
|
content = response.choices[0].message.content
|
|
if verbose:
|
|
print("\n--- GPT-4o Raw Evaluation Response ---")
|
|
print(content)
|
|
print("----------------------------------------\n")
|
|
evaluation_result = json.loads(content)
|
|
|
|
if evaluation_result.get("can_answer") is True and "new_answer" in evaluation_result:
|
|
# Task is valid. Update the answer with the refined one from the LLM.
|
|
task['answer'] = evaluation_result['new_answer']
|
|
|
|
# If the LLM provides a new question, update it.
|
|
if 'new_question' in evaluation_result:
|
|
task['question'] = evaluation_result['new_question']
|
|
print(f" - Question was rephrased: \"{task['question']}\"")
|
|
|
|
task['sql_execute_result'] = task.pop('sql_result')
|
|
task.pop('row_count', None) # Clean up temp key
|
|
final_validated_tasks.append(task)
|
|
print(f" - Evaluation PASSED. New answer: {json.dumps(task['answer'])}")
|
|
else:
|
|
reason = evaluation_result.get('reason', 'No reason provided.')
|
|
print(f" - Evaluation FAILED. Filtering task.")
|
|
print(f" - Reason: {reason}")
|
|
print(f" - Question: {task['question']}")
|
|
print(f" - SQL: {task['sql']}")
|
|
sql_result_str = json.dumps(task['sql_result'], indent=2, default=str)
|
|
print(f" - SQL Result: {sql_result_str}")
|
|
|
|
except Exception as e:
|
|
print(f" - An error occurred during semantic evaluation for task, filtering it out: {e}")
|
|
print(f" - Question: {task.get('question', 'N/A')}")
|
|
print(f" - SQL: {task.get('sql', 'N/A')}")
|
|
|
|
return final_validated_tasks
|
|
|
|
def main():
|
|
"""Main function to run the script."""
|
|
parser = argparse.ArgumentParser(description="Generate and validate e-commerce admin tasks.")
|
|
parser.add_argument(
|
|
"--target-count",
|
|
type=int,
|
|
required=True,
|
|
help="The total number of questions to generate."
|
|
)
|
|
parser.add_argument(
|
|
"--output-file",
|
|
type=str,
|
|
default="generated_tasks.json",
|
|
help="The file to save the generated tasks to (in JSON format)."
|
|
)
|
|
parser.add_argument(
|
|
"-v", "--verbose",
|
|
action="store_true",
|
|
help="Enable verbose output, including prompts and raw LLM responses."
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
# Load existing tasks from the output file
|
|
all_tasks = load_existing_tasks(args.output_file)
|
|
print(f"Found {len(all_tasks)} existing valid tasks in '{args.output_file}'.")
|
|
|
|
# Connect to DB and set up client
|
|
conn = get_db_connection()
|
|
if not conn:
|
|
return
|
|
cursor = conn.cursor()
|
|
|
|
if not OPENAI_CONFIG["api_key"]:
|
|
print("Error: OPENAI_API_KEY environment variable not set.")
|
|
return
|
|
client = OpenAI(api_key=OPENAI_CONFIG["api_key"], base_url=OPENAI_CONFIG["base_url"])
|
|
|
|
try:
|
|
# Load core tables and schema once
|
|
try:
|
|
with open('core_tables.json', 'r') as f:
|
|
core_tables = json.load(f)
|
|
except FileNotFoundError:
|
|
print("Error: core_tables.json not found. Please create it.")
|
|
return
|
|
|
|
print("Fetching full database schema...")
|
|
schema_context = get_full_schema(cursor, core_tables)
|
|
|
|
# Start the generation loop
|
|
round_num = 1
|
|
while len(all_tasks) < args.target_count:
|
|
print(f"\n--- Starting Generation Round {round_num} ---")
|
|
print(f"Goal: {args.target_count} | Current: {len(all_tasks)} | Needed: {args.target_count - len(all_tasks)}")
|
|
|
|
# Get random samples for this round
|
|
print("Sampling data from 5 random tables...")
|
|
sampled_data = get_random_tables_and_samples(cursor, core_tables, num_tables=5, num_samples=5)
|
|
if args.verbose:
|
|
print("\n--- Sampled Data ---")
|
|
print(json.dumps(sampled_data, indent=2, default=str))
|
|
print("--------------------\n")
|
|
|
|
# Generate questions
|
|
print("Generating questions with GPT-4o...")
|
|
generated_tasks = generate_questions(client, schema_context, sampled_data, verbose=args.verbose)
|
|
|
|
# Execute SQL for generated tasks
|
|
tasks_for_evaluation = []
|
|
if generated_tasks:
|
|
print("\nExecuting SQL for generated tasks...")
|
|
for task in generated_tasks:
|
|
if not isinstance(task, dict) or not all(k in task for k in ['sql', 'answer', 'question']):
|
|
print(f"Filtering task due to malformed structure: {task}")
|
|
continue
|
|
try:
|
|
cursor.execute(task['sql'])
|
|
sql_result = cursor.fetchall()
|
|
# Create a new dict for evaluation, excluding the original 'answer'.
|
|
tasks_for_evaluation.append({
|
|
'question': task['question'],
|
|
'sql': task['sql'],
|
|
'sql_result': sql_result,
|
|
'row_count': len(sql_result)
|
|
})
|
|
except mysql.connector.Error as err:
|
|
print(f"Filtering task due to SQL error: {err} on SQL: {task['sql']}")
|
|
|
|
# Perform semantic evaluation and get validated tasks
|
|
validated_tasks = evaluate_and_refine_tasks(tasks_for_evaluation, client, verbose=args.verbose)
|
|
|
|
# Append new tasks and save to file
|
|
if validated_tasks:
|
|
all_tasks.extend(validated_tasks)
|
|
with open(args.output_file, 'w') as f:
|
|
json.dump(all_tasks, f, indent=2, default=str)
|
|
|
|
print("\n--- Round Summary ---")
|
|
print(f"Generated {len(validated_tasks)} new valid tasks in this round.")
|
|
print(f"Progress: {len(all_tasks)} / {args.target_count} tasks.")
|
|
else:
|
|
print("\n--- Round Summary ---")
|
|
print("No new valid tasks were generated in this round. Retrying...")
|
|
|
|
round_num += 1
|
|
|
|
finally:
|
|
# Close the database connection
|
|
if conn.is_connected():
|
|
cursor.close()
|
|
conn.close()
|
|
print("\nDatabase connection closed.")
|
|
|
|
print(f"\nTarget of {args.target_count} tasks reached. Final output saved to {args.output_file}.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |