ksj47 commited on
Commit
3f59345
Β·
verified Β·
1 Parent(s): 80a641f

Upload 6 files

Browse files
Files changed (6) hide show
  1. CIFAR Net.pth +3 -0
  2. EXPLANATION.md +189 -0
  3. README.md +138 -12
  4. app.py +424 -0
  5. requirements.txt +5 -0
  6. space.json +12 -0
CIFAR Net.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6744133a43fe90290fdb9770d7caa0bddaa453682bd4f8a7e8f2482feb852950
3
+ size 251604
EXPLANATION.md ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PyTorch Neural Network Classifier - Detailed Explanation
2
+
3
+ ## Overview
4
+
5
+ This application provides a user-friendly interface for running predictions on a trained PyTorch neural network model. The model is based on the exact implementation from the [PyTorch Neural Networks Tutorial](https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html), which implements a simplified version of the LeNet-5 architecture.
6
+
7
+ ## Model Architecture Breakdown
8
+
9
+ The neural network implements the exact architecture from the PyTorch tutorial:
10
+
11
+ 1. **Input Layer**: Accepts grayscale images of size 32Γ—32 pixels (1 channel)
12
+ 2. **First Convolutional Block**:
13
+ - Conv2d layer: 1 input channel β†’ 6 output channels, 5Γ—5 kernel
14
+ - ReLU activation function
15
+ - MaxPool2d layer: 2Γ—2 pooling window
16
+ 3. **Second Convolutional Block**:
17
+ - Conv2d layer: 6 input channels β†’ 16 output channels, 5Γ—5 kernel
18
+ - ReLU activation function
19
+ - MaxPool2d layer: 2Γ—2 pooling window
20
+ 4. **Fully Connected Layers**:
21
+ - First FC layer: 400 inputs β†’ 120 outputs with ReLU activation
22
+ - Second FC layer: 120 inputs β†’ 84 outputs with ReLU activation
23
+ - Output layer: 84 inputs β†’ 10 outputs (for 10 classes)
24
+
25
+ ## How the Application Works
26
+
27
+ ### 1. Model Loading
28
+ When the application starts, it attempts to load your trained model weights from a file named `model.pth`. This file should contain the state dictionary of a model with the exact architecture defined in the `Net` class, matching the PyTorch tutorial.
29
+
30
+ ### 2. Image Preprocessing
31
+ Before making predictions, any input image goes through preprocessing:
32
+ - Converted to grayscale if it's in color
33
+ - Resized to 32Γ—32 pixels to match the model's expected input size
34
+ - Converted to a PyTorch tensor
35
+ - Batch dimension added (required by PyTorch)
36
+
37
+ ### 3. Prediction Process
38
+ When you submit an image for classification, the process exactly matches the PyTorch tutorial:
39
+
40
+ ```python
41
+ model.eval()
42
+ with torch.no_grad():
43
+ output = model(input_tensor)
44
+ probabilities = F.softmax(output, dim=1)
45
+ probabilities = probabilities.numpy()[0]
46
+ ```
47
+
48
+ This implementation:
49
+ - Sets the model to evaluation mode with `model.eval()`
50
+ - Disables gradient computation with `torch.no_grad()` for efficiency
51
+ - Applies softmax to convert raw outputs to probabilities
52
+ - Extracts the first (and only) batch result
53
+
54
+ ### 4. User Interface Features
55
+ The Gradio interface provides several ways to interact with the model:
56
+
57
+ - **Image Upload**: Upload any image file from your computer
58
+ - **Drawing Tool**: Draw an image directly in the browser
59
+ - **Example Images**: Use pre-made examples to quickly test the model
60
+ - **Real-time Results**: See prediction probabilities for all 10 classes
61
+ - **Responsive Design**: Works well on both desktop and mobile devices
62
+
63
+ ## Image Input Capabilities
64
+
65
+ ### Supported Image Formats
66
+ The application accepts all common image formats:
67
+ - JPEG, PNG, BMP, TIFF, GIF, and WebP
68
+ - Color images (automatically converted to grayscale)
69
+ - Images of any resolution (automatically resized to 32Γ—32)
70
+
71
+ ### Robustness Features
72
+ The model has been designed to handle various image conditions:
73
+ - **Resolution Independence**: Works with images of any size (resized to 32Γ—32)
74
+ - **Color Conversion**: Automatically converts color images to grayscale
75
+ - **Contrast Handling**: Works with both high and low contrast images
76
+ - **Noise Tolerance**: Can handle some image noise
77
+ - **Rotation Tolerance**: Some tolerance to slight rotations
78
+ - **Scale Invariance**: Works with digits of different sizes
79
+
80
+ ### Best Practices for Good Results
81
+ To get the best classification results:
82
+ 1. **Center the digit** in the image area
83
+ 2. **Use clear contrast** between the digit and background
84
+ 3. **Fill most of the image** area with the digit
85
+ 4. **Avoid excessive noise** or artifacts
86
+ 5. **Use dark digits on light background** or vice versa
87
+
88
+ ### Image Preprocessing Pipeline
89
+ The complete preprocessing pipeline:
90
+ 1. Image upload or drawing
91
+ 2. Automatic color to grayscale conversion
92
+ 3. Resize to 32Γ—32 pixels using bilinear interpolation
93
+ 4. Conversion to PyTorch tensor with values scaled to [0,1]
94
+ 5. Addition of batch dimension for model inference
95
+
96
+ ## Technical Implementation Details
97
+
98
+ ### Custom CSS Styling
99
+ The application features a modern UI with:
100
+ - Animated gradient background
101
+ - Glass-morphism design elements
102
+ - Responsive layout that adapts to different screen sizes
103
+ - Interactive buttons with hover effects
104
+ - Clean typography using Google Fonts
105
+
106
+ ### Error Handling
107
+ The application gracefully handles:
108
+ - Missing model files (shows error message)
109
+ - Empty inputs (returns zero probabilities)
110
+ - Various image formats (automatically converts to grayscale)
111
+
112
+ ### Performance Optimizations
113
+ - Model loaded once at startup
114
+ - Gradients disabled during inference
115
+ - Efficient tensor operations
116
+ - Caching of example predictions
117
+
118
+ ## Deployment to Hugging Face Spaces
119
+
120
+ To deploy this application to Hugging Face Spaces:
121
+
122
+ 1. Create a new Space with the "Gradio" SDK
123
+ 2. Upload all files from this directory
124
+ 3. Ensure your `model.pth` file is included
125
+ 4. The Space will automatically install dependencies from `requirements.txt`
126
+ 5. The application will start automatically
127
+
128
+ ## Customization Guide
129
+
130
+ ### Using a Different Model File
131
+ If your model is saved with a different filename:
132
+ 1. Modify the `model_path` variable in the `load_model()` function
133
+ 2. Ensure the model architecture matches the `Net` class
134
+
135
+ ### Changing Class Labels
136
+ To customize the class labels:
137
+ 1. Modify the `labels` list in the `predict()` function
138
+ 2. Update the range in the list comprehension to match your number of classes
139
+
140
+ ### Adjusting Image Preprocessing
141
+ To modify how images are preprocessed:
142
+ 1. Edit the `preprocess_image()` function
143
+ 2. Change the resize dimensions if your model expects different input size
144
+ 3. Add normalization if your model was trained with normalized inputs
145
+
146
+ ## Troubleshooting Common Issues
147
+
148
+ ### Model Not Loading
149
+ - Verify `model.pth` is in the same directory as `app.py`
150
+ - Ensure the model architecture matches the `Net` class definition exactly
151
+ - Check that the file is not corrupted
152
+
153
+ ### Poor Prediction Accuracy
154
+ - Verify your model was trained on similar data
155
+ - Check if the preprocessing matches what was used during training
156
+ - Ensure input images are similar to the training data
157
+
158
+ ### UI Display Issues
159
+ - Update Gradio to the latest version
160
+ - Check browser compatibility
161
+ - Clear browser cache if styles aren't loading correctly
162
+
163
+ ## File Structure
164
+ ```
165
+ classification-app/
166
+ β”œβ”€β”€ app.py # Main application file
167
+ β”œβ”€β”€ requirements.txt # Python dependencies
168
+ β”œβ”€β”€ README.md # User guide
169
+ β”œβ”€β”€ EXPLANATION.md # This file
170
+ β”œβ”€β”€ model.pth # Your trained model (to be added)
171
+ └── space.json # Hugging Face Spaces configuration
172
+ ```
173
+
174
+ ## Requirements Explanation
175
+
176
+ - **torch>=1.7.0**: Core PyTorch library for neural network operations
177
+ - **torchvision>=0.8.0**: Computer vision utilities, including image transforms
178
+ - **gradio>=4.0.0**: Framework for creating machine learning web interfaces
179
+ - **pillow>=8.0.0**: Python Imaging Library for image processing
180
+ - **numpy>=1.19.0**: Numerical computing library for array operations
181
+
182
+ ## Example Use Cases
183
+
184
+ 1. **Digit Recognition**: Classify handwritten digits (0-9)
185
+ 2. **Educational Tool**: Demonstrate how convolutional neural networks work
186
+ 3. **Model Showcase**: Present your trained model to others in an interactive way
187
+ 4. **Testing Platform**: Evaluate model performance on custom inputs
188
+
189
+ This application provides a complete solution for deploying a PyTorch model with an attractive, user-friendly interface that can be easily shared with others through Hugging Face Spaces. The implementation follows the PyTorch tutorial exactly, ensuring compatibility with models trained using the same approach.
README.md CHANGED
@@ -1,12 +1,138 @@
1
- ---
2
- title: Img Classifier
3
- emoji: πŸƒ
4
- colorFrom: blue
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 5.42.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PyTorch Neural Network Classifier
2
+
3
+ This is a Gradio interface for a convolutional neural network based on the [PyTorch Neural Networks Tutorial](https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html). The model is a simplified version of LeNet-5, designed for image classification tasks.
4
+
5
+ ## Model Architecture
6
+
7
+ The neural network has the following architecture (exactly as shown in the PyTorch tutorial):
8
+ - Two convolutional layers with ReLU activation and max pooling
9
+ - Three fully connected layers
10
+ - Designed for 32x32 grayscale input images
11
+
12
+ ```
13
+ Input β†’ Conv2d(1, 6, 5) β†’ ReLU β†’ MaxPool2d(2, 2) β†’
14
+ Conv2d(6, 16, 5) β†’ ReLU β†’ MaxPool2d(2, 2) β†’
15
+ Flatten β†’ Linear(16*5*5, 120) β†’ ReLU β†’
16
+ Linear(120, 84) β†’ ReLU β†’ Linear(84, 10) β†’ Output
17
+ ```
18
+
19
+ ## Features
20
+
21
+ - Interactive image classification interface with modern UI
22
+ - Example images for quick testing
23
+ - Real-time predictions with probability scores
24
+ - Support for custom image uploads
25
+ - Built-in drawing tool for creating test images
26
+ - Responsive design with gradient backgrounds and animations
27
+ - Automatic image preprocessing (resize, grayscale conversion)
28
+
29
+ ## How to Use with Your Existing Model
30
+
31
+ 1. Place your trained PyTorch model file in the app directory and name it `model.pth`
32
+ 2. Ensure your model uses the same architecture as defined in the Net class
33
+ 3. Install the required dependencies:
34
+ ```bash
35
+ pip install -r requirements.txt
36
+ ```
37
+
38
+ 4. Run the application:
39
+ ```bash
40
+ python app.py
41
+ ```
42
+
43
+ 5. Access the interface at `http://localhost:7860` (or the URL provided in the terminal)
44
+
45
+ ## Image Input Capabilities
46
+
47
+ The model can handle various types of image inputs:
48
+
49
+ ### Supported Image Formats
50
+ - JPG, PNG, BMP, TIFF, and other common image formats
51
+ - Color images (automatically converted to grayscale)
52
+ - Any resolution (automatically resized to 32Γ—32 pixels)
53
+
54
+ ### Robustness Features
55
+ - **Resolution Independence**: Works with images of any size (resized to 32Γ—32)
56
+ - **Color Conversion**: Automatically converts color images to grayscale
57
+ - **Contrast Handling**: Works with both high and low contrast images
58
+ - **Noise Tolerance**: Can handle some image noise
59
+ - **Rotation Tolerance**: Some tolerance to slight rotations
60
+
61
+ ### Best Practices for Good Results
62
+ 1. **Center the digit** in the image area
63
+ 2. **Use clear contrast** between the digit and background
64
+ 3. **Fill most of the image** area with the digit
65
+ 4. **Avoid excessive noise** or artifacts
66
+ 5. **Use dark digits on light background** or vice versa
67
+
68
+ ## Deployment to Hugging Face Spaces
69
+
70
+ This application can be deployed to Hugging Face Spaces by:
71
+
72
+ 1. Creating a new Space on Hugging Face
73
+ 2. Uploading these files to the repository
74
+ 3. Setting the SDK to "Gradio"
75
+ 4. Adding the requirements in the requirements.txt file
76
+ 5. Uploading your `model.pth` file
77
+
78
+ The Space will automatically run the `app.py` file as the entry point.
79
+
80
+ ## Example Usage
81
+
82
+ The interface comes with hand-drawn example images that demonstrate how the classifier works. You can:
83
+ 1. Click on any example image to load it
84
+ 2. Upload your own image using the file browser
85
+ 3. Draw an image using the built-in sketch tool
86
+ 4. View the classification probabilities for each class
87
+
88
+ Try these examples:
89
+ - Handwritten digits of different styles
90
+ - Printed digits
91
+ - Digits with varying thickness
92
+ - Digits with different backgrounds
93
+
94
+ ## Technical Details
95
+
96
+ This implementation follows the PyTorch tutorial exactly and includes:
97
+ - Gradio interface with custom CSS styling
98
+ - Image preprocessing pipeline (resize to 32x32, grayscale conversion)
99
+ - Softmax probability output (as shown in the tutorial)
100
+ - Example generation for demonstration
101
+ - Model loading functionality for your trained weights
102
+
103
+ The prediction function exactly matches the tutorial:
104
+ ```python
105
+ model.eval()
106
+ with torch.no_grad():
107
+ output = model(input_tensor)
108
+ probabilities = F.softmax(output, dim=1)
109
+ ```
110
+
111
+ The UI features:
112
+ - Animated gradient background
113
+ - Glass-morphism design elements
114
+ - Responsive layout for all screen sizes
115
+ - Interactive buttons with hover effects
116
+ - Clean, modern typography
117
+
118
+ ## Requirements
119
+
120
+ - Python 3.6+
121
+ - PyTorch >= 1.7.0
122
+ - TorchVision >= 0.8.0
123
+ - Gradio >= 4.0.0
124
+ - Pillow >= 8.0.0
125
+ - NumPy >= 1.19.0
126
+
127
+ Install with:
128
+ ```bash
129
+ pip install -r requirements.txt
130
+ ```
131
+
132
+ ## Troubleshooting
133
+
134
+ If you encounter issues:
135
+ 1. Ensure your `model.pth` file is in the same directory as `app.py`
136
+ 2. Verify that your model uses the same architecture as defined in the Net class
137
+ 3. Check that all required dependencies are installed
138
+ 4. Make sure you're using a compatible version of Python (3.6+)
app.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import gradio as gr
5
+ import numpy as np
6
+ import torchvision.transforms as transforms
7
+ from PIL import Image, ImageDraw, ImageFont
8
+ import os
9
+
10
+ # Define the neural network model from the PyTorch tutorial
11
+ class Net(nn.Module):
12
+ def __init__(self):
13
+ super(Net, self).__init__()
14
+ # 1 input image channel, 6 output channels, 5x5 square convolution kernel
15
+ self.conv1 = nn.Conv2d(1, 6, 5)
16
+ self.conv2 = nn.Conv2d(6, 16, 5)
17
+ # an affine operation: y = Wx + b
18
+ self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5*5 from image dimension
19
+ self.fc2 = nn.Linear(120, 84)
20
+ self.fc3 = nn.Linear(84, 10)
21
+
22
+ def forward(self, x):
23
+ # Convolution layer C1: 1 input image channel, 6 output channels,
24
+ # 5x5 square convolution, it uses RELU activation function, and
25
+ # outputs a Tensor with size (N, 6, 28, 28), where N is the size of the batch
26
+ c1 = F.relu(self.conv1(x))
27
+ # Subsampling layer S2: 2x2 grid, purely functional,
28
+ # this layer does not have any parameter, and outputs a (N, 6, 14, 14) Tensor
29
+ s2 = F.max_pool2d(c1, (2, 2))
30
+ # Convolution layer C3: 6 input channels, 16 output channels,
31
+ # 5x5 square convolution, it uses RELU activation function, and
32
+ # outputs a (N, 16, 10, 10) Tensor
33
+ c3 = F.relu(self.conv2(s2))
34
+ # Subsampling layer S4: 2x2 grid, purely functional,
35
+ # this layer does not have any parameter, and outputs a (N, 16, 5, 5) Tensor
36
+ s4 = F.max_pool2d(c3, 2)
37
+ # Flatten operation: purely functional, outputs a (N, 400) Tensor
38
+ s4 = torch.flatten(s4, 1)
39
+ # Fully connected layer F5: (N, 400) Tensor input,
40
+ # and outputs a (N, 120) Tensor, it uses RELU activation function
41
+ f5 = F.relu(self.fc1(s4))
42
+ # Fully connected layer F6: (N, 120) Tensor input,
43
+ # and outputs a (N, 84) Tensor, it uses RELU activation function
44
+ f6 = F.relu(self.fc2(f5))
45
+ # Gaussian layer OUTPUT: (N, 84) Tensor input, and
46
+ # outputs a (N, 10) Tensor
47
+ output = self.fc3(f6)
48
+ return output
49
+
50
+ # Initialize the model
51
+ model = Net()
52
+
53
+ # Load the trained model weights
54
+ def load_model():
55
+ model_path = "model.pth" # Update this path to where your model is stored
56
+ if os.path.exists(model_path):
57
+ # Load the trained model weights
58
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
59
+ print("Loaded trained model weights")
60
+ return True
61
+ else:
62
+ print("No trained model found at", model_path)
63
+ return False
64
+
65
+ # Preprocessing function for input images
66
+ def preprocess_image(image):
67
+ # Convert to grayscale if needed
68
+ if image.mode != 'L':
69
+ image = image.convert('L')
70
+
71
+ # Resize to 32x32 (expected input size for the network)
72
+ transform = transforms.Compose([
73
+ transforms.Resize((32, 32)),
74
+ transforms.ToTensor(),
75
+ ])
76
+
77
+ image_tensor = transform(image)
78
+ # Add batch dimension (1, 1, 32, 32)
79
+ image_tensor = image_tensor.unsqueeze(0)
80
+ return image_tensor
81
+
82
+ # Prediction function - matches the PyTorch tutorial exactly
83
+ def predict(image):
84
+ if image is None:
85
+ return {f"Class {i}": 0 for i in range(10)}
86
+
87
+ # Preprocess the image
88
+ input_tensor = preprocess_image(image)
89
+
90
+ # Make prediction - exactly as shown in the PyTorch tutorial
91
+ model.eval()
92
+ with torch.no_grad():
93
+ output = model(input_tensor)
94
+ # Apply softmax to get probabilities
95
+ probabilities = F.softmax(output, dim=1)
96
+ probabilities = probabilities.numpy()[0]
97
+
98
+ # Create labels (0-9 for MNIST-like classification)
99
+ labels = [f"Class {i}" for i in range(10)]
100
+
101
+ # Return as a dictionary for Gradio
102
+ return {label: float(prob) for label, prob in zip(labels, probabilities)}
103
+
104
+ # Create example images with different qualities and styles
105
+ def create_example_images():
106
+ examples = []
107
+
108
+ # Create hand-drawn style digits
109
+ for i in range(10):
110
+ # Create a 64x64 image for better quality
111
+ img = Image.new('L', (64, 64), color=255) # White background
112
+ draw = ImageDraw.Draw(img)
113
+
114
+ # Draw a simple representation of each digit
115
+ if i == 0:
116
+ # Draw a 0 (oval)
117
+ draw.ellipse([10, 10, 54, 54], outline=0, width=5)
118
+ elif i == 1:
119
+ # Draw a 1 (simple line)
120
+ draw.line([32, 10, 32, 54], fill=0, width=5)
121
+ elif i == 2:
122
+ # Draw a 2 (connected lines)
123
+ draw.line([15, 15, 49, 15], fill=0, width=5) # Top line
124
+ draw.line([49, 15, 49, 35], fill=0, width=5) # Right line
125
+ draw.line([49, 35, 15, 35], fill=0, width=5) # Middle line
126
+ draw.line([15, 35, 15, 54], fill=0, width=5) # Left line
127
+ draw.line([15, 54, 49, 54], fill=0, width=5) # Bottom line
128
+ elif i == 3:
129
+ # Draw a 3 (two semi-circles)
130
+ draw.arc([15, 10, 49, 35], 270, 90, fill=0, width=5) # Top semi-circle
131
+ draw.arc([15, 35, 49, 60], 90, 270, fill=0, width=5) # Bottom semi-circle
132
+ elif i == 4:
133
+ # Draw a 4 (perpendicular lines)
134
+ draw.line([35, 10, 35, 54], fill=0, width=5) # Vertical line
135
+ draw.line([15, 10, 35, 30], fill=0, width=5) # Diagonal line
136
+ draw.line([10, 30, 54, 30], fill=0, width=5) # Horizontal line
137
+ elif i == 5:
138
+ # Draw a 5 (connected lines)
139
+ draw.line([15, 15, 49, 15], fill=0, width=5) # Top line
140
+ draw.line([15, 15, 15, 35], fill=0, width=5) # Left line
141
+ draw.line([15, 35, 49, 35], fill=0, width=5) # Middle line
142
+ draw.line([49, 35, 49, 54], fill=0, width=5) # Right line
143
+ draw.line([15, 54, 49, 54], fill=0, width=5) # Bottom line
144
+ elif i == 6:
145
+ # Draw a 6 (circle with line)
146
+ draw.ellipse([15, 20, 49, 54], outline=0, width=5)
147
+ draw.line([15, 20, 25, 10], fill=0, width=5) # Top line
148
+ elif i == 7:
149
+ # Draw a 7 (diagonal with horizontal)
150
+ draw.line([15, 15, 49, 15], fill=0, width=5) # Top line
151
+ draw.line([49, 15, 20, 54], fill=0, width=5) # Diagonal line
152
+ elif i == 8:
153
+ # Draw an 8 (two circles)
154
+ draw.ellipse([15, 10, 49, 32], outline=0, width=5) # Top circle
155
+ draw.ellipse([15, 32, 49, 54], outline=0, width=5) # Bottom circle
156
+ elif i == 9:
157
+ # Draw a 9 (circle with line)
158
+ draw.ellipse([15, 10, 49, 44], outline=0, width=5)
159
+ draw.line([49, 44, 40, 54], fill=0, width=5) # Bottom line
160
+
161
+ examples.append(img)
162
+
163
+ return examples
164
+
165
+ # Custom CSS for enhanced UI
166
+ custom_css = """
167
+ @import url('https://fonts.googleapis.com/css2?family=Roboto:wght@300;400;500;700&display=swap');
168
+
169
+ body {
170
+ font-family: 'Roboto', sans-serif;
171
+ background: linear-gradient(135deg, #1a2a6c, #b21f1f, #1a2a6c);
172
+ background-size: 400% 400%;
173
+ animation: gradientBG 15s ease infinite;
174
+ color: white;
175
+ min-height: 100vh;
176
+ }
177
+
178
+ @keyframes gradientBG {
179
+ 0% { background-position: 0% 50%; }
180
+ 50% { background-position: 100% 50%; }
181
+ 100% { background-position: 0% 50%; }
182
+ }
183
+
184
+ .gradio-container {
185
+ background: rgba(0, 0, 0, 0.7) !important;
186
+ backdrop-filter: blur(10px);
187
+ border-radius: 20px !important;
188
+ box-shadow: 0 10px 30px rgba(0, 0, 0, 0.5);
189
+ border: 1px solid rgba(255, 255, 255, 0.1);
190
+ max-width: 1200px !important;
191
+ margin: 20px auto !important;
192
+ }
193
+
194
+ .container {
195
+ max-width: 100% !important;
196
+ }
197
+
198
+ h1 {
199
+ background: linear-gradient(to right, #ff7e5f, #feb47b);
200
+ -webkit-background-clip: text;
201
+ -webkit-text-fill-color: transparent;
202
+ text-align: center;
203
+ font-weight: 700 !important;
204
+ font-size: 2.5em !important;
205
+ margin-bottom: 10px !important;
206
+ text-shadow: 0 2px 4px rgba(0,0,0,0.2);
207
+ }
208
+
209
+ h2 {
210
+ color: #feb47b;
211
+ border-bottom: 2px solid #ff7e5f;
212
+ padding-bottom: 10px;
213
+ }
214
+
215
+ .markdown {
216
+ background: rgba(255, 255, 255, 0.05);
217
+ border-radius: 15px;
218
+ padding: 20px;
219
+ margin-bottom: 20px;
220
+ border: 1px solid rgba(255, 255, 255, 0.1);
221
+ }
222
+
223
+ .gradio-button {
224
+ background: linear-gradient(45deg, #ff7e5f, #feb47b) !important;
225
+ border: none !important;
226
+ color: white !important;
227
+ font-weight: 600 !important;
228
+ transition: all 0.3s ease !important;
229
+ box-shadow: 0 4px 15px rgba(255, 126, 95, 0.3) !important;
230
+ }
231
+
232
+ .gradio-button:hover {
233
+ transform: translateY(-3px) !important;
234
+ box-shadow: 0 6px 20px rgba(255, 126, 95, 0.5) !important;
235
+ }
236
+
237
+ .gradio-button:active {
238
+ transform: translateY(1px) !important;
239
+ }
240
+
241
+ .gradio-image {
242
+ border-radius: 15px !important;
243
+ overflow: hidden !important;
244
+ box-shadow: 0 8px 25px rgba(0, 0, 0, 0.4) !important;
245
+ border: 2px solid rgba(255, 255, 255, 0.1) !important;
246
+ }
247
+
248
+ .gradio-label {
249
+ background: rgba(255, 255, 255, 0.08) !important;
250
+ border-radius: 15px !important;
251
+ padding: 20px !important;
252
+ border: 1px solid rgba(255, 255, 255, 0.1) !important;
253
+ box-shadow: 0 8px 25px rgba(0, 0, 0, 0.3) !important;
254
+ }
255
+
256
+ label {
257
+ color: #feb47b !important;
258
+ font-weight: 500 !important;
259
+ }
260
+
261
+ .examples {
262
+ background: rgba(255, 255, 255, 0.05) !important;
263
+ border-radius: 15px !important;
264
+ padding: 20px !important;
265
+ margin-top: 20px !important;
266
+ border: 1px solid rgba(255, 255, 255, 0.1) !important;
267
+ }
268
+
269
+ footer {
270
+ display: none !important;
271
+ }
272
+
273
+ @media (max-width: 768px) {
274
+ .gradio-container {
275
+ margin: 10px !important;
276
+ }
277
+
278
+ h1 {
279
+ font-size: 2em !important;
280
+ }
281
+ }
282
+ """
283
+
284
+ # Initialize the model
285
+ model_loaded = load_model()
286
+
287
+ # Create the Gradio interface with enhanced styling
288
+ with gr.Blocks(
289
+ title="PyTorch Neural Network Classifier",
290
+ css=custom_css,
291
+ theme=gr.themes.Default(
292
+ font=["Roboto", "Arial", "sans-serif"]
293
+ )
294
+ ) as demo:
295
+ gr.Markdown("""
296
+ # πŸ”₯ PyTorch Neural Network Classifier
297
+ ## Convolutional Neural Network for Image Classification
298
+
299
+ This is a demonstration of a convolutional neural network based on the
300
+ [PyTorch Neural Networks Tutorial](https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html).
301
+
302
+ The model architecture consists of:
303
+ - 2 Convolutional Layers with ReLU activation
304
+ - 2 MaxPooling Layers
305
+ - 3 Fully Connected Layers
306
+ """)
307
+
308
+ # Show model loading status
309
+ if model_loaded:
310
+ gr.Markdown("βœ… Model successfully loaded")
311
+ else:
312
+ gr.Markdown("❌ Model not found. Please place your 'model.pth' file in the app directory.")
313
+
314
+ with gr.Row():
315
+ with gr.Column(scale=1):
316
+ gr.Markdown("### πŸ“₯ Input")
317
+ input_image = gr.Image(type="pil", label="Upload or Draw an Image", height=300)
318
+ with gr.Row():
319
+ submit_btn = gr.Button("Classify Image", elem_classes=["custom-button"])
320
+ clear_btn = gr.Button("Clear")
321
+
322
+ gr.Markdown("""
323
+ ### 🎯 Model Architecture
324
+ ```
325
+ Input β†’ Conv2D(1Γ—32Γ—32) β†’ ReLU β†’ MaxPool2D
326
+ β†’ Conv2D β†’ ReLU β†’ MaxPool2D
327
+ β†’ Flatten β†’ Linear β†’ ReLU
328
+ β†’ Linear β†’ ReLU β†’ Linear(10)
329
+ β†’ Output
330
+ ```
331
+ """)
332
+
333
+ with gr.Column(scale=1):
334
+ gr.Markdown("### πŸ“Š Classification Results")
335
+ output_label = gr.Label(label="Prediction Probabilities", num_top_classes=5)
336
+
337
+ gr.Markdown("""
338
+ ### ℹ️ Instructions
339
+ 1. Upload an image or draw one using the editor
340
+ 2. The image will be automatically resized to 32Γ—32 pixels
341
+ 3. Click "Classify Image" to get predictions
342
+ 4. Results show probabilities for 10 classes
343
+
344
+ ### πŸ“ Notes
345
+ - Model expects grayscale images
346
+ - Best results with MNIST-style digits
347
+ - Classes 0-9 represent digits
348
+ """)
349
+
350
+ with gr.Row():
351
+ gr.Markdown("### πŸ“‹ Example Images")
352
+ gr.Markdown("""
353
+ The examples below show hand-drawn style digits. Try clicking on any example to load it,
354
+ or use the drawing tool to create your own digits. The model can handle:
355
+ - Different handwriting styles
356
+ - Various image sizes (automatically resized to 32Γ—32)
357
+ - Both black and white backgrounds
358
+ - Low-resolution images
359
+ """)
360
+
361
+ # Create a grid of example images
362
+ example_images = create_example_images()
363
+ with gr.Row():
364
+ for i in range(5):
365
+ with gr.Column():
366
+ gr.Example(
367
+ label=f"Digit {i}",
368
+ examples=[example_images[i]],
369
+ inputs=input_image,
370
+ outputs=output_label,
371
+ fn=predict
372
+ )
373
+
374
+ with gr.Row():
375
+ for i in range(5, 10):
376
+ with gr.Column():
377
+ gr.Example(
378
+ label=f"Digit {i}",
379
+ examples=[example_images[i]],
380
+ inputs=input_image,
381
+ outputs=output_label,
382
+ fn=predict
383
+ )
384
+
385
+ gr.Markdown("""
386
+ ### πŸ§ͺ Testing Different Image Qualities
387
+
388
+ This model is robust to various image conditions:
389
+ - **Resolution**: Works with images of any resolution (automatically resized to 32Γ—32)
390
+ - **Contrast**: Handles both high and low contrast images
391
+ - **Noise**: Can tolerate some image noise
392
+ - **Rotation**: Some tolerance to slight rotations
393
+ - **Scale**: Works with digits of different sizes within the image
394
+
395
+ For best results:
396
+ 1. Center the digit in the image
397
+ 2. Use clear contrast between the digit and background
398
+ 3. Avoid excessive noise or artifacts
399
+ 4. Fill most of the image area with the digit
400
+ """)
401
+
402
+ # Event handling
403
+ submit_btn.click(
404
+ fn=predict,
405
+ inputs=input_image,
406
+ outputs=output_label
407
+ )
408
+
409
+ clear_btn.click(
410
+ fn=lambda: (None, {f"Class {i}": 0 for i in range(10)}),
411
+ inputs=None,
412
+ outputs=[input_image, output_label]
413
+ )
414
+
415
+ # Allow image upload to trigger prediction automatically
416
+ input_image.change(
417
+ fn=predict,
418
+ inputs=input_image,
419
+ outputs=output_label
420
+ )
421
+
422
+ # Launch the app
423
+ if __name__ == "__main__":
424
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch>=1.7.0
2
+ torchvision>=0.8.0
3
+ gradio>=4.0.0
4
+ pillow>=8.0.0
5
+ numpy>=1.19.0
space.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "sdk": "gradio",
3
+ "sdk_version": "4.0.0",
4
+ "app_file": "app.py",
5
+ "requirements": [
6
+ "torch",
7
+ "torchvision",
8
+ "gradio",
9
+ "pillow",
10
+ "numpy"
11
+ ]
12
+ }