stvnnnnnn commited on
Commit
33d8b39
·
verified ·
1 Parent(s): 4cf4689

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +810 -0
app.py ADDED
@@ -0,0 +1,810 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import sqlite3
4
+ import io
5
+ import csv
6
+ import zipfile
7
+ import re
8
+ import difflib
9
+ from typing import List, Optional, Dict, Any
10
+
11
+ from fastapi import FastAPI, UploadFile, File, HTTPException
12
+ from fastapi.middleware.cors import CORSMiddleware
13
+ from pydantic import BaseModel
14
+
15
+ import torch
16
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
17
+ from langdetect import detect
18
+ from transformers import MarianMTModel, MarianTokenizer
19
+
20
+ # ======================================================
21
+ # 0) Configuración general
22
+ # ======================================================
23
+
24
+ # Modelo NL→SQL entrenado por ti en Hugging Face
25
+ MODEL_DIR = os.getenv("MODEL_DIR", "stvnnnnnn/t5-large-nl2sql-spider")
26
+ DEVICE = torch.device("cpu") # inferencia en CPU
27
+
28
+ # Directorio donde se guardan las BDs convertidas a SQLite
29
+ UPLOAD_DIR = os.getenv("UPLOAD_DIR", "uploaded_dbs")
30
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
31
+
32
+ # Registro en memoria de conexiones (todas terminan siendo SQLite)
33
+ # { conn_id: { "db_path": str, "label": str } }
34
+ DB_REGISTRY: Dict[str, Dict[str, Any]] = {}
35
+
36
+ # ======================================================
37
+ # 1) Inicialización de FastAPI
38
+ # ======================================================
39
+
40
+ app = FastAPI(
41
+ title="NL2SQL T5-large Backend Universal (single-file)",
42
+ description=(
43
+ "Intérprete NL→SQL (T5-large Spider) para usuarios no expertos. "
44
+ "El usuario solo sube su BD (SQLite / dump .sql / CSV / ZIP de CSVs) "
45
+ "y todo se convierte internamente a SQLite."
46
+ ),
47
+ version="1.0.0",
48
+ )
49
+
50
+ app.add_middleware(
51
+ CORSMiddleware,
52
+ allow_origins=["*"], # en producción puedes acotar a tu dominio
53
+ allow_credentials=True,
54
+ allow_methods=["*"],
55
+ allow_headers=["*"],
56
+ )
57
+
58
+ # ======================================================
59
+ # 2) Modelo NL→SQL y traductor ES→EN
60
+ # ======================================================
61
+
62
+ t5_tokenizer = None
63
+ t5_model = None
64
+ mt_tokenizer = None
65
+ mt_model = None
66
+
67
+
68
+ def load_nl2sql_model():
69
+ """Carga el modelo NL→SQL (T5-large fine-tuned en Spider) desde HF Hub."""
70
+ global t5_tokenizer, t5_model
71
+ if t5_model is not None:
72
+ return
73
+ print(f"🔁 Cargando modelo NL→SQL desde: {MODEL_DIR}")
74
+ t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True)
75
+ t5_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_DIR, torch_dtype=torch.float32)
76
+ t5_model.to(DEVICE)
77
+ t5_model.eval()
78
+ print("✅ Modelo NL→SQL listo en memoria.")
79
+
80
+
81
+ def load_es_en_translator():
82
+ """Carga el modelo Helsinki-NLP para traducción ES→EN (solo una vez)."""
83
+ global mt_tokenizer, mt_model
84
+ if mt_model is not None:
85
+ return
86
+ model_name = "Helsinki-NLP/opus-mt-es-en"
87
+ print(f"🔁 Cargando traductor ES→EN: {model_name}")
88
+ mt_tokenizer = MarianTokenizer.from_pretrained(model_name)
89
+ mt_model = MarianMTModel.from_pretrained(model_name)
90
+ mt_model.to(DEVICE)
91
+ mt_model.eval()
92
+ print("✅ Traductor ES→EN listo.")
93
+
94
+
95
+ def detect_language(text: str) -> str:
96
+ try:
97
+ return detect(text)
98
+ except Exception:
99
+ return "unknown"
100
+
101
+
102
+ def translate_es_to_en(text: str) -> str:
103
+ """
104
+ Usa Marian ES→EN solo si el texto se detecta como español ('es').
105
+ Si no, devuelve el texto tal cual.
106
+ """
107
+ lang = detect_language(text)
108
+ if lang != "es":
109
+ return text
110
+ if mt_model is None:
111
+ load_es_en_translator()
112
+ inputs = mt_tokenizer(text, return_tensors="pt", truncation=True).to(DEVICE)
113
+ with torch.no_grad():
114
+ out = mt_model.generate(**inputs, max_length=256)
115
+ return mt_tokenizer.decode(out[0], skip_special_tokens=True)
116
+
117
+
118
+ # ======================================================
119
+ # 3) Utilidades de BDs: creación/ingesta a SQLite
120
+ # ======================================================
121
+
122
+ def _sanitize_identifier(name: str) -> str:
123
+ """Hace un nombre de tabla/columna seguro para SQLite."""
124
+ base = name.strip().replace(" ", "_")
125
+ base = re.sub(r"[^0-9a-zA-Z_]", "_", base)
126
+ if not base:
127
+ base = "table"
128
+ if base[0].isdigit():
129
+ base = "_" + base
130
+ return base
131
+
132
+
133
+ def create_empty_sqlite_db(label: str) -> str:
134
+ """Crea un archivo .sqlite vacío y lo devuelve."""
135
+ conn_id = f"db_{uuid.uuid4().hex[:8]}"
136
+ db_filename = f"{conn_id}.sqlite"
137
+ db_path = os.path.join(UPLOAD_DIR, db_filename)
138
+ # Crear archivo vacío
139
+ conn = sqlite3.connect(db_path)
140
+ conn.close()
141
+ DB_REGISTRY[conn_id] = {"db_path": db_path, "label": label}
142
+ return conn_id
143
+
144
+
145
+ def import_sql_dump_to_sqlite(db_path: str, sql_text: str) -> None:
146
+ """
147
+ Intenta importar un dump .sql (MySQL/PostgreSQL/SQLite) a SQLite.
148
+ Hace un preprocesado MUY simple para ignorar cosas específicas.
149
+ """
150
+ lines = sql_text.splitlines()
151
+ cleaned_lines = []
152
+ for line in lines:
153
+ stripped = line.strip()
154
+ upper = stripped.upper()
155
+
156
+ # Ignorar líneas típicas de MySQL/Postgres que rompen en SQLite
157
+ if not stripped:
158
+ continue
159
+ if upper.startswith(("SET ", "LOCK TABLES", "UNLOCK TABLES",
160
+ "DELIMITER ", "USE ", "START TRANSACTION",
161
+ "COMMIT", "ROLLBACK")):
162
+ continue
163
+ if upper.startswith("--") or upper.startswith("/*") or upper.startswith("*"):
164
+ continue
165
+ if "OWNER TO" in upper:
166
+ continue
167
+
168
+ # Quitar /*! ... */ estilo MySQL
169
+ if stripped.startswith("/*!") and stripped.endswith("*/;"):
170
+ continue
171
+
172
+ # Reemplazar backticks de MySQL por nada
173
+ line = line.replace("`", "")
174
+
175
+ # Quitar cosas típicas de ENGINE=InnoDB, etc.
176
+ if "ENGINE=" in line.upper():
177
+ line = line.split("ENGINE=")[0].rstrip()
178
+ if not line.endswith(";"):
179
+ line += ";"
180
+
181
+ cleaned_lines.append(line)
182
+
183
+ cleaned_sql = "\n".join(cleaned_lines)
184
+
185
+ conn = sqlite3.connect(db_path)
186
+ try:
187
+ conn.executescript(cleaned_sql)
188
+ conn.commit()
189
+ finally:
190
+ conn.close()
191
+
192
+
193
+ def import_csv_to_sqlite(db_path: str, csv_bytes: bytes, table_name: str) -> None:
194
+ """
195
+ Crea una tabla en SQLite con columnas TEXT y carga datos desde un CSV.
196
+ """
197
+ table = _sanitize_identifier(table_name or "data")
198
+ conn = sqlite3.connect(db_path)
199
+ try:
200
+ f = io.StringIO(csv_bytes.decode("utf-8", errors="ignore"))
201
+ reader = csv.reader(f)
202
+ rows = list(reader)
203
+
204
+ if not rows:
205
+ return
206
+
207
+ header = rows[0]
208
+ cols = [_sanitize_identifier(c or f"col_{i}") for i, c in enumerate(header)]
209
+
210
+ # Crear tabla
211
+ col_defs = ", ".join(f'"{c}" TEXT' for c in cols)
212
+ conn.execute(f'CREATE TABLE IF NOT EXISTS "{table}" ({col_defs});')
213
+
214
+ # Insertar filas
215
+ placeholders = ", ".join(["?"] * len(cols))
216
+ for row in rows[1:]:
217
+ # Padding/truncado por seguridad
218
+ row = list(row) + [""] * (len(cols) - len(row))
219
+ row = row[:len(cols)]
220
+ conn.execute(
221
+ f'INSERT INTO "{table}" ({", ".join(cols)}) VALUES ({placeholders})',
222
+ row,
223
+ )
224
+
225
+ conn.commit()
226
+ finally:
227
+ conn.close()
228
+
229
+
230
+ def import_zip_of_csvs_to_sqlite(db_path: str, zip_bytes: bytes) -> None:
231
+ """
232
+ Para un ZIP con múltiples CSV: cada CSV se vuelve una tabla.
233
+ """
234
+ conn = sqlite3.connect(db_path)
235
+ conn.close() # solo asegurar que el archivo existe
236
+
237
+ with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
238
+ for name in zf.namelist():
239
+ if not name.lower().endswith(".csv"):
240
+ continue
241
+ with zf.open(name) as f:
242
+ csv_bytes = f.read()
243
+ base_name = os.path.basename(name)
244
+ table_name = os.path.splitext(base_name)[0]
245
+ import_csv_to_sqlite(db_path, csv_bytes, table_name)
246
+
247
+
248
+ # ======================================================
249
+ # 4) Introspección de esquema y ejecución (sobre SQLite)
250
+ # ======================================================
251
+
252
+ def introspect_sqlite_schema(db_path: str) -> Dict[str, Any]:
253
+ """
254
+ Devuelve:
255
+ - tables: {table_name: {"columns": [col1, col2, ...]}}
256
+ - schema_str: "table(col1, col2) ; table2(...)"
257
+ """
258
+ if not os.path.exists(db_path):
259
+ raise FileNotFoundError(f"SQLite no encontrado: {db_path}")
260
+
261
+ conn = sqlite3.connect(db_path)
262
+ cur = conn.cursor()
263
+ cur.execute("SELECT name FROM sqlite_master WHERE type='table';")
264
+ tables = [row[0] for row in cur.fetchall()]
265
+
266
+ tables_info: Dict[str, Dict[str, List[str]]] = {}
267
+ parts = []
268
+
269
+ for t in tables:
270
+ cur.execute(f"PRAGMA table_info('{t}');")
271
+ rows = cur.fetchall() # cid, name, type, notnull, dflt_value, pk
272
+ cols = [r[1] for r in rows]
273
+ tables_info[t] = {"columns": cols}
274
+ parts.append(f"{t}(" + ", ".join(cols) + ")")
275
+
276
+ conn.close()
277
+ schema_str = " ; ".join(parts) if parts else "(empty_schema)"
278
+ return {"tables": tables_info, "schema_str": schema_str}
279
+
280
+
281
+ def execute_sqlite(db_path: str, sql: str) -> Dict[str, Any]:
282
+ # Seguridad mínima para evitar queries destructivas
283
+ forbidden = ["drop ", "delete ", "update ", "insert ", "alter ", "replace "]
284
+ sql_low = sql.lower()
285
+ if any(f in sql_low for f in forbidden):
286
+ return {
287
+ "ok": False,
288
+ "error": "Query bloqueada por seguridad (operación destructiva).",
289
+ "rows": None,
290
+ "columns": []
291
+ }
292
+
293
+ try:
294
+ conn = sqlite3.connect(db_path)
295
+ cur = conn.cursor()
296
+ cur.execute(sql)
297
+ rows = cur.fetchall()
298
+ col_names = [desc[0] for desc in cur.description] if cur.description else []
299
+ conn.close()
300
+ return {"ok": True, "error": None, "rows": rows, "columns": col_names}
301
+ except Exception as e:
302
+ return {"ok": False, "error": str(e), "rows": None, "columns": []}
303
+
304
+
305
+ # ======================================================
306
+ # 4.1) SQL REPAIR LAYER (avanzado)
307
+ # ======================================================
308
+
309
+ def _normalize_name_for_match(name: str) -> str:
310
+ """Normaliza un identificador (tabla/columna) para hacer matching difuso."""
311
+ s = name.lower()
312
+ s = s.replace('"', '').replace("`", "")
313
+ s = s.replace("_", "")
314
+ # singularización muy simple: tracks -> track, songs -> song, etc.
315
+ if s.endswith("s") and len(s) > 3:
316
+ s = s[:-1]
317
+ return s
318
+
319
+
320
+ def _build_schema_indexes(tables_info: Dict[str, Dict[str, List[str]]]) -> Dict[str, Dict[str, List[str]]]:
321
+ """
322
+ Construye índices de nombres normalizados:
323
+ - table_index: {normalized: [table1, table2, ...]}
324
+ - column_index: {normalized: [col1, col2, ...]}
325
+ """
326
+ table_index: Dict[str, List[str]] = {}
327
+ column_index: Dict[str, List[str]] = {}
328
+
329
+ for t, info in tables_info.items():
330
+ tn = _normalize_name_for_match(t)
331
+ table_index.setdefault(tn, [])
332
+ if t not in table_index[tn]:
333
+ table_index[tn].append(t)
334
+
335
+ for c in info.get("columns", []):
336
+ cn = _normalize_name_for_match(c)
337
+ column_index.setdefault(cn, [])
338
+ if c not in column_index[cn]:
339
+ column_index[cn].append(c)
340
+
341
+ return {"table_index": table_index, "column_index": column_index}
342
+
343
+
344
+ def _best_match_name(missing: str, index: Dict[str, List[str]]) -> Optional[str]:
345
+ """
346
+ Dado un nombre ausente y un índice normalizado, devuelve el mejor match real.
347
+ """
348
+ if not index:
349
+ return None
350
+
351
+ key = _normalize_name_for_match(missing)
352
+ # Si tenemos match directo
353
+ if key in index and index[key]:
354
+ return index[key][0]
355
+
356
+ # Matching difuso usando difflib
357
+ candidates = difflib.get_close_matches(key, list(index.keys()), n=1, cutoff=0.7)
358
+ if not candidates:
359
+ return None
360
+ best_key = candidates[0]
361
+ if index[best_key]:
362
+ return index[best_key][0]
363
+ return None
364
+
365
+
366
+ # Diccionarios de sinónimos comunes (Spider + Chinook / bases típicas)
367
+ DOMAIN_SYNONYMS_TABLE = {
368
+ "song": "track",
369
+ "songs": "track",
370
+ "tracks": "track",
371
+ "artist": "artist",
372
+ "artists": "artist",
373
+ "album": "album",
374
+ "albums": "album",
375
+ "order": "invoice",
376
+ "orders": "invoice",
377
+ }
378
+
379
+ DOMAIN_SYNONYMS_COLUMN = {
380
+ "song": "name",
381
+ "songs": "name",
382
+ "track": "name",
383
+ "title": "name",
384
+ "length": "milliseconds",
385
+ "duration": "milliseconds",
386
+ }
387
+
388
+
389
+ def try_repair_sql(sql: str, error: str, schema_meta: Dict[str, Any]) -> Optional[str]:
390
+ """
391
+ Intenta reparar SQL a partir del mensaje de error y del esquema:
392
+ - no such table: X → mapear X a una tabla existente
393
+ - no such column: Y → mapear Y a una columna existente
394
+ Devuelve:
395
+ - nuevo SQL reparado (str) si pudo cambiar algo
396
+ - None si no se aplicó ninguna reparación
397
+ """
398
+ tables_info = schema_meta["tables"]
399
+ idx = _build_schema_indexes(tables_info)
400
+ table_index = idx["table_index"]
401
+ column_index = idx["column_index"]
402
+
403
+ repaired_sql = sql
404
+ changed = False
405
+
406
+ # 1) Detectar faltas específicas por el mensaje de SQLite
407
+ missing_table = None
408
+ missing_column = None
409
+
410
+ m_t = re.search(r"no such table: ([\w\.]+)", error)
411
+ if m_t:
412
+ missing_table = m_t.group(1)
413
+
414
+ m_c = re.search(r"no such column: ([\w\.]+)", error)
415
+ if m_c:
416
+ missing_column = m_c.group(1)
417
+
418
+ # 2) Reparar tabla faltante
419
+ if missing_table:
420
+ short = missing_table.split(".")[-1] # si viene tipo T1.Songs
421
+ # Sinónimo de dominio primero (song -> track, etc.)
422
+ syn = DOMAIN_SYNONYMS_TABLE.get(short.lower())
423
+ target = None
424
+ if syn:
425
+ target = _best_match_name(syn, table_index) or syn
426
+ if not target:
427
+ target = _best_match_name(short, table_index)
428
+
429
+ if target:
430
+ pattern = r"\b" + re.escape(short) + r"\b"
431
+ new_sql = re.sub(pattern, target, repaired_sql)
432
+ if new_sql != repaired_sql:
433
+ repaired_sql = new_sql
434
+ changed = True
435
+
436
+ # 3) Reparar columna faltante
437
+ if missing_column:
438
+ short = missing_column.split(".")[-1]
439
+ syn = DOMAIN_SYNONYMS_COLUMN.get(short.lower())
440
+ target = None
441
+ if syn:
442
+ target = _best_match_name(syn, column_index) or syn
443
+ if not target:
444
+ target = _best_match_name(short, column_index)
445
+
446
+ if target:
447
+ pattern = r"\b" + re.escape(short) + r"\b"
448
+ new_sql = re.sub(pattern, target, repaired_sql)
449
+ if new_sql != repaired_sql:
450
+ repaired_sql = new_sql
451
+ changed = True
452
+
453
+ if not changed:
454
+ return None
455
+ return repaired_sql
456
+
457
+
458
+ # ======================================================
459
+ # 5) Construcción de prompt y NL→SQL + re-ranking
460
+ # ======================================================
461
+
462
+ def build_prompt(question_en: str, db_id: str, schema_str: str) -> str:
463
+ """
464
+ Estilo de entrenamiento Spider:
465
+ translate to SQL: {question} | db: {db_id} | schema: {schema_str} | note: ...
466
+ """
467
+ return (
468
+ f"translate to SQL: {question_en} | "
469
+ f"db: {db_id} | schema: {schema_str} | "
470
+ f"note: use JOIN when foreign keys link tables"
471
+ )
472
+
473
+
474
+ def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
475
+ """
476
+ Pipeline completo:
477
+ - auto-idioma + ES→EN
478
+ - introspección de esquema
479
+ - generación con beams
480
+ - re-ranking según ejecución real en SQLite
481
+ - capa de SQL Repair (tablas/columnas inexistentes, hasta 3 intentos)
482
+ """
483
+ if conn_id not in DB_REGISTRY:
484
+ raise HTTPException(status_code=404, detail=f"connection_id '{conn_id}' no registrado")
485
+
486
+ db_path = DB_REGISTRY[conn_id]["db_path"]
487
+ meta = introspect_sqlite_schema(db_path)
488
+ schema_str = meta["schema_str"]
489
+
490
+ detected = detect_language(question)
491
+ question_en = translate_es_to_en(question) if detected == "es" else question
492
+
493
+ prompt = build_prompt(question_en, db_id=conn_id, schema_str=schema_str)
494
+
495
+ if t5_model is None:
496
+ load_nl2sql_model()
497
+
498
+ inputs = t5_tokenizer([prompt], return_tensors="pt", truncation=True, max_length=768).to(DEVICE)
499
+ num_beams = 6
500
+ num_return = 6
501
+
502
+ with torch.no_grad():
503
+ out = t5_model.generate(
504
+ **inputs,
505
+ max_length=220,
506
+ num_beams=num_beams,
507
+ num_return_sequences=num_return,
508
+ return_dict_in_generate=True,
509
+ output_scores=True,
510
+ )
511
+
512
+ sequences = out.sequences
513
+ scores = out.sequences_scores
514
+ if scores is not None:
515
+ scores = scores.cpu().tolist()
516
+ else:
517
+ scores = [0.0] * sequences.size(0)
518
+
519
+ candidates: List[Dict[str, Any]] = []
520
+ best = None
521
+ best_exec = False
522
+ best_score = -1e9
523
+
524
+ for i in range(sequences.size(0)):
525
+ raw_sql = t5_tokenizer.decode(sequences[i], skip_special_tokens=True).strip()
526
+ cand: Dict[str, Any] = {
527
+ "sql": raw_sql,
528
+ "score": float(scores[i]),
529
+ "repaired_from": None,
530
+ "repair_note": None,
531
+ "raw_sql_model": raw_sql,
532
+ }
533
+
534
+ # Intento 1: ejecución directa
535
+ exec_info = execute_sqlite(db_path, raw_sql)
536
+
537
+ # Hasta 3 rondas de reparación si sigue fallando por no such table/column
538
+ if (not exec_info["ok"]) and (
539
+ "no such table" in (exec_info["error"] or "")
540
+ or "no such column" in (exec_info["error"] or "")
541
+ ):
542
+ current_sql = raw_sql
543
+ last_error = exec_info["error"]
544
+ for step in range(1, 4): # step 1, 2, 3
545
+ repaired_sql = try_repair_sql(current_sql, last_error, meta)
546
+ if not repaired_sql or repaired_sql == current_sql:
547
+ break
548
+ exec_info2 = execute_sqlite(db_path, repaired_sql)
549
+ cand["repaired_from"] = current_sql if cand["repaired_from"] is None else cand["repaired_from"]
550
+ cand["repair_note"] = f"auto-repair (table/column name, step {step})"
551
+ cand["sql"] = repaired_sql
552
+ exec_info = exec_info2
553
+ current_sql = repaired_sql
554
+ if exec_info2["ok"]:
555
+ break
556
+ last_error = exec_info2["error"]
557
+
558
+ # Guardar info final de ejecución
559
+ cand["exec_ok"] = exec_info["ok"]
560
+ cand["exec_error"] = exec_info["error"]
561
+ cand["rows_preview"] = (
562
+ [list(r) for r in exec_info["rows"][:5]] if exec_info["ok"] and exec_info["rows"] else None
563
+ )
564
+ cand["columns"] = exec_info["columns"]
565
+
566
+ candidates.append(cand)
567
+
568
+ # Seleccionar "best"
569
+ if exec_info["ok"]:
570
+ if (not best_exec) or cand["score"] > best_score:
571
+ best_exec = True
572
+ best_score = cand["score"]
573
+ best = cand
574
+ elif not best_exec and cand["score"] > best_score:
575
+ best_score = cand["score"]
576
+ best = cand
577
+
578
+ if best is None and candidates:
579
+ best = candidates[0]
580
+
581
+ return {
582
+ "question_original": question,
583
+ "detected_language": detected,
584
+ "question_en": question_en,
585
+ "connection_id": conn_id,
586
+ "schema_summary": schema_str,
587
+ "best_sql": best["sql"],
588
+ "best_exec_ok": best.get("exec_ok", False),
589
+ "best_exec_error": best.get("exec_error"),
590
+ "best_rows_preview": best.get("rows_preview"),
591
+ "best_columns": best.get("columns", []),
592
+ "candidates": candidates,
593
+ }
594
+
595
+
596
+ # ======================================================
597
+ # 6) Schemas Pydantic
598
+ # ======================================================
599
+
600
+ class UploadResponse(BaseModel):
601
+ connection_id: str
602
+ label: str
603
+ db_path: str
604
+ note: Optional[str] = None
605
+
606
+
607
+ class ConnectionInfo(BaseModel):
608
+ connection_id: str
609
+ label: str
610
+
611
+
612
+ class SchemaResponse(BaseModel):
613
+ connection_id: str
614
+ schema_summary: str
615
+ tables: Dict[str, Dict[str, List[str]]]
616
+
617
+
618
+ class PreviewResponse(BaseModel):
619
+ connection_id: str
620
+ table: str
621
+ columns: List[str]
622
+ rows: List[List[Any]]
623
+
624
+
625
+ class InferRequest(BaseModel):
626
+ connection_id: str
627
+ question: str
628
+
629
+
630
+ class InferResponse(BaseModel):
631
+ question_original: str
632
+ detected_language: str
633
+ question_en: str
634
+ connection_id: str
635
+ schema_summary: str
636
+ best_sql: str
637
+ best_exec_ok: bool
638
+ best_exec_error: Optional[str]
639
+ best_rows_preview: Optional[List[List[Any]]]
640
+ best_columns: List[str]
641
+ candidates: List[Dict[str, Any]]
642
+
643
+
644
+ # ======================================================
645
+ # 7) Endpoints FastAPI
646
+ # ======================================================
647
+
648
+ @app.on_event("startup")
649
+ async def startup_event():
650
+ # Cargamos el modelo al inicio
651
+ load_nl2sql_model()
652
+ print(f"✅ Backend NL2SQL inicializado. MODEL_DIR={MODEL_DIR}, UPLOAD_DIR={UPLOAD_DIR}")
653
+
654
+
655
+ @app.post("/upload", response_model=UploadResponse)
656
+ async def upload_database(db_file: UploadFile = File(...)):
657
+ """
658
+ Subida universal de BD.
659
+ El usuario puede subir:
660
+ - .sqlite / .db → se usa tal cual
661
+ - .sql → dump MySQL/PostgreSQL/SQLite → se importa a SQLite
662
+ - .csv → se crea una BD SQLite y una tabla
663
+ - .zip → múltiples CSV → múltiples tablas en una BD SQLite
664
+ Devuelve un connection_id para usar en /schema, /preview y /infer.
665
+ """
666
+ filename = db_file.filename
667
+ if not filename:
668
+ raise HTTPException(status_code=400, detail="Archivo sin nombre.")
669
+
670
+ fname_lower = filename.lower()
671
+ contents = await db_file.read()
672
+
673
+ note = None
674
+
675
+ # Caso 1: SQLite nativa
676
+ if fname_lower.endswith(".sqlite") or fname_lower.endswith(".db"):
677
+ conn_id = f"db_{uuid.uuid4().hex[:8]}"
678
+ dst_path = os.path.join(UPLOAD_DIR, f"{conn_id}.sqlite")
679
+ with open(dst_path, "wb") as f:
680
+ f.write(contents)
681
+ DB_REGISTRY[conn_id] = {"db_path": dst_path, "label": filename}
682
+ note = "SQLite file stored as-is."
683
+
684
+ # Caso 2: dump .sql
685
+ elif fname_lower.endswith(".sql"):
686
+ conn_id = create_empty_sqlite_db(label=filename)
687
+ db_path = DB_REGISTRY[conn_id]["db_path"]
688
+ sql_text = contents.decode("utf-8", errors="ignore")
689
+ import_sql_dump_to_sqlite(db_path, sql_text)
690
+ note = "SQL dump imported into SQLite (best effort)."
691
+
692
+ # Caso 3: CSV simple
693
+ elif fname_lower.endswith(".csv"):
694
+ conn_id = create_empty_sqlite_db(label=filename)
695
+ db_path = DB_REGISTRY[conn_id]["db_path"]
696
+ table_name = os.path.splitext(os.path.basename(filename))[0]
697
+ import_csv_to_sqlite(db_path, contents, table_name)
698
+ note = "CSV imported into a single SQLite table."
699
+
700
+ # Caso 4: ZIP con CSVs
701
+ elif fname_lower.endswith(".zip"):
702
+ conn_id = create_empty_sqlite_db(label=filename)
703
+ db_path = DB_REGISTRY[conn_id]["db_path"]
704
+ import_zip_of_csvs_to_sqlite(db_path, contents)
705
+ note = "ZIP with CSVs imported into multiple SQLite tables."
706
+
707
+ else:
708
+ raise HTTPException(
709
+ status_code=400,
710
+ detail="Formato no soportado. Usa: .sqlite, .db, .sql, .csv o .zip",
711
+ )
712
+
713
+ return UploadResponse(
714
+ connection_id=conn_id,
715
+ label=DB_REGISTRY[conn_id]["label"],
716
+ db_path=DB_REGISTRY[conn_id]["db_path"],
717
+ note=note,
718
+ )
719
+
720
+
721
+ @app.get("/connections", response_model=List[ConnectionInfo])
722
+ async def list_connections():
723
+ """
724
+ Lista las conexiones registradas (todas en SQLite interno).
725
+ """
726
+ out = []
727
+ for cid, info in DB_REGISTRY.items():
728
+ out.append(ConnectionInfo(connection_id=cid, label=info["label"]))
729
+ return out
730
+
731
+
732
+ @app.get("/schema/{connection_id}", response_model=SchemaResponse)
733
+ async def get_schema(connection_id: str):
734
+ """
735
+ Devuelve un resumen de esquema para una BD subida.
736
+ """
737
+ if connection_id not in DB_REGISTRY:
738
+ raise HTTPException(status_code=404, detail="connection_id no encontrado")
739
+
740
+ db_path = DB_REGISTRY[connection_id]["db_path"]
741
+ meta = introspect_sqlite_schema(db_path)
742
+ return SchemaResponse(
743
+ connection_id=connection_id,
744
+ schema_summary=meta["schema_str"],
745
+ tables=meta["tables"],
746
+ )
747
+
748
+
749
+ @app.get("/preview/{connection_id}/{table}", response_model=PreviewResponse)
750
+ async def preview_table(connection_id: str, table: str, limit: int = 20):
751
+ """
752
+ Devuelve un preview de filas de una tabla concreta.
753
+ Útil para el frontend (vista de tabla + diagrama).
754
+ """
755
+ if connection_id not in DB_REGISTRY:
756
+ raise HTTPException(status_code=404, detail="connection_id no encontrado")
757
+
758
+ db_path = DB_REGISTRY[connection_id]["db_path"]
759
+ try:
760
+ conn = sqlite3.connect(db_path)
761
+ cur = conn.cursor()
762
+ cur.execute(f'SELECT * FROM "{table}" LIMIT {int(limit)};')
763
+ rows = cur.fetchall()
764
+ cols = [d[0] for d in cur.description] if cur.description else []
765
+ conn.close()
766
+ except Exception as e:
767
+ raise HTTPException(status_code=400, detail=f"Error al leer tabla '{table}': {e}")
768
+
769
+ return PreviewResponse(
770
+ connection_id=connection_id,
771
+ table=table,
772
+ columns=cols,
773
+ rows=[list(r) for r in rows],
774
+ )
775
+
776
+
777
+ @app.post("/infer", response_model=InferResponse)
778
+ async def infer_sql(req: InferRequest):
779
+ """
780
+ Dada una pregunta en lenguaje natural (ES o EN) y un connection_id,
781
+ genera SQL, ejecuta la consulta y devuelve el resultado + candidatos.
782
+ """
783
+ result = nl2sql_with_rerank(req.question, req.connection_id)
784
+ return InferResponse(**result)
785
+
786
+
787
+ @app.get("/health")
788
+ async def health():
789
+ return {
790
+ "status": "ok",
791
+ "model_loaded": t5_model is not None,
792
+ "connections": len(DB_REGISTRY),
793
+ "device": str(DEVICE),
794
+ }
795
+
796
+
797
+ @app.get("/")
798
+ async def root():
799
+ return {
800
+ "message": "NL2SQL T5-large universal backend is running (single-file SQLite engine).",
801
+ "endpoints": [
802
+ "POST /upload (subir .sqlite / .db / .sql / .csv / .zip)",
803
+ "GET /connections (listar BDs subidas)",
804
+ "GET /schema/{id} (esquema resumido)",
805
+ "GET /preview/{id}/{t} (preview de tabla)",
806
+ "POST /infer (NL→SQL + ejecución)",
807
+ "GET /health (estado del backend)",
808
+ "GET /docs (OpenAPI UI)",
809
+ ],
810
+ }