|
|
|
"""🎬 Keras Video Classification CNN-RNN model |
|
|
|
Spaces for showing the model usage. |
|
|
|
Author: |
|
- Thomas Chaigneau @ChainYo |
|
""" |
|
import os |
|
import cv2 |
|
|
|
import gradio as gr |
|
import numpy as np |
|
|
|
from tensorflow import keras |
|
|
|
from tensorflow_docs.vis import embed |
|
|
|
from huggingface_hub import from_pretrained_keras |
|
|
|
|
|
IMG_SIZE = 224 |
|
NUM_FEATURES = 2048 |
|
|
|
|
|
model = from_pretrained_keras("keras-io/video-classification-cnn-rnn") |
|
|
|
|
|
samples = [] |
|
for file in os.listdir("Samples"): |
|
tag = file.split("_")[1] |
|
samples.append([f"samples/{file}"]) |
|
|
|
|
|
def crop_center_square(frame): |
|
y, x = frame.shape[0:2] |
|
min_dim = min(y, x) |
|
start_x = (x // 2) - (min_dim // 2) |
|
start_y = (y // 2) - (min_dim // 2) |
|
return frame[start_y : start_y + min_dim, start_x : start_x + min_dim] |
|
|
|
|
|
def load_video(path, max_frames=0, resize=(IMG_SIZE, IMG_SIZE)): |
|
cap = cv2.VideoCapture(path) |
|
frames = [] |
|
try: |
|
while True: |
|
ret, frame = cap.read() |
|
if not ret: |
|
break |
|
frame = crop_center_square(frame) |
|
frame = cv2.resize(frame, resize) |
|
frame = frame[:, :, [2, 1, 0]] |
|
frames.append(frame) |
|
|
|
if len(frames) == max_frames: |
|
break |
|
finally: |
|
cap.release() |
|
return np.array(frames) |
|
|
|
|
|
def build_feature_extractor(): |
|
feature_extractor = keras.applications.InceptionV3( |
|
weights="imagenet", |
|
include_top=False, |
|
pooling="avg", |
|
input_shape=(IMG_SIZE, IMG_SIZE, 3), |
|
) |
|
preprocess_input = keras.applications.inception_v3.preprocess_input |
|
|
|
inputs = keras.Input((IMG_SIZE, IMG_SIZE, 3)) |
|
preprocessed = preprocess_input(inputs) |
|
|
|
outputs = feature_extractor(preprocessed) |
|
return keras.Model(inputs, outputs, name="feature_extractor") |
|
|
|
|
|
feature_extractor = build_feature_extractor() |
|
|
|
|
|
def prepare_video(frames, max_seq_length: int = 20): |
|
frames = frames[None, ...] |
|
frame_mask = np.zeros(shape=(1, max_seq_length,), dtype="bool") |
|
frame_features = np.zeros(shape=(1, max_seq_length, NUM_FEATURES), dtype="float32") |
|
|
|
for i, batch in enumerate(frames): |
|
video_length = batch.shape[0] |
|
length = min(max_seq_length, video_length) |
|
for j in range(length): |
|
frame_features[i, j, :] = feature_extractor.predict(batch[None, j, :]) |
|
frame_mask[i, :length] = 1 |
|
|
|
return frame_features, frame_mask |
|
|
|
|
|
def sequence_prediction(path): |
|
class_vocab = ["CricketShot", "PlayingCello", "Punch", "ShavingBeard", "TennisSwing"] |
|
|
|
frames = load_video(path) |
|
frame_features, frame_mask = prepare_video(frames) |
|
probabilities = model.predict([frame_features, frame_mask])[0] |
|
|
|
preds = {} |
|
for i in np.argsort(probabilities)[::-1]: |
|
preds[class_vocab[i]] = float(probabilities[i]) |
|
return preds |
|
|
|
|
|
article = article = "<div style='text-align: center;'><a href='https://github.com/ChainYo' target='_blank'>Space by Thomas Chaigneau</a><br><a href='https://keras.io/examples/vision/video_classification/' target='_blank'>Keras example by Sayak Paul</a></div>" |
|
|
|
|
|
app = gr.Interface( |
|
fn=sequence_prediction, |
|
inputs=[gr.Video(label="Video")], |
|
outputs=gr.Label(label="Prediction"), |
|
title="Keras Video Classification with CNN-RNN", |
|
description="Video classification demo using CNN-RNN based model.", |
|
article=article, |
|
examples=samples |
|
) |
|
|
|
|
|
app.launch() |
|
|
|
|
|
|