amaisto commited on
Commit
9dfa437
·
verified ·
1 Parent(s): f16b834

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -45
app.py CHANGED
@@ -9,15 +9,21 @@ class MariannaBot:
9
  def __init__(self):
10
  self.data_path_main = "dati_per_database_riassunti.pkl"
11
  self.data_path_legends = "legends.pkl"
 
12
 
13
  print("Inizializzazione di MariannaBot (senza DB)...") # Debug
14
 
15
  self.database = self.load_data_from_pickle(self.data_path_main)
16
  self.database_legends = self.load_data_from_pickle(self.data_path_legends)
17
- self.database = self.database + self.database_legends
 
 
 
 
18
 
19
  self.db_keys = [el[0] for el in self.database] if isinstance(self.database, list) else []
20
  self.db_keys_legends = [el[0] for el in self.database] if isinstance(self.database, list) else []
 
21
  # print("Chiavi principali caricate:", len(self.db_keys)) # Debug
22
  # print("Chiavi leggende caricate:", len(self.db_keys_legends)) # Debug
23
 
@@ -27,9 +33,20 @@ class MariannaBot:
27
 
28
  self.reset_state()
29
 
 
 
 
 
 
 
 
 
 
 
 
30
  def load_queries_dataset(self):
31
  """Loads queries dataset"""
32
- return {"si, certo, certamente, ok, assolutamente si, sicuro, sisi":"si","no, non ho domande, non mi interessa, niente, nulla":"no","non so, scegli tu, fai tu, casuale, lascio a te, decidi tu, pensaci tu, sorprendimi":"non so","stronzo, vaffanculo, ti odio, pezzo di merda, cazzo":"parolacce"}
33
 
34
  def load_data_from_pickle(self, file_path):
35
  """Loads data from a pickle file."""
@@ -60,6 +77,7 @@ class MariannaBot:
60
  self.db_keys_embeddings = self.encoder.encode(self.db_keys, convert_to_tensor=True)
61
  self.db_keys_legends_embeddings = self.encoder.encode(self.db_keys_legends, convert_to_tensor=True)
62
  self.first_query_emb = self.encoder.encode(self.query_dic_keys, convert_to_tensor=True)
 
63
 
64
  print(f"Encoder initialized with {len(self.db_keys)} keys.")
65
  return True
@@ -121,14 +139,46 @@ class MariannaBot:
121
  # per recuperare il contenuto effettivo.
122
  return self.database_legends.get(key) if isinstance(self.database_legends, dict) else None
123
 
124
- def get_value(self, key):
125
  """Retrieve a value from the loaded main data by key."""
126
- for k, v in self.database:
 
 
 
 
 
127
  if k == key:
128
  return v
129
  return None
130
 
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  def handle_query(self, message):
134
  """Handle user queries by searching the database"""
@@ -148,65 +198,74 @@ class MariannaBot:
148
  for hit, score in zip(semantic_hits, cross_scores)],
149
  key=lambda x: x['cross-score'], reverse=True
150
  )
151
-
 
 
 
152
  best_hit = reranked_hits[0]
153
  best_title = self.db_keys[best_hit['corpus_id']]
154
  best_score = best_hit['cross-score']
155
- # print(best_title, best_score)
156
-
 
 
 
 
 
 
 
157
  # Main treshold = 0.75
158
  similarity_threshold = 0.75
159
 
160
  # treshold granularity
161
  if best_score < similarity_threshold:
162
  # low confidence (< 0.35)
163
- if best_score < 0.55:
164
- return random.choice(["Mi dispiace, non ho informazioni su questo argomento. Puoi chiedermi di altro sulla città di Napoli.",
165
  "Purtroppo non riesco a rammentare questo argomento, la mia memoria non è più quella di un tempo. Chiedimi qualcos'altro su Napoli e le sue bellezze!",
166
- "Mi dispiace tantissimo, ma non riesco a ricordare altro. Vuoi chiedermi altro sulla città di Napoli?"])
167
 
168
  # medium confidence(0.55 - 0.75)
 
 
 
 
 
 
 
 
 
 
 
169
  else:
170
- alternative_hits = [self.db_keys[hit['corpus_id']] for hit in reranked_hits[:2]]
171
- suggestions = ", ".join(alternative_hits)
172
- value = self.get_value(best_title)
173
- if value:
174
- partial_info = value.get('short_intro', value.get('intro', '').split('.')[0] + '.')
175
- self.state = "query"
176
- self.is_telling_stories = False
177
- return random.choice([f"Potrei avere alcune informazioni su {best_title}, ma non sono completamente sicura sia ciò che stai cercando. I miei suggerimenti sono {suggestions}. \n\nCosa ti interessa?",
178
- f"Credo che tu stia parlando de {best_title}, ma per essere sicura di ciò che vuoi sapere, potresti specificare se parli di {suggestions}?",
179
- f"Per assicurarmi di aver capito bene, vuoi che ti parli di {suggestions}?"])
180
- else:
181
- return f"Ho trovato qualcosa su {best_title}, ma non sono completamente sicura. Vuoi saperne di più?"
182
 
