Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -12,6 +12,7 @@ import random
|
|
12 |
import logging
|
13 |
import torchaudio
|
14 |
import os
|
|
|
15 |
|
16 |
# MMAudio imports
|
17 |
try:
|
@@ -20,6 +21,10 @@ except ImportError:
|
|
20 |
os.system("pip install -e .")
|
21 |
import mmaudio
|
22 |
|
|
|
|
|
|
|
|
|
23 |
from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
|
24 |
setup_eval_logging)
|
25 |
from mmaudio.model.flow_matching import FlowMatching
|
@@ -27,6 +32,18 @@ from mmaudio.model.networks import MMAudio, get_my_mmaudio
|
|
27 |
from mmaudio.model.sequence_config import SequenceConfig
|
28 |
from mmaudio.model.utils.features_utils import FeaturesUtils
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
# Video generation model setup
|
31 |
MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
|
32 |
LORA_REPO_ID = "Kijai/WanVideo_comfy"
|
@@ -53,26 +70,39 @@ log = logging.getLogger()
|
|
53 |
device = 'cuda'
|
54 |
dtype = torch.bfloat16
|
55 |
|
56 |
-
|
57 |
-
audio_model
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
seq_cfg = audio_model.seq_cfg
|
62 |
-
net: MMAudio = get_my_mmaudio(audio_model.model_name).to(device, dtype).eval()
|
63 |
-
net.load_weights(torch.load(audio_model.model_path, map_location=device, weights_only=True))
|
64 |
-
log.info(f'Loaded weights from {audio_model.model_path}')
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
# Constants
|
78 |
MOD_VALUE = 32
|
@@ -292,6 +322,13 @@ def handle_image_upload_for_dims_wan(uploaded_pil_image, current_h_val, current_
|
|
292 |
gr.Warning("Error attempting to calculate new dimensions")
|
293 |
return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
|
294 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
def get_duration(input_image, prompt, height, width,
|
296 |
negative_prompt, duration_seconds,
|
297 |
guidance_scale, steps,
|
@@ -315,6 +352,9 @@ def get_duration(input_image, prompt, height, width,
|
|
315 |
def add_audio_to_video(video_path, duration_sec, audio_prompt, audio_negative_prompt,
|
316 |
audio_seed, audio_steps, audio_cfg_strength):
|
317 |
"""Add audio to video using MMAudio"""
|
|
|
|
|
|
|
318 |
rng = torch.Generator(device=device)
|
319 |
if audio_seed >= 0:
|
320 |
rng.manual_seed(audio_seed)
|
@@ -327,14 +367,14 @@ def add_audio_to_video(video_path, duration_sec, audio_prompt, audio_negative_pr
|
|
327 |
clip_frames = video_info.clip_frames.unsqueeze(0)
|
328 |
sync_frames = video_info.sync_frames.unsqueeze(0)
|
329 |
duration = video_info.duration_sec
|
330 |
-
|
331 |
-
|
332 |
|
333 |
audios = generate(clip_frames,
|
334 |
sync_frames, [audio_prompt],
|
335 |
negative_text=[audio_negative_prompt],
|
336 |
-
feature_utils=
|
337 |
-
net=
|
338 |
fm=fm,
|
339 |
rng=rng,
|
340 |
cfg_strength=audio_cfg_strength)
|
@@ -342,7 +382,7 @@ def add_audio_to_video(video_path, duration_sec, audio_prompt, audio_negative_pr
|
|
342 |
|
343 |
# Save video with audio
|
344 |
video_with_audio_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
|
345 |
-
make_video(video_info, video_with_audio_path, audio, sampling_rate=
|
346 |
|
347 |
return video_with_audio_path
|
348 |
|
@@ -391,6 +431,10 @@ def generate_video(input_image, prompt, height, width,
|
|
391 |
audio_seed, audio_steps, audio_cfg_strength
|
392 |
)
|
393 |
|
|
|
|
|
|
|
|
|
394 |
return video_path, video_with_audio_path, current_seed
|
395 |
|
396 |
def update_audio_visibility(audio_mode):
|
|
|
12 |
import logging
|
13 |
import torchaudio
|
14 |
import os
|
15 |
+
import gc
|
16 |
|
17 |
# MMAudio imports
|
18 |
try:
|
|
|
21 |
os.system("pip install -e .")
|
22 |
import mmaudio
|
23 |
|
24 |
+
# Set environment variables for better memory management
|
25 |
+
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
|
26 |
+
os.environ['HF_HUB_CACHE'] = '/tmp/hub' # Use temp directory to avoid filling persistent storage
|
27 |
+
|
28 |
from mmaudio.eval_utils import (ModelConfig, all_model_cfg, generate, load_video, make_video,
|
29 |
setup_eval_logging)
|
30 |
from mmaudio.model.flow_matching import FlowMatching
|
|
|
32 |
from mmaudio.model.sequence_config import SequenceConfig
|
33 |
from mmaudio.model.utils.features_utils import FeaturesUtils
|
34 |
|
35 |
+
# Clean up temp files periodically
|
36 |
+
def cleanup_temp_files():
|
37 |
+
"""Clean up temporary files to save storage"""
|
38 |
+
temp_dir = tempfile.gettempdir()
|
39 |
+
for filename in os.listdir(temp_dir):
|
40 |
+
filepath = os.path.join(temp_dir, filename)
|
41 |
+
try:
|
42 |
+
if filename.endswith(('.mp4', '.flac', '.wav')):
|
43 |
+
os.remove(filepath)
|
44 |
+
except:
|
45 |
+
pass
|
46 |
+
|
47 |
# Video generation model setup
|
48 |
MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
|
49 |
LORA_REPO_ID = "Kijai/WanVideo_comfy"
|
|
|
70 |
device = 'cuda'
|
71 |
dtype = torch.bfloat16
|
72 |
|
73 |
+
# Global variables for audio model (loaded on demand)
|
74 |
+
audio_model = None
|
75 |
+
audio_net = None
|
76 |
+
audio_feature_utils = None
|
77 |
+
audio_seq_cfg = None
|
|
|
|
|
|
|
|
|
78 |
|
79 |
+
def load_audio_model():
|
80 |
+
"""Load audio model on demand to save storage"""
|
81 |
+
global audio_model, audio_net, audio_feature_utils, audio_seq_cfg
|
82 |
+
|
83 |
+
if audio_net is None:
|
84 |
+
audio_model = all_model_cfg['small_16k'] # Use smaller model
|
85 |
+
audio_model.download_if_needed()
|
86 |
+
setup_eval_logging()
|
87 |
+
|
88 |
+
seq_cfg = audio_model.seq_cfg
|
89 |
+
net = get_my_mmaudio(audio_model.model_name).to(device, dtype).eval()
|
90 |
+
net.load_weights(torch.load(audio_model.model_path, map_location=device, weights_only=True))
|
91 |
+
log.info(f'Loaded weights from {audio_model.model_path}')
|
92 |
+
|
93 |
+
feature_utils = FeaturesUtils(tod_vae_ckpt=audio_model.vae_path,
|
94 |
+
synchformer_ckpt=audio_model.synchformer_ckpt,
|
95 |
+
enable_conditions=True,
|
96 |
+
mode=audio_model.mode,
|
97 |
+
bigvgan_vocoder_ckpt=audio_model.bigvgan_16k_path,
|
98 |
+
need_vae_encoder=False)
|
99 |
+
feature_utils = feature_utils.to(device, dtype).eval()
|
100 |
+
|
101 |
+
audio_net = net
|
102 |
+
audio_feature_utils = feature_utils
|
103 |
+
audio_seq_cfg = seq_cfg
|
104 |
+
|
105 |
+
return audio_net, audio_feature_utils, audio_seq_cfg
|
106 |
|
107 |
# Constants
|
108 |
MOD_VALUE = 32
|
|
|
322 |
gr.Warning("Error attempting to calculate new dimensions")
|
323 |
return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
|
324 |
|
325 |
+
def clear_cache():
|
326 |
+
"""Clear GPU and CPU cache to free memory"""
|
327 |
+
if torch.cuda.is_available():
|
328 |
+
torch.cuda.empty_cache()
|
329 |
+
torch.cuda.synchronize()
|
330 |
+
gc.collect()
|
331 |
+
|
332 |
def get_duration(input_image, prompt, height, width,
|
333 |
negative_prompt, duration_seconds,
|
334 |
guidance_scale, steps,
|
|
|
352 |
def add_audio_to_video(video_path, duration_sec, audio_prompt, audio_negative_prompt,
|
353 |
audio_seed, audio_steps, audio_cfg_strength):
|
354 |
"""Add audio to video using MMAudio"""
|
355 |
+
# Load audio model on demand
|
356 |
+
net, feature_utils, seq_cfg = load_audio_model()
|
357 |
+
|
358 |
rng = torch.Generator(device=device)
|
359 |
if audio_seed >= 0:
|
360 |
rng.manual_seed(audio_seed)
|
|
|
367 |
clip_frames = video_info.clip_frames.unsqueeze(0)
|
368 |
sync_frames = video_info.sync_frames.unsqueeze(0)
|
369 |
duration = video_info.duration_sec
|
370 |
+
seq_cfg.duration = duration
|
371 |
+
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
|
372 |
|
373 |
audios = generate(clip_frames,
|
374 |
sync_frames, [audio_prompt],
|
375 |
negative_text=[audio_negative_prompt],
|
376 |
+
feature_utils=feature_utils,
|
377 |
+
net=net,
|
378 |
fm=fm,
|
379 |
rng=rng,
|
380 |
cfg_strength=audio_cfg_strength)
|
|
|
382 |
|
383 |
# Save video with audio
|
384 |
video_with_audio_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
|
385 |
+
make_video(video_info, video_with_audio_path, audio, sampling_rate=seq_cfg.sampling_rate)
|
386 |
|
387 |
return video_with_audio_path
|
388 |
|
|
|
431 |
audio_seed, audio_steps, audio_cfg_strength
|
432 |
)
|
433 |
|
434 |
+
# Clear cache to free memory
|
435 |
+
clear_cache()
|
436 |
+
cleanup_temp_files() # Clean up temp files
|
437 |
+
|
438 |
return video_path, video_with_audio_path, current_seed
|
439 |
|
440 |
def update_audio_visibility(audio_mode):
|