|
import os |
|
import numpy as np |
|
import cv2 |
|
import torch |
|
import dlib |
|
import face_recognition |
|
from torchvision import transforms |
|
from tqdm import tqdm |
|
from dataset.loader import normalize_data |
|
from .config import load_config |
|
from .genconvit import GenConViT |
|
import datetime |
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
torch.hub.set_dir('./cache') |
|
os.environ["HUGGINGFACE_HUB_CACHE"] = "./cache" |
|
|
|
|
|
def load_genconvit( net, ed_weight, vae_weight, fp16): |
|
|
|
model = GenConViT( |
|
|
|
ed= ed_weight, |
|
vae= vae_weight, |
|
net=net, |
|
fp16=fp16 |
|
) |
|
|
|
model.to(device) |
|
model.eval() |
|
if fp16: |
|
model.half() |
|
|
|
return model |
|
|
|
|
|
def face_rec(frames, p=None, klass=None): |
|
temp_face = np.zeros((len(frames), 224, 224, 3), dtype=np.uint8) |
|
count = 0 |
|
mod = "cnn" if dlib.DLIB_USE_CUDA else "hog" |
|
|
|
for _, frame in tqdm(enumerate(frames), total=len(frames)): |
|
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) |
|
face_locations = face_recognition.face_locations( |
|
frame, number_of_times_to_upsample=0, model=mod |
|
) |
|
|
|
for face_location in face_locations: |
|
if count < len(frames): |
|
top, right, bottom, left = face_location |
|
face_image = frame[top:bottom, left:right] |
|
face_image = cv2.resize( |
|
face_image, (224, 224), interpolation=cv2.INTER_AREA |
|
) |
|
face_image = cv2.cvtColor(face_image, cv2.COLOR_BGR2RGB) |
|
|
|
temp_face[count] = face_image |
|
count += 1 |
|
else: |
|
break |
|
|
|
return ([], 0) if count == 0 else (temp_face[:count], count) |
|
|
|
|
|
def preprocess_frame(frame): |
|
df_tensor = torch.tensor(frame, device=device).float() |
|
df_tensor = df_tensor.permute((0, 3, 1, 2)) |
|
|
|
for i in range(len(df_tensor)): |
|
df_tensor[i] = normalize_data()["vid"](df_tensor[i] / 255.0) |
|
|
|
return df_tensor |
|
|
|
def pred_vid(df, model): |
|
with torch.no_grad(): |
|
return max_prediction_value(torch.softmax(model(df), dim=1).squeeze()) |
|
|
|
|
|
|
|
def max_prediction_value(y_pred): |
|
|
|
mean_val = torch.mean(y_pred, dim=0,) |
|
return ( |
|
torch.argmax(mean_val).item(), |
|
mean_val[0].item() |
|
if mean_val[0] > mean_val[1] |
|
else abs(1 - mean_val[1]).item(), |
|
) |
|
|
|
|
|
def real_or_fake(prediction): |
|
return {0: "REAL", 1: "FAKE"}[prediction ^ 1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_frames(video_file, frames_nums=15): |
|
cap = cv2.VideoCapture(video_file) |
|
frames = [] |
|
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
step_size = max(1, frame_count // frames_nums) |
|
for i in range(0, frame_count, step_size): |
|
cap.set(cv2.CAP_PROP_POS_FRAMES, i) |
|
ret, frame = cap.read() |
|
if ret: |
|
frames.append(frame) |
|
if len(frames) >= frames_nums: |
|
break |
|
cap.release() |
|
return np.array(frames) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def df_face(vid, num_frames, net): |
|
s1 = datetime.datetime.now() |
|
img = extract_frames(vid, num_frames) |
|
e1= datetime.datetime.now() |
|
print("Time taken for frame Extraction:", e1-s1) |
|
s2 = datetime.datetime.now() |
|
face, count = face_rec(img) |
|
e2 = datetime.datetime.now() |
|
print("Time taken for face recognition:", e2-s2) |
|
print("Total time taken for image processing:", e2-s1) |
|
return preprocess_frame(face) if count > 0 else [] |
|
|
|
|
|
def is_video(vid): |
|
print('IS FILE', os.path.isfile(vid)) |
|
return os.path.isfile(vid) and vid.endswith( |
|
tuple([".avi", ".mp4", ".mpg", ".mpeg", ".mov"]) |
|
) |
|
|
|
|
|
def set_result(): |
|
return { |
|
"video": { |
|
"name": [], |
|
"pred": [], |
|
"klass": [], |
|
"pred_label": [], |
|
"correct_label": [], |
|
} |
|
} |
|
|
|
|
|
def store_result( |
|
result, filename, y, y_val, klass, correct_label=None, compression=None |
|
): |
|
result["video"]["name"].append(filename) |
|
result["video"]["pred"].append(y_val) |
|
result["video"]["klass"].append(klass.lower()) |
|
result["video"]["pred_label"].append(real_or_fake(y)) |
|
|
|
if correct_label is not None: |
|
result["video"]["correct_label"].append(correct_label) |
|
|
|
if compression is not None: |
|
result["video"]["compression"].append(compression) |
|
|
|
return result |