Kidbea commited on
Commit
3c6bbec
·
1 Parent(s): 3d550a9
Files changed (1) hide show
  1. app.py +21 -19
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import os
2
  import gradio as gr
3
  import torch
 
 
4
  from diffusers import DiffusionPipeline
5
 
6
  # Read token and optional model override from environment
@@ -9,24 +11,23 @@ if not token:
9
  raise ValueError("Environment variable HUGGINGFACE_TOKEN is not set.")
10
 
11
  # Use the Diffusers-ready model repository by default
12
- model_id = os.environ.get(
13
- "WAN_MODEL_ID", "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
14
- )
15
-
16
- # Load the pipeline with remote code support
17
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
18
- pipe = DiffusionPipeline.from_pretrained(
19
- model_id,
20
- torch_dtype=torch_dtype,
21
- trust_remote_code=True,
22
- use_auth_token=token
23
- ).to("cuda")
24
-
25
- # Enable memory-saving features
26
- pipe.enable_attention_slicing()
27
-
28
- # Generation function
29
  def generate_video(image, prompt, num_frames=16, steps=25, guidance_scale=7.5):
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  output = pipe(
31
  prompt=prompt,
32
  image=image,
@@ -34,12 +35,13 @@ def generate_video(image, prompt, num_frames=16, steps=25, guidance_scale=7.5):
34
  guidance_scale=guidance_scale,
35
  num_frames=num_frames
36
  )
 
37
  return output.videos
38
 
39
  # Gradio UI
40
  def main():
41
  with gr.Blocks() as demo:
42
- gr.Markdown("# Wan2.1 Image-to-Video Demo")
43
  with gr.Row():
44
  img_in = gr.Image(type="pil", label="Input Image")
45
  txt_p = gr.Textbox(label="Prompt")
@@ -49,4 +51,4 @@ def main():
49
  return demo
50
 
51
  if __name__ == "__main__":
52
- main().launch()
 
1
  import os
2
  import gradio as gr
3
  import torch
4
+ import ftfy
5
+ import spaces
6
  from diffusers import DiffusionPipeline
7
 
8
  # Read token and optional model override from environment
 
11
  raise ValueError("Environment variable HUGGINGFACE_TOKEN is not set.")
12
 
13
  # Use the Diffusers-ready model repository by default
14
+ model_id = os.environ.get("WAN_MODEL_ID", "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers")
15
+
16
+ @spaces.GPU # GPU is only activated when this function is called
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  def generate_video(image, prompt, num_frames=16, steps=25, guidance_scale=7.5):
18
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
19
+
20
+ # Load pipeline inside the GPU-allocated function
21
+ pipe = DiffusionPipeline.from_pretrained(
22
+ model_id,
23
+ torch_dtype=torch_dtype,
24
+ trust_remote_code=True,
25
+ use_auth_token=token
26
+ ).to("cuda")
27
+
28
+ pipe.enable_attention_slicing()
29
+
30
+ # Generate video
31
  output = pipe(
32
  prompt=prompt,
33
  image=image,
 
35
  guidance_scale=guidance_scale,
36
  num_frames=num_frames
37
  )
38
+
39
  return output.videos
40
 
41
  # Gradio UI
42
  def main():
43
  with gr.Blocks() as demo:
44
+ gr.Markdown("# Wan2.1 Image-to-Video Demo (ZeroGPU Edition)")
45
  with gr.Row():
46
  img_in = gr.Image(type="pil", label="Input Image")
47
  txt_p = gr.Textbox(label="Prompt")
 
51
  return demo
52
 
53
  if __name__ == "__main__":
54
+ main().launch()