ksj47 commited on
Commit
ce7abfc
Β·
verified Β·
1 Parent(s): aa78e42

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +418 -0
  2. model.pth +3 -0
  3. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import gradio as gr
5
+ import numpy as np
6
+ import torchvision.transforms as transforms
7
+ from PIL import Image, ImageDraw
8
+ import os
9
+
10
+ # Define the neural network model - matching your trained model with 3 input channels
11
+ class Net(nn.Module):
12
+ def __init__(self):
13
+ super(Net, self).__init__()
14
+ # 3 input image channels (RGB), 6 output channels, 5x5 square convolution kernel
15
+ self.conv1 = nn.Conv2d(3, 6, 5)
16
+ self.conv2 = nn.Conv2d(6, 16, 5)
17
+ # an affine operation: y = Wx + b
18
+ self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5*5 from image dimension
19
+ self.fc2 = nn.Linear(120, 84)
20
+ self.fc3 = nn.Linear(84, 10)
21
+
22
+ def forward(self, x):
23
+ # Convolution layer C1: 3 input image channels, 6 output channels,
24
+ # 5x5 square convolution, it uses RELU activation function, and
25
+ # outputs a Tensor with size (N, 6, 28, 28), where N is the size of the batch
26
+ c1 = F.relu(self.conv1(x))
27
+ # Subsampling layer S2: 2x2 grid, purely functional,
28
+ # this layer does not have any parameter, and outputs a (N, 6, 14, 14) Tensor
29
+ s2 = F.max_pool2d(c1, (2, 2))
30
+ # Convolution layer C3: 6 input channels, 16 output channels,
31
+ # 5x5 square convolution, it uses RELU activation function, and
32
+ # outputs a (N, 16, 10, 10) Tensor
33
+ c3 = F.relu(self.conv2(s2))
34
+ # Subsampling layer S4: 2x2 grid, purely functional,
35
+ # this layer does not have any parameter, and outputs a (N, 16, 5, 5) Tensor
36
+ s4 = F.max_pool2d(c3, 2)
37
+ # Flatten operation: purely functional, outputs a (N, 400) Tensor
38
+ s4 = torch.flatten(s4, 1)
39
+ # Fully connected layer F5: (N, 400) Tensor input,
40
+ # and outputs a (N, 120) Tensor, it uses RELU activation function
41
+ f5 = F.relu(self.fc1(s4))
42
+ # Fully connected layer F6: (N, 120) Tensor input,
43
+ # and outputs a (N, 84) Tensor, it uses RELU activation function
44
+ f6 = F.relu(self.fc2(f5))
45
+ # Gaussian layer OUTPUT: (N, 84) Tensor input, and
46
+ # outputs a (N, 10) Tensor
47
+ output = self.fc3(f6)
48
+ return output
49
+
50
+ # Initialize the model
51
+ model = Net()
52
+
53
+ # Load the trained model weights
54
+ def load_model():
55
+ model_path = "model.pth" # Update this path to where your model is stored
56
+ if os.path.exists(model_path):
57
+ try:
58
+ # Load the trained model weights
59
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
60
+ print("Loaded trained model weights")
61
+ return True
62
+ except Exception as e:
63
+ print(f"Error loading model: {e}")
64
+ return False
65
+ else:
66
+ print("No trained model found at", model_path)
67
+ # Initialize with random weights for demonstration
68
+ for m in model.modules():
69
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
70
+ nn.init.xavier_uniform_(m.weight)
71
+ if m.bias is not None:
72
+ nn.init.constant_(m.bias, 0)
73
+ return False
74
+
75
+ # Preprocessing function for input images - now handles RGB images
76
+ def preprocess_image(image):
77
+ # Resize to 32x32 (expected input size for the network)
78
+ transform = transforms.Compose([
79
+ transforms.Resize((32, 32)),
80
+ transforms.ToTensor(),
81
+ ])
82
+
83
+ image_tensor = transform(image)
84
+ # Add batch dimension (1, 3, 32, 32)
85
+ image_tensor = image_tensor.unsqueeze(0)
86
+ return image_tensor
87
+
88
+ # Prediction function - matches the PyTorch tutorial exactly
89
+ def predict(image):
90
+ if image is None:
91
+ return {f"Class {i}": 0 for i in range(10)}
92
+
93
+ # Preprocess the image
94
+ input_tensor = preprocess_image(image)
95
+
96
+ # Make prediction - exactly as shown in the PyTorch tutorial
97
+ model.eval()
98
+ with torch.no_grad():
99
+ output = model(input_tensor)
100
+ # Apply softmax to get probabilities
101
+ probabilities = F.softmax(output, dim=1)
102
+ probabilities = probabilities.numpy()[0]
103
+
104
+ # Create labels (0-9 for MNIST-like classification)
105
+ labels = [f"Class {i}" for i in range(10)]
106
+
107
+ # Return as a dictionary for Gradio
108
+ return {label: float(prob) for label, prob in zip(labels, probabilities)}
109
+
110
+ # Create example images with different qualities and styles
111
+ def create_example_images():
112
+ examples = []
113
+
114
+ # Create hand-drawn style digits
115
+ for i in range(10):
116
+ # Create a 64x64 RGB image for better quality
117
+ img = Image.new('RGB', (64, 64), color=(255, 255, 255)) # White background
118
+ draw = ImageDraw.Draw(img)
119
+
120
+ # Draw a simple representation of each digit
121
+ if i == 0:
122
+ # Draw a 0 (oval)
123
+ draw.ellipse([10, 10, 54, 54], outline=(0, 0, 0), width=5)
124
+ elif i == 1:
125
+ # Draw a 1 (simple line)
126
+ draw.line([32, 10, 32, 54], fill=(0, 0, 0), width=5)
127
+ elif i == 2:
128
+ # Draw a 2 (connected lines)
129
+ draw.line([15, 15, 49, 15], fill=(0, 0, 0), width=5) # Top line
130
+ draw.line([49, 15, 49, 35], fill=(0, 0, 0), width=5) # Right line
131
+ draw.line([49, 35, 15, 35], fill=(0, 0, 0), width=5) # Middle line
132
+ draw.line([15, 35, 15, 54], fill=(0, 0, 0), width=5) # Left line
133
+ draw.line([15, 54, 49, 54], fill=(0, 0, 0), width=5) # Bottom line
134
+ elif i == 3:
135
+ # Draw a 3 (two semi-circles)
136
+ draw.arc([15, 10, 49, 35], 270, 90, fill=(0, 0, 0), width=5) # Top semi-circle
137
+ draw.arc([15, 35, 49, 60], 90, 270, fill=(0, 0, 0), width=5) # Bottom semi-circle
138
+ elif i == 4:
139
+ # Draw a 4 (perpendicular lines)
140
+ draw.line([35, 10, 35, 54], fill=(0, 0, 0), width=5) # Vertical line
141
+ draw.line([15, 10, 35, 30], fill=(0, 0, 0), width=5) # Diagonal line
142
+ draw.line([10, 30, 54, 30], fill=(0, 0, 0), width=5) # Horizontal line
143
+ elif i == 5:
144
+ # Draw a 5 (connected lines)
145
+ draw.line([15, 15, 49, 15], fill=(0, 0, 0), width=5) # Top line
146
+ draw.line([15, 15, 15, 35], fill=(0, 0, 0), width=5) # Left line
147
+ draw.line([15, 35, 49, 35], fill=(0, 0, 0), width=5) # Middle line
148
+ draw.line([49, 35, 49, 54], fill=(0, 0, 0), width=5) # Right line
149
+ draw.line([15, 54, 49, 54], fill=(0, 0, 0), width=5) # Bottom line
150
+ elif i == 6:
151
+ # Draw a 6 (circle with line)
152
+ draw.ellipse([15, 20, 49, 54], outline=(0, 0, 0), width=5)
153
+ draw.line([15, 20, 25, 10], fill=(0, 0, 0), width=5) # Top line
154
+ elif i == 7:
155
+ # Draw a 7 (diagonal with horizontal)
156
+ draw.line([15, 15, 49, 15], fill=(0, 0, 0), width=5) # Top line
157
+ draw.line([49, 15, 20, 54], fill=(0, 0, 0), width=5) # Diagonal line
158
+ elif i == 8:
159
+ # Draw an 8 (two circles)
160
+ draw.ellipse([15, 10, 49, 32], outline=(0, 0, 0), width=5) # Top circle
161
+ draw.ellipse([15, 32, 49, 54], outline=(0, 0, 0), width=5) # Bottom circle
162
+ elif i == 9:
163
+ # Draw a 9 (circle with line)
164
+ draw.ellipse([15, 10, 49, 44], outline=(0, 0, 0), width=5)
165
+ draw.line([49, 44, 40, 54], fill=(0, 0, 0), width=5) # Bottom line
166
+
167
+ examples.append(img)
168
+
169
+ return examples
170
+
171
+ # Custom CSS for enhanced UI
172
+ custom_css = """
173
+ @import url('https://fonts.googleapis.com/css2?family=Roboto:wght@300;400;500;700&display=swap');
174
+
175
+ body {
176
+ font-family: 'Roboto', sans-serif;
177
+ background: linear-gradient(135deg, #1a2a6c, #b21f1f, #1a2a6c);
178
+ background-size: 400% 400%;
179
+ animation: gradientBG 15s ease infinite;
180
+ color: white;
181
+ min-height: 100vh;
182
+ }
183
+
184
+ @keyframes gradientBG {
185
+ 0% { background-position: 0% 50%; }
186
+ 50% { background-position: 100% 50%; }
187
+ 100% { background-position: 0% 50%; }
188
+ }
189
+
190
+ .gradio-container {
191
+ background: rgba(0, 0, 0, 0.7) !important;
192
+ backdrop-filter: blur(10px);
193
+ border-radius: 20px !important;
194
+ box-shadow: 0 10px 30px rgba(0, 0, 0, 0.5);
195
+ border: 1px solid rgba(255, 255, 255, 0.1);
196
+ max-width: 1200px !important;
197
+ margin: 20px auto !important;
198
+ }
199
+
200
+ .container {
201
+ max-width: 100% !important;
202
+ }
203
+
204
+ h1 {
205
+ background: linear-gradient(to right, #ff7e5f, #feb47b);
206
+ -webkit-background-clip: text;
207
+ -webkit-text-fill-color: transparent;
208
+ text-align: center;
209
+ font-weight: 700 !important;
210
+ font-size: 2.5em !important;
211
+ margin-bottom: 10px !important;
212
+ text-shadow: 0 2px 4px rgba(0,0,0,0.2);
213
+ }
214
+
215
+ h2 {
216
+ color: #feb47b;
217
+ border-bottom: 2px solid #ff7e5f;
218
+ padding-bottom: 10px;
219
+ }
220
+
221
+ .markdown {
222
+ background: rgba(255, 255, 255, 0.05);
223
+ border-radius: 15px;
224
+ padding: 20px;
225
+ margin-bottom: 20px;
226
+ border: 1px solid rgba(255, 255, 255, 0.1);
227
+ }
228
+
229
+ .gradio-button {
230
+ background: linear-gradient(45deg, #ff7e5f, #feb47b) !important;
231
+ border: none !important;
232
+ color: white !important;
233
+ font-weight: 600 !important;
234
+ transition: all 0.3s ease !important;
235
+ box-shadow: 0 4px 15px rgba(255, 126, 95, 0.3) !important;
236
+ }
237
+
238
+ .gradio-button:hover {
239
+ transform: translateY(-3px) !important;
240
+ box-shadow: 0 6px 20px rgba(255, 126, 95, 0.5) !important;
241
+ }
242
+
243
+ .gradio-button:active {
244
+ transform: translateY(1px) !important;
245
+ }
246
+
247
+ .gradio-image {
248
+ border-radius: 15px !important;
249
+ overflow: hidden !important;
250
+ box-shadow: 0 8px 25px rgba(0, 0, 0, 0.4) !important;
251
+ border: 2px solid rgba(255, 255, 255, 0.1) !important;
252
+ }
253
+
254
+ .gradio-label {
255
+ background: rgba(255, 255, 255, 0.08) !important;
256
+ border-radius: 15px !important;
257
+ padding: 20px !important;
258
+ border: 1px solid rgba(255, 255, 255, 0.1) !important;
259
+ box-shadow: 0 8px 25px rgba(0, 0, 0, 0.3) !important;
260
+ }
261
+
262
+ label {
263
+ color: #feb47b !important;
264
+ font-weight: 500 !important;
265
+ }
266
+
267
+ .examples {
268
+ background: rgba(255, 255, 255, 0.05) !important;
269
+ border-radius: 15px !important;
270
+ padding: 20px !important;
271
+ margin-top: 20px !important;
272
+ border: 1px solid rgba(255, 255, 255, 0.1) !important;
273
+ }
274
+
275
+ footer {
276
+ display: none !important;
277
+ }
278
+
279
+ @media (max-width: 768px) {
280
+ .gradio-container {
281
+ margin: 10px !important;
282
+ }
283
+
284
+ h1 {
285
+ font-size: 2em !important;
286
+ }
287
+ }
288
+ """
289
+
290
+ # Initialize the model
291
+ model_loaded = load_model()
292
+
293
+ # Create example images
294
+ example_images = create_example_images()
295
+
296
+ # Create the Gradio interface with enhanced styling
297
+ with gr.Blocks(
298
+ title="PyTorch Neural Network Classifier",
299
+ css=custom_css,
300
+ theme=gr.themes.Default(
301
+ font=["Roboto", "Arial", "sans-serif"]
302
+ )
303
+ ) as demo:
304
+ gr.Markdown("""
305
+ # πŸ”₯ PyTorch Neural Network Classifier
306
+ ## Convolutional Neural Network for Image Classification
307
+
308
+ This is a demonstration of a convolutional neural network based on the
309
+ [PyTorch Neural Networks Tutorial](https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html).
310
+
311
+ The model architecture consists of:
312
+ - 2 Convolutional Layers with ReLU activation
313
+ - 2 MaxPooling Layers
314
+ - 3 Fully Connected Layers
315
+ """)
316
+
317
+ # Show model loading status
318
+ if model_loaded:
319
+ gr.Markdown("βœ… Model successfully loaded")
320
+ else:
321
+ gr.Markdown("⚠️ Model not found or error loading. Using random weights for demonstration.")
322
+
323
+ with gr.Row():
324
+ with gr.Column(scale=1):
325
+ gr.Markdown("### πŸ“₯ Input")
326
+ input_image = gr.Image(type="pil", label="Upload or Draw an Image", height=300)
327
+ with gr.Row():
328
+ submit_btn = gr.Button("Classify Image", elem_classes=["custom-button"])
329
+ clear_btn = gr.Button("Clear")
330
+
331
+ gr.Markdown("""
332
+ ### 🎯 Model Architecture
333
+ ```
334
+ Input β†’ Conv2D(3Γ—32Γ—32) β†’ ReLU β†’ MaxPool2D
335
+ β†’ Conv2D β†’ ReLU β†’ MaxPool2D
336
+ β†’ Flatten β†’ Linear β†’ ReLU
337
+ β†’ Linear β†’ ReLU β†’ Linear(10)
338
+ β†’ Output
339
+ ```
340
+ """)
341
+
342
+ with gr.Column(scale=1):
343
+ gr.Markdown("### πŸ“Š Classification Results")
344
+ output_label = gr.Label(label="Prediction Probabilities", num_top_classes=5)
345
+
346
+ gr.Markdown("""
347
+ ### ℹ️ Instructions
348
+ 1. Upload an image or draw one using the editor
349
+ 2. The image will be automatically resized to 32Γ—32 pixels
350
+ 3. Click "Classify Image" to get predictions
351
+ 4. Results show probabilities for 10 classes
352
+
353
+ ### πŸ“ Notes
354
+ - Model expects RGB images
355
+ - Best results with MNIST-style digits
356
+ - Classes 0-9 represent digits
357
+ """)
358
+
359
+ with gr.Row():
360
+ gr.Markdown("### πŸ“‹ Example Images")
361
+ gr.Markdown("""
362
+ The examples below show hand-drawn style digits. Try clicking on any example to load it,
363
+ or use the drawing tool to create your own digits. The model can handle:
364
+ - Different handwriting styles
365
+ - Various image sizes (automatically resized to 32Γ—32)
366
+ - Both black and white backgrounds
367
+ - Low-resolution images
368
+ """)
369
+
370
+ # Create examples using the compatible format for Gradio 4.0.0
371
+ gr.Examples(
372
+ examples=example_images,
373
+ inputs=input_image,
374
+ outputs=output_label,
375
+ fn=predict,
376
+ cache_examples=True
377
+ )
378
+
379
+ gr.Markdown("""
380
+ ### πŸ§ͺ Testing Different Image Qualities
381
+
382
+ This model is robust to various image conditions:
383
+ - **Resolution**: Works with images of any resolution (automatically resized to 32Γ—32)
384
+ - **Contrast**: Handles both high and low contrast images
385
+ - **Noise**: Can tolerate some image noise
386
+ - **Rotation**: Some tolerance to slight rotations
387
+ - **Scale**: Works with digits of different sizes within the image
388
+
389
+ For best results:
390
+ 1. Center the digit in the image
391
+ 2. Use clear contrast between the digit and background
392
+ 3. Avoid excessive noise or artifacts
393
+ 4. Fill most of the image area with the digit
394
+ """)
395
+
396
+ # Event handling
397
+ submit_btn.click(
398
+ fn=predict,
399
+ inputs=input_image,
400
+ outputs=output_label
401
+ )
402
+
403
+ clear_btn.click(
404
+ fn=lambda: (None, {f"Class {i}": 0 for i in range(10)}),
405
+ inputs=None,
406
+ outputs=[input_image, output_label]
407
+ )
408
+
409
+ # Allow image upload to trigger prediction automatically
410
+ input_image.change(
411
+ fn=predict,
412
+ inputs=input_image,
413
+ outputs=output_label
414
+ )
415
+
416
+ # Launch the app
417
+ if __name__ == "__main__":
418
+ demo.launch()
model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6744133a43fe90290fdb9770d7caa0bddaa453682bd4f8a7e8f2482feb852950
3
+ size 251604
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch>=1.7.0
2
+ torchvision>=0.8.0
3
+ gradio==4.44.1
4
+ pillow>=8.0.0
5
+ numpy>=1.19.0