import os import tempfile import uuid import asyncio import shutil import requests from urllib.parse import urlparse from fastapi import FastAPI, UploadFile, File, HTTPException, Form, WebSocket from fastapi.responses import JSONResponse #from fastapi.middleware.cors import CORSMiddleware from fastapi import APIRouter from extensions import * from main import * #from main import import sadtalker_instance from tts_api import * from sadtalker_utils import * import base64 from stt_api import * from text_generation import * router = APIRouter() @router.post("/sadtalker") async def create_video( source_image: str = Form(None), source_image_file: UploadFile = File(None), driven_audio: str = Form(None), driven_audio_file: UploadFile = File(None), preprocess: str = Form('crop'), still_mode: bool = Form(False), use_enhancer: bool = Form(False), batch_size: int = Form(1), size: int = Form(256), pose_style: int = Form(0), exp_scale: float = Form(1.0), use_ref_video: bool = Form(False), ref_video: str = Form(None), ref_video_file: UploadFile = File(None), ref_info: str = Form(None), use_idle_mode: bool = Form(False), length_of_audio: int = Form(0), use_blink: bool = Form(True), checkpoint_dir: str = Form('checkpoints'), config_dir: str = Form('src/config'), old_version: bool = Form(False), tts_text: str = Form(None), tts_lang: str = Form('en'), ): if source_image_file and source_image: raise HTTPException(status_code=400, detail="source_image and source_image_file cannot be both not None") if driven_audio and driven_audio_file: raise HTTPException(status_code=400, detail="driven_audio and driven_audio_file cannot be both not None") if ref_video and ref_video_file: raise HTTPException(status_code=400, detail="ref_video and ref_video_file cannot be both not None") tmp_source_image = None if source_image_file: tmp_source_image = tempfile.NamedTemporaryFile(suffix=os.path.splitext(source_image_file.filename)[1], delete=False) content = await source_image_file.read() tmp_source_image.write(content) source_image_path = tmp_source_image.name elif source_image: if urlparse(source_image).scheme in ["http", "https"]: response = requests.get(source_image, stream=True) response.raise_for_status() with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_source_image: for chunk in response.iter_content(chunk_size=8192): tmp_source_image.write(chunk) source_image_path = tmp_source_image.name else: source_image_path = source_image else: raise HTTPException(status_code=400, detail="source_image not provided") tmp_driven_audio = None if driven_audio_file: tmp_driven_audio = tempfile.NamedTemporaryFile(suffix=os.path.splitext(driven_audio_file.filename)[1], delete=False) content = await driven_audio_file.read() tmp_driven_audio.write(content) driven_audio_path = tmp_driven_audio.name elif driven_audio: if urlparse(driven_audio).scheme in ["http", "https"]: response = requests.get(driven_audio, stream=True) response.raise_for_status() with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_driven_audio: for chunk in response.iter_content(chunk_size=8192): tmp_driven_audio.write(chunk) driven_audio_path = tmp_driven_audio.name else: driven_audio_path = driven_audio else: driven_audio_path = None tmp_ref_video = None if ref_video_file: tmp_ref_video = tempfile.NamedTemporaryFile(suffix=os.path.splitext(ref_video_file.filename)[1], delete=False) content = await ref_video_file.read() tmp_ref_video.write(content) ref_video_path = tmp_ref_video.name elif ref_video: if urlparse(ref_video).scheme in ["http", "https"]: response = requests.get(ref_video, stream=True) response.raise_for_status() with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_ref_video: for chunk in response.iter_content(chunk_size=8192): tmp_ref_video.write(chunk) ref_video_path = tmp_ref_video.name else: ref_video_path = ref_video else: ref_video_path=None try: loop = asyncio.get_running_loop() output_path = await loop.run_in_executor(None, sadtalker_instance.test, source_image_path, driven_audio_path, preprocess, still_mode, use_enhancer, batch_size, size, pose_style, exp_scale, use_ref_video, ref_video_path, ref_info, use_idle_mode, length_of_audio, use_blink, './results/', tts_text=tts_text, tts_lang=tts_lang, ) return {"video_url": output_path} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) finally: if tmp_source_image: os.remove(tmp_source_image.name) if tmp_driven_audio: os.remove(tmp_driven_audio.name) if tmp_ref_video: os.remove(tmp_ref_video.name) @router.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): await websocket.accept() tts_model = TTSTalker() try: while True: data = await websocket.receive_json() text = data.get("text") audio_base64 = data.get("audio") if text: audio_path = await asyncio.get_running_loop().run_in_executor(None, tts_model.test, text) elif audio_base64: try: audio_bytes = base64.b64decode(audio_base64) tmp_audio_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) tmp_audio_file.write(audio_bytes) audio_path = tmp_audio_file.name transcription_text_file = speech_to_text_func(tmp_audio_file.name) with open(transcription_text_file, 'r') as f: transcription_text = f.read() response_stream = perform_reasoning_stream(f"respond to this sentence in 10 words or less {transcription_text}", 0.7, 40, 0.0, 1.2) response_text = "" for chunk in response_stream: if chunk == "": break response_text += chunk audio_path = await asyncio.get_running_loop().run_in_executor(None, tts_model.test, response_text) except Exception as e: await websocket.send_json({"error":str(e)}) continue finally: if 'tmp_audio_file' in locals() and tmp_audio_file: os.remove(tmp_audio_file.name) else: continue source_image_path = './examples/source_image/cyarh.png' ref_video_path='./examples/driven_video/vid_xdd.mp4' loop = asyncio.get_running_loop() output = await loop.run_in_executor(None, sadtalker_instance.test, source_image_path, audio_path, 'full', True, True, 1, 256, 0, 1, True, ref_video_path, "pose+blink", False, 0, True, './results/' ) await websocket.send_json({"video_url": output}) except Exception as e: print(e) await websocket.send_json({"error":str(e)})