sql-natural-language-query

tool_calling/api · verified · python3.12/linux · json · download .py

Convert natural language to SQL, execute against SQLite, return structured results

import re
import sys
import sqlite3
import subprocess
import requests as _requests

# ----------------------------------------
# PRE_EXECUTION
# FM-2.2: fetch ground truth for all registry_refs
# ----------------------------------------

REGISTRY_REFS = ["openai", "sqlalchemy"]
MAX_RETRIES = 2
registries = {}

for lib in REGISTRY_REFS:
    for attempt in range(MAX_RETRIES):
        try:
            response = _requests.get(
                f"https://checklist.day/api/registry/{lib}",
                timeout=10
            )
            if response.status_code == 200:
                registries[lib] = response.json()
                break
        except _requests.exceptions.RequestException:
            pass

for lib in REGISTRY_REFS:
    assert lib in registries, \
        f"ABORT: registry fetch failed for {lib} after {MAX_RETRIES} attempts"

# FM-2.4: surface breaking warnings
for lib, registry in registries.items():
    breaking = [
        w for w in registry.get("warnings", [])
        if w.get("severity") == "breaking"
    ]
    if breaking:
        print(f"PRE_EXECUTION: {lib} has {len(breaking)} breaking warning(s):")
        for w in breaking:
            print(f"  [!] [{w.get('affected_versions', 'all')}] {w['message'][:120]}")
            print(f"      fix: {w['fix'][:100]}")

print()
print("PRE_EXECUTION: all registry refs verified ✓")

# ----------------------------------------
# KNOWN FAILURE MODES
#
# 1. No schema grounding — LLM hallucinates column/table names not in the DB
#    Always inject full schema into the prompt before asking for SQL
#
# 2. SQL injection via LLM output — LLM can be prompted to generate DROP TABLE,
#    DELETE, UPDATE, or stacked queries. Always validate before execution.
#
# 3. Executing destructive queries — agents should NEVER run DDL or DML
#    (DROP, DELETE, UPDATE, INSERT) unless explicitly permitted
#
# 4. No result validation — LLM says "query will return X" but result is empty
#    or wrong shape. Always assert on result structure.
#
# 5. Missing LIMIT — LLM-generated queries without LIMIT can return millions of rows
# ----------------------------------------

# Allowlist: only these SQL statement types are permitted for NL→SQL
ALLOWED_STATEMENT_TYPES = {"SELECT"}
MAX_ROWS = 1000  # FM-1.5: hard cap — never return unbounded results


def get_schema(conn: sqlite3.Connection) -> str:
    """
    FM-2.2: always ground the LLM with the actual schema.
    Never ask LLM to generate SQL without providing table/column definitions.
    """
    cursor = conn.cursor()
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
    tables = [row[0] for row in cursor.fetchall()]

    schema_parts = []
    for table in tables:
        cursor.execute(f"PRAGMA table_info({table})")
        cols = cursor.fetchall()
        col_defs = ", ".join(f"{col[1]} {col[2]}" for col in cols)
        schema_parts.append(f"CREATE TABLE {table} ({col_defs})")

    return "\n".join(schema_parts)


def validate_sql(sql: str) -> str:
    """
    FM-1.1: validate LLM-generated SQL before execution.
    Strips markdown fences, checks for destructive statements.
    """
    # Strip markdown code fences LLMs commonly add
    sql = re.sub(r"```sql\s*", "", sql, flags=re.IGNORECASE)
    sql = re.sub(r"```\s*", "", sql)
    sql = sql.strip().rstrip(";")

    # FM-2.6: only allow SELECT — reject destructive operations
    first_word = sql.strip().split()[0].upper()
    assert first_word in ALLOWED_STATEMENT_TYPES, \
        f"ABORT: destructive SQL statement type '{first_word}' not allowed — only SELECT is permitted"

    # Reject stacked queries (semicolon injection)
    assert ";" not in sql, \
        "ABORT: stacked queries detected — possible SQL injection in LLM output"

    # Inject LIMIT if missing — FM-1.5: never return unbounded results
    if "LIMIT" not in sql.upper():
        sql = f"{sql} LIMIT {MAX_ROWS}"

    return sql


