Yukang Chen
Initial commit for qwen2-1.5b-longvila-256f-internal-reasoning-run3-wei-notimestamp
7ae6739
import glob | |
import os | |
from collections import defaultdict | |
from typing import Any, Dict, List, Optional, Union | |
import cv2 | |
import numpy as np | |
import PIL | |
import PIL.Image | |
import requests | |
from transformers import PretrainedConfig | |
# from llava.constants import MEDIA_TOKENS | |
from llava.media import Image, Video | |
# from llava.utils import make_list | |
# from llava.utils.logging import logger | |
MEDIA_TOKENS = { | |
"image": "<image>", | |
"video": "<vila/video>", | |
} | |
class Media: | |
pass | |
class File(Media): | |
def __init__(self, path: str) -> None: | |
self.path = path | |
class Image(File): | |
pass | |
def make_list(obj: Any) -> List: | |
return obj if isinstance(obj, list) else [obj] | |
def _extract_image(image: Union[Image, PIL.Image.Image]) -> PIL.Image.Image: | |
if isinstance(image, Image): | |
if image.path.startswith("http://") or image.path.startswith("https://"): | |
image = PIL.Image.open(requests.get(image.path, stream=True).raw) | |
else: | |
image = PIL.Image.open(image.path) | |
return image | |
def _load_video(video_path: str, *, num_frames: int) -> List[PIL.Image.Image]: | |
# Load video frames from a directory | |
if os.path.isdir(video_path): | |
frame_paths = sorted(glob.glob(os.path.join(video_path, "*"))) | |
indices = np.round(np.linspace(0, len(frame_paths) - 1, num_frames)).astype(int) | |
return [PIL.Image.open(frame_paths[index]) for index in indices] | |
# Load video frames from a video file | |
vidcap = cv2.VideoCapture(video_path) | |
# Find the last frame as frame count might not be accurate | |
frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
while frame_count > 0: | |
vidcap.set(cv2.CAP_PROP_POS_FRAMES, frame_count - 1) | |
if vidcap.grab(): | |
break | |
frame_count -= 1 | |
else: | |
raise ValueError(f"Video '{video_path}' has no frames.") | |
# Extract frames uniformly | |
indices = np.round(np.linspace(0, frame_count - 1, num_frames)).astype(int) | |
frames = {} | |
for index in indices: | |
if index in frames: | |
continue | |
vidcap.set(cv2.CAP_PROP_POS_FRAMES, index) | |
success, frame = vidcap.read() | |
if not success: | |
print(f"Failed to read frame {index} from video '{video_path}'. Skipped.") | |
continue | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
frames[index] = PIL.Image.fromarray(frame) | |
return [frames[index] for index in indices if index in frames] | |
def _extract_video(video: Video, config: PretrainedConfig) -> List[PIL.Image.Image]: | |
num_frames = config.num_video_frames | |
if getattr(config, "fps") != 0: | |
print("Extracting frames from video with specified FPS is not supported yet. Ignored.") | |
frames = _load_video(video.path, num_frames=num_frames) | |
return frames | |
def extract_media( | |
messages: List[Dict[str, Any]], | |
config: Optional[PretrainedConfig] = None, | |
draft: bool = False, | |
) -> Dict[str, List[Any]]: | |
media = defaultdict(list) | |
for message in messages: | |
text = "" | |
for part in make_list(message["value"]): | |
if isinstance(part, str): | |
for token in MEDIA_TOKENS.values(): | |
if token in part: | |
print(f"Media token '{token}' found in text: '{part}'. Removed.") | |
part = part.replace(token, "").strip() | |
text += part | |
elif isinstance(part, (Image, PIL.Image.Image)): | |
if draft: | |
media["image"].append(part) | |
else: | |
media["image"].append(_extract_image(part)) | |
text += MEDIA_TOKENS["image"] | |
elif isinstance(part, Video): | |
if draft: | |
media["video"].append(part) | |
else: | |
media["video"].append(_extract_video(part, config)) | |
text += MEDIA_TOKENS["video"] | |
else: | |
raise ValueError(f"Unsupported prompt part type: {type(part)}") | |
message["value"] = text | |
return media | |