Spaces:
Running
Running
File size: 8,384 Bytes
1c817fd e83e49f 1c817fd e83e49f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
import os
import tempfile
import uuid
import asyncio
import shutil
import requests
from urllib.parse import urlparse
from fastapi import FastAPI, UploadFile, File, HTTPException, Form, WebSocket
from fastapi.responses import JSONResponse
from fastapi import APIRouter
from extensions import *
from main import *
from tts_api import *
from sadtalker_utils import *
import base64
from stt_api import *
from text_generation import *
router = APIRouter()
@router.post("/sadtalker")
async def create_video(
source_image: str = Form(None),
source_image_file: UploadFile = File(None),
driven_audio: str = Form(None),
driven_audio_file: UploadFile = File(None),
preprocess: str = Form('crop'),
still_mode: bool = Form(False),
use_enhancer: bool = Form(False),
batch_size: int = Form(1),
size: int = Form(256),
pose_style: int = Form(0),
exp_scale: float = Form(1.0),
use_ref_video: bool = Form(False),
ref_video: str = Form(None),
ref_video_file: UploadFile = File(None),
ref_info: str = Form(None),
use_idle_mode: bool = Form(False),
length_of_audio: int = Form(0),
use_blink: bool = Form(True),
checkpoint_dir: str = Form('checkpoints'),
config_dir: str = Form('src/config'),
old_version: bool = Form(False),
tts_text: str = Form(None),
tts_lang: str = Form('en'),
):
if source_image_file and source_image:
raise HTTPException(status_code=400, detail="source_image and source_image_file cannot be both not None")
if driven_audio and driven_audio_file:
raise HTTPException(status_code=400, detail="driven_audio and driven_audio_file cannot be both not None")
if ref_video and ref_video_file:
raise HTTPException(status_code=400, detail="ref_video and ref_video_file cannot be both not None")
tmp_source_image = None
if source_image_file:
tmp_source_image = tempfile.NamedTemporaryFile(suffix=os.path.splitext(source_image_file.filename)[1], delete=False)
content = await source_image_file.read()
tmp_source_image.write(content)
source_image_path = tmp_source_image.name
elif source_image:
if urlparse(source_image).scheme in ["http", "https"]:
response = requests.get(source_image, stream=True)
response.raise_for_status()
with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_source_image:
for chunk in response.iter_content(chunk_size=8192):
tmp_source_image.write(chunk)
source_image_path = tmp_source_image.name
else:
source_image_path = source_image
else:
raise HTTPException(status_code=400, detail="source_image not provided")
tmp_driven_audio = None
if driven_audio_file:
tmp_driven_audio = tempfile.NamedTemporaryFile(suffix=os.path.splitext(driven_audio_file.filename)[1], delete=False)
content = await driven_audio_file.read()
tmp_driven_audio.write(content)
driven_audio_path = tmp_driven_audio.name
elif driven_audio:
if urlparse(driven_audio).scheme in ["http", "https"]:
response = requests.get(driven_audio, stream=True)
response.raise_for_status()
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_driven_audio:
for chunk in response.iter_content(chunk_size=8192):
tmp_driven_audio.write(chunk)
driven_audio_path = tmp_driven_audio.name
else:
driven_audio_path = driven_audio
else:
driven_audio_path = None
tmp_ref_video = None
if ref_video_file:
tmp_ref_video = tempfile.NamedTemporaryFile(suffix=os.path.splitext(ref_video_file.filename)[1], delete=False)
content = await ref_video_file.read()
tmp_ref_video.write(content)
ref_video_path = tmp_ref_video.name
elif ref_video:
if urlparse(ref_video).scheme in ["http", "https"]:
response = requests.get(ref_video, stream=True)
response.raise_for_status()
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_ref_video:
for chunk in response.iter_content(chunk_size=8192):
tmp_ref_video.write(chunk)
ref_video_path = tmp_ref_video.name
else:
ref_video_path = ref_video
else:
ref_video_path=None
try:
loop = asyncio.get_running_loop()
output_path = await loop.run_in_executor(None, sadtalker_instance.test,
source_image_path,
driven_audio_path,
preprocess,
still_mode,
use_enhancer,
batch_size,
size,
pose_style,
exp_scale,
use_ref_video,
ref_video_path,
ref_info,
use_idle_mode,
length_of_audio,
use_blink,
'./results/',
tts_text=tts_text,
tts_lang=tts_lang,
)
return {"video_url": output_path}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
if tmp_source_image:
os.remove(tmp_source_image.name)
if tmp_driven_audio:
os.remove(tmp_driven_audio.name)
if tmp_ref_video:
os.remove(tmp_ref_video.name)
@router.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
tts_model = TTSTalker()
try:
while True:
data = await websocket.receive_json()
text = data.get("text")
audio_base64 = data.get("audio")
if text:
audio_path = await asyncio.get_running_loop().run_in_executor(None, tts_model.test, text)
elif audio_base64:
try:
audio_bytes = base64.b64decode(audio_base64)
tmp_audio_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
tmp_audio_file.write(audio_bytes)
audio_path = tmp_audio_file.name
transcription_text_file = speech_to_text_func(tmp_audio_file.name)
with open(transcription_text_file, 'r') as f:
transcription_text = f.read()
response_stream = perform_reasoning_stream(transcription_text, 0.7, 40, 0.0, 1.2)
response_text = ""
for chunk in response_stream:
if chunk == "<END_STREAM>":
break
response_text += chunk
audio_path = await asyncio.get_running_loop().run_in_executor(None, tts_model.test, response_text)
except Exception as e:
await websocket.send_json({"error":str(e)})
continue
finally:
if 'tmp_audio_file' in locals() and tmp_audio_file:
os.remove(tmp_audio_file.name)
else:
continue
source_image_path = './examples/source_image/cyarh.png'
ref_video_path='./examples/driven_video/vid_xdd.mp4'
loop = asyncio.get_running_loop()
output = await loop.run_in_executor(None, sadtalker_instance.test,
source_image_path,
audio_path,
'full',
True,
True,
1,
256,
0,
1,
True,
ref_video_path,
"pose+blink",
False,
0,
True,
'./results/'
)
await websocket.send_json({"video_url": output})
except Exception as e:
print(e)
await websocket.send_json({"error":str(e)})
router = APIRouter()
router.add_api_route("/sadtalker", create_video, methods=["POST"])
router.add_api_websocket_route("/ws", websocket_endpoint)
|