import os import random import json import mysql.connector from openai import OpenAI from dotenv import load_dotenv # --- Configuration --- load_dotenv() MYSQL_CONFIG = { "host": "localhost", "port": "23306", "user": "mcpuser", "password": "StrongPass123!", "database": "magentodb" } OPENAI_CONFIG = { "api_key": os.getenv("OPENAI_API_KEY"), "base_url": os.getenv("OPENAI_BASE_URL"), "model": "gpt-4o" } # --- Prompt Template --- # This is a carefully engineered prompt to guide the LLM's output. PROMPT_TEMPLATE = """ You are an expert database analyst and a creative test case designer for e-commerce web applications. Your goal is to generate realistic administrative tasks that can be solved by a Web Agent navigating an admin panel. I will provide you with the following context: 1. **Full Database Schema**: A list of `CREATE TABLE` statements for the core tables of a Magento e-commerce platform. 2. **Sampled Data**: A JSON object containing 5 random rows of data from 5 randomly selected core tables. This data is REAL and should be used to inspire specific, answerable questions. ## Your Task Based on the provided schema and sample data, create a JSON object containing a single key, "questions", which holds an array of up to 10 unique task objects. ### Requirements for Each Question: - **Web Agent Solvable**: The task must represent a realistic action an administrator would perform in a web UI (e.g., "Find all orders for customer X", "Update the stock for product Y", "Approve a pending review"). - **Grounded in Data**: The questions should be specific, using names, IDs, or values from the provided **Sampled Data** to make them concrete. - **Utilize Schema**: You can formulate questions that require joining tables, even if not all tables were sampled. The full schema is your guide. ### Output Format The final output MUST be a single, valid JSON object. Do not include any other text, explanations, or markdown formatting like ```json. The JSON object must have one key: "questions", containing a JSON array of task objects. Each object in the array must contain exactly three keys: `question`, `answer`, and `sql`. - **`question`**: (string) A natural language description of the task for a web agent. - **`answer`**: (string, integer, float, or list) The precise and concise answer to the question, derived by running the SQL query against the database. - **`sql`**: (string) The exact, runnable MySQL query that was used to find the answer. ### Output Format Example ```json {{ "questions": [ {{ "question": "What is the email address for customer with ID 5?", "answer": "customer5@example.com", "sql": "SELECT email FROM customer_entity WHERE entity_id = 5;" }}, {{ "question": "Find the total quantity of item with SKU 'ABC-123' in the cart.", "answer": 3, "sql": "SELECT SUM(qty) FROM quote_item WHERE sku = 'ABC-123';" }} ] }} ``` --- ### Full Database Schema {schema_context} --- ### Sampled Data Here is the sample data from randomly selected tables. Use this to make your questions specific. {sampled_data_str} --- Now, generate the JSON object based on these instructions. """ # This is a carefully engineered prompt to verify the LLM's own output. SEMANTIC_VERIFICATION_PROMPT_TEMPLATE = """ You are a meticulous data verifier. Your task is to determine if a given "answer" is semantically correct and accurately supported by the "SQL query result". I will provide you with a JSON object containing: 1. `question`: The original question asked. 2. `sql`: The SQL query used to find the answer. 3. `answer`: The answer generated by a previous AI. 4. `sql_result`: The actual data returned by executing the SQL query. ## Your Task Carefully analyze the `sql_result` and compare it to the `answer`. The match should be semantic, not just a simple substring match. For example, if the question is "How many products are in stock?", an answer of "5" should be verifiable from the SQL result which might be `[(5,)]`. ### Requirements: - Respond with a single JSON object. - Do not include any other text, explanations, or markdown formatting. - The JSON object must have exactly two keys: - `is_match`: (boolean) `true` if the `answer` is fully and accurately supported by the `sql_result`, otherwise `false`. - `reason`: (string) A brief explanation for your decision. If it's a mismatch, explain why (e.g., "The answer is 'John Doe' but the result contains 'Jane Doe'", "The answer is a count but the result is a list of names"). --- ### Verification Data {task_data_json} --- Now, provide your verification as a JSON object. """ def get_db_connection(): """Establishes a connection to the MySQL database.""" try: conn = mysql.connector.connect(**MYSQL_CONFIG) return conn except mysql.connector.Error as err: print(f"Error connecting to MySQL: {err}") return None def get_full_schema(cursor, tables): """Fetches the CREATE TABLE statements for all core tables.""" schema_parts = [] for table_name in tables: try: cursor.execute(f"SHOW CREATE TABLE `{table_name}`") result = cursor.fetchone() if result: schema_parts.append(result[1]) # result[1] is the CREATE TABLE statement except mysql.connector.Error as err: print(f"Warning: Could not get schema for table {table_name}: {err}") return "\n\n".join(schema_parts) def get_random_tables_and_samples(cursor, tables, num_tables=5, num_samples=5): """Selects random tables and samples random rows from them.""" selected_tables = random.sample(tables, num_tables) sampled_data = {} for table_name in selected_tables: try: # Use ORDER BY RAND() for random sampling. Can be slow on very large tables. query = f"SELECT * FROM `{table_name}` ORDER BY RAND() LIMIT {num_samples}" cursor.execute(query) rows = cursor.fetchall() if not rows: sampled_data[table_name] = [] continue columns = [desc[0] for desc in cursor.description] # Convert rows (tuples) to a list of dictionaries sampled_rows = [] for row in rows: row_dict = {} for i, col_value in enumerate(row): # Handle bytes by decoding, fall back to string representation if isinstance(col_value, bytes): try: row_dict[columns[i]] = col_value.decode('utf-8') except UnicodeDecodeError: row_dict[columns[i]] = str(col_value) else: row_dict[columns[i]] = col_value sampled_rows.append(row_dict) sampled_data[table_name] = sampled_rows except mysql.connector.Error as err: print(f"Warning: Could not sample data from table {table_name}: {err}") sampled_data[table_name] = f"Error: {err}" return sampled_data def generate_questions(client, schema_context, sampled_data): """Generates questions by calling the OpenAI API.""" if not client: raise ValueError("OpenAI client not provided.") sampled_data_str = json.dumps(sampled_data, indent=2, default=str) prompt = PROMPT_TEMPLATE.format( schema_context=schema_context, sampled_data_str=sampled_data_str ) try: response = client.chat.completions.create( model=OPENAI_CONFIG["model"], messages=[ {"role": "system", "content": "You are a helpful assistant designed to output JSON."}, {"role": "user", "content": prompt} ], temperature=0.7, response_format={"type": "json_object"}, ) content = response.choices[0].message.content data = json.loads(content) # The prompt asks for {"questions": [...]}, so we extract the list. if isinstance(data, dict) and "questions" in data and isinstance(data["questions"], list): return data["questions"] elif isinstance(data, list): # Fallback in case the model returns a list directly print("Warning: Model returned a raw list instead of an object with a 'questions' key.") return data else: print(f"Warning: Failed to find a 'questions' list in the model's output. Got: {content}") return None except Exception as e: print(f"Error calling OpenAI API or parsing JSON: {e}") return None def semantic_validate_tasks(tasks, client): """ Uses an LLM to semantically validate if the task's answer matches the SQL result. """ if not tasks: return [] final_validated_tasks = [] print("\nPerforming semantic validation with GPT-4o...") for task in tasks: # Prepare data for the prompt, including the SQL result task_data_for_prompt = { "question": task["question"], "sql": task["sql"], "answer": task["answer"], "sql_result": task["sql_result"] } task_data_json = json.dumps(task_data_for_prompt, indent=2, default=str) prompt = SEMANTIC_VERIFICATION_PROMPT_TEMPLATE.format(task_data_json=task_data_json) try: print(f" - Verifying question: \"{task['question'][:80]}...\"") response = client.chat.completions.create( model=OPENAI_CONFIG["model"], messages=[ {"role": "system", "content": "You are a helpful assistant designed to output JSON."}, {"role": "user", "content": prompt} ], temperature=0.0, # We want deterministic validation response_format={"type": "json_object"}, ) content = response.choices[0].message.content verification_result = json.loads(content) if verification_result.get("is_match") is True: # Task is valid. Rename sql_result for the final output. print(f" - Validation PASSED.") task['sql_execute_result'] = task.pop('sql_result') final_validated_tasks.append(task) else: reason = verification_result.get('reason', 'No reason provided.') print(f" - Validation FAILED. Filtering task.") print(f" - Reason: {reason}") print(f" - Question: {task['question']}") print(f" - Expected Answer: {json.dumps(task['answer'], default=str)}") print(f" - SQL: {task['sql']}") sql_result_str = json.dumps(task['sql_result'], indent=2, default=str) print(f" - SQL Result: {sql_result_str}") except Exception as e: print(f" - An error occurred during semantic validation for task, filtering it out: {e}") print(f" - Question: {task.get('question', 'N/A')}") print(f" - SQL: {task.get('sql', 'N/A')}") return final_validated_tasks def main(): """Main function to run the script.""" # 1. Load the list of core tables try: with open('core_tables.json', 'r') as f: core_tables = json.load(f) except FileNotFoundError: print("Error: core_tables.json not found. Please create it.") return # 2. Connect to the database conn = get_db_connection() if not conn: return cursor = conn.cursor() # 3. Setup OpenAI Client if not OPENAI_CONFIG["api_key"]: print("Error: OPENAI_API_KEY environment variable not set.") return client = OpenAI(api_key=OPENAI_CONFIG["api_key"], base_url=OPENAI_CONFIG["base_url"]) try: # 4. Get full schema context print("Fetching full database schema...") schema_context = get_full_schema(cursor, core_tables) # 5. Get random samples and print them print("Sampling data from 5 random tables...") sampled_data = get_random_tables_and_samples(cursor, core_tables, num_tables=5, num_samples=5) print(f"Sampled from tables: {list(sampled_data.keys())}") print("\n--- Sampled Data ---") print(json.dumps(sampled_data, indent=2, default=str)) print("---------------------\n") # 6. Generate questions using the LLM print("Generating questions with GPT-4o...") generated_tasks = generate_questions(client, schema_context, sampled_data) # 7. Initial validation (SQL execution and substring check) pre_validated_tasks = [] if generated_tasks: print("\nPerforming initial validation (SQL execution and substring match)...") for task in generated_tasks: if not isinstance(task, dict) or not all(k in task for k in ['sql', 'answer', 'question']): print(f"Filtering task due to malformed structure or missing keys: {task}") continue try: cursor.execute(task['sql']) sql_result = cursor.fetchall() answer_str = str(task['answer']) result_str = str(sql_result) if answer_str in result_str: task['sql_result'] = sql_result # Attach result for the next validation step pre_validated_tasks.append(task) else: print(f"Filtering task: Answer '{answer_str}' not found in SQL result.") print(f" - Question: {task['question']}") print(f" - SQL: {task['sql']}") print(f" - Result: {result_str[:250]}...") except mysql.connector.Error as err: print(f"Filtering task due to SQL error: {err}") print(f" - Question: {task['question']}") print(f" - SQL: {task['sql']}") except Exception as e: print(f"An unexpected error occurred during initial validation for task {task}: {e}") # 8. Semantic validation using LLM validated_tasks = semantic_validate_tasks(pre_validated_tasks, client) # 9. Print the final JSON output if validated_tasks: print("\n--- Final Validated Tasks ---") print(json.dumps(validated_tasks, indent=2, default=str)) else: print("Failed to generate any valid tasks after all validation steps.") finally: # 10. Close the database connection if conn.is_connected(): cursor.close() conn.close() print("\nDatabase connection closed.") if __name__ == "__main__": main()