Spaces:
Running
Running
| import os | |
| import uuid | |
| import sqlite3 | |
| import io | |
| import csv | |
| import zipfile | |
| import re | |
| import difflib | |
| from typing import List, Optional, Dict, Any | |
| from fastapi import FastAPI, UploadFile, File, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| from langdetect import detect | |
| from transformers import MarianMTModel, MarianTokenizer | |
| # ====================================================== | |
| # 0) Configuración general | |
| # ====================================================== | |
| # Modelo NL→SQL entrenado por ti en Hugging Face | |
| MODEL_DIR = os.getenv("MODEL_DIR", "stvnnnnnn/t5-large-nl2sql-spider") | |
| DEVICE = torch.device("cpu") # inferencia en CPU | |
| # Directorio donde se guardan las BDs convertidas a SQLite | |
| UPLOAD_DIR = os.getenv("UPLOAD_DIR", "uploaded_dbs") | |
| os.makedirs(UPLOAD_DIR, exist_ok=True) | |
| # Registro en memoria de conexiones (todas terminan siendo SQLite) | |
| # { conn_id: { "db_path": str, "label": str } } | |
| DB_REGISTRY: Dict[str, Dict[str, Any]] = {} | |
| # ====================================================== | |
| # 1) Inicialización de FastAPI | |
| # ====================================================== | |
| app = FastAPI( | |
| title="NL2SQL T5-large Backend Universal (single-file)", | |
| description=( | |
| "Intérprete NL→SQL (T5-large Spider) para usuarios no expertos. " | |
| "El usuario solo sube su BD (SQLite / dump .sql / CSV / ZIP de CSVs) " | |
| "y todo se convierte internamente a SQLite." | |
| ), | |
| version="1.0.0", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # en producción puedes acotar a tu dominio | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ====================================================== | |
| # 2) Modelo NL→SQL y traductor ES→EN | |
| # ====================================================== | |
| t5_tokenizer = None | |
| t5_model = None | |
| mt_tokenizer = None | |
| mt_model = None | |
| def load_nl2sql_model(): | |
| """Carga el modelo NL→SQL (T5-large fine-tuned en Spider) desde HF Hub.""" | |
| global t5_tokenizer, t5_model | |
| if t5_model is not None: | |
| return | |
| print(f"🔁 Cargando modelo NL→SQL desde: {MODEL_DIR}") | |
| t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True) | |
| t5_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_DIR, torch_dtype=torch.float32) | |
| t5_model.to(DEVICE) | |
| t5_model.eval() | |
| print("✅ Modelo NL→SQL listo en memoria.") | |
| def load_es_en_translator(): | |
| """Carga el modelo Helsinki-NLP para traducción ES→EN (solo una vez).""" | |
| global mt_tokenizer, mt_model | |
| if mt_model is not None: | |
| return | |
| model_name = "Helsinki-NLP/opus-mt-es-en" | |
| print(f"🔁 Cargando traductor ES→EN: {model_name}") | |
| mt_tokenizer = MarianTokenizer.from_pretrained(model_name) | |
| mt_model = MarianMTModel.from_pretrained(model_name) | |
| mt_model.to(DEVICE) | |
| mt_model.eval() | |
| print("✅ Traductor ES→EN listo.") | |
| def detect_language(text: str) -> str: | |
| try: | |
| return detect(text) | |
| except Exception: | |
| return "unknown" | |
| def translate_es_to_en(text: str) -> str: | |
| """ | |
| Usa Marian ES→EN solo si el texto se detecta como español ('es'). | |
| Si no, devuelve el texto tal cual. | |
| """ | |
| lang = detect_language(text) | |
| if lang != "es": | |
| return text | |
| if mt_model is None: | |
| load_es_en_translator() | |
| inputs = mt_tokenizer(text, return_tensors="pt", truncation=True).to(DEVICE) | |
| with torch.no_grad(): | |
| out = mt_model.generate(**inputs, max_length=256) | |
| return mt_tokenizer.decode(out[0], skip_special_tokens=True) | |
| # ====================================================== | |
| # 3) Utilidades de BDs: creación/ingesta a SQLite | |
| # ====================================================== | |
| def _sanitize_identifier(name: str) -> str: | |
| """Hace un nombre de tabla/columna seguro para SQLite.""" | |
| base = name.strip().replace(" ", "_") | |
| base = re.sub(r"[^0-9a-zA-Z_]", "_", base) | |
| if not base: | |
| base = "table" | |
| if base[0].isdigit(): | |
| base = "_" + base | |
| return base | |
| def create_empty_sqlite_db(label: str) -> str: | |
| """Crea un archivo .sqlite vacío y lo devuelve.""" | |
| conn_id = f"db_{uuid.uuid4().hex[:8]}" | |
| db_filename = f"{conn_id}.sqlite" | |
| db_path = os.path.join(UPLOAD_DIR, db_filename) | |
| # Crear archivo vacío | |
| conn = sqlite3.connect(db_path) | |
| conn.close() | |
| DB_REGISTRY[conn_id] = {"db_path": db_path, "label": label} | |
| return conn_id | |
| def import_sql_dump_to_sqlite(db_path: str, sql_text: str) -> None: | |
| """ | |
| Intenta importar un dump .sql (MySQL/PostgreSQL/SQLite) a SQLite. | |
| Hace un preprocesado MUY simple para ignorar cosas específicas. | |
| """ | |
| lines = sql_text.splitlines() | |
| cleaned_lines = [] | |
| for line in lines: | |
| stripped = line.strip() | |
| upper = stripped.upper() | |
| # Ignorar líneas típicas de MySQL/Postgres que rompen en SQLite | |
| if not stripped: | |
| continue | |
| if upper.startswith(("SET ", "LOCK TABLES", "UNLOCK TABLES", | |
| "DELIMITER ", "USE ", "START TRANSACTION", | |
| "COMMIT", "ROLLBACK")): | |
| continue | |
| if upper.startswith("--") or upper.startswith("/*") or upper.startswith("*"): | |
| continue | |
| if "OWNER TO" in upper: | |
| continue | |
| # Quitar /*! ... */ estilo MySQL | |
| if stripped.startswith("/*!") and stripped.endswith("*/;"): | |
| continue | |
| # Reemplazar backticks de MySQL por nada | |
| line = line.replace("`", "") | |
| # Quitar cosas típicas de ENGINE=InnoDB, etc. | |
| if "ENGINE=" in line.upper(): | |
| line = line.split("ENGINE=")[0].rstrip() | |
| if not line.endswith(";"): | |
| line += ";" | |
| cleaned_lines.append(line) | |
| cleaned_sql = "\n".join(cleaned_lines) | |
| conn = sqlite3.connect(db_path) | |
| try: | |
| conn.executescript(cleaned_sql) | |
| conn.commit() | |
| finally: | |
| conn.close() | |
| def import_csv_to_sqlite(db_path: str, csv_bytes: bytes, table_name: str) -> None: | |
| """ | |
| Crea una tabla en SQLite con columnas TEXT y carga datos desde un CSV. | |
| """ | |
| table = _sanitize_identifier(table_name or "data") | |
| conn = sqlite3.connect(db_path) | |
| try: | |
| f = io.StringIO(csv_bytes.decode("utf-8", errors="ignore")) | |
| reader = csv.reader(f) | |
| rows = list(reader) | |
| if not rows: | |
| return | |
| header = rows[0] | |
| cols = [_sanitize_identifier(c or f"col_{i}") for i, c in enumerate(header)] | |
| # Crear tabla | |
| col_defs = ", ".join(f'"{c}" TEXT' for c in cols) | |
| conn.execute(f'CREATE TABLE IF NOT EXISTS "{table}" ({col_defs});') | |
| # Insertar filas | |
| placeholders = ", ".join(["?"] * len(cols)) | |
| for row in rows[1:]: | |
| # Padding/truncado por seguridad | |
| row = list(row) + [""] * (len(cols) - len(row)) | |
| row = row[:len(cols)] | |
| conn.execute( | |
| f'INSERT INTO "{table}" ({", ".join(cols)}) VALUES ({placeholders})', | |
| row, | |
| ) | |
| conn.commit() | |
| finally: | |
| conn.close() | |
| def import_zip_of_csvs_to_sqlite(db_path: str, zip_bytes: bytes) -> None: | |
| """ | |
| Para un ZIP con múltiples CSV: cada CSV se vuelve una tabla. | |
| """ | |
| conn = sqlite3.connect(db_path) | |
| conn.close() # solo asegurar que el archivo existe | |
| with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf: | |
| for name in zf.namelist(): | |
| if not name.lower().endswith(".csv"): | |
| continue | |
| with zf.open(name) as f: | |
| csv_bytes = f.read() | |
| base_name = os.path.basename(name) | |
| table_name = os.path.splitext(base_name)[0] | |
| import_csv_to_sqlite(db_path, csv_bytes, table_name) | |
| # ====================================================== | |
| # 4) Introspección de esquema y ejecución (sobre SQLite) | |
| # ====================================================== | |
| def introspect_sqlite_schema(db_path: str) -> Dict[str, Any]: | |
| """ | |
| Devuelve: | |
| - tables: {table_name: {"columns": [col1, col2, ...]}} | |
| - schema_str: "table(col1, col2) ; table2(...)" | |
| """ | |
| if not os.path.exists(db_path): | |
| raise FileNotFoundError(f"SQLite no encontrado: {db_path}") | |
| conn = sqlite3.connect(db_path) | |
| cur = conn.cursor() | |
| cur.execute("SELECT name FROM sqlite_master WHERE type='table';") | |
| tables = [row[0] for row in cur.fetchall()] | |
| tables_info: Dict[str, Dict[str, List[str]]] = {} | |
| parts = [] | |
| for t in tables: | |
| cur.execute(f"PRAGMA table_info('{t}');") | |
| rows = cur.fetchall() # cid, name, type, notnull, dflt_value, pk | |
| cols = [r[1] for r in rows] | |
| tables_info[t] = {"columns": cols} | |
| parts.append(f"{t}(" + ", ".join(cols) + ")") | |
| conn.close() | |
| schema_str = " ; ".join(parts) if parts else "(empty_schema)" | |
| return {"tables": tables_info, "schema_str": schema_str} | |
| def execute_sqlite(db_path: str, sql: str) -> Dict[str, Any]: | |
| # Seguridad mínima para evitar queries destructivas | |
| forbidden = ["drop ", "delete ", "update ", "insert ", "alter ", "replace "] | |
| sql_low = sql.lower() | |
| if any(f in sql_low for f in forbidden): | |
| return { | |
| "ok": False, | |
| "error": "Query bloqueada por seguridad (operación destructiva).", | |
| "rows": None, | |
| "columns": [] | |
| } | |
| try: | |
| conn = sqlite3.connect(db_path) | |
| cur = conn.cursor() | |
| cur.execute(sql) | |
| rows = cur.fetchall() | |
| col_names = [desc[0] for desc in cur.description] if cur.description else [] | |
| conn.close() | |
| return {"ok": True, "error": None, "rows": rows, "columns": col_names} | |
| except Exception as e: | |
| return {"ok": False, "error": str(e), "rows": None, "columns": []} | |
| # ====================================================== | |
| # 4.1) SQL REPAIR LAYER (avanzado) | |
| # ====================================================== | |
| def _normalize_name_for_match(name: str) -> str: | |
| """Normaliza un identificador (tabla/columna) para hacer matching difuso.""" | |
| s = name.lower() | |
| s = s.replace('"', '').replace("`", "") | |
| s = s.replace("_", "") | |
| # singularización muy simple: tracks -> track, songs -> song, etc. | |
| if s.endswith("s") and len(s) > 3: | |
| s = s[:-1] | |
| return s | |
| def _build_schema_indexes(tables_info: Dict[str, Dict[str, List[str]]]) -> Dict[str, Dict[str, List[str]]]: | |
| """ | |
| Construye índices de nombres normalizados: | |
| - table_index: {normalized: [table1, table2, ...]} | |
| - column_index: {normalized: [col1, col2, ...]} | |
| """ | |
| table_index: Dict[str, List[str]] = {} | |
| column_index: Dict[str, List[str]] = {} | |
| for t, info in tables_info.items(): | |
| tn = _normalize_name_for_match(t) | |
| table_index.setdefault(tn, []) | |
| if t not in table_index[tn]: | |
| table_index[tn].append(t) | |
| for c in info.get("columns", []): | |
| cn = _normalize_name_for_match(c) | |
| column_index.setdefault(cn, []) | |
| if c not in column_index[cn]: | |
| column_index[cn].append(c) | |
| return {"table_index": table_index, "column_index": column_index} | |
| def _best_match_name(missing: str, index: Dict[str, List[str]]) -> Optional[str]: | |
| """ | |
| Dado un nombre ausente y un índice normalizado, devuelve el mejor match real. | |
| """ | |
| if not index: | |
| return None | |
| key = _normalize_name_for_match(missing) | |
| # Si tenemos match directo | |
| if key in index and index[key]: | |
| return index[key][0] | |
| # Matching difuso usando difflib | |
| candidates = difflib.get_close_matches(key, list(index.keys()), n=1, cutoff=0.7) | |
| if not candidates: | |
| return None | |
| best_key = candidates[0] | |
| if index[best_key]: | |
| return index[best_key][0] | |
| return None | |
| # Diccionarios de sinónimos comunes (Spider + Chinook / bases típicas) | |
| DOMAIN_SYNONYMS_TABLE = { | |
| "song": "track", | |
| "songs": "track", | |
| "tracks": "track", | |
| "artist": "artist", | |
| "artists": "artist", | |
| "album": "album", | |
| "albums": "album", | |
| "order": "invoice", | |
| "orders": "invoice", | |
| } | |
| DOMAIN_SYNONYMS_COLUMN = { | |
| "song": "name", | |
| "songs": "name", | |
| "track": "name", | |
| "title": "name", | |
| "length": "milliseconds", | |
| "duration": "milliseconds", | |
| } | |
| def try_repair_sql(sql: str, error: str, schema_meta: Dict[str, Any]) -> Optional[str]: | |
| """ | |
| Intenta reparar SQL a partir del mensaje de error y del esquema: | |
| - no such table: X → mapear X a una tabla existente | |
| - no such column: Y → mapear Y a una columna existente | |
| Devuelve: | |
| - nuevo SQL reparado (str) si pudo cambiar algo | |
| - None si no se aplicó ninguna reparación | |
| """ | |
| tables_info = schema_meta["tables"] | |
| idx = _build_schema_indexes(tables_info) | |
| table_index = idx["table_index"] | |
| column_index = idx["column_index"] | |
| repaired_sql = sql | |
| changed = False | |
| # 1) Detectar faltas específicas por el mensaje de SQLite | |
| missing_table = None | |
| missing_column = None | |
| m_t = re.search(r"no such table: ([\w\.]+)", error) | |
| if m_t: | |
| missing_table = m_t.group(1) | |
| m_c = re.search(r"no such column: ([\w\.]+)", error) | |
| if m_c: | |
| missing_column = m_c.group(1) | |
| # 2) Reparar tabla faltante | |
| if missing_table: | |
| short = missing_table.split(".")[-1] # si viene tipo T1.Songs | |
| # Sinónimo de dominio primero (song -> track, etc.) | |
| syn = DOMAIN_SYNONYMS_TABLE.get(short.lower()) | |
| target = None | |
| if syn: | |
| target = _best_match_name(syn, table_index) or syn | |
| if not target: | |
| target = _best_match_name(short, table_index) | |
| if target: | |
| pattern = r"\b" + re.escape(short) + r"\b" | |
| new_sql = re.sub(pattern, target, repaired_sql) | |
| if new_sql != repaired_sql: | |
| repaired_sql = new_sql | |
| changed = True | |
| # 3) Reparar columna faltante | |
| if missing_column: | |
| short = missing_column.split(".")[-1] | |
| syn = DOMAIN_SYNONYMS_COLUMN.get(short.lower()) | |
| target = None | |
| if syn: | |
| target = _best_match_name(syn, column_index) or syn | |
| if not target: | |
| target = _best_match_name(short, column_index) | |
| if target: | |
| pattern = r"\b" + re.escape(short) + r"\b" | |
| new_sql = re.sub(pattern, target, repaired_sql) | |
| if new_sql != repaired_sql: | |
| repaired_sql = new_sql | |
| changed = True | |
| if not changed: | |
| return None | |
| return repaired_sql | |
| # ====================================================== | |
| # 5) Construcción de prompt y NL→SQL + re-ranking | |
| # ====================================================== | |
| def build_prompt(question_en: str, db_id: str, schema_str: str) -> str: | |
| """ | |
| Estilo de entrenamiento Spider: | |
| translate to SQL: {question} | db: {db_id} | schema: {schema_str} | note: ... | |
| """ | |
| return ( | |
| f"translate to SQL: {question_en} | " | |
| f"db: {db_id} | schema: {schema_str} | " | |
| f"note: use JOIN when foreign keys link tables" | |
| ) | |
| def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]: | |
| """ | |
| Pipeline completo: | |
| - auto-idioma + ES→EN | |
| - introspección de esquema | |
| - generación con beams | |
| - re-ranking según ejecución real en SQLite | |
| - capa de SQL Repair (tablas/columnas inexistentes, hasta 3 intentos) | |
| """ | |
| if conn_id not in DB_REGISTRY: | |
| raise HTTPException(status_code=404, detail=f"connection_id '{conn_id}' no registrado") | |
| db_path = DB_REGISTRY[conn_id]["db_path"] | |
| meta = introspect_sqlite_schema(db_path) | |
| schema_str = meta["schema_str"] | |
| detected = detect_language(question) | |
| question_en = translate_es_to_en(question) if detected == "es" else question | |
| prompt = build_prompt(question_en, db_id=conn_id, schema_str=schema_str) | |
| if t5_model is None: | |
| load_nl2sql_model() | |
| inputs = t5_tokenizer([prompt], return_tensors="pt", truncation=True, max_length=768).to(DEVICE) | |
| num_beams = 6 | |
| num_return = 6 | |
| with torch.no_grad(): | |
| out = t5_model.generate( | |
| **inputs, | |
| max_length=220, | |
| num_beams=num_beams, | |
| num_return_sequences=num_return, | |
| return_dict_in_generate=True, | |
| output_scores=True, | |
| ) | |
| sequences = out.sequences | |
| scores = out.sequences_scores | |
| if scores is not None: | |
| scores = scores.cpu().tolist() | |
| else: | |
| scores = [0.0] * sequences.size(0) | |
| candidates: List[Dict[str, Any]] = [] | |
| best = None | |
| best_exec = False | |
| best_score = -1e9 | |
| for i in range(sequences.size(0)): | |
| raw_sql = t5_tokenizer.decode(sequences[i], skip_special_tokens=True).strip() | |
| cand: Dict[str, Any] = { | |
| "sql": raw_sql, | |
| "score": float(scores[i]), | |
| "repaired_from": None, | |
| "repair_note": None, | |
| "raw_sql_model": raw_sql, | |
| } | |
| # Intento 1: ejecución directa | |
| exec_info = execute_sqlite(db_path, raw_sql) | |
| # Hasta 3 rondas de reparación si sigue fallando por no such table/column | |
| if (not exec_info["ok"]) and ( | |
| "no such table" in (exec_info["error"] or "") | |
| or "no such column" in (exec_info["error"] or "") | |
| ): | |
| current_sql = raw_sql | |
| last_error = exec_info["error"] | |
| for step in range(1, 4): # step 1, 2, 3 | |
| repaired_sql = try_repair_sql(current_sql, last_error, meta) | |
| if not repaired_sql or repaired_sql == current_sql: | |
| break | |
| exec_info2 = execute_sqlite(db_path, repaired_sql) | |
| cand["repaired_from"] = current_sql if cand["repaired_from"] is None else cand["repaired_from"] | |
| cand["repair_note"] = f"auto-repair (table/column name, step {step})" | |
| cand["sql"] = repaired_sql | |
| exec_info = exec_info2 | |
| current_sql = repaired_sql | |
| if exec_info2["ok"]: | |
| break | |
| last_error = exec_info2["error"] | |
| # Guardar info final de ejecución | |
| cand["exec_ok"] = exec_info["ok"] | |
| cand["exec_error"] = exec_info["error"] | |
| cand["rows_preview"] = ( | |
| [list(r) for r in exec_info["rows"][:5]] if exec_info["ok"] and exec_info["rows"] else None | |
| ) | |
| cand["columns"] = exec_info["columns"] | |
| candidates.append(cand) | |
| # Seleccionar "best" | |
| if exec_info["ok"]: | |
| if (not best_exec) or cand["score"] > best_score: | |
| best_exec = True | |
| best_score = cand["score"] | |
| best = cand | |
| elif not best_exec and cand["score"] > best_score: | |
| best_score = cand["score"] | |
| best = cand | |
| if best is None and candidates: | |
| best = candidates[0] | |
| return { | |
| "question_original": question, | |
| "detected_language": detected, | |
| "question_en": question_en, | |
| "connection_id": conn_id, | |
| "schema_summary": schema_str, | |
| "best_sql": best["sql"], | |
| "best_exec_ok": best.get("exec_ok", False), | |
| "best_exec_error": best.get("exec_error"), | |
| "best_rows_preview": best.get("rows_preview"), | |
| "best_columns": best.get("columns", []), | |
| "candidates": candidates, | |
| } | |
| # ====================================================== | |
| # 6) Schemas Pydantic | |
| # ====================================================== | |
| class UploadResponse(BaseModel): | |
| connection_id: str | |
| label: str | |
| db_path: str | |
| note: Optional[str] = None | |
| class ConnectionInfo(BaseModel): | |
| connection_id: str | |
| label: str | |
| class SchemaResponse(BaseModel): | |
| connection_id: str | |
| schema_summary: str | |
| tables: Dict[str, Dict[str, List[str]]] | |
| class PreviewResponse(BaseModel): | |
| connection_id: str | |
| table: str | |
| columns: List[str] | |
| rows: List[List[Any]] | |
| class InferRequest(BaseModel): | |
| connection_id: str | |
| question: str | |
| class InferResponse(BaseModel): | |
| question_original: str | |
| detected_language: str | |
| question_en: str | |
| connection_id: str | |
| schema_summary: str | |
| best_sql: str | |
| best_exec_ok: bool | |
| best_exec_error: Optional[str] | |
| best_rows_preview: Optional[List[List[Any]]] | |
| best_columns: List[str] | |
| candidates: List[Dict[str, Any]] | |
| # ====================================================== | |
| # 7) Endpoints FastAPI | |
| # ====================================================== | |
| async def startup_event(): | |
| # Cargamos el modelo al inicio | |
| load_nl2sql_model() | |
| print(f"✅ Backend NL2SQL inicializado. MODEL_DIR={MODEL_DIR}, UPLOAD_DIR={UPLOAD_DIR}") | |
| async def upload_database(db_file: UploadFile = File(...)): | |
| """ | |
| Subida universal de BD. | |
| El usuario puede subir: | |
| - .sqlite / .db → se usa tal cual | |
| - .sql → dump MySQL/PostgreSQL/SQLite → se importa a SQLite | |
| - .csv → se crea una BD SQLite y una tabla | |
| - .zip → múltiples CSV → múltiples tablas en una BD SQLite | |
| Devuelve un connection_id para usar en /schema, /preview y /infer. | |
| """ | |
| filename = db_file.filename | |
| if not filename: | |
| raise HTTPException(status_code=400, detail="Archivo sin nombre.") | |
| fname_lower = filename.lower() | |
| contents = await db_file.read() | |
| note = None | |
| # Caso 1: SQLite nativa | |
| if fname_lower.endswith(".sqlite") or fname_lower.endswith(".db"): | |
| conn_id = f"db_{uuid.uuid4().hex[:8]}" | |
| dst_path = os.path.join(UPLOAD_DIR, f"{conn_id}.sqlite") | |
| with open(dst_path, "wb") as f: | |
| f.write(contents) | |
| DB_REGISTRY[conn_id] = {"db_path": dst_path, "label": filename} | |
| note = "SQLite file stored as-is." | |
| # Caso 2: dump .sql | |
| elif fname_lower.endswith(".sql"): | |
| conn_id = create_empty_sqlite_db(label=filename) | |
| db_path = DB_REGISTRY[conn_id]["db_path"] | |
| sql_text = contents.decode("utf-8", errors="ignore") | |
| import_sql_dump_to_sqlite(db_path, sql_text) | |
| note = "SQL dump imported into SQLite (best effort)." | |
| # Caso 3: CSV simple | |
| elif fname_lower.endswith(".csv"): | |
| conn_id = create_empty_sqlite_db(label=filename) | |
| db_path = DB_REGISTRY[conn_id]["db_path"] | |
| table_name = os.path.splitext(os.path.basename(filename))[0] | |
| import_csv_to_sqlite(db_path, contents, table_name) | |
| note = "CSV imported into a single SQLite table." | |
| # Caso 4: ZIP con CSVs | |
| elif fname_lower.endswith(".zip"): | |
| conn_id = create_empty_sqlite_db(label=filename) | |
| db_path = DB_REGISTRY[conn_id]["db_path"] | |
| import_zip_of_csvs_to_sqlite(db_path, contents) | |
| note = "ZIP with CSVs imported into multiple SQLite tables." | |
| else: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="Formato no soportado. Usa: .sqlite, .db, .sql, .csv o .zip", | |
| ) | |
| return UploadResponse( | |
| connection_id=conn_id, | |
| label=DB_REGISTRY[conn_id]["label"], | |
| db_path=DB_REGISTRY[conn_id]["db_path"], | |
| note=note, | |
| ) | |
| async def list_connections(): | |
| """ | |
| Lista las conexiones registradas (todas en SQLite interno). | |
| """ | |
| out = [] | |
| for cid, info in DB_REGISTRY.items(): | |
| out.append(ConnectionInfo(connection_id=cid, label=info["label"])) | |
| return out | |
| async def get_schema(connection_id: str): | |
| """ | |
| Devuelve un resumen de esquema para una BD subida. | |
| """ | |
| if connection_id not in DB_REGISTRY: | |
| raise HTTPException(status_code=404, detail="connection_id no encontrado") | |
| db_path = DB_REGISTRY[connection_id]["db_path"] | |
| meta = introspect_sqlite_schema(db_path) | |
| return SchemaResponse( | |
| connection_id=connection_id, | |
| schema_summary=meta["schema_str"], | |
| tables=meta["tables"], | |
| ) | |
| async def preview_table(connection_id: str, table: str, limit: int = 20): | |
| """ | |
| Devuelve un preview de filas de una tabla concreta. | |
| Útil para el frontend (vista de tabla + diagrama). | |
| """ | |
| if connection_id not in DB_REGISTRY: | |
| raise HTTPException(status_code=404, detail="connection_id no encontrado") | |
| db_path = DB_REGISTRY[connection_id]["db_path"] | |
| try: | |
| conn = sqlite3.connect(db_path) | |
| cur = conn.cursor() | |
| cur.execute(f'SELECT * FROM "{table}" LIMIT {int(limit)};') | |
| rows = cur.fetchall() | |
| cols = [d[0] for d in cur.description] if cur.description else [] | |
| conn.close() | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Error al leer tabla '{table}': {e}") | |
| return PreviewResponse( | |
| connection_id=connection_id, | |
| table=table, | |
| columns=cols, | |
| rows=[list(r) for r in rows], | |
| ) | |
| async def infer_sql(req: InferRequest): | |
| """ | |
| Dada una pregunta en lenguaje natural (ES o EN) y un connection_id, | |
| genera SQL, ejecuta la consulta y devuelve el resultado + candidatos. | |
| """ | |
| result = nl2sql_with_rerank(req.question, req.connection_id) | |
| return InferResponse(**result) | |
| async def health(): | |
| return { | |
| "status": "ok", | |
| "model_loaded": t5_model is not None, | |
| "connections": len(DB_REGISTRY), | |
| "device": str(DEVICE), | |
| } | |
| async def root(): | |
| return { | |
| "message": "NL2SQL T5-large universal backend is running (single-file SQLite engine).", | |
| "endpoints": [ | |
| "POST /upload (subir .sqlite / .db / .sql / .csv / .zip)", | |
| "GET /connections (listar BDs subidas)", | |
| "GET /schema/{id} (esquema resumido)", | |
| "GET /preview/{id}/{t} (preview de tabla)", | |
| "POST /infer (NL→SQL + ejecución)", | |
| "GET /health (estado del backend)", | |
| "GET /docs (OpenAPI UI)", | |
| ], | |
| } |