khalooei commited on
Commit
ccadb41
·
1 Parent(s): 9dbb753

update app

Browse files
Files changed (1) hide show
  1. app.py +10 -8
app.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torchvision
@@ -12,7 +15,6 @@ import time
12
  from datetime import datetime
13
  import gradio as gr
14
 
15
- # LeNet for MNIST
16
  class LeNet(nn.Module):
17
  def __init__(self):
18
  super(LeNet, self).__init__()
@@ -138,7 +140,7 @@ def get_dataset_and_transform(dataset_name):
138
  transforms.Normalize((0.1307,), (0.3081,))
139
  ])
140
  dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
141
- else: # CIFAR-10
142
  transform = transforms.Compose([
143
  transforms.Resize((224, 224)),
144
  transforms.ToTensor(),
@@ -173,7 +175,7 @@ def initialize_model(model_name, device):
173
 
174
  def layer_sustainability_analysis(dataset_name, model_name, selected_attacks, num_batches, output_dir_base='outputs'):
175
  start_time = time.time()
176
- logs = []
177
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
178
 
179
  dataset, _ = get_dataset_and_transform(dataset_name)
@@ -295,7 +297,7 @@ def layer_sustainability_analysis(dataset_name, model_name, selected_attacks, nu
295
  mvl_plot_paths.append(None)
296
 
297
  return [
298
- None, # no error
299
  cm_plot_path,
300
  *mvl_plot_paths[:5],
301
  integrated_mvl_plot_path,
@@ -340,9 +342,9 @@ def create_interface():
340
  gr.Markdown("# Layer-wise Sustainability Analysis")
341
  gr.Markdown(paper_info_html)
342
 
343
- default_input="MNIST"
344
- dataset_input = gr.Dropdown(datasets, label="Select Dataset", value=default_input)
345
- model_input = gr.Dropdown(get_models_for_dataset(default_input), value=get_models_for_dataset(default_input)[0], label="Select Model")
346
  model_text = gr.Textbox(value="LeNet", visible=False, interactive=False, label="Model")
347
 
348
  attack_input = gr.CheckboxGroup(choices=attacks, label="Select Attacks", value=attacks)
@@ -363,7 +365,7 @@ def create_interface():
363
  with gr.Tab("Model Statistics"):
364
  stats_output = gr.Markdown("## Model Statistics")
365
  with gr.Tab("Logs"):
366
- log_output = gr.Textbox(label="Processing Logs")
367
 
368
  dataset_input.change(
369
  fn=update_models,
 
1
+ # Developed by Mohammad Khalooei
2
+ # More information and contact: https://github.com/khalooei/LSA
3
+
4
  import torch
5
  import torch.nn as nn
6
  import torchvision
 
15
  from datetime import datetime
16
  import gradio as gr
17
 
 
18
  class LeNet(nn.Module):
19
  def __init__(self):
20
  super(LeNet, self).__init__()
 
140
  transforms.Normalize((0.1307,), (0.3081,))
141
  ])
142
  dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
143
+ else:
144
  transform = transforms.Compose([
145
  transforms.Resize((224, 224)),
146
  transforms.ToTensor(),
 
175
 
176
  def layer_sustainability_analysis(dataset_name, model_name, selected_attacks, num_batches, output_dir_base='outputs'):
177
  start_time = time.time()
178
+ logs = ["BSM:: experiment is being started ..."]
179
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
180
 
181
  dataset, _ = get_dataset_and_transform(dataset_name)
 
297
  mvl_plot_paths.append(None)
298
 
299
  return [
300
+ None,
301
  cm_plot_path,
302
  *mvl_plot_paths[:5],
303
  integrated_mvl_plot_path,
 
342
  gr.Markdown("# Layer-wise Sustainability Analysis")
343
  gr.Markdown(paper_info_html)
344
 
345
+ initial_input="MNIST"
346
+ dataset_input = gr.Dropdown(datasets, label="Select Dataset", value=initial_input)
347
+ model_input = gr.Dropdown(get_models_for_dataset(initial_input), value=get_models_for_dataset(initial_input)[0], label="Select Model")
348
  model_text = gr.Textbox(value="LeNet", visible=False, interactive=False, label="Model")
349
 
350
  attack_input = gr.CheckboxGroup(choices=attacks, label="Select Attacks", value=attacks)
 
365
  with gr.Tab("Model Statistics"):
366
  stats_output = gr.Markdown("## Model Statistics")
367
  with gr.Tab("Logs"):
368
+ log_output = gr.Textbox(label="Processing Logs", lines=15, interactive=False)
369
 
370
  dataset_input.change(
371
  fn=update_models,