Spaces:
Running
Running
import os | |
import shutil | |
import uuid | |
import cv2 | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import yaml | |
from PIL import Image | |
from skimage import img_as_ubyte, transform | |
import safetensors | |
import librosa | |
from pydub import AudioSegment | |
import imageio | |
from scipy import signal | |
from scipy.io import loadmat, savemat, wavfile | |
import glob | |
import tempfile | |
from tqdm import tqdm | |
import math | |
import torchaudio | |
import urllib.request | |
REALESRGAN_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth" | |
CODEFORMER_URL = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth" | |
RESTOREFORMER_URL = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth" | |
GFPGAN_URL = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" | |
kp_url = "https://huggingface.co/usyd-community/vitpose-base-simple/resolve/main/model.safetensors" | |
kp_file = "kp_detector.safetensors" | |
aud_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/auido2pose_00140-model.pth" | |
aud_file = "auido2pose_00140-model.pth" | |
wav_url = "https://huggingface.co/facebook/wav2vec2-base/resolve/main/pytorch_model.bin" | |
wav_file = "wav2vec2.pth" | |
gen_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/wav2lip.pth" | |
gen_file = "generator.pth" | |
mapx_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/mapping_00229-model.pth.tar" | |
mapx_file = "mapping.pth" | |
den_url = "https://huggingface.co/KwaiVGI/LivePortrait/resolve/main/liveportrait/base_models/motion_extractor.pth" | |
den_file = "dense_motion.pth" | |
def download_model(url, filename, checkpoint_dir): | |
if not os.path.exists(os.path.join(checkpoint_dir, filename)): | |
print(f"Downloading {filename}...") | |
os.makedirs(checkpoint_dir, exist_ok=True) | |
urllib.request.urlretrieve(url, os.path.join(checkpoint_dir, filename)) | |
print(f"{filename} downloaded.") | |
else: | |
print(f"{filename} already exists.") | |
def mp3_to_wav_util(mp3_filename, wav_filename, frame_rate): | |
AudioSegment.from_file(mp3_filename).set_frame_rate(frame_rate).export(wav_filename, format="wav") | |
def load_wav_util(path, sr): | |
return librosa.core.load(path, sr=sr)[0] | |
def save_wav_util(wav, path, sr): | |
wav *= 32767 / max(0.01, np.max(np.abs(wav))) | |
wavfile.write(path, sr, wav.astype(np.int16)) | |
class OcclusionAwareKPDetector(nn.Module): | |
def __init__(self, kp_channels, num_kp, num_dilation_blocks, dropout_rate): | |
super(OcclusionAwareKPDetector, self).__init__() | |
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) | |
self.bn1 = nn.BatchNorm2d(64) | |
self.relu = nn.ReLU() | |
self.conv2 = nn.Conv2d(64, num_kp, kernel_size=3, padding=1) | |
def forward(self, x): | |
x = self.relu(self.bn1(self.conv1(x))) | |
x = self.conv2(x) | |
kp = {'value': x.view(x.size(0), -1)} | |
return kp | |
class Wav2Vec2Model(nn.Module): | |
def __init__(self): | |
super(Wav2Vec2Model, self).__init__() | |
self.conv = nn.Conv1d(1, 64, kernel_size=10, stride=5, padding=5) | |
self.bn = nn.BatchNorm1d(64) | |
self.relu = nn.ReLU() | |
self.fc = nn.Linear(64, 2048) | |
def forward(self, audio): | |
x = audio.unsqueeze(1) | |
x = self.relu(self.bn(self.conv(x))) | |
x = torch.mean(x, dim=-1) | |
x = self.fc(x) | |
return x | |
class AudioCoeffsPredictor(nn.Module): | |
def __init__(self, input_dim, output_dim): | |
super(AudioCoeffsPredictor, self).__init__() | |
self.linear = nn.Linear(input_dim, output_dim) | |
def forward(self, audio_embedding): | |
return self.linear(audio_embedding) | |
class MappingNet(nn.Module): | |
def __init__(self, num_coeffs, num_layers, hidden_dim): | |
super(MappingNet, self).__init__() | |
layers = [] | |
input_dim = num_coeffs * 2 | |
for _ in range(num_layers): | |
layers.append(nn.Linear(input_dim, hidden_dim)) | |
layers.append(nn.ReLU()) | |
input_dim = hidden_dim | |
layers.append(nn.Linear(hidden_dim, num_coeffs)) | |
self.net = nn.Sequential(*layers) | |
def forward(self, x): | |
return self.net(x) | |
class DenseMotionNetwork(nn.Module): | |
def __init__(self, num_kp, num_channels, block_expansion, num_blocks, max_features): | |
super(DenseMotionNetwork, self).__init__() | |
self.conv1 = nn.Conv2d(num_channels, max_features, kernel_size=3, padding=1) | |
self.relu = nn.ReLU() | |
self.conv2 = nn.Conv2d(max_features, num_channels, kernel_size=3, padding=1) | |
def forward(self, kp_source, kp_driving, jacobian): | |
x = self.relu(self.conv1(kp_source)) | |
x = self.conv2(x) | |
sparse_motion = {'dense_motion': x} | |
return sparse_motion | |
class Hourglass(nn.Module): | |
def __init__(self, block_expansion, num_blocks, max_features, num_channels, kp_size, num_deform_blocks): | |
super(Hourglass, self).__init__() | |
self.encoder = nn.Sequential(nn.Conv2d(num_channels, max_features, kernel_size=7, stride=2, padding=3), | |
nn.BatchNorm2d(max_features), nn.ReLU()) | |
self.decoder = nn.Sequential( | |
nn.ConvTranspose2d(max_features, num_channels, kernel_size=4, stride=2, padding=1), nn.Tanh()) | |
def forward(self, source_image, kp_driving, **kwargs): | |
x = self.encoder(source_image) | |
x = self.decoder(x) | |
B, C, H, W = x.size() | |
video = [] | |
for _ in range(10): | |
frame = (x[0].cpu().detach().numpy().transpose(1, 2, 0) * 127.5 + 127.5).clip(0, 255).astype( | |
np.uint8) | |
video.append(frame) | |
return video | |
class Face3DHelper: | |
def __init__(self, local_pca_path, device): | |
self.local_pca_path = local_pca_path | |
self.device = device | |
def run(self, source_image): | |
h, w, _ = source_image.shape | |
x_min = w // 4 | |
y_min = h // 4 | |
x_max = x_min + w // 2 | |
y_max = y_min + h // 2 | |
return [x_min, y_min, x_max, y_max] | |
class Face3DHelperOld(Face3DHelper): | |
def __init__(self, local_pca_path, device): | |
super(Face3DHelperOld, self).__init__(local_pca_path, device) | |
class MouthDetector: | |
def __init__(self): | |
pass | |
def detect(self, image): | |
h, w = image.shape[:2] | |
return (w // 2, h // 2) | |
class KeypointNorm(nn.Module): | |
def __init__(self, device): | |
super(KeypointNorm, self).__init__() | |
self.device = device | |
def forward(self, kp_driving): | |
return kp_driving | |
def save_video_with_watermark(video_frames, audio_path, output_path): | |
H, W, _ = video_frames[0].shape | |
out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (W, H)) | |
for frame in video_frames: | |
out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) | |
out.release() | |
def paste_pic(video_path, source_image_crop, crop_info, audio_path, output_path): | |
shutil.copy(video_path, output_path) | |
class TTSTalker: | |
def __init__(self): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.tts_model = None | |
def load_model(self): | |
self.tts_model = self | |
def tokenizer(self, text): | |
return [ord(c) for c in text] | |
def __call__(self, input_tokens): | |
return torch.zeros(1, 16000, device=self.device) | |
def test(self, text, lang='en'): | |
if self.tts_model is None: | |
self.load_model() | |
output_path = os.path.join('./results', str(uuid.uuid4()) + '.wav') | |
os.makedirs('./results', exist_ok=True) | |
tokens = self.tokenizer(text) | |
input_tokens = torch.tensor([tokens], dtype=torch.long).to(self.device) | |
with torch.no_grad(): | |
audio_output = self(input_tokens) | |
torchaudio.save(output_path, audio_output.cpu(), 16000) | |
return output_path | |
class SadTalker: | |
def __init__(self, checkpoint_path='checkpoints', config_path='src/config', size=256, preprocess='crop', | |
old_version=False): | |
self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
self.cfg = self.get_cfg_defaults() | |
self.merge_from_file(os.path.join(config_path, 'sadtalker_config.yaml')) | |
self.cfg['MODEL']['CHECKPOINTS_DIR'] = checkpoint_path | |
self.cfg['MODEL']['CONFIG_DIR'] = config_path | |
self.cfg['MODEL']['DEVICE'] = self.device | |
self.cfg['INPUT_IMAGE'] = {} | |
self.cfg['INPUT_IMAGE']['SOURCE_IMAGE'] = 'None' | |
self.cfg['INPUT_IMAGE']['DRIVEN_AUDIO'] = 'None' | |
self.cfg['INPUT_IMAGE']['PREPROCESS'] = preprocess | |
self.cfg['INPUT_IMAGE']['SIZE'] = size | |
self.cfg['INPUT_IMAGE']['OLD_VERSION'] = old_version | |
download_model(kp_url, kp_file, checkpoint_path) | |
download_model(aud_url, aud_file, checkpoint_path) | |
download_model(wav_url, wav_file, checkpoint_path) | |
download_model(gen_url, gen_file, checkpoint_path) | |
download_model(mapx_url, mapx_file, checkpoint_path) | |
download_model(den_url, den_file, checkpoint_path) | |
download_model(GFPGAN_URL, 'GFPGANv1.4.pth', checkpoint_path) | |
download_model(REALESRGAN_URL, 'RealESRGAN_x2plus.pth', checkpoint_path) | |
self.sadtalker_model = SadTalkerModel(self.cfg, device_id=[0]) | |
def get_cfg_defaults(self): | |
return CN( | |
MODEL=CN( | |
CHECKPOINTS_DIR='', | |
CONFIG_DIR='', | |
DEVICE=self.device, | |
SCALE=64, | |
NUM_VOXEL_FRAMES=8, | |
NUM_MOTION_FRAMES=10, | |
MAX_FEATURES=256, | |
DRIVEN_AUDIO_SAMPLE_RATE=16000, | |
VIDEO_FPS=25, | |
OUTPUT_VIDEO_FPS=None, | |
OUTPUT_AUDIO_SAMPLE_RATE=None, | |
USE_ENHANCER=False, | |
ENHANCER_NAME='', | |
BG_UPSAMPLER=None, | |
IS_HALF=False | |
), | |
INPUT_IMAGE=CN() | |
) | |
def merge_from_file(self, filepath): | |
if os.path.exists(filepath): | |
with open(filepath, 'r') as f: | |
cfg_from_file = yaml.safe_load(f) | |
self.cfg.MODEL.update(CN(cfg_from_file['MODEL'])) | |
self.cfg.INPUT_IMAGE.update(CN(cfg_from_file['INPUT_IMAGE'])) | |
def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False, | |
batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None, | |
ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/', | |
tts_text=None, tts_lang='en'): | |
self.sadtalker_model.test(source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size, | |
pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode, | |
length_of_audio, use_blink, result_dir, tts_text, tts_lang) | |
return self.sadtalker_model.save_result() | |
class SadTalkerModel: | |
def __init__(self, sadtalker_cfg, device_id=[0]): | |
self.cfg = sadtalker_cfg | |
self.device = sadtalker_cfg.MODEL.get('DEVICE', 'cpu') | |
self.sadtalker = SadTalkerInnerModel(sadtalker_cfg, device_id) | |
self.preprocesser = self.sadtalker.preprocesser | |
self.kp_extractor = self.sadtalker.kp_extractor | |
self.generator = self.sadtalker.generator | |
self.mapping = self.sadtalker.mapping | |
self.he_estimator = self.sadtalker.he_estimator | |
self.audio_to_coeff = self.sadtalker.audio_to_coeff | |
self.animate_from_coeff = self.sadtalker.animate_from_coeff | |
self.face_enhancer = self.sadtalker.face_enhancer | |
def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False, | |
batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None, | |
ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/', | |
tts_text=None, tts_lang='en', jitter_amount=10, jitter_source_image=False): | |
self.inner_test = SadTalkerInner(self, source_image, driven_audio, preprocess, still_mode, use_enhancer, | |
batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, | |
use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang, | |
jitter_amount, jitter_source_image) | |
return self.inner_test.test() | |
def save_result(self): | |
return self.inner_test.save_result() | |
class SadTalkerInner: | |
def __init__(self, sadtalker_model, source_image, driven_audio, preprocess, still_mode, use_enhancer, | |
batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode, | |
length_of_audio, use_blink, result_dir, tts_text, tts_lang, jitter_amount, jitter_source_image): | |
self.sadtalker_model = sadtalker_model | |
self.source_image = source_image | |
self.driven_audio = driven_audio | |
self.preprocess = preprocess | |
self.still_mode = still_mode | |
self.use_enhancer = use_enhancer | |
self.batch_size = batch_size | |
self.size = size | |
self.pose_style = pose_style | |
self.exp_scale = exp_scale | |
self.use_ref_video = use_ref_video | |
self.ref_video = ref_video | |
self.ref_info = ref_info | |
self.use_idle_mode = use_idle_mode | |
self.length_of_audio = length_of_audio | |
self.use_blink = use_blink | |
self.result_dir = result_dir | |
self.tts_text = tts_text | |
self.tts_lang = tts_lang | |
self.jitter_amount = jitter_amount | |
self.jitter_source_image = jitter_source_image | |
self.device = self.sadtalker_model.device | |
self.output_path = None | |
def get_test_data(self): | |
proc = self.sadtalker_model.preprocesser | |
if self.tts_text is not None: | |
temp_dir = tempfile.mkdtemp() | |
audio_path = os.path.join(temp_dir, 'audio.wav') | |
tts = TTSTalker() | |
tts.test(self.tts_text, self.tts_lang) | |
self.driven_audio = audio_path | |
source_image_pil = Image.open(self.source_image).convert('RGB') | |
if self.jitter_source_image: | |
jitter_dx = np.random.randint(-self.jitter_amount, self.jitter_amount + 1) | |
jitter_dy = np.random.randint(-self.jitter_amount, self.jitter_amount + 1) | |
source_image_pil = Image.fromarray( | |
np.roll(np.roll(np.array(source_image_pil), jitter_dx, axis=1), jitter_dy, axis=0)) | |
source_image_tensor, crop_info, cropped_image = proc.crop(source_image_pil, self.preprocess, self.size) | |
if self.still_mode or self.use_idle_mode: | |
ref_pose_coeff = proc.generate_still_pose(self.pose_style) | |
ref_expression_coeff = proc.generate_still_expression(self.exp_scale) | |
elif self.use_idle_mode: | |
ref_pose_coeff = proc.generate_idles_pose(self.length_of_audio, self.pose_style) | |
ref_expression_coeff = proc.generate_idles_expression(self.length_of_audio) | |
else: | |
ref_pose_coeff = None | |
ref_expression_coeff = None | |
audio_tensor, audio_sample_rate = proc.process_audio(self.driven_audio, | |
self.sadtalker_model.cfg.MODEL.DRIVEN_AUDIO_SAMPLE_RATE) | |
batch = { | |
'source_image': source_image_tensor.unsqueeze(0).to(self.device), | |
'audio': audio_tensor.unsqueeze(0).to(self.device), | |
'ref_pose_coeff': ref_pose_coeff, | |
'ref_expression_coeff': ref_expression_coeff, | |
'source_image_crop': cropped_image, | |
'crop_info': crop_info, | |
'use_blink': self.use_blink, | |
'pose_style': self.pose_style, | |
'exp_scale': self.exp_scale, | |
'ref_video': self.ref_video, | |
'use_ref_video': self.use_ref_video, | |
'ref_info': self.ref_info, | |
} | |
return batch, audio_sample_rate | |
def run_inference(self, batch): | |
kp_extractor = self.sadtalker_model.kp_extractor | |
generator = self.sadtalker_model.generator | |
mapping = self.sadtalker_model.mapping | |
he_estimator = self.sadtalker_model.he_estimator | |
audio_to_coeff = self.sadtalker_model.audio_to_coeff | |
animate_from_coeff = self.sadtalker_model.animate_from_coeff | |
proc = self.sadtalker_model.preprocesser | |
with torch.no_grad(): | |
kp_source = kp_extractor(batch['source_image']) | |
if self.still_mode or self.use_idle_mode: | |
ref_pose_coeff = batch['ref_pose_coeff'] | |
ref_expression_coeff = batch['ref_expression_coeff'] | |
pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], ref_pose_coeff) | |
expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], ref_expression_coeff) | |
elif self.use_idle_mode: | |
ref_pose_coeff = batch['ref_pose_coeff'] | |
ref_expression_coeff = batch['ref_expression_coeff'] | |
pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], ref_pose_coeff) | |
expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], ref_expression_coeff) | |
else: | |
if self.use_ref_video: | |
kp_ref = kp_extractor(batch['source_image']) | |
pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], kp_ref=kp_ref, | |
use_ref_info=batch['ref_info']) | |
else: | |
pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio']) | |
expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio']) | |
coeff = {'pose_coeff': pose_coeff, 'expression_coeff': expression_coeff} | |
if self.use_blink: | |
coeff['blink_coeff'] = audio_to_coeff.get_blink_coeff(batch['audio']) | |
else: | |
coeff['blink_coeff'] = None | |
kp_driving = audio_to_coeff(batch['audio'])[0] | |
kp_norm = animate_from_coeff.normalize_kp(kp_driving) | |
coeff['kp_driving'] = kp_norm | |
coeff['jacobian'] = [torch.eye(2).unsqueeze(0).unsqueeze(0).to(self.device)] * 4 | |
face_enhancer = self.sadtalker_model.face_enhancer if self.use_enhancer else None | |
output_video = animate_from_coeff.generate(batch['source_image'], kp_source, coeff, generator, mapping, | |
he_estimator, batch['audio'], batch['source_image_crop'], | |
face_enhancer=face_enhancer) | |
return output_video | |
def post_processing(self, output_video, audio_sample_rate, batch): | |
proc = self.sadtalker_model.preprocesser | |
base_name = os.path.splitext(os.path.basename(batch['source_image_crop']))[0] | |
audio_name = os.path.splitext(os.path.basename(self.driven_audio))[0] | |
output_video_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '.mp4') | |
self.output_path = output_video_path | |
video_fps = self.sadtalker_model.cfg.MODEL.VIDEO_FPS if self.sadtalker_model.cfg.MODEL.OUTPUT_VIDEO_FPS is None else \ | |
self.sadtalker_model.cfg.MODEL.OUTPUT_VIDEO_FPS | |
audio_output_sample_rate = self.sadtalker_model.cfg.MODEL.DRIVEN_AUDIO_SAMPLE_RATE if \ | |
self.sadtalker_model.cfg.MODEL.OUTPUT_AUDIO_SAMPLE_RATE is None else \ | |
self.sadtalker_model.cfg.MODEL.OUTPUT_AUDIO_SAMPLE_RATE | |
if self.use_enhancer: | |
enhanced_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '_enhanced.mp4') | |
save_video_with_watermark(output_video, self.driven_audio, enhanced_path) | |
paste_pic(enhanced_path, batch['source_image_crop'], batch['crop_info'], self.driven_audio, | |
output_video_path) | |
os.remove(enhanced_path) | |
else: | |
save_video_with_watermark(output_video, self.driven_audio, output_video_path) | |
if self.tts_text is not None: | |
shutil.rmtree(os.path.dirname(self.driven_audio)) | |
def save_result(self): | |
return self.output_path | |
def __call__(self): | |
return self.output_path | |
def test(self): | |
batch, audio_sample_rate = self.get_test_data() | |
output_video = self.run_inference(batch) | |
self.post_processing(output_video, audio_sample_rate, batch) | |
return self.save_result() | |
class SadTalkerInnerModel: | |
def __init__(self, sadtalker_cfg, device_id=[0]): | |
self.cfg = sadtalker_cfg | |
self.device = sadtalker_cfg.MODEL.DEVICE | |
self.preprocesser = Preprocesser(sadtalker_cfg, self.device) | |
self.kp_extractor = KeyPointExtractor(sadtalker_cfg, self.device) | |
self.audio_to_coeff = Audio2Coeff(sadtalker_cfg, self.device) | |
self.animate_from_coeff = AnimateFromCoeff(sadtalker_cfg, self.device) | |
self.face_enhancer = FaceEnhancer(sadtalker_cfg, self.device) if sadtalker_cfg.MODEL.USE_ENHANCER else None | |
self.generator = Generator(sadtalker_cfg, self.device) | |
self.mapping = Mapping(sadtalker_cfg, self.device) | |
self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, self.device) | |
class Preprocesser: | |
def __init__(self, sadtalker_cfg, device): | |
self.cfg = sadtalker_cfg | |
self.device = device | |
if self.cfg.INPUT_IMAGE.get('OLD_VERSION', False): | |
self.face3d_helper = Face3DHelperOld(self.cfg.INPUT_IMAGE.get('LOCAL_PCA_PATH', ''), device) | |
else: | |
self.face3d_helper = Face3DHelper(self.cfg.INPUT_IMAGE.get('LOCAL_PCA_PATH', ''), device) | |
self.mouth_detector = MouthDetector() | |
def crop(self, source_image_pil, preprocess_type, size=256): | |
source_image = np.array(source_image_pil) | |
face_info = self.face3d_helper.run(source_image) | |
if face_info is None: | |
raise Exception("No face detected") | |
x_min, y_min, x_max, y_max = face_info[:4] | |
old_size = (x_max - x_min, y_max - y_min) | |
x_center = (x_max + x_min) / 2 | |
y_center = (y_max + y_min) / 2 | |
if preprocess_type == 'crop': | |
face_size = max(x_max - x_min, y_max - y_min) | |
x_min = int(x_center - face_size / 2) | |
y_min = int(y_center - face_size / 2) | |
x_max = int(x_center + face_size / 2) | |
y_max = int(y_center + face_size / 2) | |
else: | |
x_min -= int((x_max - x_min) * 0.1) | |
y_min -= int((y_max - y_min) * 0.1) | |
x_max += int((x_max - x_min) * 0.1) | |
y_max += int((y_max - y_min) * 0.1) | |
h, w = source_image.shape[:2] | |
x_min = max(0, x_min) | |
y_min = max(0, y_min) | |
x_max = min(w, x_max) | |
y_max = min(h, y_max) | |
cropped_image = source_image[y_min:y_max, x_min:x_max] | |
cropped_image_pil = Image.fromarray(cropped_image) | |
if size is not None and size != 0: | |
cropped_image_pil = cropped_image_pil.resize((size, size), Image.Resampling.LANCZOS) | |
source_image_tensor = self.img2tensor(cropped_image_pil) | |
return source_image_tensor, [[y_min, y_max], [x_min, x_max], old_size, cropped_image_pil.size], os.path.basename( | |
self.cfg.INPUT_IMAGE.get('SOURCE_IMAGE', '')) | |
def img2tensor(self, img): | |
img = np.array(img).astype(np.float32) / 255.0 | |
img = np.transpose(img, (2, 0, 1)) | |
return torch.FloatTensor(img) | |
def video_to_tensor(self, video, device): | |
video_tensor_list = [] | |
import torchvision.transforms as transforms | |
transform_func = transforms.ToTensor() | |
for frame in video: | |
frame_pil = Image.fromarray(frame) | |
frame_tensor = transform_func(frame_pil).unsqueeze(0).to(device) | |
video_tensor_list.append(frame_tensor) | |
video_tensor = torch.cat(video_tensor_list, dim=0) | |
return video_tensor | |
def process_audio(self, audio_path, sample_rate): | |
wav = load_wav_util(audio_path, sample_rate) | |
wav_tensor = torch.FloatTensor(wav).unsqueeze(0) | |
return wav_tensor, sample_rate | |
def generate_still_pose(self, pose_style): | |
ref_pose_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device) | |
ref_pose_coeff[:, :3] = torch.tensor([0, 0, pose_style * 0.3], dtype=torch.float32) | |
return ref_pose_coeff | |
def generate_still_expression(self, exp_scale): | |
ref_expression_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device) | |
ref_expression_coeff[:, :3] = torch.tensor([0, 0, exp_scale * 0.3], dtype=torch.float32) | |
return ref_expression_coeff | |
def generate_idles_pose(self, length_of_audio, pose_style): | |
num_frames = int(length_of_audio * self.cfg.MODEL.VIDEO_FPS) | |
ref_pose_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device) | |
start_pose = self.generate_still_pose(pose_style) | |
end_pose = self.generate_still_pose(pose_style) | |
for frame_idx in range(num_frames): | |
alpha = frame_idx / num_frames | |
ref_pose_coeff[frame_idx] = (1 - alpha) * start_pose + alpha * end_pose | |
return ref_pose_coeff | |
def generate_idles_expression(self, length_of_audio): | |
num_frames = int(length_of_audio * self.cfg.MODEL.VIDEO_FPS) | |
ref_expression_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device) | |
start_exp = self.generate_still_expression(1.0) | |
end_exp = self.generate_still_expression(1.0) | |
for frame_idx in range(num_frames): | |
alpha = frame_idx / num_frames | |
ref_expression_coeff[frame_idx] = (1 - alpha) * start_exp + alpha * end_exp | |
return ref_expression_coeff | |
class KeyPointExtractor(nn.Module): | |
def __init__(self, sadtalker_cfg, device): | |
super(KeyPointExtractor, self).__init__() | |
self.kp_extractor = OcclusionAwareKPDetector(kp_channels=sadtalker_cfg.MODEL.NUM_MOTION_FRAMES, | |
num_kp=10, | |
num_dilation_blocks=2, | |
dropout_rate=0.1).to(device) | |
checkpoint_path = os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR, 'kp_detector.safetensors') | |
self.load_kp_detector(checkpoint_path, device) | |
def load_kp_detector(self, checkpoint_path, device): | |
if os.path.exists(checkpoint_path): | |
if checkpoint_path.endswith('safetensors'): | |
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device) | |
else: | |
checkpoint = torch.load(checkpoint_path, map_location=device) | |
self.kp_extractor.load_state_dict(checkpoint.get('kp_detector', {})) | |
else: | |
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}") | |
def forward(self, x): | |
kp = self.kp_extractor(x) | |
return kp | |
class Audio2Coeff(nn.Module): | |
def __init__(self, sadtalker_cfg, device): | |
super(Audio2Coeff, self).__init__() | |
self.audio_model = Wav2Vec2Model().to(device) | |
checkpoint_path = os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR, 'wav2vec2.pth') | |
self.load_audio_model(checkpoint_path, device) | |
self.pose_mapper = AudioCoeffsPredictor(2048, 64).to(device) | |
self.exp_mapper = AudioCoeffsPredictor(2048, 64).to(device) | |
self.blink_mapper = AudioCoeffsPredictor(2048, 1).to(device) | |
mapping_checkpoint = os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR, 'audio2pose_00140-model.pth') | |
self.load_mapping_model(mapping_checkpoint, device) | |
def load_audio_model(self, checkpoint_path, device): | |
if os.path.exists(checkpoint_path): | |
if checkpoint_path.endswith('safetensors'): | |
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device) | |
else: | |
checkpoint = torch.load(checkpoint_path, map_location=device) | |
self.audio_model.load_state_dict(checkpoint.get("wav2vec2", {})) | |
else: | |
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}") | |
def load_mapping_model(self, checkpoint_path, device): | |
if os.path.exists(checkpoint_path): | |
if checkpoint_path.endswith('safetensors'): | |
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device) | |
else: | |
checkpoint = torch.load(checkpoint_path, map_location=device) | |
self.pose_mapper.load_state_dict(checkpoint.get("pose_predictor", {})) | |
self.exp_mapper.load_state_dict(checkpoint.get("exp_predictor", {})) | |
self.blink_mapper.load_state_dict(checkpoint.get("blink_predictor", {})) | |
else: | |
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}") | |
def get_pose_coeff(self, audio_tensor, ref_pose_coeff=None, kp_ref=None, use_ref_info=''): | |
audio_embedding = self.audio_model(audio_tensor) | |
pose_coeff = self.pose_mapper(audio_embedding) | |
if ref_pose_coeff is not None: | |
pose_coeff = ref_pose_coeff | |
if kp_ref is not None and use_ref_info == 'pose': | |
ref_pose_6d = kp_ref['value'][:, :6] | |
pose_coeff[:, :6] = self.mean_std_normalize(ref_pose_6d).mean(dim=1) | |
return pose_coeff | |
def get_exp_coeff(self, audio_tensor, ref_expression_coeff=None): | |
audio_embedding = self.audio_model(audio_tensor) | |
expression_coeff = self.exp_mapper(audio_embedding) | |
if ref_expression_coeff is not None: | |
expression_coeff = ref_expression_coeff | |
return expression_coeff | |
def get_blink_coeff(self, audio_tensor): | |
audio_embedding = self.audio_model(audio_tensor) | |
blink_coeff = self.blink_mapper(audio_embedding) | |
return blink_coeff | |
def forward(self, audio): | |
audio_embedding = self.audio_model(audio) | |
pose_coeff, expression_coeff, blink_coeff = self.pose_mapper(audio_embedding), self.exp_mapper( | |
audio_embedding), self.blink_mapper(audio_embedding) | |
return pose_coeff, expression_coeff, blink_coeff | |
def mean_std_normalize(self, coeff): | |
mean = coeff.mean(dim=1, keepdim=True) | |
std = coeff.std(dim=1, keepdim=True) | |
return (coeff - mean) / std | |
class AnimateFromCoeff(nn.Module): | |
def __init__(self, sadtalker_cfg, device): | |
super(AnimateFromCoeff, self).__init__() | |
self.generator = Generator(sadtalker_cfg, device) | |
self.mapping = Mapping(sadtalker_cfg, device) | |
self.kp_norm = KeypointNorm(device=device) | |
self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, device) | |
def normalize_kp(self, kp_driving): | |
return self.kp_norm(kp_driving) | |
def generate(self, source_image, kp_source, coeff, generator, mapping, he_estimator, audio, source_image_crop, | |
face_enhancer=None): | |
kp_driving = coeff['kp_driving'] | |
jacobian = coeff['jacobian'] | |
pose_coeff = coeff['pose_coeff'] | |
expression_coeff = coeff['expression_coeff'] | |
blink_coeff = coeff['blink_coeff'] | |
with torch.no_grad(): | |
if blink_coeff is not None: | |
sparse_motion = he_estimator(kp_source, kp_driving, jacobian) | |
dense_motion = sparse_motion['dense_motion'] | |
video_deocclusion = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None}) | |
face_3d = mapping(expression_coeff, pose_coeff, blink_coeff) | |
video_3d = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None}, | |
face_3d_param=face_3d) | |
video_output = video_deocclusion['video_no_reocclusion'] + video_3d['video_3d'] | |
video_output = self.make_animation(video_output) | |
else: | |
sparse_motion = he_estimator(kp_source, kp_driving, jacobian) | |
dense_motion = sparse_motion['dense_motion'] | |
face_3d = mapping(expression_coeff, pose_coeff) | |
video_3d = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None}, | |
face_3d_param=face_3d) | |
video_output = video_3d['video_3d'] | |
video_output = self.make_animation(video_output) | |
if face_enhancer is not None: | |
video_output_enhanced = [] | |
for frame in tqdm(video_output, 'Face enhancer running'): | |
pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) | |
enhanced_image = face_enhancer.enhance(np.array(pil_image))[0] | |
video_output_enhanced.append(cv2.cvtColor(enhanced_image, cv2.COLOR_BGR2RGB)) | |
video_output = video_output_enhanced | |
return video_output | |
def make_animation(self, video_array): | |
H, W, _ = video_array[0].shape | |
out = cv2.VideoWriter('./tmp.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 25, (W, H)) | |
for img in video_array: | |
out.write(cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) | |
out.release() | |
video = imageio.mimread('./tmp.mp4') | |
os.remove('./tmp.mp4') | |
return video | |
class Generator(nn.Module): | |
def __init__(self, sadtalker_cfg, device): | |
super(Generator, self).__init__() | |
self.generator = Hourglass(block_expansion=sadtalker_cfg.MODEL.SCALE, | |
num_blocks=sadtalker_cfg.MODEL.NUM_VOXEL_FRAMES, | |
max_features=sadtalker_cfg.MODEL.MAX_FEATURES, | |
num_channels=3, | |
kp_size=10, | |
num_deform_blocks=sadtalker_cfg.MODEL.NUM_MOTION_FRAMES).to(device) | |
checkpoint_path = os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR, 'generator.pth') | |
self.load_generator(checkpoint_path, device) | |
def load_generator(self, checkpoint_path, device): | |
if os.path.exists(checkpoint_path): | |
if checkpoint_path.endswith('safetensors'): | |
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device) | |
else: | |
checkpoint = torch.load(checkpoint_path, map_location=device) | |
self.generator.load_state_dict(checkpoint.get('generator', {})) | |
else: | |
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}") | |
def forward(self, source_image, dense_motion, bg_param, face_3d_param=None): | |
if face_3d_param is not None: | |
video_3d = self.generator(source_image, kp_driving=dense_motion, bg_param=bg_param, | |
face_3d_param=face_3d_param) | |
else: | |
video_3d = self.generator(source_image, kp_driving=dense_motion, bg_param=bg_param) | |
return {'video_3d': video_3d, 'video_no_reocclusion': video_3d} | |
class Mapping(nn.Module): | |
def __init__(self, sadtalker_cfg, device): | |
super(Mapping, self).__init__() | |
self.mapping_net = MappingNet(num_coeffs=64, num_layers=3, hidden_dim=128).to(device) | |
checkpoint_path = os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR, 'mapping.pth') | |
self.load_mapping_net(checkpoint_path, device) | |
self.f_3d_mean = torch.zeros(1, 64, device=device) | |
def load_mapping_net(self, checkpoint_path, device): | |
if os.path.exists(checkpoint_path): | |
if checkpoint_path.endswith('safetensors'): | |
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device) | |
else: | |
checkpoint = torch.load(checkpoint_path, map_location=device) | |
self.mapping_net.load_state_dict(checkpoint.get('mapping', {})) | |
else: | |
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}") | |
def forward(self, expression_coeff, pose_coeff, blink_coeff=None): | |
coeff = torch.cat([expression_coeff, pose_coeff], dim=1) | |
face_3d = self.mapping_net(coeff) + self.f_3d_mean | |
if blink_coeff is not None: | |
face_3d[:, -1:] = blink_coeff | |
return face_3d | |
class OcclusionAwareDenseMotion(nn.Module): | |
def __init__(self, sadtalker_cfg, device): | |
super(OcclusionAwareDenseMotion, self).__init__() | |
self.dense_motion_network = DenseMotionNetwork(num_kp=10, | |
num_channels=3, | |
block_expansion=sadtalker_cfg.MODEL.SCALE, | |
num_blocks=sadtalker_cfg.MODEL.NUM_MOTION_FRAMES - 1, | |
max_features=sadtalker_cfg.MODEL.MAX_FEATURES).to(device) | |
checkpoint_path = os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR, 'dense_motion.pth') | |
self.load_dense_motion_network(checkpoint_path, device) | |
def load_dense_motion_network(self, checkpoint_path, device): | |
if os.path.exists(checkpoint_path): | |
if checkpoint_path.endswith('safetensors'): | |
checkpoint = safetensors.torch.load_file(checkpoint_path, device=device) | |
else: | |
checkpoint = torch.load(checkpoint_path, map_location=device) | |
self.dense_motion_network.load_state_dict(checkpoint.get('dense_motion', {})) | |
else: | |
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}") | |
def forward(self, kp_source, kp_driving, jacobian): | |
sparse_motion = self.dense_motion_network(kp_source, kp_driving, jacobian) | |
return sparse_motion | |
class FaceEnhancer(nn.Module): | |
def __init__(self, sadtalker_cfg, device): | |
super(FaceEnhancer, self).__init__() | |
enhancer_name = sadtalker_cfg.MODEL.ENHANCER_NAME | |
bg_upsampler = sadtalker_cfg.MODEL.BG_UPSAMPLER | |
if enhancer_name == 'gfpgan': | |
from gfpgan import GFPGANer | |
self.face_enhancer = GFPGANer(model_path=os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR, 'GFPGANv1.4.pth'), | |
upscale=1, | |
arch='clean', | |
channel_multiplier=2, | |
bg_upsampler=bg_upsampler) | |
elif enhancer_name == 'realesrgan': | |
from realesrgan import RealESRGANer | |
half = False if device == 'cpu' else sadtalker_cfg.MODEL.IS_HALF | |
self.face_enhancer = RealESRGANer(scale=2, | |
model_path=os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR, | |
'RealESRGAN_x2plus.pth'), | |
tile=0, | |
tile_pad=10, | |
pre_pad=0, | |
half=half, | |
device=device) | |
else: | |
self.face_enhancer = None | |
def forward(self, x): | |
return self.face_enhancer.enhance(x, outscale=1)[0] | |