Kidbea commited on
Commit
e0a1d8c
·
1 Parent(s): 024adaf
Files changed (1) hide show
  1. app.py +13 -9
app.py CHANGED
@@ -3,36 +3,40 @@ import gradio as gr
3
  import torch
4
  from diffusers import DiffusionPipeline
5
 
6
- # Read token from environment (configured as a Space secret)
7
  token = os.environ.get("HUGGINGFACE_TOKEN")
8
- if token is None:
9
  raise ValueError("Environment variable HUGGINGFACE_TOKEN is not set.")
10
 
11
- model_id = "Wan-AI/Wan2.1-I2V-14B-480P"
 
 
 
12
 
13
- # Load pipeline directly from the Hub, using the token
 
14
  pipe = DiffusionPipeline.from_pretrained(
15
  model_id,
16
- torch_dtype=torch.float16,
17
  trust_remote_code=True,
18
  use_auth_token=token
19
- ).to("cuda")
20
 
21
  # Enable memory-saving features
22
  pipe.enable_attention_slicing()
23
 
24
  # Generation function
25
  def generate_video(image, prompt, num_frames=16, steps=50, guidance_scale=7.5):
26
- result = pipe(
27
  prompt=prompt,
28
  init_image=image,
29
  num_inference_steps=steps,
30
  guidance_scale=guidance_scale,
31
  num_frames=num_frames
32
  )
33
- return result.videos
34
 
35
- # Gradio UI definition
36
  def main():
37
  with gr.Blocks() as demo:
38
  gr.Markdown("# Wan2.1 Image-to-Video Demo")
 
3
  import torch
4
  from diffusers import DiffusionPipeline
5
 
6
+ # Read token and optional model override from environment
7
  token = os.environ.get("HUGGINGFACE_TOKEN")
8
+ if not token:
9
  raise ValueError("Environment variable HUGGINGFACE_TOKEN is not set.")
10
 
11
+ # Use the Diffusers-ready model repository by default
12
+ model_id = os.environ.get(
13
+ "WAN_MODEL_ID", "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
14
+ )
15
 
16
+ # Load the pipeline with remote code support
17
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
18
  pipe = DiffusionPipeline.from_pretrained(
19
  model_id,
20
+ torch_dtype=torch_dtype,
21
  trust_remote_code=True,
22
  use_auth_token=token
23
+ ).to("cuda" if torch.cuda.is_available() else "cpu")
24
 
25
  # Enable memory-saving features
26
  pipe.enable_attention_slicing()
27
 
28
  # Generation function
29
  def generate_video(image, prompt, num_frames=16, steps=50, guidance_scale=7.5):
30
+ output = pipe(
31
  prompt=prompt,
32
  init_image=image,
33
  num_inference_steps=steps,
34
  guidance_scale=guidance_scale,
35
  num_frames=num_frames
36
  )
37
+ return output.videos
38
 
39
+ # Gradio UI
40
  def main():
41
  with gr.Blocks() as demo:
42
  gr.Markdown("# Wan2.1 Image-to-Video Demo")