183
- # high confidence (above the threshold)
184
  if best_title is not None:
185
- value = self.get_value(best_title)
186
 
187
- if value:
188
- key = best_title
189
- self.main_k.append(key)
190
- self.state = "follow_up"
191
- self.is_telling_stories = False
192
- response = value.get('intro', '')
193
- if isinstance(value, dict):
194
- self.current_further_info_values = list(value.get('further_info', {}).values())
195
- else:
196
- self.current_further_info_values = [] # Se il valore non è un dizionario
197
- self.current_index = 0
198
- return f"{response}\n\nVuoi sapere altro su {self.main_k[-1]}?"
199
  else:
200
- return random.choice(["Mi dispiace, non ho informazioni su questo argomento. Puoi chiedermi di altro sulla città di Napoli.",
201
- "Purtroppo non riesco a rammentare altro su questo argomento, la mia memoria non è più quella di un tempo. Chiedimi qualcos'altro su Napoli e le sue bellezze!",
202
- "Mi dispiace tantissimo, ma non riesco a ricordare altro. Vuoi chiedermi altro sulla città di Napoli?"])
203
-
 
 
 
204
  except Exception as e:
205
  print(e)
206
  self.state = "initial"
207
  return random.choice(["Mi dispiace, c'è stato un errore. Puoi riprovare con un'altra domanda? ",
208
  "Scusami, sto facendo confusione. Puoi farmi un'altra domanda?",
209
- "Mi dispiace, non ho capito. Puoi essere più preciso?"])
210
 
211
 
212
  def first_query(self, message):
@@ -217,7 +276,6 @@ class MariannaBot:
217
  # Perform semantic search on the keys
218
  semantic_hits = util.semantic_search(query_embedding, self.first_query_emb, top_k=4)
219
  semantic_hits = semantic_hits[0]
220
- print(semantic_hits)
221
  cross_inp = [(message, self.query_dic_keys[hit['corpus_id']]) for hit in semantic_hits]
222
  cross_scores = self.cross_encoder.predict(cross_inp)
