import logging import os import io import re import base64 import uuid from typing import Dict, Any, Optional, List, Literal from dataclasses import dataclass from asyncio import Lock, Queue import asyncio import time import datetime from contextlib import asynccontextmanager from collections import defaultdict from aiohttp import web, ClientSession from huggingface_hub import InferenceClient, HfApi from gradio_client import Client import random import yaml import json from api_config import * # User role type UserRole = Literal['anon', 'normal', 'pro', 'admin'] # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) def generate_seed(): """Generate a random positive 32-bit integer seed.""" return random.randint(0, 2**32 - 1) def sanitize_yaml_response(response_text: str) -> str: """ Sanitize and format AI response into valid YAML. Returns properly formatted YAML string. """ response_text = response_text.split("```")[0] # Remove any markdown code block indicators and YAML document markers clean_text = re.sub(r'```yaml|```|---|\.\.\.$', '', response_text.strip()) # Split into lines and process each line lines = clean_text.split('\n') sanitized_lines = [] current_field = None for line in lines: stripped = line.strip() if not stripped: continue # Handle field starts if stripped.startswith('title:') or stripped.startswith('description:'): # Ensure proper YAML format with space after colon and proper quoting field_name = stripped.split(':', 1)[0] field_value = stripped.split(':', 1)[1].strip().strip('"\'') # Quote the value if it contains special characters if any(c in field_value for c in ':[]{},&*#?|-<>=!%@`'): field_value = f'"{field_value}"' sanitized_lines.append(f"{field_name}: {field_value}") current_field = field_name elif stripped.startswith('tags:'): sanitized_lines.append('tags:') current_field = 'tags' elif stripped.startswith('-') and current_field == 'tags': # Process tag values tag = stripped[1:].strip().strip('"\'') if tag: # Clean and format tag tag = re.sub(r'[^\x00-\x7F]+', '', tag) # Remove non-ASCII tag = re.sub(r'[^a-zA-Z0-9\s-]', '', tag) # Keep only alphanumeric and hyphen tag = tag.strip().lower().replace(' ', '-') if tag: sanitized_lines.append(f" - {tag}") elif current_field in ['title', 'description']: # Handle multi-line title/description continuation value = stripped.strip('"\'') if value: # Append to previous line prev = sanitized_lines[-1] sanitized_lines[-1] = f"{prev} {value}" # Ensure the YAML has all required fields required_fields = {'title', 'description', 'tags'} found_fields = {line.split(':')[0].strip() for line in sanitized_lines if ':' in line} for field in required_fields - found_fields: if field == 'tags': sanitized_lines.extend(['tags:', ' - default']) else: sanitized_lines.append(f'{field}: "No {field} provided"') return '\n'.join(sanitized_lines) @dataclass class Endpoint: id: int url: str busy: bool = False last_used: float = 0 error_count: int = 0 error_until: float = 0 # Timestamp until which this endpoint is considered in error state class EndpointManager: def __init__(self): self.endpoints: List[Endpoint] = [] self.lock = Lock() self.initialize_endpoints() self.last_used_index = -1 # Track the last used endpoint for round-robin def initialize_endpoints(self): """Initialize the list of endpoints""" for i, url in enumerate(VIDEO_ROUND_ROBIN_ENDPOINT_URLS): endpoint = Endpoint(id=i + 1, url=url) self.endpoints.append(endpoint) def _get_next_free_endpoint(self): """Get the next available non-busy endpoint, or oldest endpoint if all are busy""" current_time = time.time() # First priority: Get any non-busy and non-error endpoint free_endpoints = [ ep for ep in self.endpoints if not ep.busy and current_time > ep.error_until ] if free_endpoints: # Return the least recently used free endpoint return min(free_endpoints, key=lambda ep: ep.last_used) # Second priority: If all busy/error, use round-robin but skip error endpoints tried_count = 0 next_index = self.last_used_index while tried_count < len(self.endpoints): next_index = (next_index + 1) % len(self.endpoints) tried_count += 1 # If endpoint is not in error state, use it if current_time > self.endpoints[next_index].error_until: self.last_used_index = next_index return self.endpoints[next_index] # If all endpoints are in error state, use the one with earliest error expiry self.last_used_index = next_index return min(self.endpoints, key=lambda ep: ep.error_until) @asynccontextmanager async def get_endpoint(self, max_wait_time: int = 10): """Get the next available endpoint using a context manager""" start_time = time.time() endpoint = None try: while True: if time.time() - start_time > max_wait_time: raise TimeoutError(f"Could not acquire an endpoint within {max_wait_time} seconds") async with self.lock: # Get the next available endpoint using our selection strategy endpoint = self._get_next_free_endpoint() # Mark it as busy endpoint.busy = True endpoint.last_used = time.time() logger.info(f"Using endpoint {endpoint.id} (busy: {endpoint.busy}, last used: {endpoint.last_used})") break yield endpoint finally: if endpoint: async with self.lock: endpoint.busy = False endpoint.last_used = time.time() # We don't need to put back into queue - our strategy now picks directly from the list class ChatRoom: def __init__(self): self.messages = [] self.connected_clients = set() self.max_history = 100 def add_message(self, message): self.messages.append(message) if len(self.messages) > self.max_history: self.messages.pop(0) def get_recent_messages(self, limit=50): return self.messages[-limit:] class VideoGenerationAPI: def __init__(self): self.inference_client = InferenceClient(token=HF_TOKEN) self.hf_api = HfApi(token=HF_TOKEN) self.endpoint_manager = EndpointManager() self.active_requests: Dict[str, asyncio.Future] = {} self.chat_rooms = defaultdict(ChatRoom) self.video_events: Dict[str, List[Dict[str, Any]]] = defaultdict(list) self.event_history_limit = 50 # Cache for user roles to avoid repeated API calls self.user_role_cache: Dict[str, Dict[str, Any]] = {} # Cache expiration time (10 minutes) self.cache_expiration = 600 def _add_event(self, video_id: str, event: Dict[str, Any]): """Add an event to the video's history and maintain the size limit""" events = self.video_events[video_id] events.append(event) if len(events) > self.event_history_limit: events.pop(0) async def validate_user_token(self, token: str) -> UserRole: """ Validates a Hugging Face token and determines the user's role. Returns one of: - 'anon': Anonymous user (no token or invalid token) - 'normal': Standard Hugging Face user - 'pro': Hugging Face Pro user - 'admin': Admin user (username in ADMIN_ACCOUNTS) """ # If no token is provided, the user is anonymous if not token: return 'anon' # Check if we have a cached result for this token current_time = time.time() if token in self.user_role_cache: cached_data = self.user_role_cache[token] # If the cache is still valid if current_time - cached_data['timestamp'] < self.cache_expiration: logger.info(f"Using cached user role: {cached_data['role']}") return cached_data['role'] # No valid cache, need to check the token with the HF API try: # Use HF API to validate the token and get user info logger.info("Validating Hugging Face token...") # Run in executor to avoid blocking the event loop user_info = await asyncio.get_event_loop().run_in_executor( None, lambda: self.hf_api.whoami(token=token) ) logger.info(f"Token valid for user: {user_info.name}") # Determine the user role based on the information user_role: UserRole # Check if the user is an admin if user_info.name in ADMIN_ACCOUNTS: user_role = 'admin' # Check if the user has a pro account elif hasattr(user_info, 'is_pro') and user_info.is_pro: user_role = 'pro' else: user_role = 'normal' # Cache the result self.user_role_cache[token] = { 'role': user_role, 'timestamp': current_time, 'username': user_info.name } return user_role except Exception as e: logger.error(f"Failed to validate Hugging Face token: {str(e)}") # If validation fails, the user is treated as anonymous return 'anon' async def download_video(self, url: str) -> bytes: """Download video file from URL and return bytes""" async with ClientSession() as session: async with session.get(url) as response: if response.status != 200: raise Exception(f"Failed to download video: HTTP {response.status}") return await response.read() async def search_video(self, query: str, search_count: int = 0, attempt_count: int = 0) -> Optional[dict]: """Generate a single search result using HF text generation""" prompt = f"""# Instruction Your response MUST be a YAML object containing a title, description, and tags, consistent with what we can find on a video sharing platform. Format your YAML response with only those fields: "title" (a short string), "description" (string caption of the scene), and "tags" (array of 3 to 4 strings). Do not add any other field. In the description field, describe in a very synthetic way the visuals of the first shot (first scene), eg "