seawolf2357 commited on
Commit
432e503
·
verified ·
1 Parent(s): 831716a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -24
app.py CHANGED
@@ -12,6 +12,7 @@ import random
12
  import logging
13
  import torchaudio
14
  import os
 
15
 
16
  # MMAudio imports
17
  try:
@@ -20,6 +21,10 @@ except ImportError:
20
  os.system("pip install -e .")
21
  import mmaudio
22
 
 
 
 
 
23
  from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
24
  setup_eval_logging)
25
  from mmaudio.model.flow_matching import FlowMatching
@@ -27,6 +32,18 @@ from mmaudio.model.networks import MMAudio, get_my_mmaudio
27
  from mmaudio.model.sequence_config import SequenceConfig
28
  from mmaudio.model.utils.features_utils import FeaturesUtils
29
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  # Video generation model setup
31
  MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
32
  LORA_REPO_ID = "Kijai/WanVideo_comfy"
@@ -53,26 +70,39 @@ log = logging.getLogger()
53
  device = 'cuda'
54
  dtype = torch.bfloat16
55
 
56
- audio_model: ModelConfig = all_model_cfg['large_44k_v2']
57
- audio_model.download_if_needed()
58
- setup_eval_logging()
59
-
60
- def get_audio_model() -> tuple[MMAudio, FeaturesUtils, SequenceConfig]:
61
- seq_cfg = audio_model.seq_cfg
62
- net: MMAudio = get_my_mmaudio(audio_model.model_name).to(device, dtype).eval()
63
- net.load_weights(torch.load(audio_model.model_path, map_location=device, weights_only=True))
64
- log.info(f'Loaded weights from {audio_model.model_path}')
65
 
66
- feature_utils = FeaturesUtils(tod_vae_ckpt=audio_model.vae_path,
67
- synchformer_ckpt=audio_model.synchformer_ckpt,
68
- enable_conditions=True,
69
- mode=audio_model.mode,
70
- bigvgan_vocoder_ckpt=audio_model.bigvgan_16k_path,
71
- need_vae_encoder=False)
72
- feature_utils = feature_utils.to(device, dtype).eval()
73
- return net, feature_utils, seq_cfg
74
-
75
- audio_net, audio_feature_utils, audio_seq_cfg = get_audio_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  # Constants
78
  MOD_VALUE = 32
