import spaces import requests import tempfile import os import logging import cv2 import pandas as pd import torch # from genconvit.config import load_config from genconvit.pred_func import df_face, load_genconvit, pred_vid torch.hub.set_dir('./cache') os.environ["HUGGINGFACE_HUB_CACHE"] = "./cache" # Set up logging # logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') def load_model(): try: # config = load_config() ed_weight = 'genconvit_ed_inference' vae_weight = 'genconvit_vae_inference' net = 'genconvit' fp16 = False model = load_genconvit( net, ed_weight, vae_weight, fp16) logging.info("Model loaded successfully.") return model except Exception as e: logging.error(f"Error loading model: {e}") raise model = load_model() def detect_faces(video_url): try: video_name = video_url.split('/')[-1] response = requests.get(video_url) response.raise_for_status() # Raise an exception for HTTP errors with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as temp_file: temp_file.write(response.content) temp_file_path = temp_file.name frames = [] face_cascade = cv2.CascadeClassifier('./utils/face_detection.xml') cap = cv2.VideoCapture(temp_file_path) fps = cap.get(cv2.CAP_PROP_FPS) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) duration = total_frames / fps frame_count = 0 time_count = 0 while True: ret, frame = cap.read() if not ret: break if frame_count % int(fps * 5) == 0: gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) faces = face_cascade.detectMultiScale(gray, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30)) for (x, y, w, h) in faces: cv2.rectangle(frame, (x, y), (x+w, y+h), (255, 0, 0), 2) frame_name = f"./output/{video_name}_{time_count}.jpg" frames.append(frame_name) cv2.imwrite(frame_name, frame) logging.info(f"Processed frame saved: {frame_name}") time_count += 1 frame_count += 1 cap.release() cv2.destroyAllWindows() logging.info(f"Total video duration: {duration:.2f} seconds") logging.info(f"Total frames processed: {time_count // 5}") return frames except Exception as e: logging.error(f"Error processing video: {e}") return [] # @spaces.GPU(duration=300) def genconvit_video_prediction(video_url, factor): try: logging.info(f"Processing video URL: {video_url}") response = requests.get(video_url) response.raise_for_status() # Raise an exception for HTTP errors with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as temp_file: temp_file.write(response.content) temp_file_path = temp_file.name num_frames = get_video_frame_count(temp_file_path) logging.info(f"Number of frames in video: {num_frames}") logging.info(f"Number of frames to process: {round(num_frames * factor)}") # rounf num_frames by2 to nearest integer # df = df_face(temp_file_path, int(round(num_frames * factor)) , model) # df = df_face(temp_file_path, int(round(num_frames * factor)) , model) df = df_face(temp_file_path, 11 , model) if len(df) >= 1: y, y_val = pred_vid(df, model) else: y, y_val = torch.tensor(0).item(), torch.tensor(0.5).item() os.unlink(temp_file_path) # Clean up temporary file result = { 'score': round(y_val * 100, 2), 'frames_processed': round(num_frames*factor) } logging.info(f"Prediction result: {result}") return result except Exception as e: logging.error(f"Error in video prediction: {e}") return { 'score': 0, 'prediction': 'ERROR', 'frames_processed': 0 } def get_video_frame_count(video_path): try: cap = cv2.VideoCapture(video_path) frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) cap.release() return frame_count except Exception as e: logging.error(f"Error getting video frame count: {e}") return 0