Spaces:
Paused
Paused
Commit
·
568d1c7
1
Parent(s):
d6f9b71
Update app.py
Browse files
app.py
CHANGED
|
@@ -52,10 +52,16 @@ def spherical_dist_loss(x, y):
|
|
| 52 |
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
|
| 53 |
|
| 54 |
cc12m_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1_cfg.pth")
|
|
|
|
| 55 |
model = get_model('cc12m_1_cfg')()
|
| 56 |
_, side_y, side_x = model.shape
|
| 57 |
model.load_state_dict(torch.load(cc12m_model, map_location='cpu'))
|
| 58 |
model = model.half().cuda().eval().requires_grad_(False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
clip_model = clip.load(model.clip_model, jit=False, device='cuda')[0]
|
| 60 |
clip_model.eval().requires_grad_(False)
|
| 61 |
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
|
|
@@ -138,7 +144,7 @@ def run_all(prompt, steps, n_images, weight, clip_guided):
|
|
| 138 |
else:
|
| 139 |
extra_args = {'clip_embed': clip_embed}
|
| 140 |
cond_fn_ = cond_fn
|
| 141 |
-
model_fn = make_cond_model_fn(
|
| 142 |
outs = sampling.plms_sample(model_fn, x, step_list, extra_args)
|
| 143 |
images_out = []
|
| 144 |
for i, out in enumerate(outs):
|
|
|
|
| 52 |
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
|
| 53 |
|
| 54 |
cc12m_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1_cfg.pth")
|
| 55 |
+
cc12m_small_model = hf_hub_download(repo_id="multimodalart/crowsonkb-v-diffusion-cc12m-1-cfg", filename="cc12m_1.pth")
|
| 56 |
model = get_model('cc12m_1_cfg')()
|
| 57 |
_, side_y, side_x = model.shape
|
| 58 |
model.load_state_dict(torch.load(cc12m_model, map_location='cpu'))
|
| 59 |
model = model.half().cuda().eval().requires_grad_(False)
|
| 60 |
+
|
| 61 |
+
model_small = get_model('cc12m_1')()
|
| 62 |
+
model_small.load_state_dict(torch.load(cc12m_model, map_location='cpu'))
|
| 63 |
+
model_small = model.half().cuda().eval().requires_grad_(False)
|
| 64 |
+
|
| 65 |
clip_model = clip.load(model.clip_model, jit=False, device='cuda')[0]
|
| 66 |
clip_model.eval().requires_grad_(False)
|
| 67 |
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
|
|
|
|
| 144 |
else:
|
| 145 |
extra_args = {'clip_embed': clip_embed}
|
| 146 |
cond_fn_ = cond_fn
|
| 147 |
+
model_fn = make_cond_model_fn(model_small, cond_fn_)
|
| 148 |
outs = sampling.plms_sample(model_fn, x, step_list, extra_args)
|
| 149 |
images_out = []
|
| 150 |
for i, out in enumerate(outs):
|