|
|
|
""" |
|
Soccer QA Inference - Single Class, Clean API |
|
|
|
Usage in Colab: |
|
from soccer_qa_inference import SoccerQA |
|
model = SoccerQA("soccer-qa-3b-unified") |
|
answer = model.ask("video.mp4", "What happened?", max_tokens=128) |
|
""" |
|
|
|
import os |
|
import json |
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
from safetensors.torch import load_file |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from decord import VideoReader |
|
|
|
|
|
import src.datasets.utils.video.transforms as video_transforms |
|
import src.datasets.utils.video.volume_transforms as volume_transforms |
|
from src.models.vision_transformer import vit_giant_rope |
|
|
|
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) |
|
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) |
|
|
|
def get_video(fname, num_frames=16): |
|
"""Load video and sample frames uniformly""" |
|
vr = VideoReader(fname) |
|
frame_idx = np.linspace(0, len(vr) - 1, num=num_frames).astype(np.int64) |
|
video = vr.get_batch(frame_idx).asnumpy() |
|
return video |
|
|
|
def build_video_transform(img_size): |
|
"""Build video preprocessing transform""" |
|
short_side_size = int(256.0 / 224 * img_size) |
|
eval_transform = video_transforms.Compose([ |
|
video_transforms.Resize(short_side_size, interpolation="bilinear"), |
|
video_transforms.CenterCrop(size=(img_size, img_size)), |
|
volume_transforms.ClipToTensor(), |
|
video_transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), |
|
]) |
|
return eval_transform |
|
|
|
class SoccerQA: |
|
"""Single class for Soccer QA inference - Clean Colab API""" |
|
|
|
def __init__(self, model_dir="/home/varunkodathala/jepa_llm/soccer_pretrain/soccer-qa-3b-unified"): |
|
"""Initialize Soccer QA model |
|
|
|
Args: |
|
model_dir: Path to merged model directory |
|
""" |
|
self.model_dir = model_dir |
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
print(f"🚀 Loading Soccer QA from {model_dir}...") |
|
|
|
|
|
self._load_config() |
|
self._load_tokenizer() |
|
|
|
|
|
self._build_vision_model() |
|
self._build_text_model() |
|
self._build_projection() |
|
|
|
|
|
self._load_weights() |
|
|
|
|
|
self.video_transform = build_video_transform(self.img_size) |
|
|
|
print("✅ Soccer QA ready!") |
|
|
|
def _load_config(self): |
|
"""Load model configuration""" |
|
config_path = os.path.join(self.model_dir, "config.json") |
|
with open(config_path, 'r') as f: |
|
self.config = json.load(f) |
|
|
|
self.vision_dim = self.config["vision_dim"] |
|
self.projection_dim = self.config["projection_dim"] |
|
self.text_dim = self.config["text_dim"] |
|
self.img_size = self.config["img_size"] |
|
self.num_frames = self.config["num_frames"] |
|
|
|
def _load_tokenizer(self): |
|
"""Load tokenizer with <video> token""" |
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir) |
|
if self.tokenizer.pad_token is None: |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
def _build_vision_model(self): |
|
"""Build vision transformer using your src modules""" |
|
self.vision_model = vit_giant_rope( |
|
img_size=(self.img_size, self.img_size), |
|
num_frames=self.num_frames |
|
) |
|
self.vision_model.to(self.device).eval() |
|
|
|
|
|
for param in self.vision_model.parameters(): |
|
param.requires_grad = False |
|
|
|
def _build_text_model(self): |
|
"""Build text model - we'll load merged weights later""" |
|
self.text_model = AutoModelForCausalLM.from_pretrained( |
|
"meta-llama/Llama-3.2-3B", |
|
torch_dtype=torch.float32, |
|
device_map=self.device, |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
self.text_model.resize_token_embeddings(len(self.tokenizer)) |
|
self.text_model.eval() |
|
|
|
def _build_projection(self): |
|
"""Build vision projection layer""" |
|
self.vision_projection = nn.Sequential( |
|
nn.Linear(self.vision_dim, self.projection_dim), |
|
nn.ReLU(), |
|
nn.Dropout(0.1), |
|
nn.Linear(self.projection_dim, self.text_dim), |
|
nn.LayerNorm(self.text_dim) |
|
).to(self.device) |
|
|
|
def _load_weights(self): |
|
"""Load all weights from safetensors - optimized approach""" |
|
model_path = os.path.join(self.model_dir, "model.safetensors") |
|
print(f"Loading weights from: {model_path}") |
|
state_dict = load_file(model_path, device=str(self.device)) |
|
|
|
|
|
vision_state = {} |
|
for key, value in state_dict.items(): |
|
if key.startswith("vision_encoder."): |
|
new_key = key.replace("vision_encoder.", "") |
|
vision_state[new_key] = value |
|
|
|
msg = self.vision_model.load_state_dict(vision_state, strict=False) |
|
print(f"Vision model loaded: {msg}") |
|
|
|
|
|
projection_state = {} |
|
for key, value in state_dict.items(): |
|
if key.startswith("vision_projection."): |
|
new_key = key.replace("vision_projection.", "") |
|
projection_state[new_key] = value |
|
|
|
self.vision_projection.load_state_dict(projection_state) |
|
print("Projection layer loaded") |
|
|
|
|
|
text_state = {} |
|
for key, value in state_dict.items(): |
|
if key.startswith("text_model."): |
|
new_key = key.replace("text_model.", "") |
|
text_state[new_key] = value |
|
|
|
|
|
missing_keys, unexpected_keys = self.text_model.load_state_dict(text_state, strict=False) |
|
if missing_keys: |
|
print(f"Missing keys in text model: {len(missing_keys)} (this is normal)") |
|
if unexpected_keys: |
|
print(f"Unexpected keys in text model: {len(unexpected_keys)}") |
|
print("✅ Text model loaded with merged weights") |
|
|
|
|
|
del state_dict |
|
torch.cuda.empty_cache() if torch.cuda.is_available() else None |
|
|
|
def _get_video_embeddings(self, video_path): |
|
"""Extract video embeddings from video file""" |
|
with torch.inference_mode(): |
|
|
|
video = get_video(video_path, self.num_frames) |
|
video = torch.from_numpy(video).permute(0, 3, 1, 2) |
|
|
|
|
|
x = self.video_transform(video).to(self.device).unsqueeze(0) |
|
|
|
|
|
features = self.vision_model(x) |
|
|
|
|
|
squeezed = features.squeeze(0) |
|
if squeezed.shape[0] % 2048 == 0: |
|
num_clips = squeezed.shape[0] // 2048 |
|
reshaped = squeezed.view(num_clips, 2048, 1408) |
|
else: |
|
reshaped = squeezed.unsqueeze(0) |
|
|
|
return reshaped |
|
|
|
def _project_vision_features(self, vision_features): |
|
"""Project vision features to text embedding space""" |
|
|
|
num_clips, patches_per_clip, feature_dim = vision_features.shape |
|
|
|
|
|
flattened = vision_features.view(-1, feature_dim) |
|
|
|
|
|
projected = self.vision_projection(flattened) |
|
|
|
|
|
return projected |
|
|
|
def ask(self, video_path, question, max_tokens=128, temperature=0.7, top_p=0.9, |
|
repetition_penalty=1.2, no_repeat_ngram_size=3): |
|
"""Ask a question about a video |
|
|
|
Args: |
|
video_path: Path to video file |
|
question: Question about the video |
|
max_tokens: Maximum tokens to generate |
|
temperature: Sampling temperature |
|
top_p: Nucleus sampling parameter |
|
repetition_penalty: Penalty for repetition |
|
no_repeat_ngram_size: N-gram size for repetition blocking |
|
|
|
Returns: |
|
Generated answer as string |
|
""" |
|
with torch.no_grad(): |
|
|
|
video_features = self._get_video_embeddings(video_path) |
|
vision_embeds = self._project_vision_features(video_features) |
|
vision_embeds = vision_embeds.unsqueeze(0) |
|
|
|
|
|
question_clean = question.replace("<video>", "").strip() |
|
|
|
|
|
question_tokens = self.tokenizer( |
|
question_clean, |
|
return_tensors="pt", |
|
add_special_tokens=True |
|
).to(self.device) |
|
|
|
|
|
text_embeds = self.text_model.get_input_embeddings()(question_tokens.input_ids) |
|
|
|
|
|
combined_embeds = torch.cat([vision_embeds, text_embeds], dim=1) |
|
|
|
|
|
vision_attention = torch.ones( |
|
1, vision_embeds.shape[1], |
|
dtype=question_tokens.attention_mask.dtype, |
|
device=self.device |
|
) |
|
combined_attention_mask = torch.cat([vision_attention, question_tokens.attention_mask], dim=1) |
|
|
|
|
|
generated_ids = self.text_model.generate( |
|
inputs_embeds=combined_embeds, |
|
attention_mask=combined_attention_mask, |
|
max_new_tokens=max_tokens, |
|
temperature=temperature, |
|
top_p=top_p, |
|
do_sample=True, |
|
pad_token_id=self.tokenizer.pad_token_id, |
|
eos_token_id=self.tokenizer.eos_token_id, |
|
repetition_penalty=repetition_penalty, |
|
no_repeat_ngram_size=no_repeat_ngram_size, |
|
use_cache=True, |
|
return_dict_in_generate=False |
|
) |
|
|
|
|
|
if generated_ids.shape[1] > combined_embeds.shape[1]: |
|
|
|
new_tokens = generated_ids[:, combined_embeds.shape[1]:] |
|
else: |
|
|
|
new_tokens = generated_ids |
|
|
|
generated_text = self.tokenizer.batch_decode( |
|
new_tokens, |
|
skip_special_tokens=True |
|
)[0] |
|
|
|
return generated_text.strip() |
|
|
|
def batch_ask(self, video_path, questions, **kwargs): |
|
"""Ask multiple questions about the same video |
|
|
|
Args: |
|
video_path: Path to video file |
|
questions: List of questions |
|
**kwargs: Generation parameters |
|
|
|
Returns: |
|
List of {"question": str, "answer": str} dicts |
|
""" |
|
results = [] |
|
for question in questions: |
|
answer = self.ask(video_path, question, **kwargs) |
|
results.append({"question": question, "answer": answer}) |
|
return results |