ford442 commited on
Commit
065b416
·
verified ·
1 Parent(s): e794bf9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -1
app.py CHANGED
@@ -105,6 +105,9 @@ import torch
105
  import time
106
  import gc
107
 
 
 
 
108
  torch.backends.cuda.matmul.allow_tf32 = False
109
  torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
110
  torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
@@ -388,7 +391,6 @@ def uploadNote(prompt,num_inference_steps,guidance_scale,timestamp):
388
  return filename
389
 
390
  '''
391
-
392
  pyx = cyper.inline(code, fast_indexing=True, directives=dict(boundscheck=False, wraparound=False, language_level=3))
393
 
394
  @spaces.GPU(duration=40)
@@ -401,9 +403,14 @@ def generate_30(
401
  height: int = 768,
402
  guidance_scale: float = 4,
403
  num_inference_steps: int = 125,
 
404
  use_resolution_binning: bool = True,
405
  progress=gr.Progress(track_tqdm=True)
406
  ):
 
 
 
 
407
  seed = random.randint(0, MAX_SEED)
408
  generator = torch.Generator(device='cuda').manual_seed(seed)
409
  options = {
@@ -450,6 +457,7 @@ def generate_60(
450
  height: int = 768,
451
  guidance_scale: float = 4,
452
  num_inference_steps: int = 125,
 
453
  use_resolution_binning: bool = True,
454
  progress=gr.Progress(track_tqdm=True)
455
  ):
@@ -494,6 +502,7 @@ def generate_90(
494
  height: int = 768,
495
  guidance_scale: float = 4,
496
  num_inference_steps: int = 125,
 
497
  use_resolution_binning: bool = True,
498
  progress=gr.Progress(track_tqdm=True)
499
  ):
@@ -622,6 +631,15 @@ with gr.Blocks(theme=gr.themes.Origin(),css=css) as demo:
622
  step=10,
623
  value=180,
624
  )
 
 
 
 
 
 
 
 
 
625
 
626
  gr.Examples(
627
  examples=examples,
@@ -651,6 +669,7 @@ with gr.Blocks(theme=gr.themes.Origin(),css=css) as demo:
651
  height,
652
  guidance_scale,
653
  num_inference_steps,
 
654
  ],
655
  outputs=[result],
656
  )
@@ -670,6 +689,7 @@ with gr.Blocks(theme=gr.themes.Origin(),css=css) as demo:
670
  height,
671
  guidance_scale,
672
  num_inference_steps,
 
673
  ],
674
  outputs=[result],
675
  )
@@ -689,6 +709,7 @@ with gr.Blocks(theme=gr.themes.Origin(),css=css) as demo:
689
  height,
690
  guidance_scale,
691
  num_inference_steps,
 
692
  ],
693
  outputs=[result],
694
  )
 
105
  import time
106
  import gc
107
 
108
+ import torch.nn.functional as F
109
+ from sageattention import sageattn
110
+
111
  torch.backends.cuda.matmul.allow_tf32 = False
112
  torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False
113
  torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
 
391
  return filename
392
 
393
  '''
 
394
  pyx = cyper.inline(code, fast_indexing=True, directives=dict(boundscheck=False, wraparound=False, language_level=3))
395
 
396
  @spaces.GPU(duration=40)
 
403
  height: int = 768,
404
  guidance_scale: float = 4,
405
  num_inference_steps: int = 125,
406
+ sage: bool = False,
407
  use_resolution_binning: bool = True,
408
  progress=gr.Progress(track_tqdm=True)
409
  ):
410
+ if Sage==True:
411
+ F.scaled_dot_product_attention = sageattn
412
+ if Sage==False:
413
+ F.scaled_dot_product_attention = F.scaled_dot_product_attention
414
  seed = random.randint(0, MAX_SEED)
415
  generator = torch.Generator(device='cuda').manual_seed(seed)
416
  options = {
 
457
  height: int = 768,
458
  guidance_scale: float = 4,
459
  num_inference_steps: int = 125,
460
+ sage: bool = False,
461
  use_resolution_binning: bool = True,
462
  progress=gr.Progress(track_tqdm=True)
463
  ):
 
502
  height: int = 768,
503
  guidance_scale: float = 4,
504
  num_inference_steps: int = 125,
505
+ sage: bool = False,
506
  use_resolution_binning: bool = True,
507
  progress=gr.Progress(track_tqdm=True)
508
  ):
 
631
  step=10,
632
  value=180,
633
  )
634
+ options = [True, False]
635
+ sage = gr.Radio(
636
+ show_label=True,
637
+ container=True,
638
+ interactive=True,
639
+ choices=options,
640
+ value=False,
641
+ label="Use SageAttention: ",
642
+ )
643
 
644
  gr.Examples(
645
  examples=examples,
 
669
  height,
670
  guidance_scale,
671
  num_inference_steps,
672
+ sage,
673
  ],
674
  outputs=[result],
675
  )
 
689
  height,
690
  guidance_scale,
691
  num_inference_steps,
692
+ sage,
693
  ],
694
  outputs=[result],
695
  )
 
709
  height,
710
  guidance_scale,
711
  num_inference_steps,
712
+ sage,
713
  ],
714
  outputs=[result],
715
  )