angelahzyuan commited on
Commit
56549e9
·
verified ·
1 Parent(s): 375e6d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -17
app.py CHANGED
@@ -6,7 +6,6 @@ import numpy as np
6
  import spaces
7
 
8
 
9
- MODEL="UCLA-AGI/SPIN-Diffusion-iter3"
10
 
11
  def set_seed(seed=5775709):
12
  random.seed(seed)
@@ -17,24 +16,15 @@ def set_seed(seed=5775709):
17
 
18
  set_seed()
19
 
 
 
20
 
21
-
22
- def get_pipeline(device='cuda'):
23
- model_id = "runwayml/stable-diffusion-v1-5"
24
- #pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker = None, requires_safety_checker = False)
25
- if torch.cuda.is_available():
26
- pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32)
27
-
28
- # load finetuned model
29
- unet_id = MODEL
30
- unet = UNet2DConditionModel.from_pretrained(unet_id, subfolder="unet", torch_dtype=torch.float32)
31
- pipe.unet = unet
32
- pipe = pipe.to(device)
33
- return pipe
34
 
35
  @spaces.GPU(enable_queue=True)
36
  def generate(prompt: str, num_images: int=5, guidance_scale=7.5):
37
- pipe = get_pipeline()
38
  generator = torch.Generator(pipe.device).manual_seed(5775709)
39
  # Ensure num_images is an integer
40
  num_images = int(num_images)
@@ -69,5 +59,5 @@ with gr.Blocks() as demo:
69
 
70
  generate_btn.click(fn=generate, inputs=[prompt_input, num_images_input, guidance_scale], outputs=gallery)
71
 
72
- if __name__ == "__main__":
73
- demo.launch(share=True)
 
6
  import spaces
7
 
8
 
 
9
 
10
  def set_seed(seed=5775709):
11
  random.seed(seed)
 
16
 
17
  set_seed()
18
 
19
+ if torch.cuda.is_available():
20
+ pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, variant="fp16")
21
 
22
+ unet = UNet2DConditionModel.from_pretrained("UCLA-AGI/SPIN-Diffusion-iter3", subfolder="unet", torch_dtype=torch.float16)
23
+ pipe.unet = unet
24
+ pipe = pipe.to("cuda")
 
 
 
 
 
 
 
 
 
 
25
 
26
  @spaces.GPU(enable_queue=True)
27
  def generate(prompt: str, num_images: int=5, guidance_scale=7.5):
 
28
  generator = torch.Generator(pipe.device).manual_seed(5775709)
29
  # Ensure num_images is an integer
30
  num_images = int(num_images)
 
59
 
60
  generate_btn.click(fn=generate, inputs=[prompt_input, num_images_input, guidance_scale], outputs=gallery)
61
 
62
+
63
+ demo.queue().launch()