246 lines
9.6 KiB
Python
246 lines
9.6 KiB
Python
import os
|
|
import random
|
|
import json
|
|
import mysql.connector
|
|
from openai import OpenAI
|
|
from dotenv import load_dotenv
|
|
|
|
# --- Configuration ---
|
|
load_dotenv()
|
|
|
|
MYSQL_CONFIG = {
|
|
"host": "localhost",
|
|
"port": "23306",
|
|
"user": "mcpuser",
|
|
"password": "StrongPass123!",
|
|
"database": "magentodb"
|
|
}
|
|
|
|
OPENAI_CONFIG = {
|
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
|
"base_url": os.getenv("OPENAI_BASE_URL"),
|
|
"model": "gpt-4o"
|
|
}
|
|
|
|
# --- Prompt Template ---
|
|
# This is a carefully engineered prompt to guide the LLM's output.
|
|
PROMPT_TEMPLATE = """
|
|
You are an expert database analyst and a creative test case designer for e-commerce web applications.
|
|
Your goal is to generate realistic administrative tasks that can be solved by a Web Agent navigating an admin panel.
|
|
|
|
I will provide you with the following context:
|
|
1. **Full Database Schema**: A list of `CREATE TABLE` statements for the core tables of a Magento e-commerce platform.
|
|
2. **Sampled Data**: A JSON object containing 5 random rows of data from 5 randomly selected core tables. This data is REAL and should be used to inspire specific, answerable questions.
|
|
|
|
## Your Task
|
|
|
|
Based on the provided schema and sample data, create a JSON array of up to 10 unique questions.
|
|
|
|
### Requirements for Each Question:
|
|
- **Web Agent Solvable**: The task must represent a realistic action an administrator would perform in a web UI (e.g., "Find all orders for customer X", "Update the stock for product Y", "Approve a pending review").
|
|
- **Grounded in Data**: The questions should be specific, using names, IDs, or values from the provided **Sampled Data** to make them concrete.
|
|
- **Utilize Schema**: You can formulate questions that require joining tables, even if not all tables were sampled. The full schema is your guide.
|
|
|
|
### Output Format
|
|
The final output MUST be a single, valid JSON array of objects. Do not include any other text, explanations, or markdown formatting like ```json.
|
|
|
|
Each object in the array must contain exactly three keys: `question`, `answer`, and `sql`.
|
|
|
|
- **`question`**: (string) A natural language description of the task for a web agent.
|
|
- **`answer`**: (string, integer, float, or list) The precise and concise answer to the question, derived by running the SQL query against the database.
|
|
- **`sql`**: (string) The exact, runnable MySQL query that was used to find the answer.
|
|
|
|
---
|
|
### Full Database Schema
|
|
{schema_context}
|
|
|
|
---
|
|
### Sampled Data
|
|
Here is the sample data from randomly selected tables. Use this to make your questions specific.
|
|
|
|
{sampled_data_str}
|
|
|
|
---
|
|
Now, generate the JSON array based on these instructions.
|
|
"""
|
|
|
|
def get_db_connection():
|
|
"""Establishes a connection to the MySQL database."""
|
|
try:
|
|
conn = mysql.connector.connect(**MYSQL_CONFIG)
|
|
return conn
|
|
except mysql.connector.Error as err:
|
|
print(f"Error connecting to MySQL: {err}")
|
|
return None
|
|
|
|
def get_full_schema(cursor, tables):
|
|
"""Fetches the CREATE TABLE statements for all core tables."""
|
|
schema_parts = []
|
|
for table_name in tables:
|
|
try:
|
|
cursor.execute(f"SHOW CREATE TABLE `{table_name}`")
|
|
result = cursor.fetchone()
|
|
if result:
|
|
schema_parts.append(result[1]) # result[1] is the CREATE TABLE statement
|
|
except mysql.connector.Error as err:
|
|
print(f"Warning: Could not get schema for table {table_name}: {err}")
|
|
return "\n\n".join(schema_parts)
|
|
|
|
def get_random_tables_and_samples(cursor, tables, num_tables=5, num_samples=5):
|
|
"""Selects random tables and samples random rows from them."""
|
|
selected_tables = random.sample(tables, num_tables)
|
|
sampled_data = {}
|
|
|
|
for table_name in selected_tables:
|
|
try:
|
|
# Use ORDER BY RAND() for random sampling. Can be slow on very large tables.
|
|
query = f"SELECT * FROM `{table_name}` ORDER BY RAND() LIMIT {num_samples}"
|
|
cursor.execute(query)
|
|
|
|
rows = cursor.fetchall()
|
|
if not rows:
|
|
sampled_data[table_name] = []
|
|
continue
|
|
|
|
columns = [desc[0] for desc in cursor.description]
|
|
|
|
# Convert rows (tuples) to a list of dictionaries
|
|
sampled_rows = []
|
|
for row in rows:
|
|
row_dict = {}
|
|
for i, col_value in enumerate(row):
|
|
# Handle bytes by decoding, fall back to string representation
|
|
if isinstance(col_value, bytes):
|
|
try:
|
|
row_dict[columns[i]] = col_value.decode('utf-8')
|
|
except UnicodeDecodeError:
|
|
row_dict[columns[i]] = str(col_value)
|
|
else:
|
|
row_dict[columns[i]] = col_value
|
|
sampled_rows.append(row_dict)
|
|
|
|
sampled_data[table_name] = sampled_rows
|
|
|
|
except mysql.connector.Error as err:
|
|
print(f"Warning: Could not sample data from table {table_name}: {err}")
|
|
sampled_data[table_name] = f"Error: {err}"
|
|
|
|
return sampled_data
|
|
|
|
def generate_questions(schema_context, sampled_data):
|
|
"""Generates questions by calling the OpenAI API."""
|
|
if not OPENAI_CONFIG["api_key"]:
|
|
raise ValueError("OPENAI_API_KEY environment variable not set.")
|
|
|
|
client = OpenAI(api_key=OPENAI_CONFIG["api_key"], base_url=OPENAI_CONFIG["base_url"])
|
|
|
|
sampled_data_str = json.dumps(sampled_data, indent=2, default=str)
|
|
|
|
prompt = PROMPT_TEMPLATE.format(
|
|
schema_context=schema_context,
|
|
sampled_data_str=sampled_data_str
|
|
)
|
|
|
|
try:
|
|
response = client.chat.completions.create(
|
|
model=OPENAI_CONFIG["model"],
|
|
messages=[
|
|
{"role": "system", "content": "You are a helpful assistant designed to output JSON."},
|
|
{"role": "user", "content": prompt}
|
|
],
|
|
temperature=0.7,
|
|
)
|
|
content = response.choices[0].message.content
|
|
return json.loads(content)
|
|
except Exception as e:
|
|
print(f"Error calling OpenAI API: {e}")
|
|
return None
|
|
|
|
def main():
|
|
"""Main function to run the script."""
|
|
# 1. Load the list of core tables
|
|
try:
|
|
with open('core_tables.json', 'r') as f:
|
|
core_tables = json.load(f)
|
|
except FileNotFoundError:
|
|
print("Error: core_tables.json not found. Please create it.")
|
|
return
|
|
|
|
# 2. Connect to the database
|
|
conn = get_db_connection()
|
|
if not conn:
|
|
return
|
|
|
|
cursor = conn.cursor()
|
|
|
|
try:
|
|
# 3. Get full schema context
|
|
print("Fetching full database schema...")
|
|
schema_context = get_full_schema(cursor, core_tables)
|
|
|
|
# 4. Get random samples and print them
|
|
print("Sampling data from 5 random tables...")
|
|
sampled_data = get_random_tables_and_samples(cursor, core_tables, num_tables=5, num_samples=5)
|
|
print(f"Sampled from tables: {list(sampled_data.keys())}")
|
|
print("\n--- Sampled Data ---")
|
|
print(json.dumps(sampled_data, indent=2, default=str))
|
|
print("---------------------\n")
|
|
|
|
# 5. Generate questions using the LLM
|
|
print("Generating questions with GPT-4o...")
|
|
generated_tasks = generate_questions(schema_context, sampled_data)
|
|
|
|
# 6. Validate and filter the generated tasks
|
|
validated_tasks = []
|
|
if generated_tasks:
|
|
print("\nValidating generated tasks...")
|
|
for task in generated_tasks:
|
|
# Basic validation for task structure
|
|
if not isinstance(task, dict) or not all(k in task for k in ['sql', 'answer', 'question']):
|
|
print(f"Filtering task due to malformed structure or missing keys: {task}")
|
|
continue
|
|
|
|
try:
|
|
# Execute the SQL query from the task
|
|
cursor.execute(task['sql'])
|
|
sql_result = cursor.fetchall()
|
|
|
|
# Convert both answer and result to string for flexible substring matching
|
|
answer_str = str(task['answer'])
|
|
result_str = str(sql_result)
|
|
|
|
# If the answer exists in the result, the task is valid
|
|
if answer_str in result_str:
|
|
validated_tasks.append(task)
|
|
else:
|
|
# Log tasks that are filtered because the answer doesn't match
|
|
print(f"Filtering task: Answer '{answer_str}' not found in SQL result.")
|
|
print(f" - Question: {task['question']}")
|
|
print(f" - SQL: {task['sql']}")
|
|
# Showing a snippet of a large result is helpful for debugging
|
|
print(f" - Result: {result_str[:250]}...")
|
|
|
|
except mysql.connector.Error as err:
|
|
# Log tasks that are filtered due to SQL errors
|
|
print(f"Filtering task due to SQL error: {err}")
|
|
print(f" - Question: {task['question']}")
|
|
print(f" - SQL: {task['sql']}")
|
|
except Exception as e:
|
|
print(f"An unexpected error occurred during validation for task {task}: {e}")
|
|
|
|
# 7. Print the final JSON output
|
|
if validated_tasks:
|
|
print("\n--- Generated and Validated Tasks ---")
|
|
print(json.dumps(validated_tasks, indent=2))
|
|
else:
|
|
print("Failed to generate any valid tasks.")
|
|
|
|
finally:
|
|
# 8. Close the database connection
|
|
if conn.is_connected():
|
|
cursor.close()
|
|
conn.close()
|
|
print("\nDatabase connection closed.")
|
|
|
|
if __name__ == "__main__":
|
|
main() |