ksj47 commited on
Commit
b862b3f
·
verified ·
1 Parent(s): e59c64c

Upload 7 files

Browse files
Files changed (7) hide show
  1. EXPLANATION.md +202 -0
  2. README.md +37 -24
  3. app.py +470 -0
  4. model.pth +3 -0
  5. requirements.txt +5 -0
  6. space.json +13 -0
  7. test_model.py +20 -0
EXPLANATION.md ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CIFAR-10 Image Classifier - Detailed Explanation
2
+
3
+ ## Overview
4
+
5
+ This application provides a user-friendly interface for running predictions on a trained PyTorch neural network model. The model is based on the implementation from the [PyTorch CIFAR-10 Tutorial](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html), which trains a convolutional neural network to classify images from the CIFAR-10 dataset.
6
+
7
+ ## Model Architecture Breakdown
8
+
9
+ The neural network implements the architecture from the PyTorch CIFAR-10 tutorial:
10
+
11
+ 1. **Input Layer**: Accepts RGB images of size 32×32 pixels (3 channels)
12
+ 2. **First Convolutional Block**:
13
+ - Conv2d layer: 3 input channels → 6 output channels, 5×5 kernel
14
+ - ReLU activation function
15
+ - MaxPool2d layer: 2×2 pooling window
16
+ 3. **Second Convolutional Block**:
17
+ - Conv2d layer: 6 input channels → 16 output channels, 5×5 kernel
18
+ - ReLU activation function
19
+ - MaxPool2d layer: 2×2 pooling window
20
+ 4. **Fully Connected Layers**:
21
+ - First FC layer: 400 inputs → 120 outputs with ReLU activation
22
+ - Second FC layer: 120 inputs → 84 outputs with ReLU activation
23
+ - Output layer: 84 inputs → 10 outputs (for 10 CIFAR-10 classes)
24
+
25
+ ## CIFAR-10 Dataset
26
+
27
+ The CIFAR-10 dataset consists of 60,000 32x32 color images in 10 classes, with 6,000 images per class. The 10 classes are:
28
+ 1. **Airplane** - Aircraft flying in the sky
29
+ 2. **Automobile** - Cars and vehicles on the road
30
+ 3. **Bird** - Flying or perched birds
31
+ 4. **Cat** - Domestic cats and felines
32
+ 5. **Deer** - Wild deer and similar animals
33
+ 6. **Dog** - Domestic dogs and canines
34
+ 7. **Frog** - Amphibians like frogs
35
+ 8. **Horse** - Horses and similar animals
36
+ 9. **Ship** - Boats and ships on water
37
+ 10. **Truck** - Trucks and heavy vehicles
38
+
39
+ ## How the Application Works
40
+
41
+ ### 1. Model Loading
42
+ When the application starts, it attempts to load your trained model weights from a file named `model.pth`. This file should contain the state dictionary of a model with the exact architecture defined in the `Net` class, matching the PyTorch CIFAR-10 tutorial.
43
+
44
+ ### 2. Image Preprocessing
45
+ Before making predictions, any input image goes through preprocessing:
46
+ - Maintained as RGB (3 channels) - no color conversion
47
+ - Resized to 32×32 pixels to match the model's expected input size
48
+ - Converted to a PyTorch tensor
49
+ - Batch dimension added (required by PyTorch)
50
+
51
+ ### 3. Prediction Process
52
+ When you submit an image for classification, the process follows the PyTorch tutorial:
53
+
54
+ ```python
55
+ model.eval()
56
+ with torch.no_grad():
57
+ output = model(input_tensor)
58
+ probabilities = F.softmax(output, dim=1)
59
+ probabilities = probabilities.numpy()[0]
60
+ ```
61
+
62
+ This implementation:
63
+ - Sets the model to evaluation mode with `model.eval()`
64
+ - Disables gradient computation with `torch.no_grad()` for efficiency
65
+ - Applies softmax to convert raw outputs to probabilities
66
+ - Extracts the first (and only) batch result
67
+
68
+ ### 4. User Interface Features
69
+ The Gradio interface provides several ways to interact with the model:
70
+
71
+ - **Image Upload**: Upload any image file from your computer
72
+ - **Drawing Tool**: Draw an image directly in the browser
73
+ - **Example Images**: Use pre-made examples representing each CIFAR-10 class
74
+ - **Real-time Results**: See prediction probabilities for all 10 classes
75
+ - **Responsive Design**: Works well on both desktop and mobile devices
76
+
77
+ ## Image Input Capabilities
78
+
79
+ ### Supported Image Formats
80
+ The application accepts all common image formats:
81
+ - JPEG, PNG, BMP, TIFF, GIF, and WebP
82
+ - Color images (maintained as RGB with 3 channels)
83
+ - Images of any resolution (automatically resized to 32×32)
84
+
85
+ ### Robustness Features
86
+ The model has been designed to handle various image conditions:
87
+ - **Resolution Independence**: Works with images of any size (resized to 32×32)
88
+ - **Color Preservation**: Maintains RGB color information
89
+ - **Contrast Handling**: Works with both high and low contrast images
90
+ - **Noise Tolerance**: Can handle some image noise
91
+ - **Rotation Tolerance**: Some tolerance to slight rotations
92
+ - **Scale Invariance**: Works with objects of different sizes
93
+
94
+ ### Best Practices for Good Results
95
+ To get the best classification results:
96
+ 1. **Center the object** in the image area
97
+ 2. **Use clear contrast** between the object and background
98
+ 3. **Fill most of the image** area with the object
99
+ 4. **Avoid excessive noise** or artifacts
100
+ 5. **Ensure the object is clearly visible**
101
+
102
+ ### Image Preprocessing Pipeline
103
+ The complete preprocessing pipeline:
104
+ 1. Image upload or drawing
105
+ 2. Resize to 32×32 pixels using bilinear interpolation
106
+ 3. Conversion to PyTorch tensor with values scaled to [0,1]
107
+ 4. Addition of batch dimension for model inference
108
+
109
+ ## Technical Implementation Details
110
+
111
+ ### Custom CSS Styling
112
+ The application features a modern UI with:
113
+ - Animated gradient background
114
+ - Glass-morphism design elements
115
+ - Responsive layout that adapts to different screen sizes
116
+ - Interactive buttons with hover effects
117
+ - Clean typography using Google Fonts
118
+
119
+ ### Error Handling
120
+ The application gracefully handles:
121
+ - Missing model files (shows error message)
122
+ - Empty inputs (returns zero probabilities)
123
+ - Various image formats (maintained as RGB)
124
+
125
+ ### Performance Optimizations
126
+ - Model loaded once at startup
127
+ - Gradients disabled during inference
128
+ - Efficient tensor operations
129
+ - Caching of example predictions
130
+
131
+ ## Deployment to Hugging Face Spaces
132
+
133
+ To deploy this application to Hugging Face Spaces:
134
+
135
+ 1. Create a new Space with the "Gradio" SDK
136
+ 2. Upload all files from this directory
137
+ 3. Ensure your `model.pth` file is included
138
+ 4. The Space will automatically install dependencies from `requirements.txt`
139
+ 5. The application will start automatically
140
+
141
+ ## Customization Guide
142
+
143
+ ### Using a Different Model File
144
+ If your model is saved with a different filename:
145
+ 1. Modify the `model_path` variable in the `load_model()` function
146
+ 2. Ensure the model architecture matches the `Net` class definition exactly
147
+
148
+ ### Changing Class Labels
149
+ To customize the class labels:
150
+ 1. Modify the `cifar10_classes` list in the `predict()` function
151
+ 2. Update the example images in the `create_example_images()` function to match your new classes
152
+
153
+ ### Adjusting Image Preprocessing
154
+ To modify how images are preprocessed:
155
+ 1. Edit the `preprocess_image()` function
156
+ 2. Change the resize dimensions if your model expects different input size
157
+ 3. Add normalization if your model was trained with normalized inputs
158
+
159
+ ## Troubleshooting Common Issues
160
+
161
+ ### Model Not Loading
162
+ - Verify `model.pth` is in the same directory as `app.py`
163
+ - Ensure the model architecture matches the `Net` class definition exactly
164
+ - Check that the file is not corrupted
165
+
166
+ ### Poor Prediction Accuracy
167
+ - Verify your model was trained on similar data (CIFAR-10 or similar)
168
+ - Check if the preprocessing matches what was used during training
169
+ - Ensure input images are similar to the training data
170
+
171
+ ### UI Display Issues
172
+ - Update Gradio to the latest version
173
+ - Check browser compatibility
174
+ - Clear browser cache if styles aren't loading correctly
175
+
176
+ ## File Structure
177
+ ```
178
+ cifar10-classifier/
179
+ ├── app.py # Main application file
180
+ ├── requirements.txt # Python dependencies
181
+ ├── README.md # User guide
182
+ ├── EXPLANATION.md # This file
183
+ ├── model.pth # Your trained model (to be added)
184
+ └── space.json # Hugging Face Spaces configuration
185
+ ```
186
+
187
+ ## Requirements Explanation
188
+
189
+ - **torch>=1.7.0**: Core PyTorch library for neural network operations
190
+ - **torchvision>=0.8.0**: Computer vision utilities, including image transforms
191
+ - **gradio>=4.0.0**: Framework for creating machine learning web interfaces
192
+ - **pillow>=8.0.0**: Python Imaging Library for image processing
193
+ - **numpy>=1.19.0**: Numerical computing library for array operations
194
+
195
+ ## Example Use Cases
196
+
197
+ 1. **Object Recognition**: Classify images into 10 common object categories
198
+ 2. **Educational Tool**: Demonstrate how convolutional neural networks work on real image data
199
+ 3. **Model Showcase**: Present your trained model to others in an interactive way
200
+ 4. **Testing Platform**: Evaluate model performance on custom inputs
201
+
202
+ This application provides a complete solution for deploying a PyTorch model trained on CIFAR-10 with an attractive, user-friendly interface that can be easily shared with others through Hugging Face Spaces. The implementation is based on the PyTorch CIFAR-10 tutorial, ensuring compatibility with models trained using the same approach.
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: PyTorch Neural Network Classifier
3
- emoji: 🧠
4
  colorFrom: blue
