Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import numpy as np | |
| from transformers import AutoTokenizer, AutoModel | |
| from rdkit import Chem | |
| from rdkit.Chem import AllChem, Descriptors | |
| from torch import nn | |
| from datetime import datetime | |
| from db import get_database # Assuming you have a file db.py with get_database function to connect to MongoDB | |
| # Load tokenizer and ChemBERTa model | |
| def load_chemberta(): | |
| tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") | |
| model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1") | |
| model.eval() | |
| return tokenizer, model | |
| tokenizer, chemberta = load_chemberta() | |
| # Define your model architecture | |
| class TransformerRegressor(nn.Module): | |
| def __init__(self, emb_dim=768, feat_dim=2058, output_dim=6, nhead=8, num_layers=2): | |
| super().__init__() | |
| self.feat_proj = nn.Linear(feat_dim, emb_dim) | |
| encoder_layer = nn.TransformerEncoderLayer(d_model=768, nhead=8, dim_feedforward=1024, dropout=0.1, batch_first=True) | |
| self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) | |
| self.regression_head = nn.Sequential( | |
| nn.Linear(emb_dim, 256), nn.ReLU(), | |
| nn.Linear(256, 128), nn.ReLU(), | |
| nn.Linear(128, output_dim) | |
| ) | |
| def forward(self, x, feat): | |
| feat_emb = self.feat_proj(feat) | |
| stacked = torch.stack([x, feat_emb], dim=1) | |
| encoded = self.transformer_encoder(stacked) | |
| aggregated = encoded.mean(dim=1) | |
| return self.regression_head(aggregated) | |
| # Load your saved model | |
| def load_regression_model(): | |
| model = TransformerRegressor() | |
| state_dict = torch.load("transformer_model.pt", map_location=torch.device("cpu")) | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| return model | |
| model = load_regression_model() | |
| # Feature Functions | |
| descriptor_fns = [Descriptors.MolWt, Descriptors.MolLogP, Descriptors.TPSA, | |
| Descriptors.NumRotatableBonds, Descriptors.NumHAcceptors, | |
| Descriptors.NumHDonors, Descriptors.RingCount, | |
| Descriptors.FractionCSP3, Descriptors.HeavyAtomCount, | |
| Descriptors.NHOHCount] | |
| def fix_smiles(s): | |
| try: | |
| mol = Chem.MolFromSmiles(s.strip()) | |
| if mol: | |
| return Chem.MolToSmiles(mol) | |
| except: | |
| return None | |
| return None | |
| def compute_features(smiles): | |
| mol = Chem.MolFromSmiles(smiles) | |
| if not mol: | |
| return [0]*10 + [0]*2048 | |
| desc = [fn(mol) for fn in descriptor_fns] | |
| fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048) | |
| return desc + list(fp) | |
| def embed_smiles(smiles_list): | |
| inputs = tokenizer(smiles_list, return_tensors="pt", padding=True, truncation=True, max_length=128) | |
| outputs = chemberta(**inputs) | |
| return outputs.last_hidden_state[:, 0, :] | |
| # Function to save prediction to MongoDB | |
| def save_to_db(smiles, predictions): | |
| # Convert all prediction values to native Python float | |
| predictions_clean = {k: float(v) for k, v in predictions.items()} | |
| doc = { | |
| "smiles": smiles, | |
| "predictions": predictions_clean, | |
| "timestamp": datetime.now() | |
| } | |
| db = get_database() # Connect to MongoDB | |
| collection = db["polymer_predictions"] | |
| collection.insert_one(doc) | |
| # Prediction Page UI | |
| def show(): | |
| st.markdown("<h1 style='text-align: center; color: #4CAF50;'>π¬ Polymer Property Prediction</h1>", unsafe_allow_html=True) | |
| st.markdown("<hr style='border: 1px solid #ccc;'>", unsafe_allow_html=True) | |
| smiles_input = st.text_input("Enter SMILES Representation of Polymer") | |
| if st.button("Predict"): | |
| fixed = fix_smiles(smiles_input) | |
| if not fixed: | |
| st.error("Invalid SMILES string.") | |
| else: | |
| features = compute_features(fixed) | |
| features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0) | |
| embedding = embed_smiles([fixed]) | |
| with torch.no_grad(): | |
| pred = model(embedding, features_tensor) | |
| result = pred.numpy().flatten() | |
| properties = [ | |
| "Tensile Strength", | |
| "Ionization Energy", | |
| "Electron Affinity", | |
| "logP", | |
| "Refractive Index", | |
| "Molecular Weight" | |
| ] | |
| predictions = {} | |
| st.success("Predicted Polymer Properties:") | |
| for prop, val in zip(properties, result): | |
| st.write(f"**{prop}**: {val:.4f}") | |
| predictions[prop] = val | |
| # Save the prediction to MongoDB | |
| save_to_db(smiles_input, predictions) | |
| st.success("Prediction saved successfully!") |