import os import shutil import uuid import cv2 import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import yaml from PIL import Image from skimage import img_as_ubyte, transform import safetensors import librosa from pydub import AudioSegment import imageio from scipy import signal from scipy.io import loadmat, savemat, wavfile import glob import tempfile from tqdm import tqdm import math import torchaudio import urllib.request REALESRGAN_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth" CODEFORMER_URL = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth" RESTOREFORMER_URL = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth" GFPGAN_URL = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" kp_url = "https://huggingface.co/usyd-community/vitpose-base-simple/resolve/main/model.safetensors" kp_file = "kp_detector.safetensors" aud_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/auido2pose_00140-model.pth" aud_file = "auido2pose_00140-model.pth" wav_url = "https://huggingface.co/facebook/wav2vec2-base/resolve/main/pytorch_model.bin" wav_file = "wav2vec2.pth" gen_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/wav2lip.pth" gen_file = "generator.pth" mapx_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/mapping_00229-model.pth.tar" mapx_file = "mapping.pth" den_url = "https://huggingface.co/KwaiVGI/LivePortrait/resolve/main/liveportrait/base_models/motion_extractor.pth" den_file = "dense_motion.pth" def download_model(url, filename, checkpoint_dir): if not os.path.exists(os.path.join(checkpoint_dir, filename)): print(f"Downloading {filename}...") os.makedirs(checkpoint_dir, exist_ok=True) urllib.request.urlretrieve(url, os.path.join(checkpoint_dir, filename)) print(f"{filename} downloaded.") else: print(f"{filename} already exists.") def mp3_to_wav_util(mp3_filename, wav_filename, frame_rate): AudioSegment.from_file(mp3_filename).set_frame_rate(frame_rate).export(wav_filename, format="wav") def load_wav_util(path, sr): return librosa.core.load(path, sr=sr)[0] def save_wav_util(wav, path, sr): wav *= 32767 / max(0.01, np.max(np.abs(wav))) wavfile.write(path, sr, wav.astype(np.int16)) class OcclusionAwareKPDetector(nn.Module): def __init__(self, kp_channels, num_kp, num_dilation_blocks, dropout_rate): super(OcclusionAwareKPDetector, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(64, num_kp, kernel_size=3, padding=1) def forward(self, x): x = self.relu(self.bn1(self.conv1(x))) x = self.conv2(x) kp = {'value': x.view(x.size(0), -1)} return kp class Wav2Vec2Model(nn.Module): def __init__(self): super(Wav2Vec2Model, self).__init__() self.conv = nn.Conv1d(1, 64, kernel_size=10, stride=5, padding=5) self.bn = nn.BatchNorm1d(64) self.relu = nn.ReLU() self.fc = nn.Linear(64, 2048) def forward(self, audio): x = audio.unsqueeze(1) x = self.relu(self.bn(self.conv(x))) x = torch.mean(x, dim=-1) x = self.fc(x) return x class AudioCoeffsPredictor(nn.Module): def __init__(self, input_dim, output_dim): super(AudioCoeffsPredictor, self).__init__() self.linear = nn.Linear(input_dim, output_dim) def forward(self, audio_embedding): return self.linear(audio_embedding) class MappingNet(nn.Module): def __init__(self, num_coeffs, num_layers, hidden_dim): super(MappingNet, self).__init__() layers = [] input_dim = num_coeffs * 2 for _ in range(num_layers): layers.append(nn.Linear(input_dim, hidden_dim)) layers.append(nn.ReLU()) input_dim = hidden_dim layers.append(nn.Linear(hidden_dim, num_coeffs)) self.net = nn.Sequential(*layers) def forward(self, x): return self.net(x) class DenseMotionNetwork(nn.Module): def __init__(self, num_kp, num_channels, block_expansion, num_blocks, max_features): super(DenseMotionNetwork, self).__init__() self.conv1 = nn.Conv2d(num_channels, max_features, kernel_size=3, padding=1) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(max_features, num_channels, kernel_size=3, padding=1) def forward(self, kp_source, kp_driving, jacobian): x = self.relu(self.conv1(kp_source)) x = self.conv2(x) sparse_motion = {'dense_motion': x} return sparse_motion class Hourglass(nn.Module): def __init__(self, block_expansion, num_blocks, max_features, num_channels, kp_size, num_deform_blocks): super(Hourglass, self).__init__() self.encoder = nn.Sequential(nn.Conv2d(num_channels, max_features, kernel_size=7, stride=2, padding=3), nn.BatchNorm2d(max_features), nn.ReLU()) self.decoder = nn.Sequential( nn.ConvTranspose2d(max_features, num_channels, kernel_size=4, stride=2, padding=1), nn.Tanh()) def forward(self, source_image, kp_driving, **kwargs): x = self.encoder(source_image) x = self.decoder(x) B, C, H, W = x.size() video = [] for _ in range(10): frame = (x[0].cpu().detach().numpy().transpose(1, 2, 0) * 127.5 + 127.5).clip(0, 255).astype( np.uint8) video.append(frame) return video class Face3DHelper: def __init__(self, local_pca_path, device): self.local_pca_path = local_pca_path self.device = device def run(self, source_image): h, w, _ = source_image.shape x_min = w // 4 y_min = h // 4 x_max = x_min + w // 2 y_max = y_min + h // 2 return [x_min, y_min, x_max, y_max] class Face3DHelperOld(Face3DHelper): def __init__(self, local_pca_path, device): super(Face3DHelperOld, self).__init__(local_pca_path, device) class MouthDetector: def __init__(self): pass def detect(self, image): h, w = image.shape[:2] return (w // 2, h // 2) class KeypointNorm(nn.Module): def __init__(self, device): super(KeypointNorm, self).__init__() self.device = device def forward(self, kp_driving): return kp_driving def save_video_with_watermark(video_frames, audio_path, output_path): H, W, _ = video_frames[0].shape out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (W, H)) for frame in video_frames: out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) out.release() def paste_pic(video_path, source_image_crop, crop_info, audio_path, output_path): shutil.copy(video_path, output_path) class TTSTalker: def __init__(self): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.tts_model = None def load_model(self): self.tts_model = self def tokenizer(self, text): return [ord(c) for c in text] def __call__(self, input_tokens): return torch.zeros(1, 16000, device=self.device) def test(self, text, lang='en'): if self.tts_model is None: self.load_model() output_path = os.path.join('./results', str(uuid.uuid4()) + '.wav') os.makedirs('./results', exist_ok=True) tokens = self.tokenizer(text) input_tokens = torch.tensor([tokens], dtype=torch.long).to(self.device) with torch.no_grad(): audio_output = self(input_tokens) torchaudio.save(output_path, audio_output.cpu(), 16000) return output_path class SadTalker: def __init__(self, checkpoint_path='checkpoints', config_path='src/config', size=256, preprocess='crop', old_version=False): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.cfg = self.get_cfg_defaults() self.merge_from_file(os.path.join(config_path, 'sadtalker_config.yaml')) self.cfg['MODEL']['CHECKPOINTS_DIR'] = checkpoint_path self.cfg['MODEL']['CONFIG_DIR'] = config_path self.cfg['MODEL']['DEVICE'] = self.device self.cfg['INPUT_IMAGE'] = {} self.cfg['INPUT_IMAGE']['SOURCE_IMAGE'] = 'None' self.cfg['INPUT_IMAGE']['DRIVEN_AUDIO'] = 'None' self.cfg['INPUT_IMAGE']['PREPROCESS'] = preprocess self.cfg['INPUT_IMAGE']['SIZE'] = size self.cfg['INPUT_IMAGE']['OLD_VERSION'] = old_version download_model(kp_url, kp_file, checkpoint_path) download_model(aud_url, aud_file, checkpoint_path) download_model(wav_url, wav_file, checkpoint_path) download_model(gen_url, gen_file, checkpoint_path) download_model(mapx_url, mapx_file, checkpoint_path) download_model(den_url, den_file, checkpoint_path) download_model(GFPGAN_URL, 'GFPGANv1.4.pth', checkpoint_path) download_model(REALESRGAN_URL, 'RealESRGAN_x2plus.pth', checkpoint_path) self.sadtalker_model = SadTalkerModel(self.cfg, device_id=[0]) def get_cfg_defaults(self): return CN( MODEL=CN( CHECKPOINTS_DIR='', CONFIG_DIR='', DEVICE=self.device, SCALE=64, NUM_VOXEL_FRAMES=8, NUM_MOTION_FRAMES=10, MAX_FEATURES=256, DRIVEN_AUDIO_SAMPLE_RATE=16000, VIDEO_FPS=25, OUTPUT_VIDEO_FPS=None, OUTPUT_AUDIO_SAMPLE_RATE=None, USE_ENHANCER=False, ENHANCER_NAME='', BG_UPSAMPLER=None, IS_HALF=False ), INPUT_IMAGE=CN() ) def merge_from_file(self, filepath): if os.path.exists(filepath): with open(filepath, 'r') as f: cfg_from_file = yaml.safe_load(f) self.cfg.MODEL.update(CN(cfg_from_file['MODEL'])) self.cfg.INPUT_IMAGE.update(CN(cfg_from_file['INPUT_IMAGE'])) def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False, batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None, ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/', tts_text=None, tts_lang='en'): self.sadtalker_model.test(source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang) return self.sadtalker_model.save_result() class SadTalkerModel: def __init__(self, sadtalker_cfg, device_id=[0]): self.cfg = sadtalker_cfg self.device = sadtalker_cfg.MODEL.get('DEVICE', 'cpu') self.sadtalker = SadTalkerInnerModel(sadtalker_cfg, device_id) self.preprocesser = self.sadtalker.preprocesser self.kp_extractor = self.sadtalker.kp_extractor self.generator = self.sadtalker.generator self.mapping = self.sadtalker.mapping self.he_estimator = self.sadtalker.he_estimator self.audio_to_coeff = self.sadtalker.audio_to_coeff self.animate_from_coeff = self.sadtalker.animate_from_coeff self.face_enhancer = self.sadtalker.face_enhancer def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False, batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None, ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/', tts_text=None, tts_lang='en', jitter_amount=10, jitter_source_image=False): self.inner_test = SadTalkerInner(self, source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang, jitter_amount, jitter_source_image) return self.inner_test.test() def save_result(self): return self.inner_test.save_result() class SadTalkerInner: def __init__(self, sadtalker_model, source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang, jitter_amount, jitter_source_image): self.sadtalker_model = sadtalker_model self.source_image = source_image self.driven_audio = driven_audio self.preprocess = preprocess self.still_mode = still_mode self.use_enhancer = use_enhancer self.batch_size = batch_size self.size = size self.pose_style = pose_style self.exp_scale = exp_scale self.use_ref_video = use_ref_video self.ref_video = ref_video self.ref_info = ref_info self.use_idle_mode = use_idle_mode self.length_of_audio = length_of_audio self.use_blink = use_blink self.result_dir = result_dir self.tts_text = tts_text self.tts_lang = tts_lang self.jitter_amount = jitter_amount self.jitter_source_image = jitter_source_image self.device = self.sadtalker_model.device self.output_path = None def get_test_data(self): proc = self.sadtalker_model.preprocesser if self.tts_text is not None: temp_dir = tempfile.mkdtemp() audio_path = os.path.join(temp_dir, 'audio.wav') tts = TTSTalker() tts.test(self.tts_text, self.tts_lang) self.driven_audio = audio_path source_image_pil = Image.open(self.source_image).convert('RGB') if self.jitter_source_image: jitter_dx = np.random.randint(-self.jitter_amount, self.jitter_amount + 1) jitter_dy = np.random.randint(-self.jitter_amount, self.jitter_amount + 1) source_image_pil = Image.fromarray( np.roll(np.roll(np.array(source_image_pil), jitter_dx, axis=1), jitter_dy, axis=0)) source_image_tensor, crop_info, cropped_image = proc.crop(source_image_pil, self.preprocess, self.size) if self.still_mode or self.use_idle_mode: ref_pose_coeff = proc.generate_still_pose(self.pose_style) ref_expression_coeff = proc.generate_still_expression(self.exp_scale) elif self.use_idle_mode: ref_pose_coeff = proc.generate_idles_pose(self.length_of_audio, self.pose_style) ref_expression_coeff = proc.generate_idles_expression(self.length_of_audio) else: ref_pose_coeff = None ref_expression_coeff = None audio_tensor, audio_sample_rate = proc.process_audio(self.driven_audio, self.sadtalker_model.cfg.MODEL.DRIVEN_AUDIO_SAMPLE_RATE) batch = { 'source_image': source_image_tensor.unsqueeze(0).to(self.device), 'audio': audio_tensor.unsqueeze(0).to(self.device), 'ref_pose_coeff': ref_pose_coeff, 'ref_expression_coeff': ref_expression_coeff, 'source_image_crop': cropped_image, 'crop_info': crop_info, 'use_blink': self.use_blink, 'pose_style': self.pose_style, 'exp_scale': self.exp_scale, 'ref_video': self.ref_video, 'use_ref_video': self.use_ref_video, 'ref_info': self.ref_info, } return batch, audio_sample_rate def run_inference(self, batch): kp_extractor = self.sadtalker_model.kp_extractor generator = self.sadtalker_model.generator mapping = self.sadtalker_model.mapping he_estimator = self.sadtalker_model.he_estimator audio_to_coeff = self.sadtalker_model.audio_to_coeff animate_from_coeff = self.sadtalker_model.animate_from_coeff proc = self.sadtalker_model.preprocesser with torch.no_grad(): kp_source = kp_extractor(batch['source_image']) if self.still_mode or self.use_idle_mode: ref_pose_coeff = batch['ref_pose_coeff'] ref_expression_coeff = batch['ref_expression_coeff'] pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], ref_pose_coeff) expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], ref_expression_coeff) elif self.use_idle_mode: ref_pose_coeff = batch['ref_pose_coeff'] ref_expression_coeff = batch['ref_expression_coeff'] pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], ref_pose_coeff) expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], ref_expression_coeff) else: if self.use_ref_video: kp_ref = kp_extractor(batch['source_image']) pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], kp_ref=kp_ref, use_ref_info=batch['ref_info']) else: pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio']) expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio']) coeff = {'pose_coeff': pose_coeff, 'expression_coeff': expression_coeff} if self.use_blink: coeff['blink_coeff'] = audio_to_coeff.get_blink_coeff(batch['audio']) else: coeff['blink_coeff'] = None kp_driving = audio_to_coeff(batch['audio'])[0] kp_norm = animate_from_coeff.normalize_kp(kp_driving) coeff['kp_driving'] = kp_norm coeff['jacobian'] = [torch.eye(2).unsqueeze(0).unsqueeze(0).to(self.device)] * 4 face_enhancer = self.sadtalker_model.face_enhancer if self.use_enhancer else None output_video = animate_from_coeff.generate(batch['source_image'], kp_source, coeff, generator, mapping, he_estimator, batch['audio'], batch['source_image_crop'], face_enhancer=face_enhancer) return output_video def post_processing(self, output_video, audio_sample_rate, batch): proc = self.sadtalker_model.preprocesser base_name = os.path.splitext(os.path.basename(batch['source_image_crop']))[0] audio_name = os.path.splitext(os.path.basename(self.driven_audio))[0] output_video_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '.mp4') self.output_path = output_video_path video_fps = self.sadtalker_model.cfg.MODEL.VIDEO_FPS if self.sadtalker_model.cfg.MODEL.OUTPUT_VIDEO_FPS is None else \ self.sadtalker_model.cfg.MODEL.OUTPUT_VIDEO_FPS audio_output_sample_rate = self.sadtalker_model.cfg.MODEL.DRIVEN_AUDIO_SAMPLE_RATE if \ self.sadtalker_model.cfg.MODEL.OUTPUT_AUDIO_SAMPLE_RATE is None else \ self.sadtalker_model.cfg.MODEL.OUTPUT_AUDIO_SAMPLE_RATE if self.use_enhancer: enhanced_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '_enhanced.mp4') save_video_with_watermark(output_video, self.driven_audio, enhanced_path) paste_pic(enhanced_path, batch['source_image_crop'], batch['crop_info'], self.driven_audio, output_video_path) os.remove(enhanced_path) else: save_video_with_watermark(output_video, self.driven_audio, output_video_path) if self.tts_text is not None: shutil.rmtree(os.path.dirname(self.driven_audio)) def save_result(self): return self.output_path def __call__(self): return self.output_path def test(self): batch, audio_sample_rate = self.get_test_data() output_video = self.run_inference(batch) self.post_processing(output_video, audio_sample_rate, batch) return self.save_result() class SadTalkerInnerModel: def __init__(self, sadtalker_cfg, device_id=[0]): self.cfg = sadtalker_cfg self.device = sadtalker_cfg.MODEL.DEVICE self.preprocesser = Preprocesser(sadtalker_cfg, self.device) self.kp_extractor = KeyPointExtractor(sadtalker_cfg, self.device) self.audio_to_coeff = Audio2Coeff(sadtalker_cfg, self.device) self.animate_from_coeff = AnimateFromCoeff(sadtalker_cfg, self.device) self.face_enhancer = FaceEnhancer(sadtalker_cfg, self.device) if sadtalker_cfg.MODEL.USE_ENHANCER else None self.generator = Generator(sadtalker_cfg, self.device) self.mapping = Mapping(sadtalker_cfg, self.device) self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, self.device) class Preprocesser: def __init__(self, sadtalker_cfg, device): self.cfg = sadtalker_cfg self.device = device if self.cfg.INPUT_IMAGE.get('OLD_VERSION', False): self.face3d_helper = Face3DHelperOld(self.cfg.INPUT_IMAGE.get('LOCAL_PCA_PATH', ''), device) else: self.face3d_helper = Face3DHelper(self.cfg.INPUT_IMAGE.get('LOCAL_PCA_PATH', ''), device) self.mouth_detector = MouthDetector() def crop(self, source_image_pil, preprocess_type, size=256): source_image = np.array(source_image_pil) face_info = self.face3d_helper.run(source_image) if face_info is None: raise Exception("No face detected") x_min, y_min, x_max, y_max = face_info[:4] old_size = (x_max - x_min, y_max - y_min) x_center = (x_max + x_min) / 2 y_center = (y_max + y_min) / 2 if preprocess_type == 'crop': face_size = max(x_max - x_min, y_max - y_min) x_min = int(x_center - face_size / 2) y_min = int(y_center - face_size / 2) x_max = int(x_center + face_size / 2) y_max = int(y_center + face_size / 2) else: x_min -= int((x_max - x_min) * 0.1) y_min -= int((y_max - y_min) * 0.1) x_max += int((x_max - x_min) * 0.1) y_max += int((y_max - y_min) * 0.1) h, w = source_image.shape[:2] x_min = max(0, x_min) y_min = max(0, y_min) x_max = min(w, x_max) y_max = min(h, y_max) cropped_image = source_image[y_min:y_max, x_min:x_max] cropped_image_pil = Image.fromarray(cropped_image) if size is not None and size != 0: cropped_image_pil = cropped_image_pil.resize((size, size), Image.Resampling.LANCZOS) source_image_tensor = self.img2tensor(cropped_image_pil) return source_image_tensor, [[y_min, y_max], [x_min, x_max], old_size, cropped_image_pil.size], os.path.basename( self.cfg.INPUT_IMAGE.get('SOURCE_IMAGE', '')) def img2tensor(self, img): img = np.array(img).astype(np.float32) / 255.0 img = np.transpose(img, (2, 0, 1)) return torch.FloatTensor(img) def video_to_tensor(self, video, device): video_tensor_list = [] import torchvision.transforms as transforms transform_func = transforms.ToTensor() for frame in video: frame_pil = Image.fromarray(frame) frame_tensor = transform_func(frame_pil).unsqueeze(0).to(device) video_tensor_list.append(frame_tensor) video_tensor = torch.cat(video_tensor_list, dim=0) return video_tensor def process_audio(self, audio_path, sample_rate): wav = load_wav_util(audio_path, sample_rate) wav_tensor = torch.FloatTensor(wav).unsqueeze(0) return wav_tensor, sample_rate def generate_still_pose(self, pose_style): ref_pose_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device) ref_pose_coeff[:, :3] = torch.tensor([0, 0, pose_style * 0.3], dtype=torch.float32) return ref_pose_coeff def generate_still_expression(self, exp_scale): ref_expression_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device) ref_expression_coeff[:, :3] = torch.tensor([0, 0, exp_scale * 0.3], dtype=torch.float32) return ref_expression_coeff def generate_idles_pose(self, length_of_audio, pose_style): num_frames = int(length_of_audio * self.cfg.MODEL.VIDEO_FPS) ref_pose_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device) start_pose = self.generate_still_pose(pose_style) end_pose = self.generate_still_pose(pose_style) for frame_idx in range(num_frames): alpha = frame_idx / num_frames ref_pose_coeff[frame_idx] = (1 - alpha) * start_pose + alpha * end_pose return ref_pose_coeff def generate_idles_expression(self, length_of_audio): num_frames = int(length_of_audio * self.cfg.MODEL.VIDEO_FPS) ref_expression_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device) start_exp = self.generate_still_expression(1.0) end_exp = self.generate_still_expression(1.0) for frame_idx in range(num_frames): alpha = frame_idx / num_frames ref_expression_coeff[frame_idx] = (1 - alpha) * start_exp + alpha * end_exp return ref_expression_coeff class KeyPointExtractor(nn.Module): def __init__(self, sadtalker_cfg, device): super(KeyPointExtractor, self).__init__() self.kp_extractor = OcclusionAwareKPDetector(kp_channels=sadtalker_cfg.MODEL.NUM_MOTION_FRAMES, num_kp=10, num_dilation_blocks=2, dropout_rate=0.1).to(device) checkpoint_path = os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR, 'kp_detector.safetensors') self.load_kp_detector(checkpoint_path, device) def load_kp_detector(self, checkpoint_path, device): if os.path.exists(checkpoint_path): if checkpoint_path.endswith('safetensors'): checkpoint = safetensors.torch.load_file(checkpoint_path, device=device) else: checkpoint = torch.load(checkpoint_path, map_location=device) self.kp_extractor.load_state_dict(checkpoint.get('kp_detector', {})) else: raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}") def forward(self, x): kp = self.kp_extractor(x) return kp class Audio2Coeff(nn.Module): def __init__(self, sadtalker_cfg, device): super(Audio2Coeff, self).__init__() self.audio_model = Wav2Vec2Model().to(device) checkpoint_path = os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR, 'wav2vec2.pth') self.load_audio_model(checkpoint_path, device) self.pose_mapper = AudioCoeffsPredictor(2048, 64).to(device) self.exp_mapper = AudioCoeffsPredictor(2048, 64).to(device) self.blink_mapper = AudioCoeffsPredictor(2048, 1).to(device) mapping_checkpoint = os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR, 'audio2pose_00140-model.pth') self.load_mapping_model(mapping_checkpoint, device) def load_audio_model(self, checkpoint_path, device): if os.path.exists(checkpoint_path): if checkpoint_path.endswith('safetensors'): checkpoint = safetensors.torch.load_file(checkpoint_path, device=device) else: checkpoint = torch.load(checkpoint_path, map_location=device) self.audio_model.load_state_dict(checkpoint.get("wav2vec2", {})) else: raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}") def load_mapping_model(self, checkpoint_path, device): if os.path.exists(checkpoint_path): if checkpoint_path.endswith('safetensors'): checkpoint = safetensors.torch.load_file(checkpoint_path, device=device) else: checkpoint = torch.load(checkpoint_path, map_location=device) self.pose_mapper.load_state_dict(checkpoint.get("pose_predictor", {})) self.exp_mapper.load_state_dict(checkpoint.get("exp_predictor", {})) self.blink_mapper.load_state_dict(checkpoint.get("blink_predictor", {})) else: raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}") def get_pose_coeff(self, audio_tensor, ref_pose_coeff=None, kp_ref=None, use_ref_info=''): audio_embedding = self.audio_model(audio_tensor) pose_coeff = self.pose_mapper(audio_embedding) if ref_pose_coeff is not None: pose_coeff = ref_pose_coeff if kp_ref is not None and use_ref_info == 'pose': ref_pose_6d = kp_ref['value'][:, :6] pose_coeff[:, :6] = self.mean_std_normalize(ref_pose_6d).mean(dim=1) return pose_coeff def get_exp_coeff(self, audio_tensor, ref_expression_coeff=None): audio_embedding = self.audio_model(audio_tensor) expression_coeff = self.exp_mapper(audio_embedding) if ref_expression_coeff is not None: expression_coeff = ref_expression_coeff return expression_coeff def get_blink_coeff(self, audio_tensor): audio_embedding = self.audio_model(audio_tensor) blink_coeff = self.blink_mapper(audio_embedding) return blink_coeff def forward(self, audio): audio_embedding = self.audio_model(audio) pose_coeff, expression_coeff, blink_coeff = self.pose_mapper(audio_embedding), self.exp_mapper( audio_embedding), self.blink_mapper(audio_embedding) return pose_coeff, expression_coeff, blink_coeff def mean_std_normalize(self, coeff): mean = coeff.mean(dim=1, keepdim=True) std = coeff.std(dim=1, keepdim=True) return (coeff - mean) / std class AnimateFromCoeff(nn.Module): def __init__(self, sadtalker_cfg, device): super(AnimateFromCoeff, self).__init__() self.generator = Generator(sadtalker_cfg, device) self.mapping = Mapping(sadtalker_cfg, device) self.kp_norm = KeypointNorm(device=device) self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, device) def normalize_kp(self, kp_driving): return self.kp_norm(kp_driving) def generate(self, source_image, kp_source, coeff, generator, mapping, he_estimator, audio, source_image_crop, face_enhancer=None): kp_driving = coeff['kp_driving'] jacobian = coeff['jacobian'] pose_coeff = coeff['pose_coeff'] expression_coeff = coeff['expression_coeff'] blink_coeff = coeff['blink_coeff'] with torch.no_grad(): if blink_coeff is not None: sparse_motion = he_estimator(kp_source, kp_driving, jacobian) dense_motion = sparse_motion['dense_motion'] video_deocclusion = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None}) face_3d = mapping(expression_coeff, pose_coeff, blink_coeff) video_3d = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None}, face_3d_param=face_3d) video_output = video_deocclusion['video_no_reocclusion'] + video_3d['video_3d'] video_output = self.make_animation(video_output) else: sparse_motion = he_estimator(kp_source, kp_driving, jacobian) dense_motion = sparse_motion['dense_motion'] face_3d = mapping(expression_coeff, pose_coeff) video_3d = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None}, face_3d_param=face_3d) video_output = video_3d['video_3d'] video_output = self.make_animation(video_output) if face_enhancer is not None: video_output_enhanced = [] for frame in tqdm(video_output, 'Face enhancer running'): pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) enhanced_image = face_enhancer.enhance(np.array(pil_image))[0] video_output_enhanced.append(cv2.cvtColor(enhanced_image, cv2.COLOR_BGR2RGB)) video_output = video_output_enhanced return video_output def make_animation(self, video_array): H, W, _ = video_array[0].shape out = cv2.VideoWriter('./tmp.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 25, (W, H)) for img in video_array: out.write(cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) out.release() video = imageio.mimread('./tmp.mp4') os.remove('./tmp.mp4') return video class Generator(nn.Module): def __init__(self, sadtalker_cfg, device): super(Generator, self).__init__() self.generator = Hourglass(block_expansion=sadtalker_cfg.MODEL.SCALE, num_blocks=sadtalker_cfg.MODEL.NUM_VOXEL_FRAMES, max_features=sadtalker_cfg.MODEL.MAX_FEATURES, num_channels=3, kp_size=10, num_deform_blocks=sadtalker_cfg.MODEL.NUM_MOTION_FRAMES).to(device) checkpoint_path = os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR, 'generator.pth') self.load_generator(checkpoint_path, device) def load_generator(self, checkpoint_path, device): if os.path.exists(checkpoint_path): if checkpoint_path.endswith('safetensors'): checkpoint = safetensors.torch.load_file(checkpoint_path, device=device) else: checkpoint = torch.load(checkpoint_path, map_location=device) self.generator.load_state_dict(checkpoint.get('generator', {})) else: raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}") def forward(self, source_image, dense_motion, bg_param, face_3d_param=None): if face_3d_param is not None: video_3d = self.generator(source_image, kp_driving=dense_motion, bg_param=bg_param, face_3d_param=face_3d_param) else: video_3d = self.generator(source_image, kp_driving=dense_motion, bg_param=bg_param) return {'video_3d': video_3d, 'video_no_reocclusion': video_3d} class Mapping(nn.Module): def __init__(self, sadtalker_cfg, device): super(Mapping, self).__init__() self.mapping_net = MappingNet(num_coeffs=64, num_layers=3, hidden_dim=128).to(device) checkpoint_path = os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR, 'mapping.pth') self.load_mapping_net(checkpoint_path, device) self.f_3d_mean = torch.zeros(1, 64, device=device) def load_mapping_net(self, checkpoint_path, device): if os.path.exists(checkpoint_path): if checkpoint_path.endswith('safetensors'): checkpoint = safetensors.torch.load_file(checkpoint_path, device=device) else: checkpoint = torch.load(checkpoint_path, map_location=device) self.mapping_net.load_state_dict(checkpoint.get('mapping', {})) else: raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}") def forward(self, expression_coeff, pose_coeff, blink_coeff=None): coeff = torch.cat([expression_coeff, pose_coeff], dim=1) face_3d = self.mapping_net(coeff) + self.f_3d_mean if blink_coeff is not None: face_3d[:, -1:] = blink_coeff return face_3d class OcclusionAwareDenseMotion(nn.Module): def __init__(self, sadtalker_cfg, device): super(OcclusionAwareDenseMotion, self).__init__() self.dense_motion_network = DenseMotionNetwork(num_kp=10, num_channels=3, block_expansion=sadtalker_cfg.MODEL.SCALE, num_blocks=sadtalker_cfg.MODEL.NUM_MOTION_FRAMES - 1, max_features=sadtalker_cfg.MODEL.MAX_FEATURES).to(device) checkpoint_path = os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR, 'dense_motion.pth') self.load_dense_motion_network(checkpoint_path, device) def load_dense_motion_network(self, checkpoint_path, device): if os.path.exists(checkpoint_path): if checkpoint_path.endswith('safetensors'): checkpoint = safetensors.torch.load_file(checkpoint_path, device=device) else: checkpoint = torch.load(checkpoint_path, map_location=device) self.dense_motion_network.load_state_dict(checkpoint.get('dense_motion', {})) else: raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}") def forward(self, kp_source, kp_driving, jacobian): sparse_motion = self.dense_motion_network(kp_source, kp_driving, jacobian) return sparse_motion class FaceEnhancer(nn.Module): def __init__(self, sadtalker_cfg, device): super(FaceEnhancer, self).__init__() enhancer_name = sadtalker_cfg.MODEL.ENHANCER_NAME bg_upsampler = sadtalker_cfg.MODEL.BG_UPSAMPLER if enhancer_name == 'gfpgan': from gfpgan import GFPGANer self.face_enhancer = GFPGANer(model_path=os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR, 'GFPGANv1.4.pth'), upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=bg_upsampler) elif enhancer_name == 'realesrgan': from realesrgan import RealESRGANer half = False if device == 'cpu' else sadtalker_cfg.MODEL.IS_HALF self.face_enhancer = RealESRGANer(scale=2, model_path=os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR, 'RealESRGAN_x2plus.pth'), tile=0, tile_pad=10, pre_pad=0, half=half, device=device) else: self.face_enhancer = None def forward(self, x): return self.face_enhancer.enhance(x, outscale=1)[0]