aiqcamp commited on
Commit
fc668b9
ยท
verified ยท
1 Parent(s): 9aac7f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -45,13 +45,14 @@ pipe.scheduler = TCDScheduler.from_config(pipe.scheduler.config)
45
 
46
  # ํ…์ŠคํŠธ ์ธ์ฝ”๋”๋ฅผ float16์œผ๋กœ ๊ฐ•์ œ ๋ณ€ํ™˜
47
  pipe.text_encoder = pipe.text_encoder.to("cuda", dtype=torch.float16)
48
- # ์ถ”๊ฐ€: text_projection์˜ forward๋ฅผ ์˜ค๋ฒ„๋ผ์ด๋”ฉํ•˜์—ฌ ์ž…๋ ฅ์ด float16์ด ์•„๋‹ˆ๋ฉด half๋กœ ์บ์ŠคํŒ…
49
- original_text_projection_forward = pipe.text_encoder.text_projection.forward
50
- def fixed_text_projection_forward(x):
51
- if x.dtype != torch.float16:
52
- x = x.half()
53
- return original_text_projection_forward(x)
54
- pipe.text_encoder.text_projection.forward = fixed_text_projection_forward
 
55
 
56
  def can_expand(source_width, source_height, target_width, target_height, alignment):
57
  """Checks if the image can be expanded based on the alignment."""
 
45
 
46
  # ํ…์ŠคํŠธ ์ธ์ฝ”๋”๋ฅผ float16์œผ๋กœ ๊ฐ•์ œ ๋ณ€ํ™˜
47
  pipe.text_encoder = pipe.text_encoder.to("cuda", dtype=torch.float16)
48
+ # ๋งŒ์•ฝ text_projection ์†์„ฑ์ด ์žˆ๋‹ค๋ฉด, ์ž…๋ ฅ์ด float16์ด ์•„๋‹ˆ๋ฉด half๋กœ ์บ์ŠคํŒ…ํ•˜๋„๋ก ์˜ค๋ฒ„๋ผ์ด๋”ฉ
49
+ if hasattr(pipe.text_encoder, "text_projection"):
50
+ original_text_projection_forward = pipe.text_encoder.text_projection.forward
51
+ def fixed_text_projection_forward(x):
52
+ if x.dtype != torch.float16:
53
+ x = x.half()
54
+ return original_text_projection_forward(x)
55
+ pipe.text_encoder.text_projection.forward = fixed_text_projection_forward
56
 
57
  def can_expand(source_width, source_height, target_width, target_height, alignment):
58
  """Checks if the image can be expanded based on the alignment."""