khalooei
commited on
Commit
·
4bd1b68
1
Parent(s):
7d45691
update app
Browse files
app.py
CHANGED
|
@@ -12,6 +12,7 @@ import time
|
|
| 12 |
from datetime import datetime
|
| 13 |
import gradio as gr
|
| 14 |
|
|
|
|
| 15 |
class LeNet(nn.Module):
|
| 16 |
def __init__(self):
|
| 17 |
super(LeNet, self).__init__()
|
|
@@ -41,6 +42,7 @@ class LeNet(nn.Module):
|
|
| 41 |
else:
|
| 42 |
return x5
|
| 43 |
|
|
|
|
| 44 |
def salt_pepper_noise(images, prob=0.01, device='cuda'):
|
| 45 |
batch_smap = torch.rand_like(images) < prob / 2
|
| 46 |
pepper = torch.rand_like(images) < prob / 2
|
|
@@ -55,6 +57,7 @@ def pepper_statistical_noise(images, prob=0.01, device='cuda'):
|
|
| 55 |
noisy[pepper] = 0.0
|
| 56 |
return torch.clamp(noisy, 0, 1)
|
| 57 |
|
|
|
|
| 58 |
def get_layer_outputs(model, input_tensor):
|
| 59 |
outputs = []
|
| 60 |
def hook(module, input, output):
|
|
@@ -128,6 +131,7 @@ def get_models_for_dataset(dataset_name):
|
|
| 128 |
else:
|
| 129 |
return []
|
| 130 |
|
|
|
|
| 131 |
def get_dataset_and_transform(dataset_name):
|
| 132 |
if dataset_name == 'MNIST':
|
| 133 |
transform = transforms.Compose([
|
|
@@ -175,16 +179,20 @@ def layer_sustainability_analysis(dataset_name, model_name, selected_attacks, nu
|
|
| 175 |
logs = []
|
| 176 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 177 |
|
|
|
|
| 178 |
dataset, transform = get_dataset_and_transform(dataset_name)
|
| 179 |
testloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False)
|
| 180 |
logs.append(f"{dataset_name} dataset loaded")
|
| 181 |
|
|
|
|
| 182 |
model = initialize_model(model_name, device)
|
| 183 |
logs.append(f"Model {model_name} loaded on {device}")
|
| 184 |
|
|
|
|
| 185 |
param_count, layer_count = get_model_stats(model)
|
| 186 |
logs.append(f"Model stats: {param_count} parameters, {layer_count} layers")
|
| 187 |
|
|
|
|
| 188 |
all_attacks = {
|
| 189 |
'FGSM': FGSM(model, eps=0.03),
|
| 190 |
'PGD': PGD(model, eps=0.03, alpha=0.01, steps=40, random_start=True),
|
|
@@ -198,41 +206,109 @@ def layer_sustainability_analysis(dataset_name, model_name, selected_attacks, nu
|
|
| 198 |
return ["No valid attacks selected", None] + [None]*6 + ["", '\n'.join(logs)]
|
| 199 |
logs.append(f"Selected attacks: {', '.join(attacks.keys())}")
|
| 200 |
|
| 201 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
|
|
|
|
| 203 |
for i, (images, labels) in enumerate(testloader):
|
| 204 |
if i >= num_batches:
|
| 205 |
break
|
| 206 |
images, labels = images.to(device), labels.to(device)
|
| 207 |
logs.append(f"Processing batch {i+1}/{num_batches}...")
|
| 208 |
|
| 209 |
-
for
|
| 210 |
-
adv_images =
|
| 211 |
-
|
| 212 |
-
results[
|
| 213 |
-
cm
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
processing_time = time.time() - start_time
|
| 222 |
|
| 223 |
stats = {
|
| 224 |
'Dataset': dataset_name,
|
| 225 |
'Model': model_name,
|
| 226 |
-
'
|
| 227 |
-
'
|
| 228 |
-
'
|
| 229 |
-
'
|
| 230 |
-
'
|
| 231 |
}
|
| 232 |
-
stats_text = "## Model Statistics\n\n| Metric | Value |\n
|
| 233 |
for k,v in stats.items():
|
| 234 |
stats_text += f"| {k} | {v} |\n"
|
| 235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
return [None, cm_plot_path] + mvl_plot_paths[:5] + [integrated_mvl_plot_path, stats_text, '\n'.join(logs)]
|
| 237 |
|
| 238 |
paper_info_html = """
|
|
@@ -260,7 +336,7 @@ paper_info_html = """
|
|
| 260 |
def update_models(dataset_name):
|
| 261 |
models = get_models_for_dataset(dataset_name)
|
| 262 |
default_value = models[0] if models else None
|
| 263 |
-
return models, default_value # Return choices and default value
|
| 264 |
|
| 265 |
def create_interface():
|
| 266 |
datasets = ['MNIST', 'CIFAR-10']
|
|
@@ -292,11 +368,10 @@ def create_interface():
|
|
| 292 |
with gr.Tab("Logs"):
|
| 293 |
log_output = gr.Textbox(label="Processing Logs")
|
| 294 |
|
| 295 |
-
# Return choices and value separately for older gradio versions
|
| 296 |
dataset_input.change(
|
| 297 |
fn=update_models,
|
| 298 |
inputs=dataset_input,
|
| 299 |
-
outputs=[model_input, model_input]
|
| 300 |
)
|
| 301 |
|
| 302 |
run_button.click(
|
|
|
|
| 12 |
from datetime import datetime
|
| 13 |
import gradio as gr
|
| 14 |
|
| 15 |
+
# LeNet for MNIST
|
| 16 |
class LeNet(nn.Module):
|
| 17 |
def __init__(self):
|
| 18 |
super(LeNet, self).__init__()
|
|
|
|
| 42 |
else:
|
| 43 |
return x5
|
| 44 |
|
| 45 |
+
# Noise functions
|
| 46 |
def salt_pepper_noise(images, prob=0.01, device='cuda'):
|
| 47 |
batch_smap = torch.rand_like(images) < prob / 2
|
| 48 |
pepper = torch.rand_like(images) < prob / 2
|
|
|
|
| 57 |
noisy[pepper] = 0.0
|
| 58 |
return torch.clamp(noisy, 0, 1)
|
| 59 |
|
| 60 |
+
# MVL calculation with hooks fallback
|
| 61 |
def get_layer_outputs(model, input_tensor):
|
| 62 |
outputs = []
|
| 63 |
def hook(module, input, output):
|
|
|
|
| 131 |
else:
|
| 132 |
return []
|
| 133 |
|
| 134 |
+
|
| 135 |
def get_dataset_and_transform(dataset_name):
|
| 136 |
if dataset_name == 'MNIST':
|
| 137 |
transform = transforms.Compose([
|
|
|
|
| 179 |
logs = []
|
| 180 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 181 |
|
| 182 |
+
# Prepare dataset & loader
|
| 183 |
dataset, transform = get_dataset_and_transform(dataset_name)
|
| 184 |
testloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=False)
|
| 185 |
logs.append(f"{dataset_name} dataset loaded")
|
| 186 |
|
| 187 |
+
# Init model
|
| 188 |
model = initialize_model(model_name, device)
|
| 189 |
logs.append(f"Model {model_name} loaded on {device}")
|
| 190 |
|
| 191 |
+
# Model stats
|
| 192 |
param_count, layer_count = get_model_stats(model)
|
| 193 |
logs.append(f"Model stats: {param_count} parameters, {layer_count} layers")
|
| 194 |
|
| 195 |
+
# Setup attacks
|
| 196 |
all_attacks = {
|
| 197 |
'FGSM': FGSM(model, eps=0.03),
|
| 198 |
'PGD': PGD(model, eps=0.03, alpha=0.01, steps=40, random_start=True),
|
|
|
|
| 206 |
return ["No valid attacks selected", None] + [None]*6 + ["", '\n'.join(logs)]
|
| 207 |
logs.append(f"Selected attacks: {', '.join(attacks.keys())}")
|
| 208 |
|
| 209 |
+
# Prepare output dir for plots
|
| 210 |
+
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
| 211 |
+
output_dir = os.path.join(output_dir_base, f"{model_name}_{timestamp}")
|
| 212 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 213 |
+
logs.append(f"Output directory: {output_dir}")
|
| 214 |
+
|
| 215 |
+
# Collect results
|
| 216 |
+
results = {atk: {'cm': [], 'mvl': []} for atk in attacks}
|
| 217 |
|
| 218 |
+
# Process batches
|
| 219 |
for i, (images, labels) in enumerate(testloader):
|
| 220 |
if i >= num_batches:
|
| 221 |
break
|
| 222 |
images, labels = images.to(device), labels.to(device)
|
| 223 |
logs.append(f"Processing batch {i+1}/{num_batches}...")
|
| 224 |
|
| 225 |
+
for atk_name, atk in attacks.items():
|
| 226 |
+
adv_images = atk(images, labels)
|
| 227 |
+
mvl_vals = compute_mvl(model, images, adv_images, device)
|
| 228 |
+
results[atk_name]['mvl'].append(mvl_vals)
|
| 229 |
+
results[atk_name]['cm'].append(np.mean(mvl_vals))
|
| 230 |
+
|
| 231 |
+
# Compute mean/std CM per attack
|
| 232 |
+
cm_means = {atk: np.mean(results[atk]['cm']) for atk in attacks}
|
| 233 |
+
cm_stds = {atk: np.std(results[atk]['cm']) for atk in attacks}
|
| 234 |
+
|
| 235 |
+
# Plot CM bar
|
| 236 |
+
plt.figure(figsize=(8,6))
|
| 237 |
+
names = list(attacks.keys())
|
| 238 |
+
means = [cm_means[n] for n in names]
|
| 239 |
+
stds = [cm_stds[n] for n in names]
|
| 240 |
+
x = np.arange(len(names))
|
| 241 |
+
plt.bar(x, means, yerr=stds, capsize=5)
|
| 242 |
+
plt.xticks(x, names, rotation=45)
|
| 243 |
+
plt.ylabel("CM (Relative Error)")
|
| 244 |
+
plt.title(f"CM for {model_name} ({dataset_name})")
|
| 245 |
+
plt.tight_layout()
|
| 246 |
+
cm_plot_path = os.path.join(output_dir, "cm_plot.png")
|
| 247 |
+
plt.savefig(cm_plot_path)
|
| 248 |
+
plt.close()
|
| 249 |
+
logs.append(f"Saved CM plot to {cm_plot_path}")
|
| 250 |
+
|
| 251 |
+
# Plot MVL per attack
|
| 252 |
+
mvl_plot_paths = []
|
| 253 |
+
colors = ['skyblue', 'lightgreen', 'coral', 'lightgray', 'purple']
|
| 254 |
+
for idx, atk in enumerate(names):
|
| 255 |
+
mvl_arr = np.array(results[atk]['mvl'])
|
| 256 |
+
mean_vals = np.mean(mvl_arr, axis=0)
|
| 257 |
+
std_vals = np.std(mvl_arr, axis=0)
|
| 258 |
+
layers = [f"Layer {i+1}" for i in range(len(mean_vals))]
|
| 259 |
+
plt.figure(figsize=(8,6))
|
| 260 |
+
plt.plot(layers, mean_vals, marker='o', color=colors[idx % len(colors)], label=atk)
|
| 261 |
+
plt.fill_between(layers, mean_vals - std_vals, mean_vals + std_vals, color=colors[idx % len(colors)], alpha=0.3)
|
| 262 |
+
plt.title(f"MVL per Layer - {atk}")
|
| 263 |
+
plt.ylabel("MVL (Mean ± Std)")
|
| 264 |
+
plt.xticks(rotation=45)
|
| 265 |
+
plt.grid(True)
|
| 266 |
+
plt.tight_layout()
|
| 267 |
+
path = os.path.join(output_dir, f"mvl_{atk.lower().replace(' ', '_')}.png")
|
| 268 |
+
plt.savefig(path)
|
| 269 |
+
plt.close()
|
| 270 |
+
mvl_plot_paths.append(path)
|
| 271 |
+
logs.append(f"Saved MVL plot for {atk} to {path}")
|
| 272 |
+
|
| 273 |
+
# Integrated MVL plot
|
| 274 |
+
plt.figure(figsize=(10,6))
|
| 275 |
+
for idx, atk in enumerate(names):
|
| 276 |
+
mvl_arr = np.array(results[atk]['mvl'])
|
| 277 |
+
mean_vals = np.mean(mvl_arr, axis=0)
|
| 278 |
+
std_vals = np.std(mvl_arr, axis=0)
|
| 279 |
+
layers = [f"Layer {i+1}" for i in range(len(mean_vals))]
|
| 280 |
+
plt.plot(layers, mean_vals, marker='o', color=colors[idx % len(colors)], label=atk)
|
| 281 |
+
plt.fill_between(layers, mean_vals - std_vals, mean_vals + std_vals, color=colors[idx % len(colors)], alpha=0.3)
|
| 282 |
+
plt.title(f"Integrated MVL - {model_name}")
|
| 283 |
+
plt.ylabel("MVL (Mean ± Std)")
|
| 284 |
+
plt.xticks(rotation=45)
|
| 285 |
+
plt.legend()
|
| 286 |
+
plt.grid(True)
|
| 287 |
+
plt.tight_layout()
|
| 288 |
+
integrated_mvl_plot_path = os.path.join(output_dir, "integrated_mvl.png")
|
| 289 |
+
plt.savefig(integrated_mvl_plot_path)
|
| 290 |
+
plt.close()
|
| 291 |
+
logs.append(f"Saved integrated MVL plot to {integrated_mvl_plot_path}")
|
| 292 |
|
| 293 |
processing_time = time.time() - start_time
|
| 294 |
|
| 295 |
stats = {
|
| 296 |
'Dataset': dataset_name,
|
| 297 |
'Model': model_name,
|
| 298 |
+
'Parameters': param_count,
|
| 299 |
+
'Layers': layer_count,
|
| 300 |
+
'Batches': num_batches,
|
| 301 |
+
'Attacks': ', '.join(names),
|
| 302 |
+
'Time (s)': round(processing_time, 2)
|
| 303 |
}
|
| 304 |
+
stats_text = "## Model Statistics\n\n| Metric | Value |\n|---|---|\n"
|
| 305 |
for k,v in stats.items():
|
| 306 |
stats_text += f"| {k} | {v} |\n"
|
| 307 |
|
| 308 |
+
# Pad MVL plot paths to length 5 for UI consistency
|
| 309 |
+
while len(mvl_plot_paths) < 5:
|
| 310 |
+
mvl_plot_paths.append(None)
|
| 311 |
+
|
| 312 |
return [None, cm_plot_path] + mvl_plot_paths[:5] + [integrated_mvl_plot_path, stats_text, '\n'.join(logs)]
|
| 313 |
|
| 314 |
paper_info_html = """
|
|
|
|
| 336 |
def update_models(dataset_name):
|
| 337 |
models = get_models_for_dataset(dataset_name)
|
| 338 |
default_value = models[0] if models else None
|
| 339 |
+
return models, default_value # Return choices and default value for older gradio versions
|
| 340 |
|
| 341 |
def create_interface():
|
| 342 |
datasets = ['MNIST', 'CIFAR-10']
|
|
|
|
| 368 |
with gr.Tab("Logs"):
|
| 369 |
log_output = gr.Textbox(label="Processing Logs")
|
| 370 |
|
|
|
|
| 371 |
dataset_input.change(
|
| 372 |
fn=update_models,
|
| 373 |
inputs=dataset_input,
|
| 374 |
+
outputs=[model_input, model_input] # updates choices and default value
|
| 375 |
)
|
| 376 |
|
| 377 |
run_button.click(
|