Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import yaml | |
from PIL import Image | |
from skimage import img_as_ubyte, transform | |
import safetensors | |
import librosa | |
from pydub import AudioSegment | |
import imageio | |
from scipy.io import loadmat, savemat, wavfile | |
import glob | |
import tempfile | |
from tqdm import tqdm | |
import numpy as np | |
import math | |
import torchvision | |
import os | |
import re | |
import shutil | |
from yacs.config import CfgNode as CN | |
import requests | |
import subprocess | |
import cv2 | |
from collections import OrderedDict | |
def img2tensor(imgs, bgr2rgb=True, float32=True): | |
if isinstance(imgs, np.ndarray): | |
if imgs.ndim == 3: | |
imgs = imgs[..., np.newaxis] | |
imgs = torch.from_numpy(imgs.transpose((2, 0, 1))) | |
elif isinstance(imgs, Image.Image): | |
imgs = torch.from_numpy(np.array(imgs)).permute(2, 0, 1) | |
else: | |
raise TypeError(f'Type `{type(imgs)}` is not suitable for img2tensor') | |
if bgr2rgb: | |
if imgs.shape[0] == 3: | |
imgs = imgs[[2, 1, 0], :, :] | |
if float32: | |
imgs = imgs.float() / 255. | |
return imgs | |
def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): | |
if not isinstance(tensor, torch.Tensor): | |
raise TypeError(f'Input tensor should be torch.Tensor, but got {type(tensor)}') | |
tensor = tensor.float().cpu() | |
tensor = tensor.clamp_(*min_max) | |
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) | |
output_img = tensor.mul(255).round() | |
output_img = np.transpose(output_img.numpy(), (1, 2, 0)) | |
output_img = np.clip(output_img, 0, 255).astype(np.uint8) | |
if rgb2bgr: | |
output_img = cv2.cvtColor(output_img, cv2.COLOR_RGB2BGR) | |
return output_img if out_type == np.uint8 else output_img.astype(out_type) / 255. | |
class RealESRGANer(): | |
def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=0, half=False, device=None, gpu_id=None): | |
self.scale = scale | |
self.tile = tile | |
self.tile_pad = tile_pad | |
self.pre_pad = pre_pad | |
self.mod_scale = None | |
self.half = half | |
if device is None: | |
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
else: | |
self.device = device | |
if model is None: | |
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale) | |
if half: | |
model.half() | |
loadnet = torch.load(model_path, map_location=lambda storage, loc: storage) | |
if 'params' in loadnet: | |
model.load_state_dict(loadnet['params'], strict=True) | |
elif 'params_ema' in loadnet: | |
model.load_state_dict(loadnet['params_ema'], strict=True) | |
else: | |
model.load_state_dict(loadnet, strict=True) | |
model.eval() | |
self.model = model.to(self.device) | |
def enhance(self, img, outscale=None, tile=None, tile_pad=None, pre_pad=None, half=None): | |
h_input, w_input = img.shape[0:2] | |
if outscale is None: | |
outscale = self.scale | |
if tile is None: | |
tile = self.tile | |
if tile_pad is None: | |
tile_pad = self.tile_pad | |
if pre_pad is None: | |
pre_pad = self.pre_pad | |
if half is None: | |
half = self.half | |
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) | |
img_tensor = img2tensor(img) | |
img_tensor = img_tensor.unsqueeze(0).to(self.device) | |
if half: | |
img_tensor = img_tensor.half() | |
mod_scale = self.mod_scale | |
h_pad, w_pad = 0, 0 | |
if mod_scale is not None: | |
h_pad, w_pad = int(np.ceil(h_input / mod_scale) * mod_scale - h_input), int(np.ceil(w_input / mod_scale) * mod_scale - w_input) | |
img_tensor = F.pad(img_tensor, (0, w_pad, 0, h_pad), 'reflect') | |
window_size = 256 | |
scale = self.scale | |
overlap_ratio = 0.5 | |
if w_input * h_input < window_size**2: | |
tile = None | |
if tile is not None and tile > 0: | |
tile_overlap = tile * overlap_ratio | |
sf = scale | |
stride_w = math.ceil(tile - tile_overlap) | |
stride_h = math.ceil(tile - tile_overlap) | |
numW = math.ceil((w_input + tile_overlap) / stride_w) | |
numH = math.ceil((h_input + tile_overlap) / stride_h) | |
paddingW = (numW - 1) * stride_w + tile - w_input | |
paddingH = (numH - 1) * stride_h + tile - h_input | |
padding_bottom = int(max(paddingH, 0)) | |
padding_right = int(max(paddingW, 0)) | |
padding_left, padding_top = 0, 0 | |
img_tensor = F.pad(img_tensor, (padding_left, padding_right, padding_top, padding_bottom), mode='reflect') | |
output_h, output_w = padding_top + h_input * scale + padding_bottom, padding_left + w_input * scale + padding_right | |
output_tensor = torch.zeros([1, 3, output_h, output_w], dtype=img_tensor.dtype, device=self.device) | |
windows = [] | |
for row in range(numH): | |
for col in range(numW): | |
start_x = col * stride_w | |
start_y = row * stride_h | |
end_x = min(start_x + tile, img_tensor.shape[3]) | |
end_y = min(start_y + tile, img_tensor.shape[2]) | |
windows.append(img_tensor[:, :, start_y:end_y, start_x:end_x]) | |
results = [] | |
batch_size = 8 | |
for i in range(0, len(windows), batch_size): | |
batch_windows = torch.stack(windows[i:min(i + batch_size, len(windows))], dim=0) | |
with torch.no_grad(): | |
results.append(self.model(batch_windows)) | |
results = torch.cat(results, dim=0) | |
count = 0 | |
for row in range(numH): | |
for col in range(numW): | |
start_x = col * stride_w | |
start_y = row * stride_h | |
end_x = min(start_x + tile, img_tensor.shape[3]) | |
end_y = min(start_y + tile, img_tensor.shape[2]) | |
out_start_x, out_start_y = start_x * sf, start_y * sf | |
out_end_x, out_end_y = end_x * sf, end_y * sf | |
output_tensor[:, :, out_start_y:out_end_y, out_start_x:out_end_x] += results[count][:, :, :end_y * sf - out_start_y, :end_x * sf - out_start_x] | |
count += 1 | |
forward_img = output_tensor[:, :, :h_input * sf, :w_input * sf] | |
else: | |
with torch.no_grad(): | |
forward_img = self.model(img_tensor) | |
if half: | |
forward_img = forward_img.float() | |
output_img = tensor2img(forward_img.squeeze(0).clamp_(0, 1)) | |
if mod_scale is not None: | |
output_img = output_img[:h_input * self.scale, :w_input * self.scale, ...] | |
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB) | |
return [output_img, None] | |
def enhance(self, img, outscale=None, tile=None, tile_pad=None, pre_pad=None, half=None): | |
h_input, w_input = img.shape[0:2] | |
if outscale is None: | |
outscale = self.scale | |
if tile is None: | |
tile = self.tile | |
if tile_pad is None: | |
tile_pad = self.tile_pad | |
if pre_pad is None: | |
pre_pad = self.pre_pad | |
if half is None: | |
half = self.half | |
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) | |
img_tensor = img2tensor(img) | |
img_tensor = img_tensor.unsqueeze(0).to(self.device) | |
if half: | |
img_tensor = img_tensor.half() | |
mod_scale = self.mod_scale | |
h_pad, w_pad = 0, 0 | |
if mod_scale is not None: | |
h_pad, w_pad = int(np.ceil(h_input / mod_scale) * mod_scale - h_input), int(np.ceil(w_input / mod_scale) * mod_scale - w_input) | |
img_tensor = F.pad(img_tensor, (0, w_pad, 0, h_pad), 'reflect') | |
window_size = 256 | |
scale = self.scale | |
overlap_ratio = 0.5 | |
if w_input * h_input < window_size**2: | |
tile = None | |
if tile is not None and tile > 0: | |
tile_overlap = tile * overlap_ratio | |
sf = scale | |
stride_w = math.ceil(tile - tile_overlap) | |
stride_h = math.ceil(tile - tile_overlap) | |
numW = math.ceil((w_input + tile_overlap) / stride_w) | |
numH = math.ceil((h_input + tile_overlap) / stride_h) | |
paddingW = (numW - 1) * stride_w + tile - w_input | |
paddingH = (numH - 1) * stride_h + tile - h_input | |
padding_bottom = int(max(paddingH, 0)) | |
padding_right = int(max(paddingW, 0)) | |
padding_left, padding_top = 0, 0 | |
img_tensor = F.pad(img_tensor, (padding_left, padding_right, padding_top, padding_bottom), mode='reflect') | |
output_h, output_w = padding_top + h_input * scale + padding_bottom, padding_left + w_input * scale + padding_right | |
output_tensor = torch.zeros([1, 3, output_h, output_w], dtype=img_tensor.dtype, device=self.device) | |
windows = [] | |
for row in range(numH): | |
for col in range(numW): | |
start_x = col * stride_w | |
start_y = row * stride_h | |
end_x = min(start_x + tile, img_tensor.shape[3]) | |
end_y = min(start_y + tile, img_tensor.shape[2]) | |
windows.append(img_tensor[:, :, start_y:end_y, start_x:end_x]) | |
results = [] | |
batch_size = 8 | |
for i in range(0, len(windows), batch_size): | |
batch_windows = torch.stack(windows[i:min(i + batch_size, len(windows))], dim=0) | |
with torch.no_grad(): | |
results.append(self.model(batch_windows)) | |
results = torch.cat(results, dim=0) | |
count = 0 | |
for row in range(numH): | |
for col in range(numW): | |
start_x = col * stride_w | |
start_y = row * stride_h | |
end_x = min(start_x + tile, img_tensor.shape[3]) | |
end_y = min(start_y + tile, img_tensor.shape[2]) | |
out_start_x, out_start_y = start_x * sf, start_y * sf | |
out_end_x, out_end_y = end_x * sf, end_y * sf | |
output_tensor[:, :, out_start_y:out_end_y, out_start_x:out_end_x] += results[count][:, :, :end_y * sf - out_start_y, :end_x * sf - out_start_x] | |
count += 1 | |
forward_img = output_tensor[:, :, :h_input * sf, :w_input * sf] | |
else: | |
with torch.no_grad(): | |
forward_img = self.model(img_tensor) | |
if half: | |
forward_img = forward_img.float() | |
output_img = tensor2img(forward_img.squeeze(0).clamp_(0, 1)) | |
if mod_scale is not None: | |
output_img = output_img[:h_input * self.scale, :w_input * self.scale, ...] | |
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB) | |
return [output_img, None] | |
def save_video_with_watermark(video_frames, audio_path, output_path, watermark_path='./assets/sadtalker_logo.png'): | |
try: | |
watermark = imageio.imread(watermark_path) | |
except FileNotFoundError: | |
watermark = None | |
writer = imageio.get_writer(output_path, fps=25) | |
try: | |
for frame in tqdm(video_frames, 'Generating video'): | |
if watermark is not None: | |
frame_h, frame_w = frame.shape[:2] | |
watermark_h, watermark_w = watermark.shape[:2] | |
if watermark_h > frame_h or watermark_w > frame_w: | |
watermark = transform.resize(watermark, (frame_h // 4, frame_w // 4)) | |
watermark_h, watermark_w = watermark.shape[:2] | |
start_h = frame_h - watermark_h - 10 | |
start_w = frame_w - watermark_w - 10 | |
frame[start_h:start_h+watermark_h, start_w:start_w+watermark_h, :] = watermark | |
writer.append_data(img_as_ubyte(frame)) | |
except Exception as e: | |
print(f"Error in video writing: {e}") | |
finally: | |
writer.close() | |
if audio_path is not None: | |
try: | |
command = "ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}".format(audio_path, output_path, output_path.replace('.mp4', '_with_audio.mp4')) | |
subprocess.call(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
os.remove(output_path) | |
os.rename(output_path.replace('.mp4', '_with_audio.mp4'), output_path) | |
except Exception as e: | |
print(f"Error adding audio to video: {e}") | |
def paste_pic(video_path, pic_path, crop_info, audio_path, output_path): | |
try: | |
y_start, y_end, x_start, x_end, old_size, cropped_size = crop_info[0][0], crop_info[0][1], crop_info[1][0], crop_info[1][1], crop_info[2], crop_info[3] | |
source_image_h, source_image_w = old_size | |
cropped_h, cropped_w = cropped_size | |
delta_h, delta_w = source_image_h - cropped_h, source_image_w - cropped_w | |
box = [x_start, y_start, source_image_w - x_end, source_image_h - y_end] | |
command = "ffmpeg -y -i {} -i {} -filter_complex \"[1]crop=w={}:h={}:x={}:y={},[s];[0][s]overlay=x={}:y={}\" -codec:a copy {}".format(video_path, pic_path, cropped_w, cropped_h, box[0], box[1], box[0], box[1], output_path) | |
subprocess.call(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
except Exception as e: | |
print(f"Error pasting picture to video: {e}") | |
def color_transfer_batch(source, target, mode='numpy'): | |
source_np = tensor2img(source) | |
target_np = tensor2img(target) | |
source_lab = cv2.cvtColor(source_np, cv2.COLOR_RGB2LAB).astype(np.float32) | |
target_lab = cv2.cvtColor(target_np, cv2.COLOR_RGB2LAB).astype(np.float32) | |
source_mu = np.mean(source_lab, axis=(0, 1), keepdims=True) | |
source_std = np.std(source_lab, axis=(0, 1), keepdims=True) | |
target_mu = np.mean(target_lab, axis=(0, 1), keepdims=True) | |
target_std = np.std(target_lab, axis=(0, 1), keepdims=True) | |
transfer_lab = (target_lab - target_mu) * (source_std / target_std) + source_mu | |
transfer_rgb = cv2.cvtColor(np.clip(transfer_lab, 0, 255).astype(np.uint8), cv2.COLOR_LAB2RGB) | |
transfer_rgb_tensor = img2tensor(transfer_rgb) | |
return transfer_rgb_tensor.unsqueeze(0).to(source.device) | |
def load_video_to_cv2(path, resize=None): | |
video = [] | |
try: | |
cap = cv2.VideoCapture(path) | |
if not cap.isOpened(): | |
raise Exception("Error opening video stream or file") | |
while(cap.isOpened()): | |
ret, frame = cap.read() | |
if ret: | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
if resize is not None: | |
frame_rgb = cv2.resize(frame_rgb, resize) | |
video.append(frame_rgb) | |
else: | |
break | |
cap.release() | |
except Exception as e: | |
print(f"Error loading video: {e}") | |
return video | |
def get_prior_from_bfm(bfm_path): | |
mat_path = os.path.join(bfm_path, 'BFM_prior.mat') | |
C = loadmat(mat_path) | |
pc_tex = torch.tensor(C['pc_tex'].astype(np.float32)).unsqueeze(0) | |
pc_exp = torch.tensor(C['pc_exp'].astype(np.float32)).unsqueeze(0) | |
u_tex = torch.tensor(C['u_tex'].astype(np.float32)).unsqueeze(0) | |
u_exp = torch.tensor(C['u_exp'].astype(np.float32)).unsqueeze(0) | |
prior_coeff = { | |
'pc_tex': pc_tex, | |
'pc_exp': pc_exp, | |
'u_tex': u_tex, | |
'u_exp': u_exp | |
} | |
return prior_coeff |