Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -9,15 +9,21 @@ class MariannaBot:
|
|
| 9 |
def __init__(self):
|
| 10 |
self.data_path_main = "dati_per_database_riassunti.pkl"
|
| 11 |
self.data_path_legends = "legends.pkl"
|
|
|
|
| 12 |
|
| 13 |
print("Inizializzazione di MariannaBot (senza DB)...") # Debug
|
| 14 |
|
| 15 |
self.database = self.load_data_from_pickle(self.data_path_main)
|
| 16 |
self.database_legends = self.load_data_from_pickle(self.data_path_legends)
|
| 17 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
self.db_keys = [el[0] for el in self.database] if isinstance(self.database, list) else []
|
| 20 |
self.db_keys_legends = [el[0] for el in self.database] if isinstance(self.database, list) else []
|
|
|
|
| 21 |
# print("Chiavi principali caricate:", len(self.db_keys)) # Debug
|
| 22 |
# print("Chiavi leggende caricate:", len(self.db_keys_legends)) # Debug
|
| 23 |
|
|
@@ -27,9 +33,20 @@ class MariannaBot:
|
|
| 27 |
|
| 28 |
self.reset_state()
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
def load_queries_dataset(self):
|
| 31 |
"""Loads queries dataset"""
|
| 32 |
-
return {"si, certo, certamente, ok, assolutamente si, sicuro, sisi":"si","no, non ho domande, non mi interessa, niente, nulla":"no","non so, scegli tu, fai tu, casuale, lascio a te, decidi tu, pensaci tu, sorprendimi":"non so","stronzo, vaffanculo, ti odio, pezzo di merda, cazzo":"parolacce"}
|
| 33 |
|
| 34 |
def load_data_from_pickle(self, file_path):
|
| 35 |
"""Loads data from a pickle file."""
|
|
@@ -60,6 +77,7 @@ class MariannaBot:
|
|
| 60 |
self.db_keys_embeddings = self.encoder.encode(self.db_keys, convert_to_tensor=True)
|
| 61 |
self.db_keys_legends_embeddings = self.encoder.encode(self.db_keys_legends, convert_to_tensor=True)
|
| 62 |
self.first_query_emb = self.encoder.encode(self.query_dic_keys, convert_to_tensor=True)
|
|
|
|
| 63 |
|
| 64 |
print(f"Encoder initialized with {len(self.db_keys)} keys.")
|
| 65 |
return True
|
|
@@ -121,14 +139,46 @@ class MariannaBot:
|
|
| 121 |
# per recuperare il contenuto effettivo.
|
| 122 |
return self.database_legends.get(key) if isinstance(self.database_legends, dict) else None
|
| 123 |
|
| 124 |
-
def get_value(self, key):
|
| 125 |
"""Retrieve a value from the loaded main data by key."""
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
if k == key:
|
| 128 |
return v
|
| 129 |
return None
|
| 130 |
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
def handle_query(self, message):
|
| 134 |
"""Handle user queries by searching the database"""
|
|
@@ -148,65 +198,74 @@ class MariannaBot:
|
|
| 148 |
for hit, score in zip(semantic_hits, cross_scores)],
|
| 149 |
key=lambda x: x['cross-score'], reverse=True
|
| 150 |
)
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
| 152 |
best_hit = reranked_hits[0]
|
| 153 |
best_title = self.db_keys[best_hit['corpus_id']]
|
| 154 |
best_score = best_hit['cross-score']
|
| 155 |
-
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
# Main treshold = 0.75
|
| 158 |
similarity_threshold = 0.75
|
| 159 |
|
| 160 |
# treshold granularity
|
| 161 |
if best_score < similarity_threshold:
|
| 162 |
# low confidence (< 0.35)
|
| 163 |
-
|
| 164 |
-
|
| 165 |
"Purtroppo non riesco a rammentare questo argomento, la mia memoria non è più quella di un tempo. Chiedimi qualcos'altro su Napoli e le sue bellezze!",
|
| 166 |
-
"Mi dispiace tantissimo, ma non riesco a ricordare
|
| 167 |
|
| 168 |
# medium confidence(0.55 - 0.75)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
else:
|
| 170 |
-
|
| 171 |
-
suggestions = ", ".join(alternative_hits)
|
| 172 |
-
value = self.get_value(best_title)
|
| 173 |
-
if value:
|
| 174 |
-
partial_info = value.get('short_intro', value.get('intro', '').split('.')[0] + '.')
|
| 175 |
-
self.state = "query"
|
| 176 |
-
self.is_telling_stories = False
|
| 177 |
-
return random.choice([f"Potrei avere alcune informazioni su {best_title}, ma non sono completamente sicura sia ciò che stai cercando. I miei suggerimenti sono {suggestions}. \n\nCosa ti interessa?",
|
| 178 |
-
f"Credo che tu stia parlando de {best_title}, ma per essere sicura di ciò che vuoi sapere, potresti specificare se parli di {suggestions}?",
|
| 179 |
-
f"Per assicurarmi di aver capito bene, vuoi che ti parli di {suggestions}?"])
|
| 180 |
-
else:
|
| 181 |
-
return f"Ho trovato qualcosa su {best_title}, ma non sono completamente sicura. Vuoi saperne di più?"
|
| 182 |
|
| 183 |
-
|
| 184 |
if best_title is not None:
|
| 185 |
-
|
| 186 |
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
else:
|
| 196 |
-
self.current_further_info_values = [] # Se il valore non è un dizionario
|
| 197 |
-
self.current_index = 0
|
| 198 |
-
return f"{response}\n\nVuoi sapere altro su {self.main_k[-1]}?"
|
| 199 |
else:
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
| 204 |
except Exception as e:
|
| 205 |
print(e)
|
| 206 |
self.state = "initial"
|
| 207 |
return random.choice(["Mi dispiace, c'è stato un errore. Puoi riprovare con un'altra domanda? ",
|
| 208 |
"Scusami, sto facendo confusione. Puoi farmi un'altra domanda?",
|
| 209 |
-
"Mi dispiace, non ho capito. Puoi essere più preciso?"])
|
| 210 |
|
| 211 |
|
| 212 |
def first_query(self, message):
|
|
@@ -217,7 +276,6 @@ class MariannaBot:
|
|
| 217 |
# Perform semantic search on the keys
|
| 218 |
semantic_hits = util.semantic_search(query_embedding, self.first_query_emb, top_k=4)
|
| 219 |
semantic_hits = semantic_hits[0]
|
| 220 |
-
print(semantic_hits)
|
| 221 |
cross_inp = [(message, self.query_dic_keys[hit['corpus_id']]) for hit in semantic_hits]
|
| 222 |
cross_scores = self.cross_encoder.predict(cross_inp)
|
| 223 |
reranked_hits = sorted(
|
|
@@ -228,7 +286,7 @@ class MariannaBot:
|
|
| 228 |
best_hit = reranked_hits[0]
|
| 229 |
best_title = self.query_dic[self.query_dic_keys[best_hit['corpus_id']]]
|
| 230 |
best_score = best_hit['cross-score']
|
| 231 |
-
print(best_title, best_score)
|
| 232 |
# Main treshold = 0.75
|
| 233 |
similarity_threshold = 0.35
|
| 234 |
|
|
@@ -252,14 +310,14 @@ class MariannaBot:
|
|
| 252 |
self.state = "initial"
|
| 253 |
return random.choice(["Mi dispiace, c'è stato un errore. Puoi riprovare con un'altra domanda? ",
|
| 254 |
"Scusami, sto facendo confusione. Puoi farmi un'altra domanda?",
|
| 255 |
-
"Mi dispiace, non ho capito. Puoi essere più preciso?"])
|
| 256 |
|
| 257 |
|
| 258 |
def respond(self, message, history):
|
| 259 |
if not message:
|
| 260 |
return random.choice(["Mi dispiace, c'è stato un errore. Puoi riprovare con un'altra domanda? ",
|
| 261 |
"Scusami, sto facendo confusione. Puoi farmi un'altra domanda?",
|
| 262 |
-
"Mi dispiace, non ho capito. Puoi essere più preciso?"])
|
| 263 |
|
| 264 |
message = message.lower().strip()
|
| 265 |
|
|
|
|
| 9 |
def __init__(self):
|
| 10 |
self.data_path_main = "dati_per_database_riassunti.pkl"
|
| 11 |
self.data_path_legends = "legends.pkl"
|
| 12 |
+
self.data_path_exp = "secondDB.pkl"
|
| 13 |
|
| 14 |
print("Inizializzazione di MariannaBot (senza DB)...") # Debug
|
| 15 |
|
| 16 |
self.database = self.load_data_from_pickle(self.data_path_main)
|
| 17 |
self.database_legends = self.load_data_from_pickle(self.data_path_legends)
|
| 18 |
+
self.database_expansion = self.load_data_from_pickle(self.data_path_exp)
|
| 19 |
+
|
| 20 |
+
self.database = self.database + self.database_legends + self.database_expansion
|
| 21 |
+
self.further_dataset = self.load_further_info_as_dataset(self.database)
|
| 22 |
+
self.further_dataset = self.further_dataset + self.database
|
| 23 |
|
| 24 |
self.db_keys = [el[0] for el in self.database] if isinstance(self.database, list) else []
|
| 25 |
self.db_keys_legends = [el[0] for el in self.database] if isinstance(self.database, list) else []
|
| 26 |
+
self.db_keys_further = [el[0] for el in self.further_dataset] if isinstance(self.further_dataset, list) else []
|
| 27 |
# print("Chiavi principali caricate:", len(self.db_keys)) # Debug
|
| 28 |
# print("Chiavi leggende caricate:", len(self.db_keys_legends)) # Debug
|
| 29 |
|
|
|
|
| 33 |
|
| 34 |
self.reset_state()
|
| 35 |
|
| 36 |
+
def load_further_info_as_dataset(self,dataset):
|
| 37 |
+
nuova_lista = []
|
| 38 |
+
for chiave_principale, info in dataset:
|
| 39 |
+
nuovo_dizionario = {'intro': info['intro']}
|
| 40 |
+
if 'further_info' in info:
|
| 41 |
+
for chiave_secondaria in info['further_info']:
|
| 42 |
+
nuova_lista.append((f"{chiave_secondaria} ({chiave_principale})", {'intro': info['further_info'][chiave_secondaria]}))
|
| 43 |
+
else:
|
| 44 |
+
nuova_lista.append((chiave_principale, nuovo_dizionario))
|
| 45 |
+
return nuova_lista
|
| 46 |
+
|
| 47 |
def load_queries_dataset(self):
|
| 48 |
"""Loads queries dataset"""
|
| 49 |
+
return {"si, certo, certamente, ok, assolutamente si, sicuro, sisi, continua, prosegui":"si","no, non ho domande, non mi interessa, niente, nulla":"no","non so, scegli tu, fai tu, casuale, lascio a te, decidi tu, pensaci tu, sorprendimi":"non so","stronzo, vaffanculo, ti odio, pezzo di merda, cazzo":"parolacce"}
|
| 50 |
|
| 51 |
def load_data_from_pickle(self, file_path):
|
| 52 |
"""Loads data from a pickle file."""
|
|
|
|
| 77 |
self.db_keys_embeddings = self.encoder.encode(self.db_keys, convert_to_tensor=True)
|
| 78 |
self.db_keys_legends_embeddings = self.encoder.encode(self.db_keys_legends, convert_to_tensor=True)
|
| 79 |
self.first_query_emb = self.encoder.encode(self.query_dic_keys, convert_to_tensor=True)
|
| 80 |
+
self.further_embeddings = self.encoder.encode(self.db_keys_further, convert_to_tensor=True)
|
| 81 |
|
| 82 |
print(f"Encoder initialized with {len(self.db_keys)} keys.")
|
| 83 |
return True
|
|
|
|
| 139 |
# per recuperare il contenuto effettivo.
|
| 140 |
return self.database_legends.get(key) if isinstance(self.database_legends, dict) else None
|
| 141 |
|
| 142 |
+
def get_value(self, key,state):
|
| 143 |
"""Retrieve a value from the loaded main data by key."""
|
| 144 |
+
if state=="A":
|
| 145 |
+
for k, v in self.database:
|
| 146 |
+
if k == key:
|
| 147 |
+
return v
|
| 148 |
+
else:
|
| 149 |
+
for k, v in self.further_dataset:
|
| 150 |
if k == key:
|
| 151 |
return v
|
| 152 |
return None
|
| 153 |
|
| 154 |
|
| 155 |
+
def deeper_handle_query(self,message,query_embedding,CS_old):
|
| 156 |
+
print('Ricerca in profondità')
|
| 157 |
+
try:
|
| 158 |
+
semantic_hits = util.semantic_search(query_embedding, self.further_embeddings, top_k=3)
|
| 159 |
+
semantic_hits = semantic_hits[0]
|
| 160 |
+
|
| 161 |
+
cross_inp = [(message, self.db_keys_further[hit['corpus_id']]) for hit in semantic_hits]
|
| 162 |
+
cross_scores = self.cross_encoder.predict(cross_inp)
|
| 163 |
+
cross_scores = cross_scores + CS_old
|
| 164 |
+
|
| 165 |
+
reranked_hits = sorted(
|
| 166 |
+
[{'corpus_id': hit['corpus_id'], 'cross-score': score}
|
| 167 |
+
for hit, score in zip(semantic_hits, cross_scores)],
|
| 168 |
+
key=lambda x: x['cross-score'], reverse=True
|
| 169 |
+
)
|
| 170 |
+
for h in reranked_hits:
|
| 171 |
+
print(self.db_keys_further[h['corpus_id']],h['cross-score'])
|
| 172 |
+
best_hit = reranked_hits[0]
|
| 173 |
+
best_title = self.db_keys_further[best_hit['corpus_id']]
|
| 174 |
+
best_score = best_hit['cross-score']
|
| 175 |
+
return reranked_hits
|
| 176 |
+
except Exception as e:
|
| 177 |
+
print(e)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
|
| 182 |
|
| 183 |
def handle_query(self, message):
|
| 184 |
"""Handle user queries by searching the database"""
|
|
|
|
| 198 |
for hit, score in zip(semantic_hits, cross_scores)],
|
| 199 |
key=lambda x: x['cross-score'], reverse=True
|
| 200 |
)
|
| 201 |
+
|
| 202 |
+
for h in reranked_hits:
|
| 203 |
+
print(self.db_keys[h['corpus_id']],h['cross-score'])
|
| 204 |
+
chiavi = self.db_keys
|
| 205 |
best_hit = reranked_hits[0]
|
| 206 |
best_title = self.db_keys[best_hit['corpus_id']]
|
| 207 |
best_score = best_hit['cross-score']
|
| 208 |
+
state="A"
|
| 209 |
+
|
| 210 |
+
if best_score < 0.75:
|
| 211 |
+
reranked_hits = self.deeper_handle_query(message,query_embedding,cross_scores)
|
| 212 |
+
best_hit = reranked_hits[0]
|
| 213 |
+
best_title = self.db_keys_further[best_hit['corpus_id']]
|
| 214 |
+
best_score = best_hit['cross-score']
|
| 215 |
+
state="B"
|
| 216 |
+
|
| 217 |
# Main treshold = 0.75
|
| 218 |
similarity_threshold = 0.75
|
| 219 |
|
| 220 |
# treshold granularity
|
| 221 |
if best_score < similarity_threshold:
|
| 222 |
# low confidence (< 0.35)
|
| 223 |
+
if best_score < 0.55:
|
| 224 |
+
return random.choice(["Mi dispiace, non ho informazioni su questo argomento. Puoi chiedermi di altro sulla città di Napoli.",
|
| 225 |
"Purtroppo non riesco a rammentare questo argomento, la mia memoria non è più quella di un tempo. Chiedimi qualcos'altro su Napoli e le sue bellezze!",
|
| 226 |
+
"Mi dispiace tantissimo, ma non riesco a ricordare. Vuoi chiedermi altro sulla città di Napoli?"])
|
| 227 |
|
| 228 |
# medium confidence(0.55 - 0.75)
|
| 229 |
+
else:
|
| 230 |
+
alternative_hits = [self.db_keys[hit['corpus_id']] for hit in reranked_hits[:2]]
|
| 231 |
+
suggestions = " o ".join(alternative_hits)
|
| 232 |
+
value = self.get_value(best_title,state)
|
| 233 |
+
if value:
|
| 234 |
+
partial_info = value.get('short_intro', value.get('intro', '').split('.')[0] + '.')
|
| 235 |
+
self.state = "query"
|
| 236 |
+
self.is_telling_stories = False
|
| 237 |
+
return random.choice([f"Potrei avere alcune informazioni su {best_title}, ma non sono completamente sicura sia ciò che stai cercando. I miei suggerimenti sono {suggestions}. \n\nCosa ti interessa?",
|
| 238 |
+
f"Credo che tu stia parlando de {best_title}, ma per essere sicura di ciò che vuoi sapere, potresti specificare se parli di {suggestions}?",
|
| 239 |
+
f"Per assicurarmi di aver capito bene, vuoi che ti parli di {suggestions}?"])
|
| 240 |
else:
|
| 241 |
+
return f"Ho trovato qualcosa su {best_title}, ma non sono completamente sicura. Vuoi saperne di più?"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
+
# high confidence (above the threshold)
|
| 244 |
if best_title is not None:
|
| 245 |
+
value = self.get_value(best_title,state)
|
| 246 |
|
| 247 |
+
if value:
|
| 248 |
+
key = best_title
|
| 249 |
+
self.main_k.append(key)
|
| 250 |
+
self.state = "follow_up"
|
| 251 |
+
self.is_telling_stories = False
|
| 252 |
+
response = value.get('intro', '')
|
| 253 |
+
if isinstance(value, dict):
|
| 254 |
+
self.current_further_info_values = list(value.get('further_info', {}).values())
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
else:
|
| 256 |
+
self.current_further_info_values = [] # Se il valore non è un dizionario
|
| 257 |
+
self.current_index = 0
|
| 258 |
+
return f"{response}\n\nVuoi sapere altro su {self.main_k[-1]}?"
|
| 259 |
+
else:
|
| 260 |
+
return random.choice(["Mi dispiace, non ho informazioni su questo argomento. Puoi chiedermi di altro sulla città di Napoli.",
|
| 261 |
+
"Purtroppo non riesco a rammentare altro su questo argomento, la mia memoria non è più quella di un tempo. Chiedimi qualcos'altro su Napoli e le sue bellezze!",
|
| 262 |
+
"Mi dispiace tantissimo, ma non riesco a ricordare altro. Vuoi chiedermi altro sulla città di Napoli?"])
|
| 263 |
except Exception as e:
|
| 264 |
print(e)
|
| 265 |
self.state = "initial"
|
| 266 |
return random.choice(["Mi dispiace, c'è stato un errore. Puoi riprovare con un'altra domanda? ",
|
| 267 |
"Scusami, sto facendo confusione. Puoi farmi un'altra domanda?",
|
| 268 |
+
"Mi dispiace, non ho capito. Puoi essere più preciso?"])
|
| 269 |
|
| 270 |
|
| 271 |
def first_query(self, message):
|
|
|
|
| 276 |
# Perform semantic search on the keys
|
| 277 |
semantic_hits = util.semantic_search(query_embedding, self.first_query_emb, top_k=4)
|
| 278 |
semantic_hits = semantic_hits[0]
|
|
|
|
| 279 |
cross_inp = [(message, self.query_dic_keys[hit['corpus_id']]) for hit in semantic_hits]
|
| 280 |
cross_scores = self.cross_encoder.predict(cross_inp)
|
| 281 |
reranked_hits = sorted(
|
|
|
|
| 286 |
best_hit = reranked_hits[0]
|
| 287 |
best_title = self.query_dic[self.query_dic_keys[best_hit['corpus_id']]]
|
| 288 |
best_score = best_hit['cross-score']
|
| 289 |
+
print(message,best_title, best_score)
|
| 290 |
# Main treshold = 0.75
|
| 291 |
similarity_threshold = 0.35
|
| 292 |
|
|
|
|
| 310 |
self.state = "initial"
|
| 311 |
return random.choice(["Mi dispiace, c'è stato un errore. Puoi riprovare con un'altra domanda? ",
|
| 312 |
"Scusami, sto facendo confusione. Puoi farmi un'altra domanda?",
|
| 313 |
+
"Mi dispiace, non ho capito. Puoi essere più preciso?"])
|
| 314 |
|
| 315 |
|
| 316 |
def respond(self, message, history):
|
| 317 |
if not message:
|
| 318 |
return random.choice(["Mi dispiace, c'è stato un errore. Puoi riprovare con un'altra domanda? ",
|
| 319 |
"Scusami, sto facendo confusione. Puoi farmi un'altra domanda?",
|
| 320 |
+
"Mi dispiace, non ho capito. Puoi essere più preciso?"])
|
| 321 |
|
| 322 |
message = message.lower().strip()
|
| 323 |
|