aduleba commited on
Commit
1909343
·
1 Parent(s): f3c91cb

[UPD] softmax model

Browse files
Files changed (1) hide show
  1. app.py +10 -9
app.py CHANGED
@@ -28,8 +28,8 @@ def load_model_and_tokenizer():
28
 
29
  return model, tokenizer, unique_tags
30
 
31
- def predict(texts, model, tokenizer, unique_tags, top_threshold=0.95):
32
- device = next(model.parameters()).device # Берём девайс из модели
33
  model.eval()
34
  predictions = []
35
 
@@ -44,17 +44,18 @@ def predict(texts, model, tokenizer, unique_tags, top_threshold=0.95):
44
  ).to(device)
45
 
46
  outputs = model(**inputs)
47
- probs = torch.sigmoid(outputs.logits).cpu().numpy().flatten()
48
 
49
- sorted_indices = probs.argsort()[::-1]
50
- cumulative = 0
51
  selected = []
52
 
53
  for idx in sorted_indices:
54
- cumulative += probs[idx]
55
- selected.append((unique_tags[idx], float(probs[idx])))
56
- if cumulative >= top_threshold:
57
- break
 
58
  predictions.append(selected)
59
 
60
  return predictions
 
28
 
29
  return model, tokenizer, unique_tags
30
 
31
+ def predict(texts, model, tokenizer, unique_tags, top_threshold=0.95, max_tags=4):
32
+ device = next(model.parameters()).device
33
  model.eval()
34
  predictions = []
35
 
 
44
  ).to(device)
45
 
46
  outputs = model(**inputs)
47
+ raw_probs = torch.softmax(outputs.logits, dim=1).cpu().numpy().flatten()
48
 
49
+ sorted_indices = raw_probs.argsort()[::-1]
50
+ cumulative = 0.0
51
  selected = []
52
 
53
  for idx in sorted_indices:
54
+ prob = raw_probs[idx]
55
+ cumulative += prob
56
+ selected.append((unique_tags[idx], float(prob)))
57
+ if cumulative >= top_threshold or len(selected) == max_tags:
58
+ break
59
  predictions.append(selected)
60
 
61
  return predictions