Spaces:
Runtime error
Runtime error
Commit
·
4a10914
1
Parent(s):
2620eb0
different sizes compatibility
Browse files
app.py
CHANGED
|
@@ -2,7 +2,7 @@ import gradio as gr
|
|
| 2 |
import torch
|
| 3 |
from PIL import Image
|
| 4 |
from torchvision import transforms
|
| 5 |
-
from utils import normalize_lab, denormalize_lab
|
| 6 |
from model import Generator
|
| 7 |
import kornia.color as color
|
| 8 |
|
|
@@ -15,15 +15,13 @@ model = model.to(device)
|
|
| 15 |
model.eval()
|
| 16 |
|
| 17 |
|
| 18 |
-
# Define preprocessing transforms
|
| 19 |
-
transform = transforms.Compose([
|
| 20 |
-
transforms.Resize((256, 256), Image.BICUBIC),
|
| 21 |
-
transforms.ToTensor(),
|
| 22 |
-
])
|
| 23 |
-
|
| 24 |
-
|
| 25 |
def preprocess(image):
|
| 26 |
image = image.convert('RGB')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
image = transform(image)
|
| 28 |
image = image.to(device)
|
| 29 |
image = color.rgb_to_lab(image)
|
|
@@ -33,8 +31,13 @@ def preprocess(image):
|
|
| 33 |
print(L.shape)
|
| 34 |
return L.unsqueeze(0)
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
def predict(image):
|
|
|
|
| 38 |
L = preprocess(image)
|
| 39 |
with torch.no_grad():
|
| 40 |
output = model(L)
|
|
@@ -42,6 +45,7 @@ def predict(image):
|
|
| 42 |
L, ab = denormalize_lab(L, output)
|
| 43 |
output = torch.cat([L, ab], dim=1)
|
| 44 |
output = color.lab_to_rgb(output)
|
|
|
|
| 45 |
image = transforms.ToPILImage()(output.squeeze().cpu())
|
| 46 |
|
| 47 |
return image
|
|
|
|
| 2 |
import torch
|
| 3 |
from PIL import Image
|
| 4 |
from torchvision import transforms
|
| 5 |
+
from utils import normalize_lab, denormalize_lab, pad_image
|
| 6 |
from model import Generator
|
| 7 |
import kornia.color as color
|
| 8 |
|
|
|
|
| 15 |
model.eval()
|
| 16 |
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
def preprocess(image):
|
| 19 |
image = image.convert('RGB')
|
| 20 |
+
image = pad_image(image)
|
| 21 |
+
transform = transforms.Compose([
|
| 22 |
+
#transforms.Resize((height, width), Image.BICUBIC),
|
| 23 |
+
transforms.ToTensor(),
|
| 24 |
+
])
|
| 25 |
image = transform(image)
|
| 26 |
image = image.to(device)
|
| 27 |
image = color.rgb_to_lab(image)
|
|
|
|
| 31 |
print(L.shape)
|
| 32 |
return L.unsqueeze(0)
|
| 33 |
|
| 34 |
+
def crop_to_original_size(image, original_size):
|
| 35 |
+
width, height = original_size
|
| 36 |
+
return transforms.functional.crop(image, top=0, left=0, height=height, width=width)
|
| 37 |
+
|
| 38 |
|
| 39 |
def predict(image):
|
| 40 |
+
original_size = image.size
|
| 41 |
L = preprocess(image)
|
| 42 |
with torch.no_grad():
|
| 43 |
output = model(L)
|
|
|
|
| 45 |
L, ab = denormalize_lab(L, output)
|
| 46 |
output = torch.cat([L, ab], dim=1)
|
| 47 |
output = color.lab_to_rgb(output)
|
| 48 |
+
output = crop_to_original_size(output, original_size)
|
| 49 |
image = transforms.ToPILImage()(output.squeeze().cpu())
|
| 50 |
|
| 51 |
return image
|
utils.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
def normalize_lab(L, ab):
|
| 2 |
"""
|
| 3 |
Normalize the L and ab channels of an image in Lab color space.
|
|
@@ -15,4 +17,33 @@ def denormalize_lab(L, ab):
|
|
| 15 |
"""
|
| 16 |
L = (L + 1) * 50.
|
| 17 |
ab = ab * 110.
|
| 18 |
-
return L, ab
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torchvision import transforms
|
| 2 |
+
|
| 3 |
def normalize_lab(L, ab):
|
| 4 |
"""
|
| 5 |
Normalize the L and ab channels of an image in Lab color space.
|
|
|
|
| 17 |
"""
|
| 18 |
L = (L + 1) * 50.
|
| 19 |
ab = ab * 110.
|
| 20 |
+
return L, ab
|
| 21 |
+
|
| 22 |
+
def decide_size(image):
|
| 23 |
+
height = image.size[1]
|
| 24 |
+
width = image.size[0]
|
| 25 |
+
|
| 26 |
+
new_height = 2
|
| 27 |
+
new_width = 2
|
| 28 |
+
|
| 29 |
+
while new_height < height:
|
| 30 |
+
new_height *= 2
|
| 31 |
+
while new_width < width:
|
| 32 |
+
new_width *= 2
|
| 33 |
+
|
| 34 |
+
return new_height, new_width
|
| 35 |
+
|
| 36 |
+
def pad_image(image):
|
| 37 |
+
height = image.size[1]
|
| 38 |
+
width = image.size[0]
|
| 39 |
+
|
| 40 |
+
new_height, new_width = decide_size(image)
|
| 41 |
+
|
| 42 |
+
pad_height = new_height - height
|
| 43 |
+
pad_width = new_width - width
|
| 44 |
+
|
| 45 |
+
padding = (0, 0, pad_width, pad_height)
|
| 46 |
+
|
| 47 |
+
image = transforms.Pad(padding, padding_mode='reflect')(image)
|
| 48 |
+
|
| 49 |
+
return image
|