hw4 / app.py
aduleba's picture
[UPD] softmax model
1909343
raw
history blame
2.94 kB
import streamlit as st
import torch
import json
from transformers import DistilBertForSequenceClassification, DistilBertTokenizer, DistilBertConfig
@st.cache_resource
def load_model_and_tokenizer():
with open('unique_tags.json', 'r', encoding='utf-8') as f:
unique_tags = json.load(f)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = len(unique_tags)
config = DistilBertConfig.from_pretrained(
'distilbert-base-cased',
num_labels=num_classes,
problem_type="single_label_classification"
)
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-cased', config=config)
model.to(device)
checkpoint = torch.load("best_model1.pt", map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
return model, tokenizer, unique_tags
def predict(texts, model, tokenizer, unique_tags, top_threshold=0.95, max_tags=4):
device = next(model.parameters()).device
model.eval()
predictions = []
with torch.no_grad():
for text in texts:
inputs = tokenizer(
text,
max_length=512,
padding='max_length',
truncation=True,
return_tensors='pt'
).to(device)
outputs = model(**inputs)
raw_probs = torch.softmax(outputs.logits, dim=1).cpu().numpy().flatten()
sorted_indices = raw_probs.argsort()[::-1]
cumulative = 0.0
selected = []
for idx in sorted_indices:
prob = raw_probs[idx]
cumulative += prob
selected.append((unique_tags[idx], float(prob)))
if cumulative >= top_threshold or len(selected) == max_tags:
break
predictions.append(selected)
return predictions
# Загружаем всё необходимое (один раз при запуске)
model, tokenizer, unique_tags = load_model_and_tokenizer()
st.title("Предсказывание тегов статей")
title_input = st.text_input("Введите Title статьи:")
abstract_input = st.text_area("Введите Abstract статьи:")
if st.button("Предсказать"):
full_text = (title_input.strip() + " " + abstract_input.strip()).strip()
if len(full_text) > 0:
predictions = predict([full_text], model, tokenizer, unique_tags)
st.subheader("Предсказанные теги:")
for tag, prob in predictions[0]:
st.write(f"• **{tag}** (вероятность={prob:.4f})")
else:
st.warning("Пожалуйста, введите хотя бы что-то в Title или Abstract.")