soccer-qa-4b / soccer_qa_inference.py
VarunKodathala's picture
Upload folder using huggingface_hub
0e37bb2 verified
#!/usr/bin/env python3
"""
Soccer QA Inference - Single Class, Clean API
Usage in Colab:
from soccer_qa_inference import SoccerQA
model = SoccerQA("soccer-qa-3b-unified")
answer = model.ask("video.mp4", "What happened?", max_tokens=128)
"""
import os
import json
import torch
import torch.nn as nn
import numpy as np
from safetensors.torch import load_file
from transformers import AutoTokenizer, AutoModelForCausalLM
from decord import VideoReader
# Import your existing modules
import src.datasets.utils.video.transforms as video_transforms
import src.datasets.utils.video.volume_transforms as volume_transforms
from src.models.vision_transformer import vit_giant_rope
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
def get_video(fname, num_frames=16):
"""Load video and sample frames uniformly"""
vr = VideoReader(fname)
frame_idx = np.linspace(0, len(vr) - 1, num=num_frames).astype(np.int64)
video = vr.get_batch(frame_idx).asnumpy()
return video
def build_video_transform(img_size):
"""Build video preprocessing transform"""
short_side_size = int(256.0 / 224 * img_size)
eval_transform = video_transforms.Compose([
video_transforms.Resize(short_side_size, interpolation="bilinear"),
video_transforms.CenterCrop(size=(img_size, img_size)),
volume_transforms.ClipToTensor(),
video_transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
])
return eval_transform
class SoccerQA:
"""Single class for Soccer QA inference - Clean Colab API"""
def __init__(self, model_dir="/home/varunkodathala/jepa_llm/soccer_pretrain/soccer-qa-3b-unified"):
"""Initialize Soccer QA model
Args:
model_dir: Path to merged model directory
"""
self.model_dir = model_dir
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🚀 Loading Soccer QA from {model_dir}...")
# Load config and tokenizer
self._load_config()
self._load_tokenizer()
# Build models
self._build_vision_model()
self._build_text_model()
self._build_projection()
# Load all weights
self._load_weights()
# Build video transforms
self.video_transform = build_video_transform(self.img_size)
print("✅ Soccer QA ready!")
def _load_config(self):
"""Load model configuration"""
config_path = os.path.join(self.model_dir, "config.json")
with open(config_path, 'r') as f:
self.config = json.load(f)
self.vision_dim = self.config["vision_dim"] # 1408
self.projection_dim = self.config["projection_dim"] # 2048
self.text_dim = self.config["text_dim"] # 3072
self.img_size = self.config["img_size"] # 256
self.num_frames = self.config["num_frames"] # 16
def _load_tokenizer(self):
"""Load tokenizer with <video> token"""
self.tokenizer = AutoTokenizer.from_pretrained(self.model_dir)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def _build_vision_model(self):
"""Build vision transformer using your src modules"""
self.vision_model = vit_giant_rope(
img_size=(self.img_size, self.img_size),
num_frames=self.num_frames
)
self.vision_model.to(self.device).eval()
# Freeze vision model
for param in self.vision_model.parameters():
param.requires_grad = False
def _build_text_model(self):
"""Build text model - we'll load merged weights later"""
self.text_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-3B",
torch_dtype=torch.float32,
device_map=self.device,
trust_remote_code=True
)
# Resize for <video> token to match saved model
self.text_model.resize_token_embeddings(len(self.tokenizer))
self.text_model.eval()
def _build_projection(self):
"""Build vision projection layer"""
self.vision_projection = nn.Sequential(
nn.Linear(self.vision_dim, self.projection_dim), # 1408 -> 2048
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(self.projection_dim, self.text_dim), # 2048 -> 3072
nn.LayerNorm(self.text_dim)
).to(self.device)
def _load_weights(self):
"""Load all weights from safetensors - optimized approach"""
model_path = os.path.join(self.model_dir, "model.safetensors")
print(f"Loading weights from: {model_path}")
state_dict = load_file(model_path, device=str(self.device))
# Load vision encoder weights
vision_state = {}
for key, value in state_dict.items():
if key.startswith("vision_encoder."):
new_key = key.replace("vision_encoder.", "")
vision_state[new_key] = value
msg = self.vision_model.load_state_dict(vision_state, strict=False)
print(f"Vision model loaded: {msg}")
# Load projection weights
projection_state = {}
for key, value in state_dict.items():
if key.startswith("vision_projection."):
new_key = key.replace("vision_projection.", "")
projection_state[new_key] = value
self.vision_projection.load_state_dict(projection_state)
print("Projection layer loaded")
# Load text model weights directly from merged state_dict
text_state = {}
for key, value in state_dict.items():
if key.startswith("text_model."):
new_key = key.replace("text_model.", "")
text_state[new_key] = value
# Apply merged weights directly to text model
missing_keys, unexpected_keys = self.text_model.load_state_dict(text_state, strict=False)
if missing_keys:
print(f"Missing keys in text model: {len(missing_keys)} (this is normal)")
if unexpected_keys:
print(f"Unexpected keys in text model: {len(unexpected_keys)}")
print("✅ Text model loaded with merged weights")
# Clear state_dict from memory
del state_dict
torch.cuda.empty_cache() if torch.cuda.is_available() else None
def _get_video_embeddings(self, video_path):
"""Extract video embeddings from video file"""
with torch.inference_mode():
# Load video
video = get_video(video_path, self.num_frames)
video = torch.from_numpy(video).permute(0, 3, 1, 2) # T x C x H x W
# Preprocess
x = self.video_transform(video).to(self.device).unsqueeze(0) # [1, 16, 3, 256, 256]
# Extract features
features = self.vision_model(x) # [1, 2048, 1408]
# Handle reshaping
squeezed = features.squeeze(0) # [2048, 1408]
if squeezed.shape[0] % 2048 == 0:
num_clips = squeezed.shape[0] // 2048
reshaped = squeezed.view(num_clips, 2048, 1408)
else:
reshaped = squeezed.unsqueeze(0) # [1, 2048, 1408]
return reshaped
def _project_vision_features(self, vision_features):
"""Project vision features to text embedding space"""
# vision_features: [num_clips, 2048, 1408]
num_clips, patches_per_clip, feature_dim = vision_features.shape
# Flatten: [num_clips * 2048, 1408]
flattened = vision_features.view(-1, feature_dim)
# Project: [num_clips * 2048, 3072]
projected = self.vision_projection(flattened)
# Return flattened for sequence: [total_patches, 3072]
return projected
def ask(self, video_path, question, max_tokens=128, temperature=0.7, top_p=0.9,
repetition_penalty=1.2, no_repeat_ngram_size=3):
"""Ask a question about a video
Args:
video_path: Path to video file
question: Question about the video
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Nucleus sampling parameter
repetition_penalty: Penalty for repetition
no_repeat_ngram_size: N-gram size for repetition blocking
Returns:
Generated answer as string
"""
with torch.no_grad():
# Get video embeddings
video_features = self._get_video_embeddings(video_path) # [num_clips, 2048, 1408]
vision_embeds = self._project_vision_features(video_features) # [total_patches, 3072]
vision_embeds = vision_embeds.unsqueeze(0) # [1, total_patches, 3072]
# Process question (remove <video> token if present)
question_clean = question.replace("<video>", "").strip()
# Tokenize question
question_tokens = self.tokenizer(
question_clean,
return_tensors="pt",
add_special_tokens=True
).to(self.device)
# Get text embeddings
text_embeds = self.text_model.get_input_embeddings()(question_tokens.input_ids)
# Combine vision and text embeddings
combined_embeds = torch.cat([vision_embeds, text_embeds], dim=1)
# Create attention mask
vision_attention = torch.ones(
1, vision_embeds.shape[1],
dtype=question_tokens.attention_mask.dtype,
device=self.device
)
combined_attention_mask = torch.cat([vision_attention, question_tokens.attention_mask], dim=1)
# Generate response
generated_ids = self.text_model.generate(
inputs_embeds=combined_embeds,
attention_mask=combined_attention_mask,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=self.tokenizer.eos_token_id,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
use_cache=True,
return_dict_in_generate=False
)
# Handle different return formats from generate()
if generated_ids.shape[1] > combined_embeds.shape[1]:
# Full sequence returned - slice from combined length
new_tokens = generated_ids[:, combined_embeds.shape[1]:]
else:
# Only new tokens returned - use all
new_tokens = generated_ids
generated_text = self.tokenizer.batch_decode(
new_tokens,
skip_special_tokens=True
)[0]
return generated_text.strip()
def batch_ask(self, video_path, questions, **kwargs):
"""Ask multiple questions about the same video
Args:
video_path: Path to video file
questions: List of questions
**kwargs: Generation parameters
Returns:
List of {"question": str, "answer": str} dicts
"""
results = []
for question in questions:
answer = self.ask(video_path, question, **kwargs)
results.append({"question": question, "answer": answer})
return results