def nl_to_sql(natural_language: str, schema: str) -> str:
    """
    Converts natural language question to SQL using OpenAI.
    FM-2.2: schema always injected — LLM never guesses column names.

    In production: replace with your LLM provider call.
    This checklist uses a mock for auth_required: false execution.
    """
    # --- MOCK for local execution (no API key required) ---
    # Replace this block with a real openai call:
    #
    # from openai import OpenAI
    # client = OpenAI()  # reads OPENAI_API_KEY from env
    # response = client.chat.completions.create(
    #     model="gpt-4o",
    #     messages=[
    #         {"role": "system", "content": (
    #             "You are a SQL expert. Given a SQLite schema and a natural language question, "
    #             "return ONLY a valid SQLite SELECT query. No explanation. No markdown. "
    #             "Only use tables and columns that exist in the schema."
    #         )},
    #         {"role": "user", "content": f"Schema:\n{schema}\n\nQuestion: {natural_language}"}
    #     ],
    #     temperature=0
    # )
    # return response.choices[0].message.content.strip()

    # Mock: deterministic SQL for test assertion
    mock_responses = {
        "how many products are there": "SELECT COUNT(*) as total FROM products",
        "show me all products with price above 50": "SELECT id, name, price FROM products WHERE price > 50",
        "what is the most expensive product": "SELECT name, price FROM products ORDER BY price DESC LIMIT 1",
    }
    key = natural_language.lower().strip("?")
    if key in mock_responses:
        return mock_responses[key]
    return "SELECT * FROM products"


def execute_query(conn: sqlite3.Connection, sql: str) -> list[dict]:
    """
    FM-2.6: use row_factory for dict results — never return raw tuples to agent.
    FM-3.2: check result shape before returning.
    """
    conn.row_factory = sqlite3.Row
    cursor = conn.cursor()
    cursor.execute(sql)
    rows = cursor.fetchall()
    return [dict(row) for row in rows]


# ----------------------------------------
# EXECUTION
# Set up in-memory SQLite DB, demonstrate NL→SQL pipeline
# FM-1.1: in-memory DB is ephemeral — idempotent on every run
# ----------------------------------------

print()
print("EXECUTION: setting up in-memory SQLite database...")

conn = sqlite3.connect(":memory:")
conn.execute("""
    CREATE TABLE products (
        id      INTEGER PRIMARY KEY,
        name    TEXT NOT NULL,
        price   REAL NOT NULL,
        category TEXT NOT NULL
    )
""")
conn.executemany(
    "INSERT INTO products (name, price, category) VALUES (?, ?, ?)",
    [
        ("Widget A", 29.99, "hardware"),
        ("Widget B", 79.99, "hardware"),
        ("Gadget X", 149.99, "electronics"),
        ("Gadget Y", 9.99,  "electronics"),
        ("Tool Z",   59.99, "tools"),
    ]
)
conn.commit()
print("EXECUTION: database seeded with 5 products ✓")

# Ground LLM with actual schema — never skip this step
schema = get_schema(conn)
print(f"EXECUTION: schema extracted ✓\n  {schema}")

# Run three NL queries
test_cases = [
    {
        "question": "how many products are there",
        "expected_key": "total",
        "expected_value": 5,
    },
    {
        "question": "show me all products with price above 50",
        "expected_count": 3,  # Widget B, Gadget X, Tool Z
    },
    {
        "question": "what is the most expensive product",
        "expected_name": "Gadget X",
    },
]

results_store = {}

print()
for tc in test_cases:
    q = tc["question"]
    print(f"EXECUTION: '{q}'")

    raw_sql = nl_to_sql(q, schema)
    print(f"  raw SQL  : {raw_sql}")

    safe_sql = validate_sql(raw_sql)
    print(f"  safe SQL : {safe_sql}")

    rows = execute_query(conn, safe_sql)
    print(f"  rows     : {rows}")
    results_store[q] = rows

conn.close()

# ----------------------------------------
# POST_EXECUTION
# FM-3.2: verify result shape before returning to agent
# FM-3.3: exact match assertions
# ----------------------------------------

# Test 1: count
count_result = results_store["how many products are there"]
assert len(count_result) == 1, \
    f"FAIL: expected 1 row for COUNT query, got {len(count_result)}"
assert count_result[0]["total"] == 5, \
    f"FAIL: expected total=5, got {count_result[0]}"

# Test 2: filter
filter_result = results_store["show me all products with price above 50"]
assert len(filter_result) == 3, \
    f"FAIL: expected 3 products above $50, got {len(filter_result)}"
returned_names = {r["name"] for r in filter_result}
assert returned_names == {"Widget B", "Gadget X", "Tool Z"}, \
    f"FAIL: wrong products returned — got {returned_names}"

# Test 3: ordering
top_result = results_store["what is the most expensive product"]
assert len(top_result) == 1, \
    f"FAIL: expected 1 row, got {len(top_result)}"
assert top_result[0]["name"] == "Gadget X", \
    f"FAIL: expected 'Gadget X', got '{top_result[0]['name']}'"

print()
print("POST_EXECUTION: count query verified ✓  (total=5)")
print("POST_EXECUTION: filter query verified ✓  (3 products above $50)")
print("POST_EXECUTION: order query verified ✓  (most expensive = Gadget X)")

result = {
    "status": "pass",
    "count_query_verified": True,
    "filter_query_verified": True,
    "order_query_verified": True,
}
print(result)
print("PASS")