Hhhh / extensions.py
Hjgugugjhuhjggg's picture
Upload 28 files
e83e49f verified
raw
history blame
15.8 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import yaml
from PIL import Image
from skimage import img_as_ubyte, transform
import safetensors
import librosa
from pydub import AudioSegment
import imageio
from scipy.io import loadmat, savemat, wavfile
import glob
import tempfile
from tqdm import tqdm
import numpy as np
import math
import torchvision
import os
import re
import shutil
from yacs.config import CfgNode as CN
import requests
import subprocess
import cv2
from collections import OrderedDict
def img2tensor(imgs, bgr2rgb=True, float32=True):
if isinstance(imgs, np.ndarray):
if imgs.ndim == 3:
imgs = imgs[..., np.newaxis]
imgs = torch.from_numpy(imgs.transpose((2, 0, 1)))
elif isinstance(imgs, Image.Image):
imgs = torch.from_numpy(np.array(imgs)).permute(2, 0, 1)
else:
raise TypeError(f'Type `{type(imgs)}` is not suitable for img2tensor')
if bgr2rgb:
if imgs.shape[0] == 3:
imgs = imgs[[2, 1, 0], :, :]
if float32:
imgs = imgs.float() / 255.
return imgs
def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
if not isinstance(tensor, torch.Tensor):
raise TypeError(f'Input tensor should be torch.Tensor, but got {type(tensor)}')
tensor = tensor.float().cpu()
tensor = tensor.clamp_(*min_max)
tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])
output_img = tensor.mul(255).round()
output_img = np.transpose(output_img.numpy(), (1, 2, 0))
output_img = np.clip(output_img, 0, 255).astype(np.uint8)
if rgb2bgr:
output_img = cv2.cvtColor(output_img, cv2.COLOR_RGB2BGR)
return output_img if out_type == np.uint8 else output_img.astype(out_type) / 255.
class RealESRGANer():
def __init__(self, scale, model_path, model=None, tile=0, tile_pad=10, pre_pad=0, half=False, device=None, gpu_id=None):
self.scale = scale
self.tile = tile
self.tile_pad = tile_pad
self.pre_pad = pre_pad
self.mod_scale = None
self.half = half
if device is None:
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
self.device = device
if model is None:
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale)
if half:
model.half()
loadnet = torch.load(model_path, map_location=lambda storage, loc: storage)
if 'params' in loadnet:
model.load_state_dict(loadnet['params'], strict=True)
elif 'params_ema' in loadnet:
model.load_state_dict(loadnet['params_ema'], strict=True)
else:
model.load_state_dict(loadnet, strict=True)
model.eval()
self.model = model.to(self.device)
def enhance(self, img, outscale=None, tile=None, tile_pad=None, pre_pad=None, half=None):
h_input, w_input = img.shape[0:2]
if outscale is None:
outscale = self.scale
if tile is None:
tile = self.tile
if tile_pad is None:
tile_pad = self.tile_pad
if pre_pad is None:
pre_pad = self.pre_pad
if half is None:
half = self.half
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
img_tensor = img2tensor(img)
img_tensor = img_tensor.unsqueeze(0).to(self.device)
if half:
img_tensor = img_tensor.half()
mod_scale = self.mod_scale
h_pad, w_pad = 0, 0
if mod_scale is not None:
h_pad, w_pad = int(np.ceil(h_input / mod_scale) * mod_scale - h_input), int(np.ceil(w_input / mod_scale) * mod_scale - w_input)
img_tensor = F.pad(img_tensor, (0, w_pad, 0, h_pad), 'reflect')
window_size = 256
scale = self.scale
overlap_ratio = 0.5
if w_input * h_input < window_size**2:
tile = None
if tile is not None and tile > 0:
tile_overlap = tile * overlap_ratio
sf = scale
stride_w = math.ceil(tile - tile_overlap)
stride_h = math.ceil(tile - tile_overlap)
numW = math.ceil((w_input + tile_overlap) / stride_w)
numH = math.ceil((h_input + tile_overlap) / stride_h)
paddingW = (numW - 1) * stride_w + tile - w_input
paddingH = (numH - 1) * stride_h + tile - h_input
padding_bottom = int(max(paddingH, 0))
padding_right = int(max(paddingW, 0))
padding_left, padding_top = 0, 0
img_tensor = F.pad(img_tensor, (padding_left, padding_right, padding_top, padding_bottom), mode='reflect')
output_h, output_w = padding_top + h_input * scale + padding_bottom, padding_left + w_input * scale + padding_right
output_tensor = torch.zeros([1, 3, output_h, output_w], dtype=img_tensor.dtype, device=self.device)
windows = []
for row in range(numH):
for col in range(numW):
start_x = col * stride_w
start_y = row * stride_h
end_x = min(start_x + tile, img_tensor.shape[3])
end_y = min(start_y + tile, img_tensor.shape[2])
windows.append(img_tensor[:, :, start_y:end_y, start_x:end_x])
results = []
batch_size = 8
for i in range(0, len(windows), batch_size):
batch_windows = torch.stack(windows[i:min(i + batch_size, len(windows))], dim=0)
with torch.no_grad():
results.append(self.model(batch_windows))
results = torch.cat(results, dim=0)
count = 0
for row in range(numH):
for col in range(numW):
start_x = col * stride_w
start_y = row * stride_h
end_x = min(start_x + tile, img_tensor.shape[3])
end_y = min(start_y + tile, img_tensor.shape[2])
out_start_x, out_start_y = start_x * sf, start_y * sf
out_end_x, out_end_y = end_x * sf, end_y * sf
output_tensor[:, :, out_start_y:out_end_y, out_start_x:out_end_x] += results[count][:, :, :end_y * sf - out_start_y, :end_x * sf - out_start_x]
count += 1
forward_img = output_tensor[:, :, :h_input * sf, :w_input * sf]
else:
with torch.no_grad():
forward_img = self.model(img_tensor)
if half:
forward_img = forward_img.float()
output_img = tensor2img(forward_img.squeeze(0).clamp_(0, 1))
if mod_scale is not None:
output_img = output_img[:h_input * self.scale, :w_input * self.scale, ...]
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB)
return [output_img, None]
def enhance(self, img, outscale=None, tile=None, tile_pad=None, pre_pad=None, half=None):
h_input, w_input = img.shape[0:2]
if outscale is None:
outscale = self.scale
if tile is None:
tile = self.tile
if tile_pad is None:
tile_pad = self.tile_pad
if pre_pad is None:
pre_pad = self.pre_pad
if half is None:
half = self.half
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
img_tensor = img2tensor(img)
img_tensor = img_tensor.unsqueeze(0).to(self.device)
if half:
img_tensor = img_tensor.half()
mod_scale = self.mod_scale
h_pad, w_pad = 0, 0
if mod_scale is not None:
h_pad, w_pad = int(np.ceil(h_input / mod_scale) * mod_scale - h_input), int(np.ceil(w_input / mod_scale) * mod_scale - w_input)
img_tensor = F.pad(img_tensor, (0, w_pad, 0, h_pad), 'reflect')
window_size = 256
scale = self.scale
overlap_ratio = 0.5
if w_input * h_input < window_size**2:
tile = None
if tile is not None and tile > 0:
tile_overlap = tile * overlap_ratio
sf = scale
stride_w = math.ceil(tile - tile_overlap)
stride_h = math.ceil(tile - tile_overlap)
numW = math.ceil((w_input + tile_overlap) / stride_w)
numH = math.ceil((h_input + tile_overlap) / stride_h)
paddingW = (numW - 1) * stride_w + tile - w_input
paddingH = (numH - 1) * stride_h + tile - h_input
padding_bottom = int(max(paddingH, 0))
padding_right = int(max(paddingW, 0))
padding_left, padding_top = 0, 0
img_tensor = F.pad(img_tensor, (padding_left, padding_right, padding_top, padding_bottom), mode='reflect')
output_h, output_w = padding_top + h_input * scale + padding_bottom, padding_left + w_input * scale + padding_right
output_tensor = torch.zeros([1, 3, output_h, output_w], dtype=img_tensor.dtype, device=self.device)
windows = []
for row in range(numH):
for col in range(numW):
start_x = col * stride_w
start_y = row * stride_h
end_x = min(start_x + tile, img_tensor.shape[3])
end_y = min(start_y + tile, img_tensor.shape[2])
windows.append(img_tensor[:, :, start_y:end_y, start_x:end_x])
results = []
batch_size = 8
for i in range(0, len(windows), batch_size):
batch_windows = torch.stack(windows[i:min(i + batch_size, len(windows))], dim=0)
with torch.no_grad():
results.append(self.model(batch_windows))
results = torch.cat(results, dim=0)
count = 0
for row in range(numH):
for col in range(numW):
start_x = col * stride_w
start_y = row * stride_h
end_x = min(start_x + tile, img_tensor.shape[3])
end_y = min(start_y + tile, img_tensor.shape[2])
out_start_x, out_start_y = start_x * sf, start_y * sf
out_end_x, out_end_y = end_x * sf, end_y * sf
output_tensor[:, :, out_start_y:out_end_y, out_start_x:out_end_x] += results[count][:, :, :end_y * sf - out_start_y, :end_x * sf - out_start_x]
count += 1
forward_img = output_tensor[:, :, :h_input * sf, :w_input * sf]
else:
with torch.no_grad():
forward_img = self.model(img_tensor)
if half:
forward_img = forward_img.float()
output_img = tensor2img(forward_img.squeeze(0).clamp_(0, 1))
if mod_scale is not None:
output_img = output_img[:h_input * self.scale, :w_input * self.scale, ...]
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB)
return [output_img, None]
def save_video_with_watermark(video_frames, audio_path, output_path, watermark_path='./assets/sadtalker_logo.png'):
try:
watermark = imageio.imread(watermark_path)
except FileNotFoundError:
watermark = None
writer = imageio.get_writer(output_path, fps=25)
try:
for frame in tqdm(video_frames, 'Generating video'):
if watermark is not None:
frame_h, frame_w = frame.shape[:2]
watermark_h, watermark_w = watermark.shape[:2]
if watermark_h > frame_h or watermark_w > frame_w:
watermark = transform.resize(watermark, (frame_h // 4, frame_w // 4))
watermark_h, watermark_w = watermark.shape[:2]
start_h = frame_h - watermark_h - 10
start_w = frame_w - watermark_w - 10
frame[start_h:start_h+watermark_h, start_w:start_w+watermark_h, :] = watermark
writer.append_data(img_as_ubyte(frame))
except Exception as e:
print(f"Error in video writing: {e}")
finally:
writer.close()
if audio_path is not None:
try:
command = "ffmpeg -y -i {} -i {} -strict -2 -q:v 1 {}".format(audio_path, output_path, output_path.replace('.mp4', '_with_audio.mp4'))
subprocess.call(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
os.remove(output_path)
os.rename(output_path.replace('.mp4', '_with_audio.mp4'), output_path)
except Exception as e:
print(f"Error adding audio to video: {e}")
def paste_pic(video_path, pic_path, crop_info, audio_path, output_path):
try:
y_start, y_end, x_start, x_end, old_size, cropped_size = crop_info[0][0], crop_info[0][1], crop_info[1][0], crop_info[1][1], crop_info[2], crop_info[3]
source_image_h, source_image_w = old_size
cropped_h, cropped_w = cropped_size
delta_h, delta_w = source_image_h - cropped_h, source_image_w - cropped_w
box = [x_start, y_start, source_image_w - x_end, source_image_h - y_end]
command = "ffmpeg -y -i {} -i {} -filter_complex \"[1]crop=w={}:h={}:x={}:y={},[s];[0][s]overlay=x={}:y={}\" -codec:a copy {}".format(video_path, pic_path, cropped_w, cropped_h, box[0], box[1], box[0], box[1], output_path)
subprocess.call(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
except Exception as e:
print(f"Error pasting picture to video: {e}")
def color_transfer_batch(source, target, mode='numpy'):
source_np = tensor2img(source)
target_np = tensor2img(target)
source_lab = cv2.cvtColor(source_np, cv2.COLOR_RGB2LAB).astype(np.float32)
target_lab = cv2.cvtColor(target_np, cv2.COLOR_RGB2LAB).astype(np.float32)
source_mu = np.mean(source_lab, axis=(0, 1), keepdims=True)
source_std = np.std(source_lab, axis=(0, 1), keepdims=True)
target_mu = np.mean(target_lab, axis=(0, 1), keepdims=True)
target_std = np.std(target_lab, axis=(0, 1), keepdims=True)
transfer_lab = (target_lab - target_mu) * (source_std / target_std) + source_mu
transfer_rgb = cv2.cvtColor(np.clip(transfer_lab, 0, 255).astype(np.uint8), cv2.COLOR_LAB2RGB)
transfer_rgb_tensor = img2tensor(transfer_rgb)
return transfer_rgb_tensor.unsqueeze(0).to(source.device)
def load_video_to_cv2(path, resize=None):
video = []
try:
cap = cv2.VideoCapture(path)
if not cap.isOpened():
raise Exception("Error opening video stream or file")
while(cap.isOpened()):
ret, frame = cap.read()
if ret:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
if resize is not None:
frame_rgb = cv2.resize(frame_rgb, resize)
video.append(frame_rgb)
else:
break
cap.release()
except Exception as e:
print(f"Error loading video: {e}")
return video
def get_prior_from_bfm(bfm_path):
mat_path = os.path.join(bfm_path, 'BFM_prior.mat')
C = loadmat(mat_path)
pc_tex = torch.tensor(C['pc_tex'].astype(np.float32)).unsqueeze(0)
pc_exp = torch.tensor(C['pc_exp'].astype(np.float32)).unsqueeze(0)
u_tex = torch.tensor(C['u_tex'].astype(np.float32)).unsqueeze(0)
u_exp = torch.tensor(C['u_exp'].astype(np.float32)).unsqueeze(0)
prior_coeff = {
'pc_tex': pc_tex,
'pc_exp': pc_exp,
'u_tex': u_tex,
'u_exp': u_exp
}
return prior_coeff