JohnWeck commited on
Commit
08587f3
·
verified ·
1 Parent(s): 898c085

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -6
app.py CHANGED
@@ -30,12 +30,10 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
30
  text_encoder = CLIPTextModel.from_pretrained(model_repo_id, subfolder="text_encoder")
31
  vae = AutoencoderKL.from_pretrained(model_repo_id, subfolder="vae")
32
  unet = UNet2DConditionModel.from_config(model_repo_id, subfolder="unet")
33
- try:
34
- medsegfactory_ckpt = hf_hub_download(repo_id=medsegfactory_id, filename=filename)
35
- unet.load_state_dict(torch.load(medsegfactory_ckpt, map_location='cpu'))
36
- except Exception as e:
37
- print("[Error] Failed to load checkpoint:", e)
38
- pipeline = None
39
  vae.requires_grad_(False)
40
  text_encoder.requires_grad_(False)
41
 
 
30
  text_encoder = CLIPTextModel.from_pretrained(model_repo_id, subfolder="text_encoder")
31
  vae = AutoencoderKL.from_pretrained(model_repo_id, subfolder="vae")
32
  unet = UNet2DConditionModel.from_config(model_repo_id, subfolder="unet")
33
+
34
+ medsegfactory_ckpt = hf_hub_download(repo_id=medsegfactory_id, filename=filename)
35
+ unet.load_state_dict(torch.load(medsegfactory_ckpt, map_location='cpu'))
36
+
 
 
37
  vae.requires_grad_(False)
38
  text_encoder.requires_grad_(False)
39