223
  reranked_hits = sorted(
@@ -228,7 +286,7 @@ class MariannaBot:
228
  best_hit = reranked_hits[0]
229
  best_title = self.query_dic[self.query_dic_keys[best_hit['corpus_id']]]
230
  best_score = best_hit['cross-score']
231
- print(best_title, best_score)
232
  # Main treshold = 0.75
233
  similarity_threshold = 0.35
234
 
@@ -252,14 +310,14 @@ class MariannaBot:
252
  self.state = "initial"
253
  return random.choice(["Mi dispiace, c'è stato un errore. Puoi riprovare con un'altra domanda? ",
254
  "Scusami, sto facendo confusione. Puoi farmi un'altra domanda?",
255
- "Mi dispiace, non ho capito. Puoi essere più preciso?"])
256
 
257
 
258
  def respond(self, message, history):
259
  if not message:
260
  return random.choice(["Mi dispiace, c'è stato un errore. Puoi riprovare con un'altra domanda? ",
261
  "Scusami, sto facendo confusione. Puoi farmi un'altra domanda?",
262
- "Mi dispiace, non ho capito. Puoi essere più preciso?"])
263
 
264
  message = message.lower().strip()
265
 
 
9
  def __init__(self):
10
  self.data_path_main = "dati_per_database_riassunti.pkl"
11
  self.data_path_legends = "legends.pkl"
12
+ self.data_path_exp = "secondDB.pkl"
13
 
14
  print("Inizializzazione di MariannaBot (senza DB)...") # Debug
15
 
16
  self.database = self.load_data_from_pickle(self.data_path_main)
17
  self.database_legends = self.load_data_from_pickle(self.data_path_legends)
18
+ self.database_expansion = self.load_data_from_pickle(self.data_path_exp)
19
+
20
+ self.database = self.database + self.database_legends + self.database_expansion
21
+ self.further_dataset = self.load_further_info_as_dataset(self.database)
22
+ self.further_dataset = self.further_dataset + self.database
23
 
24
  self.db_keys = [el[0] for el in self.database] if isinstance(self.database, list) else []
25
  self.db_keys_legends = [el[0] for el in self.database] if isinstance(self.database, list) else []
26
+ self.db_keys_further = [el[0] for el in self.further_dataset] if isinstance(self.further_dataset, list) else []
27
  # print("Chiavi principali caricate:", len(self.db_keys)) # Debug
28
  # print("Chiavi leggende caricate:", len(self.db_keys_legends)) # Debug
29
 
 
33
 
34
  self.reset_state()
35
 
36
+ def load_further_info_as_dataset(self,dataset):
37
+ nuova_lista = []
38
+ for chiave_principale, info in dataset:
39
+ nuovo_dizionario = {'intro': info['intro']}
40
+ if 'further_info' in info:
41
+ for chiave_secondaria in info['further_info']:
42
+ nuova_lista.append((f"{chiave_secondaria} ({chiave_principale})", {'intro': info['further_info'][chiave_secondaria]}))
43
+ else:
44
+ nuova_lista.append((chiave_principale, nuovo_dizionario))
45
+ return nuova_lista
46
+
47
  def load_queries_dataset(self):
48
  """Loads queries dataset"""
49
+ return {"si, certo, certamente, ok, assolutamente si, sicuro, sisi, continua, prosegui":"si","no, non ho domande, non mi interessa, niente, nulla":"no","non so, scegli tu, fai tu, casuale, lascio a te, decidi tu, pensaci tu, sorprendimi":"non so","stronzo, vaffanculo, ti odio, pezzo di merda, cazzo":"parolacce"}
50
 
51
  def load_data_from_pickle(self, file_path):
52
  """Loads data from a pickle file."""
 
77
  self.db_keys_embeddings = self.encoder.encode(self.db_keys, convert_to_tensor=True)
78
  self.db_keys_legends_embeddings = self.encoder.encode(self.db_keys_legends, convert_to_tensor=True)
79
  self.first_query_emb = self.encoder.encode(self.query_dic_keys, convert_to_tensor=True)
80
+ self.further_embeddings = self.encoder.encode(self.db_keys_further, convert_to_tensor=True)
81
 
82
  print(f"Encoder initialized with {len(self.db_keys)} keys.")
83
  return True
 
139
  # per recuperare il contenuto effettivo.
140
  return self.database_legends.get(key) if isinstance(self.database_legends, dict) else None
141
 
142
+ def get_value(self, key,state):
143
  """Retrieve a value from the loaded main data by key."""
144
+ if state=="A":
145
+ for k, v in self.database:
146
+ if k == key:
147
+ return v
148
+ else:
149
+ for k, v in self.further_dataset:
150
  if k == key:
151
  return v
152
  return None
153
 
154
 
155
+ def deeper_handle_query(self,message,query_embedding,CS_old):
156
+ print('Ricerca in profondità')
157
+ try:
158
+ semantic_hits = util.semantic_search(query_embedding, self.further_embeddings, top_k=3)
159
+ semantic_hits = semantic_hits[0]
160
+
161
+ cross_inp = [(message, self.db_keys_further[hit['corpus_id']]) for hit in semantic_hits]
162
+ cross_scores = self.cross_encoder.predict(cross_inp)
163
+ cross_scores = cross_scores + CS_old
164
+
165
+ reranked_hits = sorted(
166
+ [{'corpus_id': hit['corpus_id'], 'cross-score': score}
167
+ for hit, score in zip(semantic_hits, cross_scores)],
168
+ key=lambda x: x['cross-score'], reverse=True
169
+ )
170
+ for h in reranked_hits:
171
+ print(self.db_keys_further[h['corpus_id']],h['cross-score'])
172
+ best_hit = reranked_hits[0]
173
+ best_title = self.db_keys_further[best_hit['corpus_id']]
174
+ best_score = best_hit['cross-score']
175
+ return reranked_hits
176
+ except Exception as e:
177
+ print(e)
178
+
179
+
180
+
181
+
182
 
183
  def handle_query(self, message):
184
  """Handle user queries by searching the database"""
 
198
  for hit, score in zip(semantic_hits, cross_scores)],
199
  key=lambda x: x['cross-score'], reverse=True
200
  )
201
+
202
+ for h in reranked_hits:
203
+ print(self.db_keys[h['corpus_id']],h['cross-score'])
204
+ chiavi = self.db_keys
205
  best_hit = reranked_hits[0]
206
  best_title = self.db_keys[best_hit['corpus_id']]
207
  best_score = best_hit['cross-score']
208
+ state="A"
209
+
210
+ if best_score < 0.75:
211
+ reranked_hits = self.deeper_handle_query(message,query_embedding,cross_scores)
212
+ best_hit = reranked_hits[0]
213
+ best_title = self.db_keys_further[best_hit['corpus_id']]
214
+ best_score = best_hit['cross-score']
215
+ state="B"
216
+
217
  # Main treshold = 0.75
218
  similarity_threshold = 0.75
219
 
220
  # treshold granularity
221
  if best_score < similarity_threshold:
222
  # low confidence (< 0.35)
