Spaces:
Configuration error
Configuration error
parokshsaxena
commited on
Commit
Β·
d52990b
1
Parent(s):
72b00c6
using enhanced garment net based on the claude suggestions
Browse files- src/enhanced_garment_net.py +123 -0
- src/tryon_pipeline.py +5 -1
src/enhanced_garment_net.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
class ResidualBlock(nn.Module):
|
| 6 |
+
def __init__(self, in_channels, out_channels):
|
| 7 |
+
super(ResidualBlock, self).__init__()
|
| 8 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
|
| 9 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
|
| 10 |
+
self.bn1 = nn.BatchNorm2d(out_channels)
|
| 11 |
+
self.bn2 = nn.BatchNorm2d(out_channels)
|
| 12 |
+
self.relu = nn.ReLU(inplace=True)
|
| 13 |
+
self.downsample = nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else None
|
| 14 |
+
|
| 15 |
+
def forward(self, x):
|
| 16 |
+
residual = x
|
| 17 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
| 18 |
+
out = self.bn2(self.conv2(out))
|
| 19 |
+
if self.downsample:
|
| 20 |
+
residual = self.downsample(x)
|
| 21 |
+
out += residual
|
| 22 |
+
return self.relu(out)
|
| 23 |
+
|
| 24 |
+
class EnhancedGarmentNet(nn.Module):
|
| 25 |
+
def __init__(self, in_channels=3, base_channels=64, num_residual_blocks=4):
|
| 26 |
+
super(EnhancedGarmentNet, self).__init__()
|
| 27 |
+
|
| 28 |
+
self.initial = nn.Sequential(
|
| 29 |
+
nn.Conv2d(in_channels, base_channels, kernel_size=7, padding=3),
|
| 30 |
+
nn.BatchNorm2d(base_channels),
|
| 31 |
+
nn.ReLU(inplace=True)
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
self.encoder1 = self._make_layer(base_channels, base_channels, num_residual_blocks)
|
| 35 |
+
self.encoder2 = self._make_layer(base_channels, base_channels*2, num_residual_blocks)
|
| 36 |
+
self.encoder3 = self._make_layer(base_channels*2, base_channels*4, num_residual_blocks)
|
| 37 |
+
|
| 38 |
+
self.bridge = self._make_layer(base_channels*4, base_channels*8, num_residual_blocks)
|
| 39 |
+
|
| 40 |
+
self.decoder3 = self._make_layer(base_channels*8, base_channels*4, num_residual_blocks)
|
| 41 |
+
self.decoder2 = self._make_layer(base_channels*4, base_channels*2, num_residual_blocks)
|
| 42 |
+
self.decoder1 = self._make_layer(base_channels*2, base_channels, num_residual_blocks)
|
| 43 |
+
|
| 44 |
+
self.final = nn.Conv2d(base_channels, in_channels, kernel_size=7, padding=3)
|
| 45 |
+
|
| 46 |
+
self.downsample = nn.MaxPool2d(2)
|
| 47 |
+
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
| 48 |
+
|
| 49 |
+
def _make_layer(self, in_channels, out_channels, num_blocks):
|
| 50 |
+
layers = []
|
| 51 |
+
layers.append(ResidualBlock(in_channels, out_channels))
|
| 52 |
+
for _ in range(1, num_blocks):
|
| 53 |
+
layers.append(ResidualBlock(out_channels, out_channels))
|
| 54 |
+
return nn.Sequential(*layers)
|
| 55 |
+
|
| 56 |
+
def forward(self, x):
|
| 57 |
+
# Initial convolution
|
| 58 |
+
x = self.initial(x)
|
| 59 |
+
|
| 60 |
+
# Encoder
|
| 61 |
+
e1 = self.encoder1(x)
|
| 62 |
+
e2 = self.encoder2(self.downsample(e1))
|
| 63 |
+
e3 = self.encoder3(self.downsample(e2))
|
| 64 |
+
|
| 65 |
+
# Bridge
|
| 66 |
+
b = self.bridge(self.downsample(e3))
|
| 67 |
+
|
| 68 |
+
# Decoder with skip connections
|
| 69 |
+
d3 = self.decoder3(torch.cat([self.upsample(b), e3], dim=1))
|
| 70 |
+
d2 = self.decoder2(torch.cat([self.upsample(d3), e2], dim=1))
|
| 71 |
+
d1 = self.decoder1(torch.cat([self.upsample(d2), e1], dim=1))
|
| 72 |
+
|
| 73 |
+
# Final convolution
|
| 74 |
+
out = self.final(d1)
|
| 75 |
+
|
| 76 |
+
return out, [e1, e2, e3, b]
|
| 77 |
+
|
| 78 |
+
class EnhancedGarmentNetWithTimestep(nn.Module):
|
| 79 |
+
def __init__(self, in_channels=3, base_channels=64, num_residual_blocks=4, time_emb_dim=256):
|
| 80 |
+
super(EnhancedGarmentNetWithTimestep, self).__init__()
|
| 81 |
+
|
| 82 |
+
self.garment_net = EnhancedGarmentNet(in_channels, base_channels, num_residual_blocks)
|
| 83 |
+
|
| 84 |
+
# Timestep embedding
|
| 85 |
+
self.time_mlp = nn.Sequential(
|
| 86 |
+
nn.Linear(1, time_emb_dim),
|
| 87 |
+
nn.SiLU(),
|
| 88 |
+
nn.Linear(time_emb_dim, time_emb_dim)
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Projection for text embeddings
|
| 92 |
+
self.text_proj = nn.Linear(768, time_emb_dim) # Assuming text embeddings are 768-dimensional
|
| 93 |
+
|
| 94 |
+
# Combine garment features with time and text embeddings
|
| 95 |
+
self.combine = nn.ModuleList([
|
| 96 |
+
nn.Conv2d(base_channels + time_emb_dim, base_channels, kernel_size=1),
|
| 97 |
+
nn.Conv2d(base_channels*2 + time_emb_dim, base_channels*2, kernel_size=1),
|
| 98 |
+
nn.Conv2d(base_channels*4 + time_emb_dim, base_channels*4, kernel_size=1),
|
| 99 |
+
nn.Conv2d(base_channels*8 + time_emb_dim, base_channels*8, kernel_size=1)
|
| 100 |
+
])
|
| 101 |
+
|
| 102 |
+
def forward(self, x, t, text_embeds):
|
| 103 |
+
# Get garment features
|
| 104 |
+
garment_out, garment_features = self.garment_net(x)
|
| 105 |
+
|
| 106 |
+
# Process timestep
|
| 107 |
+
t_emb = self.time_mlp(t.unsqueeze(-1)).unsqueeze(-1).unsqueeze(-1)
|
| 108 |
+
|
| 109 |
+
# Process text embeddings
|
| 110 |
+
text_emb = self.text_proj(text_embeds).unsqueeze(-1).unsqueeze(-1)
|
| 111 |
+
|
| 112 |
+
# Combine embeddings
|
| 113 |
+
cond_emb = t_emb + text_emb
|
| 114 |
+
|
| 115 |
+
# Combine garment features with conditional embedding
|
| 116 |
+
combined_features = []
|
| 117 |
+
for feat, comb_layer in zip(garment_features, self.combine):
|
| 118 |
+
# Expand conditional embedding to match feature map size
|
| 119 |
+
expanded_cond_emb = cond_emb.expand(-1, -1, feat.size(2), feat.size(3))
|
| 120 |
+
combined = comb_layer(torch.cat([feat, expanded_cond_emb], dim=1))
|
| 121 |
+
combined_features.append(combined)
|
| 122 |
+
|
| 123 |
+
return garment_out, combined_features
|
src/tryon_pipeline.py
CHANGED
|
@@ -56,6 +56,8 @@ from diffusers.utils import (
|
|
| 56 |
from diffusers.utils.torch_utils import randn_tensor
|
| 57 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 58 |
|
|
|
|
|
|
|
| 59 |
|
| 60 |
|
| 61 |
if is_torch_xla_available():
|
|
@@ -398,6 +400,7 @@ class StableDiffusionXLInpaintPipeline(
|
|
| 398 |
force_zeros_for_empty_prompt: bool = True,
|
| 399 |
):
|
| 400 |
super().__init__()
|
|
|
|
| 401 |
|
| 402 |
self.register_modules(
|
| 403 |
vae=vae,
|
|
@@ -1781,7 +1784,8 @@ class StableDiffusionXLInpaintPipeline(
|
|
| 1781 |
if ip_adapter_image is not None:
|
| 1782 |
added_cond_kwargs["image_embeds"] = image_embeds
|
| 1783 |
# down,reference_features = self.UNet_Encoder(cloth,t, text_embeds_cloth,added_cond_kwargs= {"text_embeds": pooled_prompt_embeds_c, "time_ids": add_time_ids},return_dict=False)
|
| 1784 |
-
down,reference_features = self.unet_encoder(cloth,t, text_embeds_cloth,return_dict=False)
|
|
|
|
| 1785 |
# print(type(reference_features))
|
| 1786 |
# print(reference_features)
|
| 1787 |
reference_features = list(reference_features)
|
|
|
|
| 56 |
from diffusers.utils.torch_utils import randn_tensor
|
| 57 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 58 |
|
| 59 |
+
from enhanced_garment_net import EnhancedGarmentNetWithTimestep
|
| 60 |
+
|
| 61 |
|
| 62 |
|
| 63 |
if is_torch_xla_available():
|
|
|
|
| 400 |
force_zeros_for_empty_prompt: bool = True,
|
| 401 |
):
|
| 402 |
super().__init__()
|
| 403 |
+
self.garment_net = EnhancedGarmentNetWithTimestep()
|
| 404 |
|
| 405 |
self.register_modules(
|
| 406 |
vae=vae,
|
|
|
|
| 1784 |
if ip_adapter_image is not None:
|
| 1785 |
added_cond_kwargs["image_embeds"] = image_embeds
|
| 1786 |
# down,reference_features = self.UNet_Encoder(cloth,t, text_embeds_cloth,added_cond_kwargs= {"text_embeds": pooled_prompt_embeds_c, "time_ids": add_time_ids},return_dict=False)
|
| 1787 |
+
# down,reference_features = self.unet_encoder(cloth,t, text_embeds_cloth,return_dict=False)
|
| 1788 |
+
garment_out, reference_features = self.garment_net(cloth, t, text_embeds_cloth)
|
| 1789 |
# print(type(reference_features))
|
| 1790 |
# print(reference_features)
|
| 1791 |
reference_features = list(reference_features)
|