File size: 2,944 Bytes
f3c91cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1909343
 
f3c91cb
 
 
 
 
 
 
 
 
 
 
 
 
 
1909343
f3c91cb
1909343
 
f3c91cb
 
 
1909343
 
 
 
 
f3c91cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import streamlit as st
import torch
import json
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer, DistilBertConfig

@st.cache_resource
def load_model_and_tokenizer():
    with open('unique_tags.json', 'r', encoding='utf-8') as f:
        unique_tags = json.load(f)
        
        
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    num_classes = len(unique_tags)
    config = DistilBertConfig.from_pretrained(
        'distilbert-base-cased',
        num_labels=num_classes,
        problem_type="single_label_classification"
    )
    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
    model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-cased', config=config)
    model.to(device)

    checkpoint = torch.load("best_model1.pt", map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.eval()

    
    return model, tokenizer, unique_tags

def predict(texts, model, tokenizer, unique_tags, top_threshold=0.95, max_tags=4):
    device = next(model.parameters()).device
    model.eval()
    predictions = []
    
    with torch.no_grad():
        for text in texts:
            inputs = tokenizer(
                text,
                max_length=512,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            ).to(device)
            
            outputs = model(**inputs)
            raw_probs = torch.softmax(outputs.logits, dim=1).cpu().numpy().flatten()
            
            sorted_indices = raw_probs.argsort()[::-1]
            cumulative = 0.0
            selected = []
            
            for idx in sorted_indices:
                prob = raw_probs[idx]
                cumulative += prob
                selected.append((unique_tags[idx], float(prob)))
                if cumulative >= top_threshold or len(selected) == max_tags:
                    break             
            predictions.append(selected)
    
    return predictions

# Загружаем всё необходимое (один раз при запуске)
model, tokenizer, unique_tags = load_model_and_tokenizer()

st.title("Предсказывание тегов статей")
title_input = st.text_input("Введите Title статьи:")
abstract_input = st.text_area("Введите Abstract статьи:")

if st.button("Предсказать"):
    full_text = (title_input.strip() + " " + abstract_input.strip()).strip()
    if len(full_text) > 0:
        predictions = predict([full_text], model, tokenizer, unique_tags)
        
        st.subheader("Предсказанные теги:")
        for tag, prob in predictions[0]:
            st.write(f"• **{tag}** (вероятность={prob:.4f})")
    else:
        st.warning("Пожалуйста, введите хотя бы что-то в Title или Abstract.")