@@ -292,6 +322,13 @@ def handle_image_upload_for_dims_wan(uploaded_pil_image, current_h_val, current_
292
  gr.Warning("Error attempting to calculate new dimensions")
293
  return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
294
 
 
 
 
 
 
 
 
295
  def get_duration(input_image, prompt, height, width,
296
  negative_prompt, duration_seconds,
297
  guidance_scale, steps,
@@ -315,6 +352,9 @@ def get_duration(input_image, prompt, height, width,
315
  def add_audio_to_video(video_path, duration_sec, audio_prompt, audio_negative_prompt,
316
  audio_seed, audio_steps, audio_cfg_strength):
317
  """Add audio to video using MMAudio"""
 
 
 
318
  rng = torch.Generator(device=device)
319
  if audio_seed >= 0:
320
  rng.manual_seed(audio_seed)
@@ -327,14 +367,14 @@ def add_audio_to_video(video_path, duration_sec, audio_prompt, audio_negative_pr
327
  clip_frames = video_info.clip_frames.unsqueeze(0)
328
  sync_frames = video_info.sync_frames.unsqueeze(0)
329
  duration = video_info.duration_sec
330
- audio_seq_cfg.duration = duration
331
- audio_net.update_seq_lengths(audio_seq_cfg.latent_seq_len, audio_seq_cfg.clip_seq_len, audio_seq_cfg.sync_seq_len)
332
 
333
  audios = generate(clip_frames,
334
  sync_frames, [audio_prompt],
335
  negative_text=[audio_negative_prompt],
336
- feature_utils=audio_feature_utils,
337
- net=audio_net,
338
  fm=fm,
339
  rng=rng,
340
  cfg_strength=audio_cfg_strength)
@@ -342,7 +382,7 @@ def add_audio_to_video(video_path, duration_sec, audio_prompt, audio_negative_pr
342
 
343
  # Save video with audio
344
  video_with_audio_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
345
- make_video(video_info, video_with_audio_path, audio, sampling_rate=audio_seq_cfg.sampling_rate)
346
 
347
  return video_with_audio_path
348
 
@@ -391,6 +431,10 @@ def generate_video(input_image, prompt, height, width,
391
  audio_seed, audio_steps, audio_cfg_strength
392
  )
393
 
 
 
 
 
394
  return video_path, video_with_audio_path, current_seed
395
 
396
  def update_audio_visibility(audio_mode):
 
12
  import logging
13
  import torchaudio
14
  import os
15
+ import gc
16
 
17
  # MMAudio imports
18
  try:
 
21
  os.system("pip install -e .")
22
  import mmaudio
23
 
24
+ # Set environment variables for better memory management
25
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
26
+ os.environ['HF_HUB_CACHE'] = '/tmp/hub' # Use temp directory to avoid filling persistent storage
27
+
28
  from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
29
  setup_eval_logging)
30
  from mmaudio.model.flow_matching import FlowMatching
 
32
  from mmaudio.model.sequence_config import SequenceConfig
33
  from mmaudio.model.utils.features_utils import FeaturesUtils
34
 
35
+ # Clean up temp files periodically
36
+ def cleanup_temp_files():
37
+ """Clean up temporary files to save storage"""
38
+ temp_dir = tempfile.gettempdir()
39
+ for filename in os.listdir(temp_dir):
40
+ filepath = os.path.join(temp_dir, filename)
41
+ try:
42
+ if filename.endswith(('.mp4', '.flac', '.wav')):
43
+ os.remove(filepath)
44
+ except:
45
+ pass
46
+
47
  # Video generation model setup
48
  MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
49
  LORA_REPO_ID = "Kijai/WanVideo_comfy"
 
70
  device = 'cuda'
71
  dtype = torch.bfloat16
72
 
73
+ # Global variables for audio model (loaded on demand)
74
+ audio_model = None
75
+ audio_net = None
76
+ audio_feature_utils = None
77
+ audio_seq_cfg = None
 
 
 
 
78
 
79
+ def load_audio_model():
80
+ """Load audio model on demand to save storage"""
81
+ global audio_model, audio_net, audio_feature_utils, audio_seq_cfg
82
+
83
+ if audio_net is None:
84
+ audio_model = all_model_cfg['small_16k'] # Use smaller model
85
+ audio_model.download_if_needed()
86
+ setup_eval_logging()
87
+
88
+ seq_cfg = audio_model.seq_cfg
89
+ net = get_my_mmaudio(audio_model.model_name).to(device, dtype).eval()
90
+ net.load_weights(torch.load(audio_model.model_path, map_location=device, weights_only=True))
91
+ log.info(f'Loaded weights from {audio_model.model_path}')
92
+
93
+ feature_utils = FeaturesUtils(tod_vae_ckpt=audio_model.vae_path,
94
+ synchformer_ckpt=audio_model.synchformer_ckpt,
95
+ enable_conditions=True,
96
+ mode=audio_model.mode,
97
+ bigvgan_vocoder_ckpt=audio_model.bigvgan_16k_path,
98
+ need_vae_encoder=False)
99
+ feature_utils = feature_utils.to(device, dtype).eval()
100
+
101
+ audio_net = net
102
+ audio_feature_utils = feature_utils
103
+ audio_seq_cfg = seq_cfg
104
+
105
+ return audio_net, audio_feature_utils, audio_seq_cfg
106
 
107
  # Constants
108
  MOD_VALUE = 32
 
322
  gr.Warning("Error attempting to calculate new dimensions")
323
  return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
324
 
325
+ def clear_cache():
326
+ """Clear GPU and CPU cache to free memory"""
327
+ if torch.cuda.is_available():
328
+ torch.cuda.empty_cache()
329
+ torch.cuda.synchronize()
330
+ gc.collect()
331
+
332
  def get_duration(input_image, prompt, height, width,
333
  negative_prompt, duration_seconds,
334
  guidance_scale, steps,
 
352
  def add_audio_to_video(video_path, duration_sec, audio_prompt, audio_negative_prompt,
353
  audio_seed, audio_steps, audio_cfg_strength):
354
  """Add audio to video using MMAudio"""
355
+ # Load audio model on demand
356
+ net, feature_utils, seq_cfg = load_audio_model()
357
+
358
  rng = torch.Generator(device=device)
359
  if audio_seed >= 0:
360
  rng.manual_seed(audio_seed)
 
367
  clip_frames = video_info.clip_frames.unsqueeze(0)
368
  sync_frames = video_info.sync_frames.unsqueeze(0)
369
  duration = video_info.duration_sec
370
+ seq_cfg.duration = duration
371
+ net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
372
 
373
  audios = generate(clip_frames,
374
  sync_frames, [audio_prompt],
375
  negative_text=[audio_negative_prompt],
376
+ feature_utils=feature_utils,
377
+ net=net,
378
  fm=fm,
379
  rng=rng,
380
  cfg_strength=audio_cfg_strength)
 
382
 
383
  # Save video with audio
384
  video_with_audio_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
385
+ make_video(video_info, video_with_audio_path, audio, sampling_rate=seq_cfg.sampling_rate)
386
 
387
  return video_with_audio_path
388
 
 
431
  audio_seed, audio_steps, audio_cfg_strength
432
  )
433
 
434
+ # Clear cache to free memory
435
+ clear_cache()
436
+ cleanup_temp_files() # Clean up temp files
437
+
438
  return video_path, video_with_audio_path, current_seed
439
 
440
  def update_audio_visibility(audio_mode):