import torch import torch.nn as nn import torch.nn.functional as F import yaml from PIL import Image from skimage import img_as_ubyte, transform import safetensors import librosa from pydub import AudioSegment import imageio from scipy.io import loadmat, savemat, wavfile import glob import tempfile from tqdm import tqdm import numpy as np import math import torchvision import os import re import shutil from yacs.config import CfgNode as CN import requests import subprocess import cv2 from collections import OrderedDict def img2tensor(imgs, bgr2rgb=True, float32=True): if isinstance(imgs, np.ndarray): if imgs.ndim == 3: imgs = imgs[..., np.newaxis] imgs = torch.from_numpy(imgs.transpose((2, 0, 1))) elif isinstance(imgs, Image.Image): imgs = torch.from_numpy(np.array(imgs)).permute(2, 0, 1) else: raise TypeError(f'Type `{type(imgs)}` is not suitable for img2tensor') if bgr2rgb: if imgs.shape[0] == 3: imgs = imgs[[2, 1, 0], :, :] if float32: imgs = imgs.float() / 255. return imgs def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): if not isinstance(tensor, torch.Tensor): raise TypeError(f'Input tensor should be torch.Tensor, but got {type(tensor)}') tensor = tensor.float().cpu() tensor = tensor.clamp_(*min_max) tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) output_img = tensor.mul(255).round() output_img = np.transpose(output_img.numpy(), (1, 2, 0)) output_img = np.clip(output_img, 0, 255).astype(np.uint8) if rgb2bgr: output_img = cv2.cvtColor(output_img, cv2.COLOR_RGB2BGR) return output_img if out_type == np.uint8 else output_img.astype(out_type) / 255. class RealESRGANer(): def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=0, half=False, device=None, gpu_id=None): self.scale = scale self.tile = tile self.tile_pad = tile_pad self.pre_pad = pre_pad self.mod_scale = None self.half = half if device is None: self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: self.device = device if model is None: model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale) if half: model.half() loadnet = torch.load(model_path, map_location=lambda storage, loc: storage) if 'params' in loadnet: model.load_state_dict(loadnet['params'], strict=True) elif 'params_ema' in loadnet: model.load_state_dict(loadnet['params_ema'], strict=True) else: model.load_state_dict(loadnet, strict=True) model.eval() self.model = model.to(self.device) def enhance(self, img, outscale=None, tile=None, tile_pad=None, pre_pad=None, half=None): h_input, w_input = img.shape[0:2] if outscale is None: outscale = self.scale if tile is None: tile = self.tile if tile_pad is None: tile_pad = self.tile_pad if pre_pad is None: pre_pad = self.pre_pad if half is None: half = self.half img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) img_tensor = img2tensor(img) img_tensor = img_tensor.unsqueeze(0).to(self.device) if half: img_tensor = img_tensor.half() mod_scale = self.mod_scale h_pad, w_pad = 0, 0 if mod_scale is not None: h_pad, w_pad = int(np.ceil(h_input / mod_scale) * mod_scale - h_input), int(np.ceil(w_input / mod_scale) * mod_scale - w_input) img_tensor = F.pad(img_tensor, (0, w_pad, 0, h_pad), 'reflect') window_size = 256 scale = self.scale overlap_ratio = 0.5 if w_input * h_input < window_size**2: tile = None if tile is not None and tile > 0: tile_overlap = tile * overlap_ratio sf = scale stride_w = math.ceil(tile - tile_overlap) stride_h = math.ceil(tile - tile_overlap) numW = math.ceil((w_input + tile_overlap) / stride_w) numH = math.ceil((h_input + tile_overlap) / stride_h) paddingW = (numW - 1) * stride_w + tile - w_input paddingH = (numH - 1) * stride_h + tile - h_input padding_bottom = int(max(paddingH, 0)) padding_right = int(max(paddingW, 0)) padding_left, padding_top = 0, 0 img_tensor = F.pad(img_tensor, (padding_left, padding_right, padding_top, padding_bottom), mode='reflect') output_h, output_w = padding_top + h_input * scale + padding_bottom, padding_left + w_input * scale + padding_right output_tensor = torch.zeros([1, 3, output_h, output_w], dtype=img_tensor.dtype, device=self.device) windows = [] for row in range(numH): for col in range(numW): start_x = col * stride_w start_y = row * stride_h end_x = min(start_x + tile, img_tensor.shape[3]) end_y = min(start_y + tile, img_tensor.shape[2]) windows.append(img_tensor[:, :, start_y:end_y, start_x:end_x]) results = [] batch_size = 8 for i in range(0, len(windows), batch_size): batch_windows = torch.stack(windows[i:min(i + batch_size, len(windows))], dim=0) with torch.no_grad(): results.append(self.model(batch_windows)) results = torch.cat(results, dim=0) count = 0 for row in range(numH): for col in range(numW): start_x = col * stride_w start_y = row * stride_h end_x = min(start_x + tile, img_tensor.shape[3]) end_y = min(start_y + tile, img_tensor.shape[2]) out_start_x, out_start_y = start_x * sf, start_y * sf out_end_x, out_end_y = end_x * sf, end_y * sf output_tensor[:, :, out_start_y:out_end_y, out_start_x:out_end_x] += results[count][:, :, :end_y * sf - out_start_y, :end_x * sf - out_start_x] count += 1 forward_img = output_tensor[:, :, :h_input * sf, :w_input * sf] else: with torch.no_grad(): forward_img = self.model(img_tensor) if half: forward_img = forward_img.float() output_img = tensor2img(forward_img.squeeze(0).clamp_(0, 1)) if mod_scale is not None: output_img = output_img[:h_input * self.scale, :w_input * self.scale, ...] output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB) return [output_img, None] def enhance(self, img, outscale=None, tile=None, tile_pad=None, pre_pad=None, half=None): h_input, w_input = img.shape[0:2] if outscale is None: outscale = self.scale if tile is None: tile = self.tile if tile_pad is None: tile_pad = self.tile_pad if pre_pad is None: pre_pad = self.pre_pad if half is None: half = self.half img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) img_tensor = img2tensor(img) img_tensor = img_tensor.unsqueeze(0).to(self.device) if half: img_tensor = img_tensor.half() mod_scale = self.mod_scale h_pad, w_pad = 0, 0 if mod_scale is not None: h_pad, w_pad = int(np.ceil(h_input / mod_scale) * mod_scale - h_input), int(np.ceil(w_input / mod_scale) * mod_scale - w_input) img_tensor = F.pad(img_tensor, (0, w_pad, 0, h_pad), 'reflect') window_size = 256 scale = self.scale overlap_ratio = 0.5 if w_input * h_input < window_size**2: tile = None if tile is not None and tile > 0: tile_overlap = tile * overlap_ratio sf = scale stride_w = math.ceil(tile - tile_overlap) stride_h = math.ceil(tile - tile_overlap) numW = math.ceil((w_input + tile_overlap) / stride_w) numH = math.ceil((h_input + tile_overlap) / stride_h) paddingW = (numW - 1) * stride_w + tile - w_input paddingH = (numH - 1) * stride_h + tile - h_input padding_bottom = int(max(paddingH, 0)) padding_right = int(max(paddingW, 0)) padding_left, padding_top = 0, 0 img_tensor = F.pad(img_tensor, (padding_left, padding_right, padding_top, padding_bottom), mode='reflect') output_h, output_w = padding_top + h_input * scale + padding_bottom, padding_left + w_input * scale + padding_right output_tensor = torch.zeros([1, 3, output_h, output_w], dtype=img_tensor.dtype, device=self.device) windows = [] for row in range(numH): for col in range(numW): start_x = col * stride_w start_y = row * stride_h end_x = min(start_x + tile, img_tensor.shape[3]) end_y = min(start_y + tile, img_tensor.shape[2]) windows.append(img_tensor[:, :, start_y:end_y, start_x:end_x]) results = [] batch_size = 8 for i in range(0, len(windows), batch_size): batch_windows = torch.stack(windows[i:min(i + batch_size, len(windows))], dim=0) with torch.no_grad(): results.append(self.model(batch_windows)) results = torch.cat(results, dim=0) count = 0 for row in range(numH): for col in range(numW): start_x = col * stride_w start_y = row * stride_h end_x = min(start_x + tile, img_tensor.shape[3]) end_y = min(start_y + tile, img_tensor.shape[2]) out_start_x, out_start_y = start_x * sf, start_y * sf out_end_x, out_end_y = end_x * sf, end_y * sf output_tensor[:, :, out_start_y:out_end_y, out_start_x:out_end_x] += results[count][:, :, :end_y * sf - out_start_y, :end_x * sf - out_start_x] count += 1 forward_img = output_tensor[:, :, :h_input * sf, :w_input * sf] else: with torch.no_grad(): forward_img = self.model(img_tensor) if half: forward_img = forward_img.float() output_img = tensor2img(forward_img.squeeze(0).clamp_(0, 1)) if mod_scale is not None: output_img = output_img[:h_input * self.scale, :w_input * self.scale, ...] output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB) return [output_img, None] def save_video_with_watermark(video_frames, audio_path, output_path, watermark_path='./assets/sadtalker_logo.png'): try: watermark = imageio.imread(watermark_path) except FileNotFoundError: watermark = None writer = imageio.get_writer(output_path, fps=25) try: for frame in tqdm(video_frames, 'Generating video'): if watermark is not None: frame_h, frame_w = frame.shape[:2] watermark_h, watermark_w = watermark.shape[:2] if watermark_h > frame_h or watermark_w > frame_w: watermark = transform.resize(watermark, (frame_h // 4, frame_w // 4)) watermark_h, watermark_w = watermark.shape[:2] start_h = frame_h - watermark_h - 10 start_w = frame_w - watermark_w - 10 frame[start_h:start_h+watermark_h, start_w:start_w+watermark_h, :] = watermark writer.append_data(img_as_ubyte(frame)) except Exception as e: print(f"Error in video writing: {e}") finally: writer.close() if audio_path is not None: try: command = "ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}".format(audio_path, output_path, output_path.replace('.mp4', '_with_audio.mp4')) subprocess.call(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) os.remove(output_path) os.rename(output_path.replace('.mp4', '_with_audio.mp4'), output_path) except Exception as e: print(f"Error adding audio to video: {e}") def paste_pic(video_path, pic_path, crop_info, audio_path, output_path): try: y_start, y_end, x_start, x_end, old_size, cropped_size = crop_info[0][0], crop_info[0][1], crop_info[1][0], crop_info[1][1], crop_info[2], crop_info[3] source_image_h, source_image_w = old_size cropped_h, cropped_w = cropped_size delta_h, delta_w = source_image_h - cropped_h, source_image_w - cropped_w box = [x_start, y_start, source_image_w - x_end, source_image_h - y_end] command = "ffmpeg -y -i {} -i {} -filter_complex \"[1]crop=w={}:h={}:x={}:y={},[s];[0][s]overlay=x={}:y={}\" -codec:a copy {}".format(video_path, pic_path, cropped_w, cropped_h, box[0], box[1], box[0], box[1], output_path) subprocess.call(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) except Exception as e: print(f"Error pasting picture to video: {e}") def color_transfer_batch(source, target, mode='numpy'): source_np = tensor2img(source) target_np = tensor2img(target) source_lab = cv2.cvtColor(source_np, cv2.COLOR_RGB2LAB).astype(np.float32) target_lab = cv2.cvtColor(target_np, cv2.COLOR_RGB2LAB).astype(np.float32) source_mu = np.mean(source_lab, axis=(0, 1), keepdims=True) source_std = np.std(source_lab, axis=(0, 1), keepdims=True) target_mu = np.mean(target_lab, axis=(0, 1), keepdims=True) target_std = np.std(target_lab, axis=(0, 1), keepdims=True) transfer_lab = (target_lab - target_mu) * (source_std / target_std) + source_mu transfer_rgb = cv2.cvtColor(np.clip(transfer_lab, 0, 255).astype(np.uint8), cv2.COLOR_LAB2RGB) transfer_rgb_tensor = img2tensor(transfer_rgb) return transfer_rgb_tensor.unsqueeze(0).to(source.device) def load_video_to_cv2(path, resize=None): video = [] try: cap = cv2.VideoCapture(path) if not cap.isOpened(): raise Exception("Error opening video stream or file") while(cap.isOpened()): ret, frame = cap.read() if ret: frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) if resize is not None: frame_rgb = cv2.resize(frame_rgb, resize) video.append(frame_rgb) else: break cap.release() except Exception as e: print(f"Error loading video: {e}") return video def get_prior_from_bfm(bfm_path): mat_path = os.path.join(bfm_path, 'BFM_prior.mat') C = loadmat(mat_path) pc_tex = torch.tensor(C['pc_tex'].astype(np.float32)).unsqueeze(0) pc_exp = torch.tensor(C['pc_exp'].astype(np.float32)).unsqueeze(0) u_tex = torch.tensor(C['u_tex'].astype(np.float32)).unsqueeze(0) u_exp = torch.tensor(C['u_exp'].astype(np.float32)).unsqueeze(0) prior_coeff = { 'pc_tex': pc_tex, 'pc_exp': pc_exp, 'u_tex': u_tex, 'u_exp': u_exp } return prior_coeff