khalooei
commited on
Commit
·
ccadb41
1
Parent(s):
9dbb753
update app
Browse files
app.py
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
import torchvision
|
|
@@ -12,7 +15,6 @@ import time
|
|
| 12 |
from datetime import datetime
|
| 13 |
import gradio as gr
|
| 14 |
|
| 15 |
-
# LeNet for MNIST
|
| 16 |
class LeNet(nn.Module):
|
| 17 |
def __init__(self):
|
| 18 |
super(LeNet, self).__init__()
|
|
@@ -138,7 +140,7 @@ def get_dataset_and_transform(dataset_name):
|
|
| 138 |
transforms.Normalize((0.1307,), (0.3081,))
|
| 139 |
])
|
| 140 |
dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
|
| 141 |
-
else:
|
| 142 |
transform = transforms.Compose([
|
| 143 |
transforms.Resize((224, 224)),
|
| 144 |
transforms.ToTensor(),
|
|
@@ -173,7 +175,7 @@ def initialize_model(model_name, device):
|
|
| 173 |
|
| 174 |
def layer_sustainability_analysis(dataset_name, model_name, selected_attacks, num_batches, output_dir_base='outputs'):
|
| 175 |
start_time = time.time()
|
| 176 |
-
logs = []
|
| 177 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 178 |
|
| 179 |
dataset, _ = get_dataset_and_transform(dataset_name)
|
|
@@ -295,7 +297,7 @@ def layer_sustainability_analysis(dataset_name, model_name, selected_attacks, nu
|
|
| 295 |
mvl_plot_paths.append(None)
|
| 296 |
|
| 297 |
return [
|
| 298 |
-
None,
|
| 299 |
cm_plot_path,
|
| 300 |
*mvl_plot_paths[:5],
|
| 301 |
integrated_mvl_plot_path,
|
|
@@ -340,9 +342,9 @@ def create_interface():
|
|
| 340 |
gr.Markdown("# Layer-wise Sustainability Analysis")
|
| 341 |
gr.Markdown(paper_info_html)
|
| 342 |
|
| 343 |
-
|
| 344 |
-
dataset_input = gr.Dropdown(datasets, label="Select Dataset", value=
|
| 345 |
-
model_input = gr.Dropdown(get_models_for_dataset(
|
| 346 |
model_text = gr.Textbox(value="LeNet", visible=False, interactive=False, label="Model")
|
| 347 |
|
| 348 |
attack_input = gr.CheckboxGroup(choices=attacks, label="Select Attacks", value=attacks)
|
|
@@ -363,7 +365,7 @@ def create_interface():
|
|
| 363 |
with gr.Tab("Model Statistics"):
|
| 364 |
stats_output = gr.Markdown("## Model Statistics")
|
| 365 |
with gr.Tab("Logs"):
|
| 366 |
-
log_output = gr.Textbox(label="Processing Logs")
|
| 367 |
|
| 368 |
dataset_input.change(
|
| 369 |
fn=update_models,
|
|
|
|
| 1 |
+
# Developed by Mohammad Khalooei
|
| 2 |
+
# More information and contact: https://github.com/khalooei/LSA
|
| 3 |
+
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
| 6 |
import torchvision
|
|
|
|
| 15 |
from datetime import datetime
|
| 16 |
import gradio as gr
|
| 17 |
|
|
|
|
| 18 |
class LeNet(nn.Module):
|
| 19 |
def __init__(self):
|
| 20 |
super(LeNet, self).__init__()
|
|
|
|
| 140 |
transforms.Normalize((0.1307,), (0.3081,))
|
| 141 |
])
|
| 142 |
dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
|
| 143 |
+
else:
|
| 144 |
transform = transforms.Compose([
|
| 145 |
transforms.Resize((224, 224)),
|
| 146 |
transforms.ToTensor(),
|
|
|
|
| 175 |
|
| 176 |
def layer_sustainability_analysis(dataset_name, model_name, selected_attacks, num_batches, output_dir_base='outputs'):
|
| 177 |
start_time = time.time()
|
| 178 |
+
logs = ["BSM:: experiment is being started ..."]
|
| 179 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 180 |
|
| 181 |
dataset, _ = get_dataset_and_transform(dataset_name)
|
|
|
|
| 297 |
mvl_plot_paths.append(None)
|
| 298 |
|
| 299 |
return [
|
| 300 |
+
None,
|
| 301 |
cm_plot_path,
|
| 302 |
*mvl_plot_paths[:5],
|
| 303 |
integrated_mvl_plot_path,
|
|
|
|
| 342 |
gr.Markdown("# Layer-wise Sustainability Analysis")
|
| 343 |
gr.Markdown(paper_info_html)
|
| 344 |
|
| 345 |
+
initial_input="MNIST"
|
| 346 |
+
dataset_input = gr.Dropdown(datasets, label="Select Dataset", value=initial_input)
|
| 347 |
+
model_input = gr.Dropdown(get_models_for_dataset(initial_input), value=get_models_for_dataset(initial_input)[0], label="Select Model")
|
| 348 |
model_text = gr.Textbox(value="LeNet", visible=False, interactive=False, label="Model")
|
| 349 |
|
| 350 |
attack_input = gr.CheckboxGroup(choices=attacks, label="Select Attacks", value=attacks)
|
|
|
|
| 365 |
with gr.Tab("Model Statistics"):
|
| 366 |
stats_output = gr.Markdown("## Model Statistics")
|
| 367 |
with gr.Tab("Logs"):
|
| 368 |
+
log_output = gr.Textbox(label="Processing Logs", lines=15, interactive=False)
|
| 369 |
|
| 370 |
dataset_input.change(
|
| 371 |
fn=update_models,
|