5
  colorTo: red
6
  sdk: gradio
@@ -9,19 +9,19 @@ app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- # PyTorch Neural Network Classifier
13
 
14
- This is a Gradio interface for a convolutional neural network based on the [PyTorch Neural Networks Tutorial](https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html). The model is a simplified version of LeNet-5, designed for image classification tasks.
15
 
16
  ## Model Architecture
17
 
18
- The neural network has the following architecture (exactly as shown in the PyTorch tutorial):
19
  - Two convolutional layers with ReLU activation and max pooling
20
  - Three fully connected layers
21
- - Designed for 32x32 grayscale input images
22
 
23
  ```
24
- Input → Conv2d(1, 6, 5) → ReLU → MaxPool2d(2, 2) →
25
  Conv2d(6, 16, 5) → ReLU → MaxPool2d(2, 2) →
26
  Flatten → Linear(16*5*5, 120) → ReLU →
27
  Linear(120, 84) → ReLU → Linear(84, 10) → Output
@@ -30,12 +30,12 @@ Linear(120, 84) → ReLU → Linear(84, 10) → Output
30
  ## Features
31
 
32
  - Interactive image classification interface with modern UI
33
- - Example images for quick testing
34
  - Real-time predictions with probability scores
35
  - Support for custom image uploads
36
  - Built-in drawing tool for creating test images
37
  - Responsive design with gradient backgrounds and animations
38
- - Automatic image preprocessing (resize, grayscale conversion)
39
 
40
  ## How to Use with Your Existing Model
41
 
@@ -59,22 +59,36 @@ The model can handle various types of image inputs:
59
 
60
  ### Supported Image Formats
61
  - JPG, PNG, BMP, TIFF, and other common image formats
62
- - Color images (automatically converted to grayscale)
63
  - Any resolution (automatically resized to 32×32 pixels)
64
 
65
  ### Robustness Features
66
  - **Resolution Independence**: Works with images of any size (resized to 32×32)
67
- - **Color Conversion**: Automatically converts color images to grayscale
68
  - **Contrast Handling**: Works with both high and low contrast images
69
  - **Noise Tolerance**: Can handle some image noise
70
  - **Rotation Tolerance**: Some tolerance to slight rotations
71
 
72
  ### Best Practices for Good Results
73
- 1. **Center the digit** in the image area
74
- 2. **Use clear contrast** between the digit and background
75
- 3. **Fill most of the image** area with the digit
76
  4. **Avoid excessive noise** or artifacts
77
- 5. **Use dark digits on light background** or vice versa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  ## Deployment to Hugging Face Spaces
80
 
@@ -90,28 +104,27 @@ The Space will automatically run the `app.py` file as the entry point.
90
 
91
  ## Example Usage
92
 
93
- The interface comes with hand-drawn example images that demonstrate how the classifier works. You can:
94
  1. Click on any example image to load it
95
  2. Upload your own image using the file browser
96
  3. Draw an image using the built-in sketch tool
97
  4. View the classification probabilities for each class
98
 
99
  Try these examples:
100
- - Handwritten digits of different styles
101
- - Printed digits
102
- - Digits with varying thickness
103
- - Digits with different backgrounds
104
 
105
  ## Technical Details
106
 
107
- This implementation follows the PyTorch tutorial exactly and includes:
108
  - Gradio interface with custom CSS styling
109
- - Image preprocessing pipeline (resize to 32x32, grayscale conversion)
110
- - Softmax probability output (as shown in the tutorial)
111
  - Example generation for demonstration
112
  - Model loading functionality for your trained weights
113
 
114
- The prediction function exactly matches the tutorial:
115
  ```python
