|
|
import streamlit as st |
|
|
import torch |
|
|
import json |
|
|
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer, DistilBertConfig |
|
|
|
|
|
@st.cache_resource |
|
|
def load_model_and_tokenizer(): |
|
|
with open('unique_tags.json', 'r', encoding='utf-8') as f: |
|
|
unique_tags = json.load(f) |
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
num_classes = len(unique_tags) |
|
|
config = DistilBertConfig.from_pretrained( |
|
|
'distilbert-base-cased', |
|
|
num_labels=num_classes, |
|
|
problem_type="single_label_classification" |
|
|
) |
|
|
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased') |
|
|
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-cased', config=config) |
|
|
model.to(device) |
|
|
|
|
|
checkpoint = torch.load("best_model1.pt", map_location=device) |
|
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
return model, tokenizer, unique_tags |
|
|
|
|
|
def predict(texts, model, tokenizer, unique_tags, top_threshold=0.95, max_tags=4): |
|
|
device = next(model.parameters()).device |
|
|
model.eval() |
|
|
predictions = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for text in texts: |
|
|
inputs = tokenizer( |
|
|
text, |
|
|
max_length=512, |
|
|
padding='max_length', |
|
|
truncation=True, |
|
|
return_tensors='pt' |
|
|
).to(device) |
|
|
|
|
|
outputs = model(**inputs) |
|
|
raw_probs = torch.softmax(outputs.logits, dim=1).cpu().numpy().flatten() |
|
|
|
|
|
sorted_indices = raw_probs.argsort()[::-1] |
|
|
cumulative = 0.0 |
|
|
selected = [] |
|
|
|
|
|
for idx in sorted_indices: |
|
|
prob = raw_probs[idx] |
|
|
cumulative += prob |
|
|
selected.append((unique_tags[idx], float(prob))) |
|
|
if cumulative >= top_threshold or len(selected) == max_tags: |
|
|
break |
|
|
predictions.append(selected) |
|
|
|
|
|
return predictions |
|
|
|
|
|
|
|
|
model, tokenizer, unique_tags = load_model_and_tokenizer() |
|
|
|
|
|
st.title("Предсказывание тегов статей") |
|
|
title_input = st.text_input("Введите Title статьи:") |
|
|
abstract_input = st.text_area("Введите Abstract статьи:") |
|
|
|
|
|
if st.button("Предсказать"): |
|
|
full_text = (title_input.strip() + " " + abstract_input.strip()).strip() |
|
|
if len(full_text) > 0: |
|
|
predictions = predict([full_text], model, tokenizer, unique_tags) |
|
|
|
|
|
st.subheader("Предсказанные теги:") |
|
|
for tag, prob in predictions[0]: |
|
|
st.write(f"• **{tag}** (вероятность={prob:.4f})") |
|
|
else: |
|
|
st.warning("Пожалуйста, введите хотя бы что-то в Title или Abstract.") |
|
|
|