223
+ if best_score < 0.55:
224
+ return random.choice(["Mi dispiace, non ho informazioni su questo argomento. Puoi chiedermi di altro sulla città di Napoli.",
225
  "Purtroppo non riesco a rammentare questo argomento, la mia memoria non è più quella di un tempo. Chiedimi qualcos'altro su Napoli e le sue bellezze!",
226
+ "Mi dispiace tantissimo, ma non riesco a ricordare. Vuoi chiedermi altro sulla città di Napoli?"])
227
 
228
  # medium confidence(0.55 - 0.75)
229
+ else:
230
+ alternative_hits = [self.db_keys[hit['corpus_id']] for hit in reranked_hits[:2]]
231
+ suggestions = " o ".join(alternative_hits)
232
+ value = self.get_value(best_title,state)
233
+ if value:
234
+ partial_info = value.get('short_intro', value.get('intro', '').split('.')[0] + '.')
235
+ self.state = "query"
236
+ self.is_telling_stories = False
237
+ return random.choice([f"Potrei avere alcune informazioni su {best_title}, ma non sono completamente sicura sia ciò che stai cercando. I miei suggerimenti sono {suggestions}. \n\nCosa ti interessa?",
238
+ f"Credo che tu stia parlando de {best_title}, ma per essere sicura di ciò che vuoi sapere, potresti specificare se parli di {suggestions}?",
239
+ f"Per assicurarmi di aver capito bene, vuoi che ti parli di {suggestions}?"])
240
  else:
241
+ return f"Ho trovato qualcosa su {best_title}, ma non sono completamente sicura. Vuoi saperne di più?"
 
 
 
 
 
 
 
 
 
 
 
242
 
243
+ # high confidence (above the threshold)
244
  if best_title is not None:
245
+ value = self.get_value(best_title,state)
246
 
247
+ if value:
248
+ key = best_title
249
+ self.main_k.append(key)
250
+ self.state = "follow_up"
251
+ self.is_telling_stories = False
252
+ response = value.get('intro', '')
253
+ if isinstance(value, dict):
254
+ self.current_further_info_values = list(value.get('further_info', {}).values())
 
 
 
 
255
  else:
256
+ self.current_further_info_values = [] # Se il valore non è un dizionario
257
+ self.current_index = 0
258
+ return f"{response}\n\nVuoi sapere altro su {self.main_k[-1]}?"
259
+ else:
260
+ return random.choice(["Mi dispiace, non ho informazioni su questo argomento. Puoi chiedermi di altro sulla città di Napoli.",
261
+ "Purtroppo non riesco a rammentare altro su questo argomento, la mia memoria non è più quella di un tempo. Chiedimi qualcos'altro su Napoli e le sue bellezze!",
262
+ "Mi dispiace tantissimo, ma non riesco a ricordare altro. Vuoi chiedermi altro sulla città di Napoli?"])
263
  except Exception as e:
264
  print(e)
265
  self.state = "initial"
266
  return random.choice(["Mi dispiace, c'è stato un errore. Puoi riprovare con un'altra domanda? ",
267
  "Scusami, sto facendo confusione. Puoi farmi un'altra domanda?",
268
+ "Mi dispiace, non ho capito. Puoi essere più preciso?"])
269
 
270
 
271
  def first_query(self, message):
 
276
  # Perform semantic search on the keys
277
  semantic_hits = util.semantic_search(query_embedding, self.first_query_emb, top_k=4)
278
  semantic_hits = semantic_hits[0]
 
279
  cross_inp = [(message, self.query_dic_keys[hit['corpus_id']]) for hit in semantic_hits]
280
  cross_scores = self.cross_encoder.predict(cross_inp)
281
  reranked_hits = sorted(
 
286
  best_hit = reranked_hits[0]
287
  best_title = self.query_dic[self.query_dic_keys[best_hit['corpus_id']]]
288
  best_score = best_hit['cross-score']
289
+ print(message,best_title, best_score)
290
  # Main treshold = 0.75
291
  similarity_threshold = 0.35
292
 
 
310
  self.state = "initial"
311
  return random.choice(["Mi dispiace, c'è stato un errore. Puoi riprovare con un'altra domanda? ",
312
  "Scusami, sto facendo confusione. Puoi farmi un'altra domanda?",
313
+ "Mi dispiace, non ho capito. Puoi essere più preciso?"])
314
 
315
 
316
  def respond(self, message, history):
317
  if not message:
318
  return random.choice(["Mi dispiace, c'è stato un errore. Puoi riprovare con un'altra domanda? ",
319
  "Scusami, sto facendo confusione. Puoi farmi un'altra domanda?",
320
+ "Mi dispiace, non ho capito. Puoi essere più preciso?"])
321
 
322
  message = message.lower().strip()
323