116
  model.eval()
117
  with torch.no_grad():
 
1
  ---
2
+ title: CIFAR-10 Image Classifier
3
+ emoji: 🚀
4
  colorFrom: blue
5
  colorTo: red
6
  sdk: gradio
 
9
  pinned: false
10
  ---
11
 
12
+ # CIFAR-10 Image Classifier
13
 
14
+ This is a Gradio interface for a convolutional neural network trained on the CIFAR-10 dataset. The model can classify images into 10 different object categories: Airplane, Automobile, Bird, Cat, Deer, Dog, Frog, Horse, Ship, and Truck.
15
 
16
  ## Model Architecture
17
 
18
+ The neural network has the following architecture (based on the PyTorch CIFAR-10 Tutorial):
19
  - Two convolutional layers with ReLU activation and max pooling
20
  - Three fully connected layers
21
+ - Designed for 32x32 RGB input images
22
 
23
  ```
24
+ Input → Conv2d(3, 6, 5) → ReLU → MaxPool2d(2, 2) →
25
  Conv2d(6, 16, 5) → ReLU → MaxPool2d(2, 2) →
26
  Flatten → Linear(16*5*5, 120) → ReLU →
27
  Linear(120, 84) → ReLU → Linear(84, 10) → Output
 
30
  ## Features
31
 
32
  - Interactive image classification interface with modern UI
33
+ - Example images for each CIFAR-10 class
34
  - Real-time predictions with probability scores
35
  - Support for custom image uploads
36
  - Built-in drawing tool for creating test images
37
  - Responsive design with gradient backgrounds and animations
38
+ - Automatic image preprocessing (resize to 32×32)
39
 
40
  ## How to Use with Your Existing Model
41
 
 
59
 
60
  ### Supported Image Formats
61
  - JPG, PNG, BMP, TIFF, and other common image formats
62
+ - Color images (RGB with 3 channels)
63
  - Any resolution (automatically resized to 32×32 pixels)
64
 
65
  ### Robustness Features
66
  - **Resolution Independence**: Works with images of any size (resized to 32×32)
67
+ - **Color Preservation**: Maintains RGB color information
68
  - **Contrast Handling**: Works with both high and low contrast images
69
  - **Noise Tolerance**: Can handle some image noise
70
  - **Rotation Tolerance**: Some tolerance to slight rotations
71
 
72
  ### Best Practices for Good Results
73
+ 1. **Center the object** in the image area
74
+ 2. **Use clear contrast** between the object and background
75
+ 3. **Fill most of the image** area with the object
76
  4. **Avoid excessive noise** or artifacts
77
+ 5. **Ensure the object is clearly visible**
78
+
79
+ ## CIFAR-10 Classes
80
+
81
+ The model classifies images into these 10 categories:
82
+ 1. **Airplane** - Aircraft flying in the sky
83
+ 2. **Automobile** - Cars and vehicles on the road
84
+ 3. **Bird** - Flying or perched birds
85
+ 4. **Cat** - Domestic cats and felines
86
+ 5. **Deer** - Wild deer and similar animals
87
+ 6. **Dog** - Domestic dogs and canines
88
+ 7. **Frog** - Amphibians like frogs
89
+ 8. **Horse** - Horses and similar animals
90
+ 9. **Ship** - Boats and ships on water
91
+ 10. **Truck** - Trucks and heavy vehicles
92
 
93
  ## Deployment to Hugging Face Spaces
94
 
 
104
 
105
  ## Example Usage
106
 
107
+ The interface comes with simple example images representing each CIFAR-10 class. You can:
108
  1. Click on any example image to load it
109
  2. Upload your own image using the file browser
110
  3. Draw an image using the built-in sketch tool
111
  4. View the classification probabilities for each class
112
 
113
  Try these examples:
114
+ - Simple drawings of objects from each class
115
+ - Photos of objects that match the CIFAR-10 categories
116
+ - Images with varying styles and backgrounds
 
117
 
118
  ## Technical Details
119
 
120
+ This implementation is based on the PyTorch CIFAR-10 tutorial and includes:
121
  - Gradio interface with custom CSS styling
122
+ - Image preprocessing pipeline (resize to 32x32)
123
+ - Softmax probability output
124
  - Example generation for demonstration
125
  - Model loading functionality for your trained weights
126
 
127
+ The prediction function:
128
  ```python
