from huggingface_hub import snapshot_download # Download All Required Models using `snapshot_download` # Download Wan2.1-I2V-14B-480P model wan_model_path = snapshot_download( repo_id="Wan-AI/Wan2.1-I2V-14B-480P", local_dir="./weights/Wan2.1-I2V-14B-480P", #local_dir_use_symlinks=False ) # Download Chinese wav2vec2 model wav2vec_path = snapshot_download( repo_id="TencentGameMate/chinese-wav2vec2-base", local_dir="./weights/chinese-wav2vec2-base", #local_dir_use_symlinks=False ) # Download MeiGen MultiTalk weights multitalk_path = snapshot_download( repo_id="MeiGen-AI/MeiGen-MultiTalk", local_dir="./weights/MeiGen-MultiTalk", #local_dir_use_symlinks=False ) import os import shutil # Define paths base_model_dir = "./weights/Wan2.1-I2V-14B-480P" multitalk_dir = "./weights/MeiGen-MultiTalk" # File to rename original_index = os.path.join(base_model_dir, "diffusion_pytorch_model.safetensors.index.json") backup_index = os.path.join(base_model_dir, "diffusion_pytorch_model.safetensors.index.json_old") # Rename the original index file if os.path.exists(original_index): os.rename(original_index, backup_index) print("Renamed original index file to .json_old") # Copy updated index file from MultiTalk shutil.copy2( os.path.join(multitalk_dir, "diffusion_pytorch_model.safetensors.index.json"), base_model_dir ) # Copy MultiTalk model weights shutil.copy2( os.path.join(multitalk_dir, "multitalk.safetensors"), base_model_dir ) print("Copied MultiTalk files into base model directory.") import torch # Check if CUDA-compatible GPU is available if torch.cuda.is_available(): # Get current GPU name gpu_name = torch.cuda.get_device_name(torch.cuda.current_device()) print(f"Current GPU: {gpu_name}") # Enforce GPU requirement if "A100" not in gpu_name and "L4" not in gpu_name: raise RuntimeError(f"This notebook requires an A100 or L4 GPU. Found: {gpu_name}") elif "L4" in gpu_name: print("Warning: L4 is supported, but A100 is recommended for faster inference.") else: raise RuntimeError("No CUDA-compatible GPU found. An A100 or L4 GPU is required.") GPU_TO_VRAM_PARAMS = { "NVIDIA A100": 11000000000, "NVIDIA A100-SXM4-40GB": 11000000000, "NVIDIA A100-SXM4-80GB": 22000000000, "NVIDIA L4": 5000000000 } USED_VRAM_PARAMS = GPU_TO_VRAM_PARAMS[gpu_name] print("Using", USED_VRAM_PARAMS, "for num_persistent_param_in_dit") import subprocess import json import tempfile #import os def create_temp_input_json(prompt: str, cond_image_path: str, cond_audio_path: str) -> str: """ Create a temporary JSON file with the user-provided prompt, image, and audio paths. Returns the path to the temporary JSON file. """ # Structure based on your original JSON format data = { "prompt": prompt, "cond_image": cond_image_path, "cond_audio": { "person1": cond_audio_path } } # Create a temp file temp_json = tempfile.NamedTemporaryFile(delete=False, suffix=".json", mode='w', encoding='utf-8') json.dump(data, temp_json, indent=4) temp_json_path = temp_json.name temp_json.close() print(f"Temporary input JSON saved to: {temp_json_path}") return temp_json_path def infer(prompt, cond_image_path, cond_audio_path): # Example usage (from user input) prompt = "A woman sings passionately in a dimly lit studio." cond_image_path = "examples/single/single1.png" # Assume uploaded via Gradio cond_audio_path = "examples/single/1.wav" # Assume uploaded via Gradio input_json_path = create_temp_input_json(prompt, cond_image_path, cond_audio_path) cmd = [ "python3", "generate_multitalk.py", "--ckpt_dir", "weights/Wan2.1-I2V-14B-480P", "--wav2vec_dir", "weights/chinese-wav2vec2-base", "--input_json", "./examples/single_example_1.json", "--sample_steps", "20", "--num_persistent_param_in_dit", str(USED_VRAM_PARAMS), "--mode", "streaming", "--use_teacache", "--save_file", "multi_long_mediumvram_exp" ] subprocess.run(cmd, check=True) return "multi_long_mediumvra_exp.mp4" import gradio as gr with gr.Blocks(title="MultiTalk Inference") as demo: gr.Markdown("## 🎤 MultiTalk Inference Demo") with gr.Row(): with gr.Column(): prompt_input = gr.Textbox( label="Text Prompt", placeholder="Describe the scene...", lines=4 ) image_input = gr.Image( type="filepath", label="Conditioning Image" ) audio_input = gr.Audio( type="filepath", label="Conditioning Audio (.wav)" ) submit_btn = gr.Button("Generate") with gr.Column(): output_video = gr.Video(label="Generated Video") submit_btn.click( fn=infer, inputs=[prompt_input, image_input, audio_input], outputs=output_video ) demo.launch()