transpolymer commited on
Commit
dd45972
·
verified ·
1 Parent(s): a70af4b

Update prediction.py

Browse files
Files changed (1) hide show
  1. prediction.py +22 -22
prediction.py CHANGED
@@ -5,15 +5,18 @@ from transformers import AutoTokenizer, AutoModel
5
  from rdkit import Chem
6
  from rdkit.Chem import AllChem, Descriptors
7
  from torch import nn
8
- import pandas as pd
9
- import requests
10
- import datetime
11
  from db import get_database # Assuming you have a file db.py with get_database function to connect to MongoDB
12
 
13
- # Model Setup
14
- tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
15
- chemberta = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
16
- chemberta.eval()
 
 
 
 
 
17
 
18
  # Define your model architecture
19
  class TransformerRegressor(nn.Module):
@@ -35,10 +38,16 @@ class TransformerRegressor(nn.Module):
35
  aggregated = encoded.mean(dim=1)
36
  return self.regression_head(aggregated)
37
 
38
- # Load model
39
- model = TransformerRegressor()
40
- model.load_state_dict(torch.load("transformer_model.pt", map_location=torch.device('cpu')))
41
- model.eval()
 
 
 
 
 
 
42
 
43
  # Feature Functions
44
  descriptor_fns = [Descriptors.MolWt, Descriptors.MolLogP, Descriptors.TPSA,
@@ -69,15 +78,7 @@ def embed_smiles(smiles_list):
69
  outputs = chemberta(**inputs)
70
  return outputs.last_hidden_state[:, 0, :]
71
 
72
- # Function to validate SMILES string
73
- def is_valid_smiles(smiles):
74
- """ Validate if the input is a valid SMILES string using RDKit """
75
- mol = Chem.MolFromSmiles(smiles)
76
- return mol is not None
77
-
78
  # Function to save prediction to MongoDB
79
- from datetime import datetime
80
-
81
  def save_to_db(smiles, predictions):
82
  # Convert all prediction values to native Python float
83
  predictions_clean = {k: float(v) for k, v in predictions.items()}
@@ -88,11 +89,10 @@ def save_to_db(smiles, predictions):
88
  "timestamp": datetime.now()
89
  }
90
 
91
- db = get_database()
92
  collection = db["polymer_predictions"]
93
  collection.insert_one(doc)
94
 
95
-
96
  # Prediction Page UI
97
  def show():
98
  st.markdown("<h1 style='text-align: center; color: #4CAF50;'>🔬 Polymer Property Prediction</h1>", unsafe_allow_html=True)
@@ -130,4 +130,4 @@ def show():
130
 
131
  # Save the prediction to MongoDB
132
  save_to_db(smiles_input, predictions)
133
- st.success("Prediction saved successfully!")
 
5
  from rdkit import Chem
6
  from rdkit.Chem import AllChem, Descriptors
7
  from torch import nn
8
+ from datetime import datetime
 
 
9
  from db import get_database # Assuming you have a file db.py with get_database function to connect to MongoDB
10
 
11
+ # Load tokenizer and ChemBERTa
12
+ @st.cache_resource
13
+ def load_chemberta():
14
+ tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
15
+ model = AutoModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
16
+ model.eval()
17
+ return tokenizer, model
18
+
19
+ tokenizer, chemberta = load_chemberta()
20
 
21
  # Define your model architecture
22
  class TransformerRegressor(nn.Module):
 
38
  aggregated = encoded.mean(dim=1)
39
  return self.regression_head(aggregated)
40
 
41
+ # Load your saved model
42
+ @st.cache_resource
43
+ def load_regression_model():
44
+ model = TransformerRegressor()
45
+ state_dict = torch.load("transformer_model.pt", map_location=torch.device("cpu"))
46
+ model.load_state_dict(state_dict)
47
+ model.eval()
48
+ return model
49
+
50
+ model = load_regression_model()
51
 
52
  # Feature Functions
53
  descriptor_fns = [Descriptors.MolWt, Descriptors.MolLogP, Descriptors.TPSA,
 
78
  outputs = chemberta(**inputs)
79
  return outputs.last_hidden_state[:, 0, :]
80
 
 
 
 
 
 
 
81
  # Function to save prediction to MongoDB
 
 
82
  def save_to_db(smiles, predictions):
83
  # Convert all prediction values to native Python float
84
  predictions_clean = {k: float(v) for k, v in predictions.items()}
 
89
  "timestamp": datetime.now()
90
  }
91
 
92
+ db = get_database() # Connect to MongoDB
93
  collection = db["polymer_predictions"]
94
  collection.insert_one(doc)
95
 
 
96
  # Prediction Page UI
97
  def show():
98
  st.markdown("<h1 style='text-align: center; color: #4CAF50;'>🔬 Polymer Property Prediction</h1>", unsafe_allow_html=True)
 
130
 
131
  # Save the prediction to MongoDB
132
  save_to_db(smiles_input, predictions)
133
+ st.success("Prediction saved successfully!")