khalooei commited on
Commit
7d45691
·
1 Parent(s): 8e74c98

initial commit

Browse files
Files changed (2) hide show
  1. app.py +312 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchvision
4
+ import torchvision.transforms as transforms
5
+ from torchvision.models import vgg16, vgg19, googlenet, resnet18
6
+ import timm
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ from torchattacks import FGSM, PGD, APGD
10
+ import os
11
+ import time
12
+ from datetime import datetime
13
+ import gradio as gr
14
+
15
+ class LeNet(nn.Module):
16
+ def __init__(self):
17
+ super(LeNet, self).__init__()
18
+ self.conv1 = nn.Conv2d(1, 6, 5)
19
+ self.conv2 = nn.Conv2d(6, 16, 5)
20
+ self.fc1 = nn.Linear(16 * 4 * 4, 120)
21
+ self.fc2 = nn.Linear(120, 84)
22
+ self.fc3 = nn.Linear(84, 10)
23
+ self.relu = nn.ReLU()
24
+ self.pool = nn.MaxPool2d(2, 2)
25
+
26
+ def forward(self, x, return_all=False):
27
+ outputs = []
28
+ x1 = self.pool(self.relu(self.conv1(x)))
29
+ outputs.append(x1)
30
+ x2 = self.pool(self.relu(self.conv2(x1)))
31
+ outputs.append(x2)
32
+ x2_flat = x2.view(-1, 16 * 4 * 4)
33
+ x3 = self.relu(self.fc1(x2_flat))
34
+ outputs.append(x3)
35
+ x4 = self.relu(self.fc2(x3))
36
+ outputs.append(x4)
37
+ x5 = self.fc3(x4)
38
+ outputs.append(x5)
39
+ if return_all:
40
+ return outputs
41
+ else:
42
+ return x5
43
+
44
+ def salt_pepper_noise(images, prob=0.01, device='cuda'):
45
+ batch_smap = torch.rand_like(images) < prob / 2
46
+ pepper = torch.rand_like(images) < prob / 2
47
+ noisy = images.clone()
48
+ noisy[batch_smap] = 1.0
49
+ noisy[pepper] = 0.0
50
+ return torch.clamp(noisy, 0, 1)
51
+
52
+ def pepper_statistical_noise(images, prob=0.01, device='cuda'):
53
+ pepper = torch.rand_like(images) < prob
54
+ noisy = images.clone()
55
+ noisy[pepper] = 0.0
56
+ return torch.clamp(noisy, 0, 1)
57
+
58
+ def get_layer_outputs(model, input_tensor):
59
+ outputs = []
60
+ def hook(module, input, output):
61
+ outputs.append(output)
62
+ hooks = []
63
+ for layer in model.modules():
64
+ if isinstance(layer, (nn.Conv2d, nn.Linear)):
65
+ hooks.append(layer.register_forward_hook(hook))
66
+ model.eval()
67
+ with torch.no_grad():
68
+ model(input_tensor)
69
+ for hook in hooks:
70
+ hook.remove()
71
+ return outputs
72
+
73
+ def compute_mvl(model, clean_images, adv_images, device='cuda'):
74
+ model.eval()
75
+ with torch.no_grad():
76
+ try:
77
+ clean_outputs = model(clean_images, return_all=True)
78
+ adv_outputs = model(adv_images, return_all=True)
79
+ except TypeError:
80
+ clean_outputs = get_layer_outputs(model, clean_images)
81
+ adv_outputs = get_layer_outputs(model, adv_images)
82
+
83
+ mvl_list = []
84
+ for clean_out, adv_out in zip(clean_outputs, adv_outputs):
85
+ if clean_out.ndim == 4:
86
+ diff = torch.norm(clean_out - adv_out, p=2, dim=(1,2,3))
87
+ clean_norm = torch.norm(clean_out, p=2, dim=(1,2,3))
88
+ else:
89
+ diff = torch.norm(clean_out - adv_out, p=2, dim=1)
90
+ clean_norm = torch.norm(clean_out, p=2, dim=1)
91
+ mvl = diff / (clean_norm + 1e-8)
92
+ mvl_list.append(mvl.mean().item())
93
+ return mvl_list
94
+
95
+ def get_model_stats(model):
96
+ param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
97
+ layer_count = len([m for m in model.modules() if isinstance(m, (nn.Conv2d, nn.Linear))])
98
+ return param_count, layer_count
99
+
100
+ def modify_model(model, model_name):
101
+ if model_name.startswith('VGG'):
102
+ model.classifier[6] = nn.Linear(4096, 10)
103
+ elif model_name == 'GoogLeNet':
104
+ model.fc = nn.Linear(1024, 10)
105
+ elif model_name == 'ResNet18':
106
+ model.fc = nn.Linear(512, 10)
107
+ elif model_name == 'WideResNet':
108
+ model.fc = nn.Linear(2048, 10)
109
+ elif model_name == 'DenseNet121':
110
+ model.classifier = nn.Linear(model.classifier.in_features, 10)
111
+ elif model_name == 'MobileNetV2':
112
+ if isinstance(model.classifier, nn.Sequential):
113
+ model.classifier[1] = nn.Linear(model.classifier[1].in_features, 10)
114
+ else:
115
+ model.classifier = nn.Linear(model.classifier.in_features, 10)
116
+ elif model_name == 'EfficientNet-B0':
117
+ model.classifier = nn.Linear(model.classifier.in_features, 10)
118
+ return model
119
+
120
+ def get_models_for_dataset(dataset_name):
121
+ if dataset_name == 'MNIST':
122
+ return ['LeNet']
123
+ elif dataset_name == 'CIFAR-10':
124
+ return [
125
+ 'VGG16', 'VGG19', 'GoogLeNet', 'ResNet18', 'WideResNet',
126
+ 'DenseNet121', 'MobileNetV2', 'EfficientNet-B0'
127
+ ]
128
+ else:
129
+ return []
130
+
131
+ def get_dataset_and_transform(dataset_name):
132
+ if dataset_name == 'MNIST':
133
+ transform = transforms.Compose([
134
+ transforms.Resize((28, 28)),
135
+ transforms.Grayscale(num_output_channels=1),
136
+ transforms.ToTensor(),
137
+ transforms.Normalize((0.1307,), (0.3081,))
138
+ ])
139
+ dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
140
+ else: # CIFAR-10
141
+ transform = transforms.Compose([
142
+ transforms.Resize((224, 224)),
143
+ transforms.ToTensor(),
144
+ transforms.Normalize((0.485, 0.456, 0.406),
145
+ (0.229, 0.224, 0.225))
146
+ ])
147
+ dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
148
+ return dataset, transform
149
+
150
+ def initialize_model(model_name, device):
151
+ if model_name == 'LeNet':
152
+ model = LeNet()
153
+ elif model_name == 'VGG16':
154
+ model = modify_model(vgg16(weights='IMAGENET1K_V1'), model_name)
155
+ elif model_name == 'VGG19':
156
+ model = modify_model(vgg19(weights='IMAGENET1K_V1'), model_name)
157
+ elif model_name == 'GoogLeNet':
158
+ model = modify_model(googlenet(weights='IMAGENET1K_V1'), model_name)
159
+ elif model_name == 'ResNet18':
160
+ model = modify_model(resnet18(weights='IMAGENET1K_V1'), model_name)
161
+ elif model_name == 'WideResNet':
162
+ model = modify_model(timm.create_model('wide_resnet50_2', pretrained=True), model_name)
163
+ elif model_name == 'DenseNet121':
164
+ model = modify_model(timm.create_model('densenet121', pretrained=True), model_name)
165
+ elif model_name == 'MobileNetV2':
166
+ model = modify_model(timm.create_model('mobilenetv2_100', pretrained=True), model_name)
167
+ elif model_name == 'EfficientNet-B0':
168
+ model = modify_model(timm.create_model('efficientnet_b0', pretrained=True), model_name)
169
+ else:
170
+ raise ValueError(f"Unknown model {model_name}")
171
+ return model.to(device)
172
+
173
+ def layer_sustainability_analysis(dataset_name, model_name, selected_attacks, num_batches, output_dir_base='outputs'):
174
+ start_time = time.time()
175
+ logs = []
176
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
177
+
178
+ dataset, transform = get_dataset_and_transform(dataset_name)
179
+ testloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False)
180
+ logs.append(f"{dataset_name} dataset loaded")
181
+
182
+ model = initialize_model(model_name, device)
183
+ logs.append(f"Model {model_name} loaded on {device}")
184
+
185
+ param_count, layer_count = get_model_stats(model)
186
+ logs.append(f"Model stats: {param_count} parameters, {layer_count} layers")
187
+
188
+ all_attacks = {
189
+ 'FGSM': FGSM(model, eps=0.03),
190
+ 'PGD': PGD(model, eps=0.03, alpha=0.01, steps=40, random_start=True),
191
+ 'APGD': APGD(model, eps=0.03, steps=100, loss='ce'),
192
+ 'Salt & Pepper': lambda x, y: salt_pepper_noise(x, prob=0.01, device=device),
193
+ 'Pepper Statistical': lambda x, y: pepper_statistical_noise(x, prob=0.01, device=device)
194
+ }
195
+ attacks = {name: attack for name, attack in all_attacks.items() if name in selected_attacks}
196
+ if not attacks:
197
+ logs.append("Error: No valid attacks selected")
198
+ return ["No valid attacks selected", None] + [None]*6 + ["", '\n'.join(logs)]
199
+ logs.append(f"Selected attacks: {', '.join(attacks.keys())}")
200
+
201
+ results = {attack_name: {'cm': [], 'mvl': []} for attack_name in attacks}
202
+
203
+ for i, (images, labels) in enumerate(testloader):
204
+ if i >= num_batches:
205
+ break
206
+ images, labels = images.to(device), labels.to(device)
207
+ logs.append(f"Processing batch {i+1}/{num_batches}...")
208
+
209
+ for attack_name, attack in attacks.items():
210
+ adv_images = attack(images, labels)
211
+ mvl_list = compute_mvl(model, images, adv_images, device)
212
+ results[attack_name]['mvl'].append(mvl_list)
213
+ cm = np.mean(mvl_list)
214
+ results[attack_name]['cm'].append(cm)
215
+
216
+ # Placeholders for plots (add your plot generation here)
217
+ cm_plot_path = None
218
+ mvl_plot_paths = [None]*5
219
+ integrated_mvl_plot_path = None
220
+
221
+ processing_time = time.time() - start_time
222
+
223
+ stats = {
224
+ 'Dataset': dataset_name,
225
+ 'Model': model_name,
226
+ 'Parameter Count': param_count,
227
+ 'Layer Count': layer_count,
228
+ 'Processing Time (s)': round(processing_time, 2),
229
+ 'Number of Batches': num_batches,
230
+ 'Attacks Used': ', '.join(attacks.keys())
231
+ }
232
+ stats_text = "## Model Statistics\n\n| Metric | Value |\n|--------|-------|\n"
233
+ for k,v in stats.items():
234
+ stats_text += f"| {k} | {v} |\n"
235
+
236
+ return [None, cm_plot_path] + mvl_plot_paths[:5] + [integrated_mvl_plot_path, stats_text, '\n'.join(logs)]
237
+
238
+ paper_info_html = """
239
+ <div style="border: 1px solid #ccc; padding: 15px; border-radius: 8px; margin-bottom: 15px;">
240
+ <h2>Layer-wise Regularized Adversarial Training Using Layers Sustainability Analysis Framework</h2>
241
+ <h3>Authors</h3>
242
+ <p>Mohammad Khalooei, Mohammad Mehdi Homaypour, Maryam Amirmazlaghani</p>
243
+
244
+ <h3>Abstract</h3>
245
+ <ul>
246
+ <li>The layer sustainability analysis (LSA) framework is introduced to evaluate the behavior of layer-level representations of DNNs in dealing with network input perturbations using Lipschitz theoretical concepts.</li>
247
+ <li>A layer-wise regularized adversarial training (AT-LR) approach significantly improves the generalization and robustness of different deep neural network architectures for significant perturbations while reducing layer-level vulnerabilities.</li>
248
+ <li>AT-LR loss landscapes for each LSA MVL proposal can interpret layer importance for different layers, which is an intriguing aspect.</li>
249
+ </ul>
250
+
251
+ <h3>Links</h3>
252
+ <ul>
253
+ <li><a href="https://arxiv.org/abs/2202.02626" target="_blank">ArXiv Paper</a></li>
254
+ <li><a href="https://github.com/khalooei/LSA" target="_blank">GitHub Repository</a></li>
255
+ <li><a href="https://www.sciencedirect.com/science/article/abs/pii/S0925231223002928" target="_blank">ScienceDirect Article</a></li>
256
+ </ul>
257
+ </div>
258
+ """
259
+
260
+ def update_models(dataset_name):
261
+ models = get_models_for_dataset(dataset_name)
262
+ default_value = models[0] if models else None
263
+ return models, default_value # Return choices and default value as a tuple
264
+
265
+ def create_interface():
266
+ datasets = ['MNIST', 'CIFAR-10']
267
+ attacks = ['FGSM', 'PGD', 'APGD', 'Salt & Pepper', 'Pepper Statistical']
268
+
269
+ with gr.Blocks() as interface:
270
+ gr.Markdown("# Layer-wise Sustainability Analysis")
271
+ gr.Markdown(paper_info_html)
272
+
273
+ dataset_input = gr.Dropdown(datasets, label="Select Dataset", value='CIFAR-10')
274
+ model_input = gr.Dropdown(get_models_for_dataset('CIFAR-10'), label="Select Model")
275
+ attack_input = gr.CheckboxGroup(choices=attacks, label="Select Attacks", value=attacks)
276
+ batch_input = gr.Slider(minimum=1, maximum=20, step=1, value=5, label="Number of Batches")
277
+ run_button = gr.Button("Run Analysis")
278
+
279
+ error_output = gr.Textbox(label="Error", visible=False)
280
+ cm_output = gr.Image(label="Comparative Measure (CM)")
281
+
282
+ with gr.Tabs():
283
+ mvl_outputs = []
284
+ for attack in attacks:
285
+ with gr.Tab(f"MVL: {attack}"):
286
+ mvl_output = gr.Image(label=f"MVL for {attack}")
287
+ mvl_outputs.append(mvl_output)
288
+ with gr.Tab("Integrated MVL"):
289
+ integrated_mvl_output = gr.Image(label="Integrated MVL for All Attacks")
290
+ with gr.Tab("Model Statistics"):
291
+ stats_output = gr.Markdown("## Model Statistics")
292
+ with gr.Tab("Logs"):
293
+ log_output = gr.Textbox(label="Processing Logs")
294
+
295
+ # Return choices and value separately for older gradio versions
296
+ dataset_input.change(
297
+ fn=update_models,
298
+ inputs=dataset_input,
299
+ outputs=[model_input, model_input]
300
+ )
301
+
302
+ run_button.click(
303
+ fn=layer_sustainability_analysis,
304
+ inputs=[dataset_input, model_input, attack_input, batch_input],
305
+ outputs=[error_output, cm_output] + mvl_outputs + [integrated_mvl_output, stats_output, log_output]
306
+ )
307
+
308
+ return interface
309
+
310
+ if __name__ == '__main__':
311
+ interface = create_interface()
312
+ interface.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torchattacks
2
+ timm
3
+ gradio
4
+ datetime
5
+ torch
6
+ torchvision
7
+ numpy
8
+ matplotlib