129
  model.eval()
130
  with torch.no_grad():
app.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import gradio as gr
5
+ import numpy as np
6
+ import torchvision.transforms as transforms
7
+ from PIL import Image, ImageDraw
8
+ import os
9
+
10
+ # Define the neural network model - matching your trained model with 3 input channels
11
+ class Net(nn.Module):
12
+ def __init__(self):
13
+ super(Net, self).__init__()
14
+ # 3 input image channels (RGB), 6 output channels, 5x5 square convolution kernel
15
+ self.conv1 = nn.Conv2d(3, 6, 5)
16
+ self.conv2 = nn.Conv2d(6, 16, 5)
17
+ # an affine operation: y = Wx + b
18
+ self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5*5 from image dimension
19
+ self.fc2 = nn.Linear(120, 84)
20
+ self.fc3 = nn.Linear(84, 10)
21
+
22
+ def forward(self, x):
23
+ # Convolution layer C1: 3 input image channels, 6 output channels,
24
+ # 5x5 square convolution, it uses RELU activation function, and
25
+ # outputs a Tensor with size (N, 6, 28, 28), where N is the size of the batch
26
+ c1 = F.relu(self.conv1(x))
27
+ # Subsampling layer S2: 2x2 grid, purely functional,
28
+ # this layer does not have any parameter, and outputs a (N, 6, 14, 14) Tensor
29
+ s2 = F.max_pool2d(c1, (2, 2))
30
+ # Convolution layer C3: 6 input channels, 16 output channels,
31
+ # 5x5 square convolution, it uses RELU activation function, and
32
+ # outputs a (N, 16, 10, 10) Tensor
33
+ c3 = F.relu(self.conv2(s2))
34
+ # Subsampling layer S4: 2x2 grid, purely functional,
35
+ # this layer does not have any parameter, and outputs a (N, 16, 5, 5) Tensor
36
+ s4 = F.max_pool2d(c3, 2)
37
+ # Flatten operation: purely functional, outputs a (N, 400) Tensor
38
+ s4 = torch.flatten(s4, 1)
39
+ # Fully connected layer F5: (N, 400) Tensor input,
40
+ # and outputs a (N, 120) Tensor, it uses RELU activation function
41
+ f5 = F.relu(self.fc1(s4))
42
+ # Fully connected layer F6: (N, 120) Tensor input,
43
+ # and outputs a (N, 84) Tensor, it uses RELU activation function
44
+ f6 = F.relu(self.fc2(f5))
45
+ # Gaussian layer OUTPUT: (N, 84) Tensor input, and
46
+ # outputs a (N, 10) Tensor
47
+ output = self.fc3(f6)
48
+ return output
49
+
50
+ # Initialize the model
51
+ model = Net()
52
+
53
+ # Load the trained model weights
54
+ def load_model():
55
+ model_path = "model.pth" # Update this path to where your model is stored
56
+ if os.path.exists(model_path):
57
+ try:
58
+ # Load the trained model weights
59
+ # Handle different PyTorch versions
60
+ try:
61
+ # For PyTorch 2.6+, we need to set weights_only=False for compatibility
62
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=False))
63
+ except TypeError:
64
+ # For older PyTorch versions that don't support weights_only parameter
65
+ model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
66
+ print("Loaded trained model weights")
67
+ return True
68
+ except Exception as e:
69
+ print(f"Error loading model: {e}")
70
+ return False
71
+ else:
72
+ print("No trained model found at", model_path)
73
+ # Initialize with random weights for demonstration
74
+ for m in model.modules():
75
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
76
+ nn.init.xavier_uniform_(m.weight)
77
+ if m.bias is not None:
78
+ nn.init.constant_(m.bias, 0)
79
+ return False
80
+
81
+ # Preprocessing function for input images - now handles RGB images
82
+ def preprocess_image(image):
83
+ # Resize to 32x32 (expected input size for the network)
84
+ transform = transforms.Compose([
85
+ transforms.Resize((32, 32)),
86
+ transforms.ToTensor(),
87
+ ])
88
+
89
+ image_tensor = transform(image)
90
+ # Add batch dimension (1, 3, 32, 32)
91
+ image_tensor = image_tensor.unsqueeze(0)
92
+ return image_tensor
93
+
94
+ # Prediction function - matches the PyTorch tutorial exactly
95
+ def predict(image):
96
+ if image is None:
97
+ return {f"Class {i}": 0 for i in range(10)}
98
+
99
+ # Preprocess the image
100
+ input_tensor = preprocess_image(image)
101
+
102
+ # Make prediction - exactly as shown in the PyTorch tutorial
103
+ model.eval()
104
+ with torch.no_grad():
105
+ output = model(input_tensor)
106
+ # Apply softmax to get probabilities
107
+ probabilities = F.softmax(output, dim=1)
108
+ probabilities = probabilities.numpy()[0]
109
+
110
+ # Create labels for CIFAR-10 classes
111
+ cifar10_classes = ["Airplane", "Automobile", "Bird", "Cat", "Deer", "Dog", "Frog", "Horse", "Ship", "Truck"]
112
+
113
+ # Return as a dictionary for Gradio
114
+ return {label: float(prob) for label, prob in zip(cifar10_classes, probabilities)}
115
+
116
+ # Create example images representing CIFAR-10 classes
117
+ def create_example_images():
118
+ examples = []
119
+
120
+ # CIFAR-10 class names
121
+ cifar10_classes = ["Airplane", "Automobile", "Bird", "Cat", "Deer", "Dog", "Frog", "Horse", "Ship", "Truck"]
122
+
123
+ # Create simple representations of CIFAR-10 classes
124
+ for i, class_name in enumerate(cifar10_classes):
125
+ # Create a 64x64 RGB image for better quality
126
+ img = Image.new('RGB', (64, 64), color=(255, 255, 255)) # White background
127
+ draw = ImageDraw.Draw(img)
128
+
129
+ # Draw simple representations of each class
130
+ if i == 0: # Airplane
131
+ # Draw a simple airplane shape
132
+ draw.polygon([(32, 10), (20, 30), (44, 30)], fill=(169, 169, 169)) # Main body
133
+ draw.rectangle([25, 30, 39, 35], fill=(105, 105, 105)) # Wings
134
+ draw.rectangle([30, 35, 34, 45], fill=(128, 128, 128)) # Tail
135
+ elif i == 1: # Automobile
136
+ # Draw a simple car shape
137
+ draw.rectangle([15, 30, 49, 45], fill=(0, 0, 255)) # Body
138
+ draw.ellipse([20, 40, 30, 50], fill=(0, 0, 0)) # Wheels
139
+ draw.ellipse([34, 40, 44, 50], fill=(0, 0, 0))
140
+ draw.rectangle([25, 20, 39, 30], fill=(0, 0, 255)) # Top
141
+ elif i == 2: # Bird
142
+ # Draw a simple bird shape
143
+ draw.ellipse([25, 25, 39, 39], fill=(255, 165, 0)) # Body
144
+ draw.polygon([(32, 15), (25, 25), (39, 25)], fill=(255, 140, 0)) # Head
145
+ draw.line([20, 30, 10, 20], fill=(255, 165, 0), width=3) # Wing
146
+ draw.line([44, 30, 54, 20], fill=(255, 165, 0), width=3) # Wing
147
+ elif i == 3: # Cat
148
+ # Draw a simple cat shape
149
+ draw.ellipse([25, 25, 39, 39], fill=(128, 128, 128)) # Body
150
+ draw.ellipse([30, 20, 40, 30], fill=(169, 169, 169)) # Head
151
+ draw.polygon([(35, 22), (33, 27), (37, 27)], fill=(0, 0, 0)) # Ear
152
+ draw.ellipse([32, 28, 34, 30], fill=(0, 0, 0)) # Eye
153
+ elif i == 4: # Deer
154
+ # Draw a simple deer shape
155
+ draw.ellipse([25, 30, 39, 44], fill=(139, 69, 19)) # Body
156
+ draw.ellipse([30, 25, 40, 35], fill=(160, 82, 45)) # Head
157
+ draw.line([35, 15, 40, 25], fill=(139, 69, 19), width=3) # Antler
158
+ draw.line([20, 35, 10, 30], fill=(139, 69, 19), width=2) # Leg
159
+ elif i == 5: # Dog
160
+ # Draw a simple dog shape
161
+ draw.ellipse([25, 30, 39, 44], fill=(139, 69, 19)) # Body
162
+ draw.ellipse([30, 25, 40, 35], fill=(160, 82, 45)) # Head
163
+ draw.ellipse([32, 28, 34, 30], fill=(0, 0, 0)) # Eye
164
+ draw.ellipse([36, 32, 38, 34], fill=(0, 0, 0)) # Nose
165
+ elif i == 6: # Frog
166
+ # Draw a simple frog shape
167
+ draw.ellipse([25, 30, 39, 44], fill=(34, 139, 34)) # Body
168
+ draw.ellipse([30, 25, 40, 35], fill=(0, 100, 0)) # Head
169
+ draw.ellipse([27, 32, 29, 34], fill=(0, 0, 0)) # Eye
170
+ draw.ellipse([35, 32, 37, 34], fill=(0, 0, 0)) # Eye
171
+ elif i == 7: # Horse
172
+ # Draw a simple horse shape
173
+ draw.ellipse([25, 30, 39, 44], fill=(169, 169, 169)) # Body
174
+ draw.ellipse([35, 20, 45, 30], fill=(128, 128, 128)) # Head
175
+ draw.line([40, 25, 50, 15], fill=(105, 105, 105), width=3) # Mane
176
+ elif i == 8: # Ship
177
+ # Draw a simple ship shape
178
+ draw.polygon([(20, 35), (44, 35), (38, 45), (26, 45)], fill=(139, 69, 19)) # Hull
179
+ draw.rectangle([30, 20, 34, 35], fill=(169, 169, 169)) # Mast
180
+ draw.polygon([(30, 20), (32, 15), (34, 20)], fill=(255, 255, 255)) # Sail
181
+ elif i == 9: # Truck
182
+ # Draw a simple truck shape
183
+ draw.rectangle([15, 25, 49, 45], fill=(255, 0, 0)) # Cab
184
+ draw.rectangle([25, 15, 45, 25], fill=(255, 0, 0)) # Load area
185
+ draw.ellipse([20, 40, 30, 50], fill=(0, 0, 0)) # Wheels
186
+ draw.ellipse([34, 40, 44, 50], fill=(0, 0, 0))
187
+
188
+ examples.append(img)
189
+
190
+ return examples
191
+
192
+ # Custom CSS for enhanced UI
193
+ custom_css = """
194
+ @import url('https://fonts.googleapis.com/css2?family=Roboto:wght@300;400;500;700&display=swap');
195
+
196
+ body {
197
+ font-family: 'Roboto', sans-serif;
198
+ background: linear-gradient(135deg, #1a2a6c, #b21f1f, #1a2a6c);
199
+ background-size: 400% 400%;
200
+ animation: gradientBG 15s ease infinite;
201
+ color: white;
202
+ min-height: 100vh;
203
+ }
204
+
205
+ @keyframes gradientBG {
206
+ 0% { background-position: 0% 50%; }
207
+ 50% { background-position: 100% 50%; }
208
+ 100% { background-position: 0% 50%; }
209
+ }
210
+
211
+ .gradio-container {
212
+ background: rgba(0, 0, 0, 0.7) !important;
213
+ backdrop-filter: blur(10px);
214
+ border-radius: 20px !important;
215
+ box-shadow: 0 10px 30px rgba(0, 0, 0, 0.5);
216
+ border: 1px solid rgba(255, 255, 255, 0.1);
217
+ max-width: 1200px !important;
218
+ margin: 20px auto !important;
219
+ }
220
+
221
+ .container {
222
+ max-width: 100% !important;
223
+ }
224
+
225
+ h1 {
226
+ background: linear-gradient(to right, #ff7e5f, #feb47b);
227
+ -webkit-background-clip: text;
228
+ -webkit-text-fill-color: transparent;
229
+ text-align: center;
230
+ font-weight: 700 !important;
231
+ font-size: 2.5em !important;
232
+ margin-bottom: 10px !important;
233
+ text-shadow: 0 2px 4px rgba(0,0,0,0.2);
234
+ }
235
+
236
+ h2 {
237
+ color: #feb47b;
238
+ border-bottom: 2px solid #ff7e5f;
239
+ padding-bottom: 10px;
240
+ }
241
+
242
+ .markdown {
243
+ background: rgba(255, 255, 255, 0.05);
244
+ border-radius: 15px;
245
+ padding: 20px;
246
+ margin-bottom: 20px;
247
+ border: 1px solid rgba(255, 255, 255, 0.1);
248
+ }
249
+
250
+ .gradio-button {
251
+ background: linear-gradient(45deg, #ff7e5f, #feb47b) !important;
252
+ border: none !important;
253
+ color: white !important;
254
+ font-weight: 600 !important;
255
+ transition: all 0.3s ease !important;
256
+ box-shadow: 0 4px 15px rgba(255, 126, 95, 0.3) !important;
257
+ }
258
+
259
+ .gradio-button:hover {
260
+ transform: translateY(-3px) !important;
261
+ box-shadow: 0 6px 20px rgba(255, 126, 95, 0.5) !important;
262
+ }
263
+
264
+ .gradio-button:active {
265
+ transform: translateY(1px) !important;
266
+ }
267
+
268
+ .gradio-image {
269
+ border-radius: 15px !important;
270
+ overflow: hidden !important;
271
+ box-shadow: 0 8px 25px rgba(0, 0, 0, 0.4) !important;
272
+ border: 2px solid rgba(255, 255, 255, 0.1) !important;
273
+ }
274
+
275
+ .gradio-label {
276
+ background: rgba(255, 255, 255, 0.08) !important;
277
+ border-radius: 15px !important;
278
+ padding: 20px !important;
279
+ border: 1px solid rgba(255, 255, 255, 0.1) !important;
280
+ box-shadow: 0 8px 25px rgba(0, 0, 0, 0.3) !important;
281
+ }
282
+
283
+ label {
284
+ color: #feb47b !important;
285
+ font-weight: 500 !important;
286
+ }
287
+
288
+ .examples {
289
+ background: rgba(255, 255, 255, 0.05) !important;
290
+ border-radius: 15px !important;
291
+ padding: 20px !important;
292
+ margin-top: 20px !important;
293
+ border: 1px solid rgba(255, 255, 255, 0.1) !important;
294
+ }
295
+
296
+ footer {
297
+ display: none !important;
298
+ }
299
+
300
+ @media (max-width: 768px) {
301
+ .gradio-container {
302
+ margin: 10px !important;
303
+ }
304
+
305
+ h1 {
306
+ font-size: 2em !important;
307
+ }
308
+ }
309
+ """
310
+
311
+ # Initialize the model
312
+ model_loaded = load_model()
313
+
314
+ # Create the Gradio interface with enhanced styling
315
+ with gr.Blocks(
316
+ title="CIFAR-10 Image Classifier",
317
+ css=custom_css,
318
+ theme=gr.themes.Default(
319
+ font=["Roboto", "Arial", "sans-serif"]
320
+ )
321
+ ) as demo:
322
+ gr.Markdown("""
323
+ # 🚀 CIFAR-10 Image Classifier
324
+ ## Convolutional Neural Network for Object Recognition
325
+
326
+ This is a demonstration of a convolutional neural network trained on the CIFAR-10 dataset.
327
+ The model can classify images into 10 different object categories.
328
+
329
+ The model architecture consists of:
330
+ - 2 Convolutional Layers with ReLU activation
331
+ - 2 MaxPooling Layers
332
+ - 3 Fully Connected Layers
333
+ """)
334
+
335
+ # Show model loading status
336
+ if model_loaded:
337
+ gr.Markdown("✅ Model successfully loaded")
338
+ else:
339
+ gr.Markdown("⚠️ Model not found or error loading. Using random weights for demonstration.")
340
+
341
+ with gr.Row():
342
+ with gr.Column(scale=1):
343
+ gr.Markdown("### 📥 Input")
344
+ input_image = gr.Image(type="pil", label="Upload or Draw an Image", height=300)
345
+ with gr.Row():
346
+ submit_btn = gr.Button("Classify Image", elem_classes=["custom-button"])
347
+ clear_btn = gr.Button("Clear")
348
+
349
+ gr.Markdown("""
350
+ ### 🎯 Model Architecture
351
+ ```
352
+ Input → Conv2D(3×32×32) → ReLU → MaxPool2D
353
+ → Conv2D → ReLU → MaxPool2D
354
+ → Flatten → Linear → ReLU
355
+ → Linear → ReLU → Linear(10)
356
+ → Output
357
+ ```
358
+ """)
359
+
360
+ with gr.Column(scale=1):
361
+ gr.Markdown("### 📊 Classification Results")
362
+ output_label = gr.Label(label="Prediction Probabilities", num_top_classes=5)
363
+
364
+ gr.Markdown("""
365
+ ### ℹ️ Instructions
366
+ 1. Upload an image or draw one using the editor
367
+ 2. The image will be automatically resized to 32×32 pixels
368
+ 3. Click "Classify Image" to get predictions
369
+ 4. Results show probabilities for 10 CIFAR-10 classes
370
+
371
+ ### 📝 Notes
372
+ - Model expects RGB images of 32×32 pixels
373
+ - Trained on the CIFAR-10 dataset
374
+ - Classes: Airplane, Automobile, Bird, Cat, Deer, Dog, Frog, Horse, Ship, Truck
375
+ """)
376
+
377
+ with gr.Row():
378
+ gr.Markdown("### 📋 Example Images")
379
+ gr.Markdown("""
380
+ The examples below show actual CIFAR-10 images.
381
+ Try clicking on any example to load it, or use the drawing tool to create your own images. The model can handle:
382
+ - Various image sizes (automatically resized to 32×32)
383
+ - Both black and white backgrounds
384
+ - Low-resolution images
385
+
386
+ Classes: Airplane, Automobile, Bird, Cat, Deer, Dog, Frog, Horse, Ship, Truck
387
+ """)
388
+
389
+ # Create examples using the compatible format for Gradio 4.0.0
390
+ # Use existing example images from the examples directory
391
+ example_paths = []
392
+ import os
393
+
394
+ # Create examples directory if it doesn't exist
395
+ examples_dir = "examples"
396
+ if not os.path.exists(examples_dir):
397
+ os.makedirs(examples_dir)
398
+
399
+ # Use all example images from the examples directory
400
+ example_paths = []
401
+ cifar10_classes = ["Airplane", "Automobile", "Bird", "Cat", "Deer", "Dog", "Frog", "Horse", "Ship", "Truck"]
402
+
403
+ for i in range(10):
404
+ example_path = os.path.join(examples_dir, f"example_{i}.png")
405
+ # All example images should now exist in the directory
406
+ if os.path.exists(example_path):
407
+ example_paths.append(example_path)
408
+
409
+ gr.Examples(
410
+ examples=example_paths,
411
+ inputs=input_image,
412
+ outputs=output_label,
413
+ fn=predict,
414
+ cache_examples=True
415
+ )
416
+
417
+ gr.Markdown("""
418
+ ### 🧪 Testing Different Image Qualities
419
+
420
+ This model is robust to various image conditions:
421
+ - **Resolution**: Works with images of any resolution (automatically resized to 32×32)
422
+ - **Contrast**: Handles both high and low contrast images
423
+ - **Noise**: Can tolerate some image noise
424
+ - **Rotation**: Some tolerance to slight rotations
425
+ - **Scale**: Works with objects of different sizes within the image
426
+
427
+ For best results:
428
+ 1. Center the object in the image
429
+ 2. Use clear contrast between the object and background
430
+ 3. Avoid excessive noise or artifacts
431
+ 4. Fill most of the image area with the object
432
+
433
+ ### 🎯 CIFAR-10 Classes
434
+
435
+ The model can classify images into these 10 categories:
436
+ 1. **Airplane** - Aircraft flying in the sky
437
+ 2. **Automobile** - Cars and vehicles on the road
438
+ 3. **Bird** - Flying or perched birds
439
+ 4. **Cat** - Domestic cats and felines
440
+ 5. **Deer** - Wild deer and similar animals
441
+ 6. **Dog** - Domestic dogs and canines
442
+ 7. **Frog** - Amphibians like frogs
443
+ 8. **Horse** - Horses and similar animals
444
+ 9. **Ship** - Boats and ships on water
445
+ 10. **Truck** - Trucks and heavy vehicles
446
+ """)
447
+
448
+ # Event handling
449
+ submit_btn.click(
450
+ fn=predict,
451
+ inputs=input_image,
452
+ outputs=output_label
453
+ )
454
+
455
+ clear_btn.click(
456
+ fn=lambda: (None, {cifar10_class: 0 for cifar10_class in ["Airplane", "Automobile", "Bird", "Cat", "Deer", "Dog", "Frog", "Horse", "Ship", "Truck"]}),
457
+ inputs=None,
458
+ outputs=[input_image, output_label]
459
+ )
460
+
461
+ # Allow image upload to trigger prediction automatically
462
+ input_image.change(
463
+ fn=predict,
464
+ inputs=input_image,
465
+ outputs=output_label
466
+ )
467
+
468
+ # Launch the app
469
+ if __name__ == "__main__":
470
+ demo.launch(share=True)
model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a43cea8f5cb7725b1d5074767f28d1f5c4ff81d5b1435ba2350dd7b7b77a6a63
3
+ size 252005
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch>=1.7.0
2
+ torchvision>=0.8.0
3
+ gradio>=4.44.1
4
+ pillow>=8.0.0
5
+ numpy>=1.19.0
space.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "title": "CIFAR-10 Image Classifier",
3
+ "sdk": "gradio",
4
+ "sdk_version": "4.44.1",
5
+ "app_file": "app.py",
6
+ "requirements": [
7
+ "torch",
8
+ "torchvision",
9
+ "gradio",
10
+ "pillow",
11
+ "numpy"
12
+ ]
13
+ }
test_model.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+
4
+ # Check if model file exists
5
+ model_path = "model.pth"
6
+ if os.path.exists(model_path):
7
+ print(f"Model file exists at {model_path}")
8
+ print(f"File size: {os.path.getsize(model_path)} bytes")
9
+
10
+ try:
11
+ # Try to load the model
12
+ model_data = torch.load(model_path, map_location=torch.device('cpu'))
13
+ print("Model loaded successfully!")
14
+ print(f"Model type: {type(model_data)}")
15
+ if isinstance(model_data, dict):
16
+ print(f"Model keys: {list(model_data.keys())}")
17
+ except Exception as e:
18
+ print(f"Error loading model: {e}")
19
+ else:
20
+ print(f"Model file not found at {model_path}")