Spaces:
Paused
Paused
Commit
·
e3d2366
1
Parent(s):
d8035da
Update app.py
Browse files
app.py
CHANGED
|
@@ -71,14 +71,18 @@ def run_all(prompt, steps, n_images, weight, clip_guided):
|
|
| 71 |
torch.manual_seed(seed)
|
| 72 |
x = torch.randn([n_images, 3, side_y, side_x], device='cuda')
|
| 73 |
t = torch.linspace(1, 0, steps + 1, device='cuda')[:-1]
|
| 74 |
-
step_list = utils.get_spliced_ddpm_cosine_schedule(t)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
if(not clip_guided):
|
| 76 |
outs = sampling.plms_sample(cfg_model_fn, x, step_list, {})#, callback=display_callback)
|
| 77 |
else:
|
| 78 |
extra_args = {'clip_embed': clip_embed}
|
| 79 |
cond_fn_ = cond_fn
|
| 80 |
model_fn = make_cond_model_fn(model, cond_fn_)
|
| 81 |
-
outs = sampling.plms_sample(model_fn, x,
|
| 82 |
images_out = []
|
| 83 |
for i, out in enumerate(outs):
|
| 84 |
images_out.append(utils.to_pil_image(out))
|
|
|
|
| 71 |
torch.manual_seed(seed)
|
| 72 |
x = torch.randn([n_images, 3, side_y, side_x], device='cuda')
|
| 73 |
t = torch.linspace(1, 0, steps + 1, device='cuda')[:-1]
|
| 74 |
+
#step_list = utils.get_spliced_ddpm_cosine_schedule(t)
|
| 75 |
+
if model.min_t == 0:
|
| 76 |
+
step_list = utils.get_spliced_ddpm_cosine_schedule(t)
|
| 77 |
+
else:
|
| 78 |
+
step_list = utils.get_ddpm_schedule(t)
|
| 79 |
if(not clip_guided):
|
| 80 |
outs = sampling.plms_sample(cfg_model_fn, x, step_list, {})#, callback=display_callback)
|
| 81 |
else:
|
| 82 |
extra_args = {'clip_embed': clip_embed}
|
| 83 |
cond_fn_ = cond_fn
|
| 84 |
model_fn = make_cond_model_fn(model, cond_fn_)
|
| 85 |
+
outs = sampling.plms_sample(model_fn, x, step_list, extra_args)
|
| 86 |
images_out = []
|
| 87 |
for i, out in enumerate(outs):
|
| 88 |
images_out.append(utils.to_pil_image(out))
|