Spaces:
Running
Running
Upload 28 files
Browse files- api.py +62 -162
- background_tasks.py +6 -40
- codegen_api.py +5 -5
- configs.py +27 -5
- constants.py +1 -7
- extensions.py +81 -1
- imagegen_api.py +5 -4
- main.py +5 -34
- model_loader.py +1 -1
- models.py +1 -1
- musicgen_api.py +1 -1
- sadtalker_api.py +5 -1
- sadtalker_utils.py +56 -57
- sentiment_api.py +15 -12
- stt_api.py +7 -9
- summarization_api.py +6 -7
- text_generation.py +94 -104
- tokenxxx.py +1 -1
- translation_api.py +7 -9
- tts_api.py +10 -6
api.py
CHANGED
@@ -1,18 +1,17 @@
|
|
1 |
from main import *
|
2 |
-
from tts_api import
|
3 |
-
from stt_api import
|
4 |
-
from sentiment_api import
|
5 |
-
from imagegen_api import
|
6 |
-
from musicgen_api import
|
7 |
-
from translation_api import
|
8 |
-
from codegen_api import
|
9 |
-
from text_to_video_api import
|
10 |
-
from summarization_api import
|
11 |
-
from image_to_3d_api import
|
12 |
from flask import Flask, request, jsonify, Response, send_file, stream_with_context
|
13 |
from flask_cors import CORS
|
14 |
import torch
|
15 |
-
import torch.nn as nn
|
16 |
import torch.nn.functional as F
|
17 |
import torchaudio
|
18 |
import numpy as np
|
@@ -22,9 +21,12 @@ import tempfile
|
|
22 |
import queue
|
23 |
import json
|
24 |
import base64
|
|
|
|
|
25 |
|
26 |
app = Flask(__name__)
|
27 |
CORS(app)
|
|
|
28 |
html_code = """<!DOCTYPE html>
|
29 |
<html lang="en">
|
30 |
<head>
|
@@ -225,59 +227,6 @@ html_code = """<!DOCTYPE html>
|
|
225 |
"""
|
226 |
feedback_queue = queue.Queue()
|
227 |
|
228 |
-
class TextGenerationModel(nn.Module):
|
229 |
-
def __init__(self, vocab_size, embed_dim, hidden_dim):
|
230 |
-
super(TextGenerationModel, self).__init__()
|
231 |
-
self.embedding = nn.Embedding(vocab_size, embed_dim)
|
232 |
-
self.rnn = nn.GRU(embed_dim, hidden_dim, batch_first=True)
|
233 |
-
self.fc = nn.Linear(hidden_dim, vocab_size)
|
234 |
-
def forward(self, x, hidden=None):
|
235 |
-
x = self.embedding(x)
|
236 |
-
out, hidden = self.rnn(x, hidden)
|
237 |
-
out = self.fc(out)
|
238 |
-
return out, hidden
|
239 |
-
|
240 |
-
vocab = ["hola", "mundo", "este", "es", "un", "ejemplo", "de", "texto", "generado", "con", "torch"]
|
241 |
-
vocab_size = len(vocab)
|
242 |
-
embed_dim = 16
|
243 |
-
hidden_dim = 32
|
244 |
-
text_model = TextGenerationModel(vocab_size, embed_dim, hidden_dim)
|
245 |
-
text_model.eval()
|
246 |
-
|
247 |
-
def tokenize(text):
|
248 |
-
tokens = text.lower().split()
|
249 |
-
indices = [vocab.index(token) if token in vocab else 0 for token in tokens]
|
250 |
-
return torch.tensor(indices, dtype=torch.long).unsqueeze(0)
|
251 |
-
|
252 |
-
def perform_reasoning_stream(text, temperature, top_k, top_p, repetition_penalty):
|
253 |
-
input_tensor = tokenize(text)
|
254 |
-
hidden = None
|
255 |
-
while True:
|
256 |
-
outputs, hidden = text_model(input_tensor, hidden)
|
257 |
-
logits = outputs[:, -1, :] / temperature
|
258 |
-
probs = F.softmax(logits, dim=-1)
|
259 |
-
topk_probs, topk_indices = torch.topk(probs, min(top_k, logits.shape[-1]))
|
260 |
-
chosen_index = topk_indices[0, torch.multinomial(topk_probs[0], 1).item()].item()
|
261 |
-
token_str = vocab[chosen_index]
|
262 |
-
yield token_str
|
263 |
-
input_tensor = torch.cat([input_tensor, torch.tensor([[chosen_index]], dtype=torch.long)], dim=1)
|
264 |
-
if token_str == "mundo":
|
265 |
-
yield "<END_STREAM>"
|
266 |
-
break
|
267 |
-
|
268 |
-
|
269 |
-
class SentimentModel(nn.Module):
|
270 |
-
def __init__(self, input_dim, hidden_dim, output_dim):
|
271 |
-
super(SentimentModel, self).__init__()
|
272 |
-
self.fc1 = nn.Linear(input_dim, hidden_dim)
|
273 |
-
self.fc2 = nn.Linear(hidden_dim, output_dim)
|
274 |
-
def forward(self, x):
|
275 |
-
x = F.relu(self.fc1(x))
|
276 |
-
x = self.fc2(x)
|
277 |
-
return x
|
278 |
-
|
279 |
-
sentiment_model = SentimentModel(10, 16, 2)
|
280 |
-
sentiment_model.eval()
|
281 |
|
282 |
@app.route("/")
|
283 |
def index():
|
@@ -290,16 +239,30 @@ def generate_stream():
|
|
290 |
top_k = int(request.args.get("top_k", 40))
|
291 |
top_p = float(request.args.get("top_p", 0.0))
|
292 |
reppenalty = float(request.args.get("reppenalty", 1.2))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
@stream_with_context
|
294 |
def event_stream():
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
|
|
|
|
|
|
|
|
|
|
303 |
return Response(event_stream(), mimetype="text/event-stream")
|
304 |
|
305 |
@app.route("/api/v1/generate", methods=["POST"])
|
@@ -310,15 +273,20 @@ def generate():
|
|
310 |
top_k = int(data.get("top_k", 40))
|
311 |
top_p = float(data.get("top_p", 0.0))
|
312 |
reppenalty = float(data.get("reppenalty", 1.2))
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
|
|
|
|
|
|
|
|
|
|
322 |
|
323 |
@app.route("/api/v1/feedback", methods=["POST"])
|
324 |
def feedback():
|
@@ -332,116 +300,48 @@ def feedback():
|
|
332 |
|
333 |
@app.route("/api/v1/tts", methods=["POST"])
|
334 |
def tts_api():
|
335 |
-
|
336 |
-
text = data.get("text", "")
|
337 |
-
sr = 22050
|
338 |
-
duration = 3.0
|
339 |
-
t = torch.linspace(0, duration, int(sr * duration))
|
340 |
-
frequency = 440.0
|
341 |
-
audio = 0.5 * torch.sin(2 * torch.pi * frequency * t)
|
342 |
-
audio = audio.unsqueeze(0)
|
343 |
-
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
344 |
-
torchaudio.save(tmp.name, audio, sr)
|
345 |
-
tmp_path = tmp.name
|
346 |
-
return send_file(tmp_path, mimetype="audio/wav", as_attachment=True, download_name="output.wav")
|
347 |
|
348 |
@app.route("/api/v1/stt", methods=["POST"])
|
349 |
def stt_api():
|
350 |
-
|
351 |
-
audio_b64 = data.get("audio", "")
|
352 |
-
if audio_b64:
|
353 |
-
audio_bytes = base64.b64decode(audio_b64)
|
354 |
-
buf = io.BytesIO(audio_bytes)
|
355 |
-
waveform, sr = torchaudio.load(buf)
|
356 |
-
mean_amp = waveform.abs().mean().item()
|
357 |
-
recognized_text = f"Audio processed with mean amplitude {mean_amp:.3f}"
|
358 |
-
return jsonify({"text": recognized_text})
|
359 |
-
return jsonify({"text": ""})
|
360 |
|
361 |
@app.route("/api/v1/sentiment", methods=["POST"])
|
362 |
def sentiment_api():
|
363 |
-
|
364 |
-
text = data.get("text", "")
|
365 |
-
if not text:
|
366 |
-
return jsonify({"sentiment": "neutral"})
|
367 |
-
ascii_vals = [ord(c) for c in text[:10]]
|
368 |
-
while len(ascii_vals) < 10:
|
369 |
-
ascii_vals.append(0)
|
370 |
-
features = torch.tensor(ascii_vals, dtype=torch.float32).unsqueeze(0)
|
371 |
-
output = sentiment_model(features)
|
372 |
-
sentiment_idx = torch.argmax(output, dim=1).item()
|
373 |
-
sentiment = "positivo" if sentiment_idx == 1 else "negativo"
|
374 |
-
return jsonify({"sentiment": sentiment})
|
375 |
|
376 |
@app.route("/api/v1/imagegen", methods=["POST"])
|
377 |
def imagegen_api():
|
378 |
-
|
379 |
-
prompt = data.get("prompt", "")
|
380 |
-
image_tensor = torch.rand(3, 256, 256)
|
381 |
-
np_image = image_tensor.mul(255).clamp(0, 255).byte().numpy().transpose(1, 2, 0)
|
382 |
-
img = Image.fromarray(np_image)
|
383 |
-
buf = io.BytesIO()
|
384 |
-
img.save(buf, format="PNG")
|
385 |
-
buf.seek(0)
|
386 |
-
return send_file(buf, mimetype="image/png", as_attachment=True, download_name="image.png")
|
387 |
|
388 |
@app.route("/api/v1/musicgen", methods=["POST"])
|
389 |
def musicgen_api():
|
390 |
-
|
391 |
-
prompt = data.get("prompt", "")
|
392 |
-
sr = 22050
|
393 |
-
duration = 5.0
|
394 |
-
t = torch.linspace(0, duration, int(sr * duration))
|
395 |
-
frequency = 440.0
|
396 |
-
audio = 0.5 * torch.sin(2 * torch.pi * frequency * t)
|
397 |
-
audio = audio.unsqueeze(0)
|
398 |
-
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
399 |
-
torchaudio.save(tmp.name, tmp.name, sr)
|
400 |
-
tmp_path = tmp.name
|
401 |
-
return send_file(tmp_path, mimetype="audio/wav", as_attachment=True, download_name="music.wav")
|
402 |
|
403 |
@app.route("/api/v1/translation", methods=["POST"])
|
404 |
def translation_api():
|
405 |
-
|
406 |
-
text = data.get("text", "")
|
407 |
-
translated = " ".join(text.split()[::-1])
|
408 |
-
return jsonify({"translated_text": translated})
|
409 |
|
410 |
@app.route("/api/v1/codegen", methods=["POST"])
|
411 |
def codegen_api():
|
412 |
-
|
413 |
-
prompt = data.get("prompt", "")
|
414 |
-
generated_code = f"# Generated code based on prompt: {prompt}\nprint('Hello from Torch-generated code')"
|
415 |
-
return jsonify({"code": generated_code})
|
416 |
|
417 |
@app.route("/api/v1/text_to_video", methods=["POST"])
|
418 |
def text_to_video_api():
|
419 |
-
|
420 |
-
prompt = data.get("prompt", "")
|
421 |
-
video_tensor = torch.randint(0, 255, (10, 3, 64, 64), dtype=torch.uint8)
|
422 |
-
video_bytes = video_tensor.numpy().tobytes()
|
423 |
-
buf = io.BytesIO(video_bytes)
|
424 |
-
return send_file(buf, mimetype="video/mp4", as_attachment=True, download_name="video.mp4")
|
425 |
|
426 |
@app.route("/api/v1/summarization", methods=["POST"])
|
427 |
def summarization_api():
|
428 |
-
|
429 |
-
text = data.get("text", "")
|
430 |
-
sentences = text.split('.')
|
431 |
-
summary = sentences[0] if sentences[0] else text
|
432 |
-
return jsonify({"summary": summary})
|
433 |
|
434 |
@app.route("/api/v1/image_to_3d", methods=["POST"])
|
435 |
def image_to_3d_api():
|
436 |
-
|
437 |
-
prompt = data.get("prompt", "")
|
438 |
-
obj_data = "o Cube\nv 0 0 0\nv 1 0 0\nv 1 1 0\nv 0 1 0\nf 1 2 3 4"
|
439 |
-
buf = io.BytesIO(obj_data.encode("utf-8"))
|
440 |
-
return send_file(buf, mimetype="text/plain", as_attachment=True, download_name="model.obj")
|
441 |
|
442 |
-
@app.route("/api/v1/sadtalker", methods=["
|
443 |
def sadtalker():
|
444 |
-
|
|
|
445 |
|
446 |
if __name__ == "__main__":
|
447 |
app.run(host="0.0.0.0", port=7860)
|
|
|
1 |
from main import *
|
2 |
+
from tts_api import tts_api as tts_route
|
3 |
+
from stt_api import stt_api as stt_route
|
4 |
+
from sentiment_api import sentiment_api as sentiment_route
|
5 |
+
from imagegen_api import imagegen_api as imagegen_route
|
6 |
+
from musicgen_api import musicgen_api as musicgen_route
|
7 |
+
from translation_api import translation_api as translation_route
|
8 |
+
from codegen_api import codegen_api as codegen_route
|
9 |
+
from text_to_video_api import text_to_video_api as text_to_video_route
|
10 |
+
from summarization_api import summarization_api as summarization_route
|
11 |
+
from image_to_3d_api import image_to_3d_api as image_to_3d_route
|
12 |
from flask import Flask, request, jsonify, Response, send_file, stream_with_context
|
13 |
from flask_cors import CORS
|
14 |
import torch
|
|
|
15 |
import torch.nn.functional as F
|
16 |
import torchaudio
|
17 |
import numpy as np
|
|
|
21 |
import queue
|
22 |
import json
|
23 |
import base64
|
24 |
+
from markupsafe import Markup
|
25 |
+
from markupsafe import escape
|
26 |
|
27 |
app = Flask(__name__)
|
28 |
CORS(app)
|
29 |
+
|
30 |
html_code = """<!DOCTYPE html>
|
31 |
<html lang="en">
|
32 |
<head>
|
|
|
227 |
"""
|
228 |
feedback_queue = queue.Queue()
|
229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
@app.route("/")
|
232 |
def index():
|
|
|
239 |
top_k = int(request.args.get("top_k", 40))
|
240 |
top_p = float(request.args.get("top_p", 0.0))
|
241 |
reppenalty = float(request.args.get("reppenalty", 1.2))
|
242 |
+
response_queue = queue.Queue()
|
243 |
+
reasoning_queue.put({
|
244 |
+
'text_input': text,
|
245 |
+
'temperature': temp,
|
246 |
+
'top_k': top_k,
|
247 |
+
'top_p': top_p,
|
248 |
+
'repetition_penalty': reppenalty,
|
249 |
+
'response_queue': response_queue
|
250 |
+
})
|
251 |
@stream_with_context
|
252 |
def event_stream():
|
253 |
+
while True:
|
254 |
+
output = response_queue.get()
|
255 |
+
if "error" in output:
|
256 |
+
yield "data: <ERROR>\n\n"
|
257 |
+
break
|
258 |
+
text_chunk = output.get("text")
|
259 |
+
if text_chunk:
|
260 |
+
for word in text_chunk.split(' '):
|
261 |
+
clean_word = word.strip()
|
262 |
+
if clean_word:
|
263 |
+
yield "data: " + clean_word + "\n\n"
|
264 |
+
yield "data: <END_STREAM>\n\n"
|
265 |
+
break
|
266 |
return Response(event_stream(), mimetype="text/event-stream")
|
267 |
|
268 |
@app.route("/api/v1/generate", methods=["POST"])
|
|
|
273 |
top_k = int(data.get("top_k", 40))
|
274 |
top_p = float(data.get("top_p", 0.0))
|
275 |
reppenalty = float(data.get("reppenalty", 1.2))
|
276 |
+
response_queue = queue.Queue()
|
277 |
+
reasoning_queue.put({
|
278 |
+
'text_input': text,
|
279 |
+
'temperature': temp,
|
280 |
+
'top_k': top_k,
|
281 |
+
'top_p': top_p,
|
282 |
+
'repetition_penalty': reppenalty,
|
283 |
+
'response_queue': response_queue
|
284 |
+
})
|
285 |
+
output = response_queue.get()
|
286 |
+
if "error" in output:
|
287 |
+
return jsonify({"error": output["error"]}), 500
|
288 |
+
result_text = output.get("text", "").strip()
|
289 |
+
return jsonify({"response": result_text})
|
290 |
|
291 |
@app.route("/api/v1/feedback", methods=["POST"])
|
292 |
def feedback():
|
|
|
300 |
|
301 |
@app.route("/api/v1/tts", methods=["POST"])
|
302 |
def tts_api():
|
303 |
+
return tts_route()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
304 |
|
305 |
@app.route("/api/v1/stt", methods=["POST"])
|
306 |
def stt_api():
|
307 |
+
return stt_route()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
308 |
|
309 |
@app.route("/api/v1/sentiment", methods=["POST"])
|
310 |
def sentiment_api():
|
311 |
+
return sentiment_route()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
|
313 |
@app.route("/api/v1/imagegen", methods=["POST"])
|
314 |
def imagegen_api():
|
315 |
+
return imagegen_route()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
316 |
|
317 |
@app.route("/api/v1/musicgen", methods=["POST"])
|
318 |
def musicgen_api():
|
319 |
+
return musicgen_route()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
|
321 |
@app.route("/api/v1/translation", methods=["POST"])
|
322 |
def translation_api():
|
323 |
+
return translation_route()
|
|
|
|
|
|
|
324 |
|
325 |
@app.route("/api/v1/codegen", methods=["POST"])
|
326 |
def codegen_api():
|
327 |
+
return codegen_route()
|
|
|
|
|
|
|
328 |
|
329 |
@app.route("/api/v1/text_to_video", methods=["POST"])
|
330 |
def text_to_video_api():
|
331 |
+
return text_to_video_route()
|
|
|
|
|
|
|
|
|
|
|
332 |
|
333 |
@app.route("/api/v1/summarization", methods=["POST"])
|
334 |
def summarization_api():
|
335 |
+
return summarization_route()
|
|
|
|
|
|
|
|
|
336 |
|
337 |
@app.route("/api/v1/image_to_3d", methods=["POST"])
|
338 |
def image_to_3d_api():
|
339 |
+
return image_to_3d_route()
|
|
|
|
|
|
|
|
|
340 |
|
341 |
+
@app.route("/api/v1/sadtalker", methods=["POST"])
|
342 |
def sadtalker():
|
343 |
+
from sadtalker_api import router as sadtalker_router
|
344 |
+
return sadtalker_router.create_video()
|
345 |
|
346 |
if __name__ == "__main__":
|
347 |
app.run(host="0.0.0.0", port=7860)
|
background_tasks.py
CHANGED
@@ -114,46 +114,12 @@ def background_training():
|
|
114 |
except Exception:
|
115 |
time.sleep(5)
|
116 |
|
117 |
-
class ReasoningModel(nn.Module):
|
118 |
-
def __init__(self, vocab_size, embed_dim=128, hidden_dim=128):
|
119 |
-
super(ReasoningModel, self).__init__()
|
120 |
-
self.embedding = nn.Embedding(vocab_size, embed_dim)
|
121 |
-
self.rnn = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
|
122 |
-
self.fc = nn.Linear(hidden_dim, vocab_size)
|
123 |
-
def forward(self, x, hidden=None):
|
124 |
-
emb = self.embedding(x)
|
125 |
-
output, hidden = self.rnn(emb, hidden)
|
126 |
-
logits = self.fc(output)
|
127 |
-
return logits, hidden
|
128 |
-
def generate(self, input_seq, max_length=999999999, temperature=1.0):
|
129 |
-
self.eval()
|
130 |
-
tokens = input_seq.copy()
|
131 |
-
hidden = None
|
132 |
-
generated = []
|
133 |
-
while True:
|
134 |
-
input_tensor = torch.tensor([tokens], dtype=torch.long)
|
135 |
-
logits, hidden = self.forward(input_tensor, hidden)
|
136 |
-
next_token_logits = logits[0, -1, :] / temperature
|
137 |
-
probabilities = torch.softmax(next_token_logits, dim=0)
|
138 |
-
next_token = torch.multinomial(probabilities, 1).item()
|
139 |
-
tokens.append(next_token)
|
140 |
-
generated.append(next_token)
|
141 |
-
if next_token == word_to_index.get("<EOS>"):
|
142 |
-
break
|
143 |
-
if len(generated) > max_length:
|
144 |
-
break
|
145 |
-
return generated
|
146 |
-
|
147 |
-
reasoning_model = ReasoningModel(len(vocabulary))
|
148 |
-
|
149 |
def perform_reasoning_stream(text_input, temperature=0.7, top_k=40, top_p=0.0, repetition_penalty=1.2):
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
yield vocabulary[idx] + " "
|
156 |
-
yield "<END_STREAM>"
|
157 |
|
158 |
def background_reasoning_queue():
|
159 |
global reasoning_queue, seen_responses
|
@@ -179,7 +145,7 @@ def background_reasoning_queue():
|
|
179 |
if chunk == "<END_STREAM>":
|
180 |
break
|
181 |
full_response += chunk
|
182 |
-
cleaned_response = re.sub(r'\s+(?=[.,,。])', '', full_response.replace("<|endoftext|>", "")
|
183 |
if cleaned_response in seen_responses:
|
184 |
final_response = "**Response is repetitive. Please try again or rephrase your query.**";
|
185 |
resp_queue.put({"text": final_response})
|
|
|
114 |
except Exception:
|
115 |
time.sleep(5)
|
116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
def perform_reasoning_stream(text_input, temperature=0.7, top_k=40, top_p=0.0, repetition_penalty=1.2):
|
118 |
+
for token in sample_sequence(text_input, model_gpt2, enc, length=999999999, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, device=device):
|
119 |
+
if token == "<END_STREAM>":
|
120 |
+
yield "<END_STREAM>"
|
121 |
+
break
|
122 |
+
yield token + " "
|
|
|
|
|
123 |
|
124 |
def background_reasoning_queue():
|
125 |
global reasoning_queue, seen_responses
|
|
|
145 |
if chunk == "<END_STREAM>":
|
146 |
break
|
147 |
full_response += chunk
|
148 |
+
cleaned_response = re.sub(r'\s+(?=[.,,。])', '', full_response.replace("<|endoftext|>", "").strip())
|
149 |
if cleaned_response in seen_responses:
|
150 |
final_response = "**Response is repetitive. Please try again or rephrase your query.**";
|
151 |
resp_queue.put({"text": final_response})
|
codegen_api.py
CHANGED
@@ -2,10 +2,10 @@ from flask import jsonify, send_file, request
|
|
2 |
from main import *
|
3 |
|
4 |
def generate_code(prompt, output_path="output_code.py"):
|
5 |
-
if codegen_model is None:
|
6 |
-
return "Code generation model not initialized."
|
7 |
-
input_ids = codegen_tokenizer
|
8 |
-
output = codegen_model.generate(input_ids, max_length=
|
9 |
code = codegen_tokenizer.decode(output[0], skip_special_tokens=True)
|
10 |
with open(output_path, "w") as file:
|
11 |
file.write(code)
|
@@ -17,6 +17,6 @@ def codegen_api():
|
|
17 |
if not prompt:
|
18 |
return jsonify({"error": "Prompt is required"}), 400
|
19 |
output_file = generate_code(prompt)
|
20 |
-
if output_file == "Code generation model not initialized.":
|
21 |
return jsonify({"error": "Code generation failed"}), 500
|
22 |
return send_file(output_file, mimetype="text/x-python", as_attachment=True, download_name="output.py")
|
|
|
2 |
from main import *
|
3 |
|
4 |
def generate_code(prompt, output_path="output_code.py"):
|
5 |
+
if codegen_model is None or codegen_tokenizer is None:
|
6 |
+
return "Code generation model or tokenizer not initialized."
|
7 |
+
input_ids = codegen_tokenizer(prompt, return_tensors='pt').to(device)
|
8 |
+
output = codegen_model.generate(input_ids, max_length=2048, temperature=0.7, top_p=0.9)
|
9 |
code = codegen_tokenizer.decode(output[0], skip_special_tokens=True)
|
10 |
with open(output_path, "w") as file:
|
11 |
file.write(code)
|
|
|
17 |
if not prompt:
|
18 |
return jsonify({"error": "Prompt is required"}), 400
|
19 |
output_file = generate_code(prompt)
|
20 |
+
if output_file == "Code generation model or tokenizer not initialized.":
|
21 |
return jsonify({"error": "Code generation failed"}), 500
|
22 |
return send_file(output_file, mimetype="text/x-python", as_attachment=True, download_name="output.py")
|
configs.py
CHANGED
@@ -58,11 +58,33 @@ class CodeGenConfig:
|
|
58 |
|
59 |
class SummarizationConfig:
|
60 |
def __init__(self):
|
61 |
-
self.vocab_size =
|
62 |
-
self.
|
63 |
-
self.
|
64 |
-
self.
|
65 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
|
67 |
class Clip4ClipConfig:
|
68 |
def __init__(self, vocab_size=30522, hidden_size=512, num_hidden_layers=6, num_attention_heads=8, intermediate_size=2048, hidden_act="gelu", hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, bos_token_id=1, eos_token_id=2, **kwargs):
|
|
|
58 |
|
59 |
class SummarizationConfig:
|
60 |
def __init__(self):
|
61 |
+
self.vocab_size = 50265
|
62 |
+
self.max_position_embeddings = 1024
|
63 |
+
self.encoder_layers = 12
|
64 |
+
self.encoder_ffn_dim = 4096
|
65 |
+
self.encoder_attention_heads = 16
|
66 |
+
self.decoder_layers = 12
|
67 |
+
self.decoder_ffn_dim = 4096
|
68 |
+
self.decoder_attention_heads = 16
|
69 |
+
self.encoder_layerdrop = 0.0
|
70 |
+
self.decoder_layerdrop = 0.0
|
71 |
+
self.activation_function = "gelu"
|
72 |
+
self.d_model = 1024
|
73 |
+
self.dropout = 0.1
|
74 |
+
self.attention_dropout = 0.0
|
75 |
+
self.activation_dropout = 0.0
|
76 |
+
self.init_std = 0.02
|
77 |
+
self.classifier_dropout = 0.0
|
78 |
+
self.num_labels = 3
|
79 |
+
self.pad_token_id = 1
|
80 |
+
self.bos_token_id = 0
|
81 |
+
self.eos_token_id = 2
|
82 |
+
self.layer_norm_eps = 1e-05
|
83 |
+
self.num_beams = 4
|
84 |
+
self.early_stopping = True
|
85 |
+
self.max_length = 100
|
86 |
+
self.min_length = 30
|
87 |
+
self.scale_embedding = False
|
88 |
|
89 |
class Clip4ClipConfig:
|
90 |
def __init__(self, vocab_size=30522, hidden_size=512, num_hidden_layers=6, num_attention_heads=8, intermediate_size=2048, hidden_act="gelu", hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, bos_token_id=1, eos_token_id=2, **kwargs):
|
constants.py
CHANGED
@@ -158,13 +158,7 @@ html_code = """<!DOCTYPE html>
|
|
158 |
top_p: top_p_val,
|
159 |
reppenalty: repetition_penalty_val
|
160 |
};
|
161 |
-
eventSource = new EventSource('/generate_stream'
|
162 |
-
headers: {
|
163 |
-
'Content-Type': 'application/json'
|
164 |
-
},
|
165 |
-
method: 'POST',
|
166 |
-
body: JSON.stringify(requestData)
|
167 |
-
});
|
168 |
eventSource.onmessage = function(event) {
|
169 |
if (event.data === "<END_STREAM>") {
|
170 |
eventSource.close();
|
|
|
158 |
top_p: top_p_val,
|
159 |
reppenalty: repetition_penalty_val
|
160 |
};
|
161 |
+
eventSource = new EventSource('/api/v1/generate_stream?' + new URLSearchParams(requestData).toString());
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
eventSource.onmessage = function(event) {
|
163 |
if (event.data === "<END_STREAM>") {
|
164 |
eventSource.close();
|
extensions.py
CHANGED
@@ -159,6 +159,86 @@ class RealESRGANer():
|
|
159 |
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB)
|
160 |
return [output_img, None]
|
161 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
def save_video_with_watermark(video_frames, audio_path, output_path, watermark_path='./assets/sadtalker_logo.png'):
|
163 |
try:
|
164 |
watermark = imageio.imread(watermark_path)
|
@@ -249,4 +329,4 @@ def get_prior_from_bfm(bfm_path):
|
|
249 |
'u_tex': u_tex,
|
250 |
'u_exp': u_exp
|
251 |
}
|
252 |
-
return prior_coeff
|
|
|
159 |
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB)
|
160 |
return [output_img, None]
|
161 |
|
162 |
+
def enhance(self, img, outscale=None, tile=None, tile_pad=None, pre_pad=None, half=None):
|
163 |
+
h_input, w_input = img.shape[0:2]
|
164 |
+
if outscale is None:
|
165 |
+
outscale = self.scale
|
166 |
+
if tile is None:
|
167 |
+
tile = self.tile
|
168 |
+
if tile_pad is None:
|
169 |
+
tile_pad = self.tile_pad
|
170 |
+
if pre_pad is None:
|
171 |
+
pre_pad = self.pre_pad
|
172 |
+
if half is None:
|
173 |
+
half = self.half
|
174 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
175 |
+
img_tensor = img2tensor(img)
|
176 |
+
img_tensor = img_tensor.unsqueeze(0).to(self.device)
|
177 |
+
if half:
|
178 |
+
img_tensor = img_tensor.half()
|
179 |
+
mod_scale = self.mod_scale
|
180 |
+
h_pad, w_pad = 0, 0
|
181 |
+
if mod_scale is not None:
|
182 |
+
h_pad, w_pad = int(np.ceil(h_input / mod_scale) * mod_scale - h_input), int(np.ceil(w_input / mod_scale) * mod_scale - w_input)
|
183 |
+
img_tensor = F.pad(img_tensor, (0, w_pad, 0, h_pad), 'reflect')
|
184 |
+
window_size = 256
|
185 |
+
scale = self.scale
|
186 |
+
overlap_ratio = 0.5
|
187 |
+
if w_input * h_input < window_size**2:
|
188 |
+
tile = None
|
189 |
+
if tile is not None and tile > 0:
|
190 |
+
tile_overlap = tile * overlap_ratio
|
191 |
+
sf = scale
|
192 |
+
stride_w = math.ceil(tile - tile_overlap)
|
193 |
+
stride_h = math.ceil(tile - tile_overlap)
|
194 |
+
numW = math.ceil((w_input + tile_overlap) / stride_w)
|
195 |
+
numH = math.ceil((h_input + tile_overlap) / stride_h)
|
196 |
+
paddingW = (numW - 1) * stride_w + tile - w_input
|
197 |
+
paddingH = (numH - 1) * stride_h + tile - h_input
|
198 |
+
padding_bottom = int(max(paddingH, 0))
|
199 |
+
padding_right = int(max(paddingW, 0))
|
200 |
+
padding_left, padding_top = 0, 0
|
201 |
+
img_tensor = F.pad(img_tensor, (padding_left, padding_right, padding_top, padding_bottom), mode='reflect')
|
202 |
+
output_h, output_w = padding_top + h_input * scale + padding_bottom, padding_left + w_input * scale + padding_right
|
203 |
+
output_tensor = torch.zeros([1, 3, output_h, output_w], dtype=img_tensor.dtype, device=self.device)
|
204 |
+
windows = []
|
205 |
+
for row in range(numH):
|
206 |
+
for col in range(numW):
|
207 |
+
start_x = col * stride_w
|
208 |
+
start_y = row * stride_h
|
209 |
+
end_x = min(start_x + tile, img_tensor.shape[3])
|
210 |
+
end_y = min(start_y + tile, img_tensor.shape[2])
|
211 |
+
windows.append(img_tensor[:, :, start_y:end_y, start_x:end_x])
|
212 |
+
results = []
|
213 |
+
batch_size = 8
|
214 |
+
for i in range(0, len(windows), batch_size):
|
215 |
+
batch_windows = torch.stack(windows[i:min(i + batch_size, len(windows))], dim=0)
|
216 |
+
with torch.no_grad():
|
217 |
+
results.append(self.model(batch_windows))
|
218 |
+
results = torch.cat(results, dim=0)
|
219 |
+
count = 0
|
220 |
+
for row in range(numH):
|
221 |
+
for col in range(numW):
|
222 |
+
start_x = col * stride_w
|
223 |
+
start_y = row * stride_h
|
224 |
+
end_x = min(start_x + tile, img_tensor.shape[3])
|
225 |
+
end_y = min(start_y + tile, img_tensor.shape[2])
|
226 |
+
out_start_x, out_start_y = start_x * sf, start_y * sf
|
227 |
+
out_end_x, out_end_y = end_x * sf, end_y * sf
|
228 |
+
output_tensor[:, :, out_start_y:out_end_y, out_start_x:out_end_x] += results[count][:, :, :end_y * sf - out_start_y, :end_x * sf - out_start_x]
|
229 |
+
count += 1
|
230 |
+
forward_img = output_tensor[:, :, :h_input * sf, :w_input * sf]
|
231 |
+
else:
|
232 |
+
with torch.no_grad():
|
233 |
+
forward_img = self.model(img_tensor)
|
234 |
+
if half:
|
235 |
+
forward_img = forward_img.float()
|
236 |
+
output_img = tensor2img(forward_img.squeeze(0).clamp_(0, 1))
|
237 |
+
if mod_scale is not None:
|
238 |
+
output_img = output_img[:h_input * self.scale, :w_input * self.scale, ...]
|
239 |
+
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB)
|
240 |
+
return [output_img, None]
|
241 |
+
|
242 |
def save_video_with_watermark(video_frames, audio_path, output_path, watermark_path='./assets/sadtalker_logo.png'):
|
243 |
try:
|
244 |
watermark = imageio.imread(watermark_path)
|
|
|
329 |
'u_tex': u_tex,
|
330 |
'u_exp': u_exp
|
331 |
}
|
332 |
+
return prior_coeff
|
imagegen_api.py
CHANGED
@@ -10,10 +10,11 @@ def generate_image(prompt, output_path="output_image.png"):
|
|
10 |
return "Image generation model not initialized."
|
11 |
|
12 |
generator = torch.Generator(device=device).manual_seed(0)
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
17 |
image.save(output_path)
|
18 |
return output_path
|
19 |
|
|
|
10 |
return "Image generation model not initialized."
|
11 |
|
12 |
generator = torch.Generator(device=device).manual_seed(0)
|
13 |
+
with torch.no_grad():
|
14 |
+
image = imagegen_model(
|
15 |
+
prompt,
|
16 |
+
generator=generator,
|
17 |
+
).images[0]
|
18 |
image.save(output_path)
|
19 |
return output_path
|
20 |
|
main.py
CHANGED
@@ -7,7 +7,7 @@ import re
|
|
7 |
import json
|
8 |
from flask import Flask
|
9 |
from flask_cors import CORS
|
10 |
-
from api import
|
11 |
from extensions import *
|
12 |
from constants import *
|
13 |
from configs import *
|
@@ -17,8 +17,7 @@ from model_loader import *
|
|
17 |
from utils import *
|
18 |
from background_tasks import generate_and_queue_text, background_training, background_reasoning_queue
|
19 |
from text_generation import *
|
20 |
-
from sadtalker_utils import
|
21 |
-
import torch
|
22 |
|
23 |
state_dict = None
|
24 |
enc = None
|
@@ -59,7 +58,7 @@ tts_model = None
|
|
59 |
musicgen_model = None
|
60 |
|
61 |
def load_models():
|
62 |
-
global model_gpt2, enc, translation_model, codegen_model, codegen_tokenizer, codegen_vocabulary, codegen_index_to_word, codegen_word_to_index, summarization_model, imagegen_model, image_to_3d_model, text_to_video_model, sadtalker_instance, sentiment_model, stt_model, tts_model, musicgen_model
|
63 |
model_gpt2, enc = initialize_gpt2_model(GPT2_FOLDER, {MODEL_FILE: MODEL_URL, ENCODER_FILE: ENCODER_URL, VOCAB_FILE: VOCAB_URL, CONFIG_FILE: GPT2CONFHG})
|
64 |
translation_model = initialize_translation_model(TRANSLATION_FOLDER, TRANSLATION_MODEL_FILES_URLS)
|
65 |
codegen_model, codegen_tokenizer, codegen_vocabulary, codegen_index_to_word, codegen_word_to_index = initialize_codegen_model(CODEGEN_FOLDER, CODEGEN_FILES_URLS)
|
@@ -71,35 +70,7 @@ def load_models():
|
|
71 |
stt_model = initialize_stt_model(STT_FOLDER, STT_FILES_URLS)
|
72 |
tts_model = initialize_tts_model(TTS_FOLDER, TTS_FILES_URLS)
|
73 |
musicgen_model = initialize_musicgen_model(MUSICGEN_FOLDER, MUSICGEN_FILES_URLS)
|
74 |
-
|
75 |
-
class SimpleClassifier(torch.nn.Module):
|
76 |
-
def __init__(self, vocab_size, num_classes):
|
77 |
-
super(SimpleClassifier, self).__init__()
|
78 |
-
self.embedding = torch.nn.Embedding(vocab_size, 128)
|
79 |
-
self.linear = torch.nn.Linear(128, num_classes)
|
80 |
-
def forward(self, x):
|
81 |
-
embedded = self.embedding(x)
|
82 |
-
pooled = torch.mean(embedded, dim=1)
|
83 |
-
return self.linear(pooled)
|
84 |
-
|
85 |
-
def tokenize_text(text):
|
86 |
-
global vocabulary, word_to_index, index_to_word
|
87 |
-
tokens = text.lower().split()
|
88 |
-
for token in tokens:
|
89 |
-
if token not in vocabulary:
|
90 |
-
vocabulary.add(token)
|
91 |
-
word_to_index[token] = len(index_to_word)
|
92 |
-
index_to_word.append(token)
|
93 |
-
return tokens
|
94 |
-
|
95 |
-
def text_to_vector(text):
|
96 |
-
global vocabulary, word_to_index
|
97 |
-
tokens = tokenize_text(text)
|
98 |
-
vector = torch.zeros(len(vocabulary))
|
99 |
-
for token in tokens:
|
100 |
-
if token in word_to_index:
|
101 |
-
vector[word_to_index[token]] += 1
|
102 |
-
return vector
|
103 |
|
104 |
if __name__ == "__main__":
|
105 |
nltk.download('punkt')
|
@@ -115,4 +86,4 @@ if __name__ == "__main__":
|
|
115 |
background_threads.append(threading.Thread(target=background_reasoning_queue, daemon=True))
|
116 |
for thread in background_threads:
|
117 |
thread.start()
|
118 |
-
app.run(host='0.0.0.0', port=7860)
|
|
|
7 |
import json
|
8 |
from flask import Flask
|
9 |
from flask_cors import CORS
|
10 |
+
from api import app
|
11 |
from extensions import *
|
12 |
from constants import *
|
13 |
from configs import *
|
|
|
17 |
from utils import *
|
18 |
from background_tasks import generate_and_queue_text, background_training, background_reasoning_queue
|
19 |
from text_generation import *
|
20 |
+
from sadtalker_utils import SadTalker
|
|
|
21 |
|
22 |
state_dict = None
|
23 |
enc = None
|
|
|
58 |
musicgen_model = None
|
59 |
|
60 |
def load_models():
|
61 |
+
global model_gpt2, enc, translation_model, codegen_model, codegen_tokenizer, codegen_vocabulary, codegen_index_to_word, codegen_word_to_index, summarization_model, imagegen_model, image_to_3d_model, text_to_video_model, sadtalker_instance, sentiment_model, stt_model, tts_model, musicgen_model
|
62 |
model_gpt2, enc = initialize_gpt2_model(GPT2_FOLDER, {MODEL_FILE: MODEL_URL, ENCODER_FILE: ENCODER_URL, VOCAB_FILE: VOCAB_URL, CONFIG_FILE: GPT2CONFHG})
|
63 |
translation_model = initialize_translation_model(TRANSLATION_FOLDER, TRANSLATION_MODEL_FILES_URLS)
|
64 |
codegen_model, codegen_tokenizer, codegen_vocabulary, codegen_index_to_word, codegen_word_to_index = initialize_codegen_model(CODEGEN_FOLDER, CODEGEN_FILES_URLS)
|
|
|
70 |
stt_model = initialize_stt_model(STT_FOLDER, STT_FILES_URLS)
|
71 |
tts_model = initialize_tts_model(TTS_FOLDER, TTS_FILES_URLS)
|
72 |
musicgen_model = initialize_musicgen_model(MUSICGEN_FOLDER, MUSICGEN_FILES_URLS)
|
73 |
+
sadtalker_instance = SadTalker(checkpoint_path='./checkpoints', config_path='./src/config')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
if __name__ == "__main__":
|
76 |
nltk.download('punkt')
|
|
|
86 |
background_threads.append(threading.Thread(target=background_reasoning_queue, daemon=True))
|
87 |
for thread in background_threads:
|
88 |
thread.start()
|
89 |
+
app.run(host='0.0.0.0', port=7860)
|
model_loader.py
CHANGED
@@ -265,7 +265,7 @@ class ResnetBlock(nn.Module):
|
|
265 |
sc = self.conv_shortcut(x)
|
266 |
h = F.silu(self.norm1(x))
|
267 |
h = self.conv1(h)
|
268 |
-
h = F.silu(self.norm2(
|
269 |
h = self.conv2(h)
|
270 |
return h + sc
|
271 |
|
|
|
265 |
sc = self.conv_shortcut(x)
|
266 |
h = F.silu(self.norm1(x))
|
267 |
h = self.conv1(h)
|
268 |
+
h = F.silu(self.norm2(x))
|
269 |
h = self.conv2(h)
|
270 |
return h + sc
|
271 |
|
models.py
CHANGED
@@ -91,4 +91,4 @@ class MusicGenModel(nn.Module):
|
|
91 |
audio_output.append(predicted_token.cpu())
|
92 |
input_tokens = torch.cat((input_tokens, predicted_token), dim=1)
|
93 |
audio_output = torch.cat(audio_output, dim=1).float()
|
94 |
-
return audio_output
|
|
|
91 |
audio_output.append(predicted_token.cpu())
|
92 |
input_tokens = torch.cat((input_tokens, predicted_token), dim=1)
|
93 |
audio_output = torch.cat(audio_output, dim=1).float()
|
94 |
+
return audio_output
|
musicgen_api.py
CHANGED
@@ -11,7 +11,7 @@ def generate_music(prompt, output_path="output_music.wav"):
|
|
11 |
|
12 |
attributes = [prompt]
|
13 |
sample_rate = 32000
|
14 |
-
duration =
|
15 |
audio_values = musicgen_model.sample(
|
16 |
attributes=attributes,
|
17 |
sample_rate=sample_rate,
|
|
|
11 |
|
12 |
attributes = [prompt]
|
13 |
sample_rate = 32000
|
14 |
+
duration = 10
|
15 |
audio_values = musicgen_model.sample(
|
16 |
attributes=attributes,
|
17 |
sample_rate=sample_rate,
|
sadtalker_api.py
CHANGED
@@ -157,7 +157,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
157 |
transcription_text_file = speech_to_text_func(tmp_audio_file.name)
|
158 |
with open(transcription_text_file, 'r') as f:
|
159 |
transcription_text = f.read()
|
160 |
-
response_stream = perform_reasoning_stream(
|
161 |
response_text = ""
|
162 |
for chunk in response_stream:
|
163 |
if chunk == "<END_STREAM>":
|
@@ -198,3 +198,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
198 |
except Exception as e:
|
199 |
print(e)
|
200 |
await websocket.send_json({"error":str(e)})
|
|
|
|
|
|
|
|
|
|
157 |
transcription_text_file = speech_to_text_func(tmp_audio_file.name)
|
158 |
with open(transcription_text_file, 'r') as f:
|
159 |
transcription_text = f.read()
|
160 |
+
response_stream = perform_reasoning_stream(transcription_text, 0.7, 40, 0.0, 1.2)
|
161 |
response_text = ""
|
162 |
for chunk in response_stream:
|
163 |
if chunk == "<END_STREAM>":
|
|
|
198 |
except Exception as e:
|
199 |
print(e)
|
200 |
await websocket.send_json({"error":str(e)})
|
201 |
+
|
202 |
+
router = APIRouter()
|
203 |
+
router.add_api_route("/sadtalker", create_video, methods=["POST"])
|
204 |
+
router.add_api_websocket_route("/ws", websocket_endpoint)
|
sadtalker_utils.py
CHANGED
@@ -269,32 +269,33 @@ class SadTalker:
|
|
269 |
self.sadtalker_model = SadTalkerModel(self.cfg, device_id=[0])
|
270 |
|
271 |
def get_cfg_defaults(self):
|
272 |
-
return
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
|
293 |
def merge_from_file(self, filepath):
|
294 |
if os.path.exists(filepath):
|
295 |
with open(filepath, 'r') as f:
|
296 |
cfg_from_file = yaml.safe_load(f)
|
297 |
-
self.cfg.update(cfg_from_file)
|
|
|
298 |
|
299 |
def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False,
|
300 |
batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None,
|
@@ -310,7 +311,7 @@ class SadTalkerModel:
|
|
310 |
|
311 |
def __init__(self, sadtalker_cfg, device_id=[0]):
|
312 |
self.cfg = sadtalker_cfg
|
313 |
-
self.device = sadtalker_cfg
|
314 |
self.sadtalker = SadTalkerInnerModel(sadtalker_cfg, device_id)
|
315 |
self.preprocesser = self.sadtalker.preprocesser
|
316 |
self.kp_extractor = self.sadtalker.kp_extractor
|
@@ -389,7 +390,7 @@ class SadTalkerInner:
|
|
389 |
ref_pose_coeff = None
|
390 |
ref_expression_coeff = None
|
391 |
audio_tensor, audio_sample_rate = proc.process_audio(self.driven_audio,
|
392 |
-
self.sadtalker_model.cfg
|
393 |
batch = {
|
394 |
'source_image': source_image_tensor.unsqueeze(0).to(self.device),
|
395 |
'audio': audio_tensor.unsqueeze(0).to(self.device),
|
@@ -455,12 +456,11 @@ class SadTalkerInner:
|
|
455 |
audio_name = os.path.splitext(os.path.basename(self.driven_audio))[0]
|
456 |
output_video_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '.mp4')
|
457 |
self.output_path = output_video_path
|
458 |
-
video_fps = self.sadtalker_model.cfg
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
self.sadtalker_model.cfg
|
463 |
-
self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE']
|
464 |
if self.use_enhancer:
|
465 |
enhanced_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '_enhanced.mp4')
|
466 |
save_video_with_watermark(output_video, self.driven_audio, enhanced_path)
|
@@ -489,13 +489,12 @@ class SadTalkerInnerModel:
|
|
489 |
|
490 |
def __init__(self, sadtalker_cfg, device_id=[0]):
|
491 |
self.cfg = sadtalker_cfg
|
492 |
-
self.device = sadtalker_cfg
|
493 |
self.preprocesser = Preprocesser(sadtalker_cfg, self.device)
|
494 |
self.kp_extractor = KeyPointExtractor(sadtalker_cfg, self.device)
|
495 |
self.audio_to_coeff = Audio2Coeff(sadtalker_cfg, self.device)
|
496 |
self.animate_from_coeff = AnimateFromCoeff(sadtalker_cfg, self.device)
|
497 |
-
self.face_enhancer = FaceEnhancer(sadtalker_cfg, self.device) if sadtalker_cfg
|
498 |
-
'USE_ENHANCER'] else None
|
499 |
self.generator = Generator(sadtalker_cfg, self.device)
|
500 |
self.mapping = Mapping(sadtalker_cfg, self.device)
|
501 |
self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, self.device)
|
@@ -506,10 +505,10 @@ class Preprocesser:
|
|
506 |
def __init__(self, sadtalker_cfg, device):
|
507 |
self.cfg = sadtalker_cfg
|
508 |
self.device = device
|
509 |
-
if self.cfg
|
510 |
-
self.face3d_helper = Face3DHelperOld(self.cfg
|
511 |
else:
|
512 |
-
self.face3d_helper = Face3DHelper(self.cfg
|
513 |
self.mouth_detector = MouthDetector()
|
514 |
|
515 |
def crop(self, source_image_pil, preprocess_type, size=256):
|
@@ -543,7 +542,7 @@ class Preprocesser:
|
|
543 |
cropped_image_pil = cropped_image_pil.resize((size, size), Image.Resampling.LANCZOS)
|
544 |
source_image_tensor = self.img2tensor(cropped_image_pil)
|
545 |
return source_image_tensor, [[y_min, y_max], [x_min, x_max], old_size, cropped_image_pil.size], os.path.basename(
|
546 |
-
self.cfg
|
547 |
|
548 |
def img2tensor(self, img):
|
549 |
img = np.array(img).astype(np.float32) / 255.0
|
@@ -577,7 +576,7 @@ class Preprocesser:
|
|
577 |
return ref_expression_coeff
|
578 |
|
579 |
def generate_idles_pose(self, length_of_audio, pose_style):
|
580 |
-
num_frames = int(length_of_audio * self.cfg
|
581 |
ref_pose_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device)
|
582 |
start_pose = self.generate_still_pose(pose_style)
|
583 |
end_pose = self.generate_still_pose(pose_style)
|
@@ -587,7 +586,7 @@ class Preprocesser:
|
|
587 |
return ref_pose_coeff
|
588 |
|
589 |
def generate_idles_expression(self, length_of_audio):
|
590 |
-
num_frames = int(length_of_audio * self.cfg
|
591 |
ref_expression_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device)
|
592 |
start_exp = self.generate_still_expression(1.0)
|
593 |
end_exp = self.generate_still_expression(1.0)
|
@@ -601,11 +600,11 @@ class KeyPointExtractor(nn.Module):
|
|
601 |
|
602 |
def __init__(self, sadtalker_cfg, device):
|
603 |
super(KeyPointExtractor, self).__init__()
|
604 |
-
self.kp_extractor = OcclusionAwareKPDetector(kp_channels=sadtalker_cfg
|
605 |
num_kp=10,
|
606 |
num_dilation_blocks=2,
|
607 |
dropout_rate=0.1).to(device)
|
608 |
-
checkpoint_path = os.path.join(sadtalker_cfg
|
609 |
self.load_kp_detector(checkpoint_path, device)
|
610 |
|
611 |
def load_kp_detector(self, checkpoint_path, device):
|
@@ -628,12 +627,12 @@ class Audio2Coeff(nn.Module):
|
|
628 |
def __init__(self, sadtalker_cfg, device):
|
629 |
super(Audio2Coeff, self).__init__()
|
630 |
self.audio_model = Wav2Vec2Model().to(device)
|
631 |
-
checkpoint_path = os.path.join(sadtalker_cfg
|
632 |
self.load_audio_model(checkpoint_path, device)
|
633 |
self.pose_mapper = AudioCoeffsPredictor(2048, 64).to(device)
|
634 |
self.exp_mapper = AudioCoeffsPredictor(2048, 64).to(device)
|
635 |
self.blink_mapper = AudioCoeffsPredictor(2048, 1).to(device)
|
636 |
-
mapping_checkpoint = os.path.join(sadtalker_cfg
|
637 |
self.load_mapping_model(mapping_checkpoint, device)
|
638 |
|
639 |
def load_audio_model(self, checkpoint_path, device):
|
@@ -753,13 +752,13 @@ class Generator(nn.Module):
|
|
753 |
|
754 |
def __init__(self, sadtalker_cfg, device):
|
755 |
super(Generator, self).__init__()
|
756 |
-
self.generator = Hourglass(block_expansion=sadtalker_cfg
|
757 |
-
num_blocks=sadtalker_cfg
|
758 |
-
max_features=sadtalker_cfg
|
759 |
num_channels=3,
|
760 |
kp_size=10,
|
761 |
-
num_deform_blocks=sadtalker_cfg
|
762 |
-
checkpoint_path = os.path.join(sadtalker_cfg
|
763 |
self.load_generator(checkpoint_path, device)
|
764 |
|
765 |
def load_generator(self, checkpoint_path, device):
|
@@ -786,7 +785,7 @@ class Mapping(nn.Module):
|
|
786 |
def __init__(self, sadtalker_cfg, device):
|
787 |
super(Mapping, self).__init__()
|
788 |
self.mapping_net = MappingNet(num_coeffs=64, num_layers=3, hidden_dim=128).to(device)
|
789 |
-
checkpoint_path = os.path.join(sadtalker_cfg
|
790 |
self.load_mapping_net(checkpoint_path, device)
|
791 |
self.f_3d_mean = torch.zeros(1, 64, device=device)
|
792 |
|
@@ -814,10 +813,10 @@ class OcclusionAwareDenseMotion(nn.Module):
|
|
814 |
super(OcclusionAwareDenseMotion, self).__init__()
|
815 |
self.dense_motion_network = DenseMotionNetwork(num_kp=10,
|
816 |
num_channels=3,
|
817 |
-
block_expansion=sadtalker_cfg
|
818 |
-
num_blocks=sadtalker_cfg
|
819 |
-
max_features=sadtalker_cfg
|
820 |
-
checkpoint_path = os.path.join(sadtalker_cfg
|
821 |
self.load_dense_motion_network(checkpoint_path, device)
|
822 |
|
823 |
def load_dense_motion_network(self, checkpoint_path, device):
|
@@ -839,20 +838,20 @@ class FaceEnhancer(nn.Module):
|
|
839 |
|
840 |
def __init__(self, sadtalker_cfg, device):
|
841 |
super(FaceEnhancer, self).__init__()
|
842 |
-
enhancer_name = sadtalker_cfg
|
843 |
-
bg_upsampler = sadtalker_cfg
|
844 |
if enhancer_name == 'gfpgan':
|
845 |
from gfpgan import GFPGANer
|
846 |
-
self.face_enhancer = GFPGANer(model_path=os.path.join(sadtalker_cfg
|
847 |
upscale=1,
|
848 |
arch='clean',
|
849 |
channel_multiplier=2,
|
850 |
bg_upsampler=bg_upsampler)
|
851 |
elif enhancer_name == 'realesrgan':
|
852 |
from realesrgan import RealESRGANer
|
853 |
-
half = False if device == 'cpu' else sadtalker_cfg
|
854 |
self.face_enhancer = RealESRGANer(scale=2,
|
855 |
-
model_path=os.path.join(sadtalker_cfg
|
856 |
'RealESRGAN_x2plus.pth'),
|
857 |
tile=0,
|
858 |
tile_pad=10,
|
|
|
269 |
self.sadtalker_model = SadTalkerModel(self.cfg, device_id=[0])
|
270 |
|
271 |
def get_cfg_defaults(self):
|
272 |
+
return CN(
|
273 |
+
MODEL=CN(
|
274 |
+
CHECKPOINTS_DIR='',
|
275 |
+
CONFIG_DIR='',
|
276 |
+
DEVICE=self.device,
|
277 |
+
SCALE=64,
|
278 |
+
NUM_VOXEL_FRAMES=8,
|
279 |
+
NUM_MOTION_FRAMES=10,
|
280 |
+
MAX_FEATURES=256,
|
281 |
+
DRIVEN_AUDIO_SAMPLE_RATE=16000,
|
282 |
+
VIDEO_FPS=25,
|
283 |
+
OUTPUT_VIDEO_FPS=None,
|
284 |
+
OUTPUT_AUDIO_SAMPLE_RATE=None,
|
285 |
+
USE_ENHANCER=False,
|
286 |
+
ENHANCER_NAME='',
|
287 |
+
BG_UPSAMPLER=None,
|
288 |
+
IS_HALF=False
|
289 |
+
),
|
290 |
+
INPUT_IMAGE=CN()
|
291 |
+
)
|
292 |
|
293 |
def merge_from_file(self, filepath):
|
294 |
if os.path.exists(filepath):
|
295 |
with open(filepath, 'r') as f:
|
296 |
cfg_from_file = yaml.safe_load(f)
|
297 |
+
self.cfg.MODEL.update(CN(cfg_from_file['MODEL']))
|
298 |
+
self.cfg.INPUT_IMAGE.update(CN(cfg_from_file['INPUT_IMAGE']))
|
299 |
|
300 |
def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False,
|
301 |
batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None,
|
|
|
311 |
|
312 |
def __init__(self, sadtalker_cfg, device_id=[0]):
|
313 |
self.cfg = sadtalker_cfg
|
314 |
+
self.device = sadtalker_cfg.MODEL.get('DEVICE', 'cpu')
|
315 |
self.sadtalker = SadTalkerInnerModel(sadtalker_cfg, device_id)
|
316 |
self.preprocesser = self.sadtalker.preprocesser
|
317 |
self.kp_extractor = self.sadtalker.kp_extractor
|
|
|
390 |
ref_pose_coeff = None
|
391 |
ref_expression_coeff = None
|
392 |
audio_tensor, audio_sample_rate = proc.process_audio(self.driven_audio,
|
393 |
+
self.sadtalker_model.cfg.MODEL.DRIVEN_AUDIO_SAMPLE_RATE)
|
394 |
batch = {
|
395 |
'source_image': source_image_tensor.unsqueeze(0).to(self.device),
|
396 |
'audio': audio_tensor.unsqueeze(0).to(self.device),
|
|
|
456 |
audio_name = os.path.splitext(os.path.basename(self.driven_audio))[0]
|
457 |
output_video_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '.mp4')
|
458 |
self.output_path = output_video_path
|
459 |
+
video_fps = self.sadtalker_model.cfg.MODEL.VIDEO_FPS if self.sadtalker_model.cfg.MODEL.OUTPUT_VIDEO_FPS is None else \
|
460 |
+
self.sadtalker_model.cfg.MODEL.OUTPUT_VIDEO_FPS
|
461 |
+
audio_output_sample_rate = self.sadtalker_model.cfg.MODEL.DRIVEN_AUDIO_SAMPLE_RATE if \
|
462 |
+
self.sadtalker_model.cfg.MODEL.OUTPUT_AUDIO_SAMPLE_RATE is None else \
|
463 |
+
self.sadtalker_model.cfg.MODEL.OUTPUT_AUDIO_SAMPLE_RATE
|
|
|
464 |
if self.use_enhancer:
|
465 |
enhanced_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '_enhanced.mp4')
|
466 |
save_video_with_watermark(output_video, self.driven_audio, enhanced_path)
|
|
|
489 |
|
490 |
def __init__(self, sadtalker_cfg, device_id=[0]):
|
491 |
self.cfg = sadtalker_cfg
|
492 |
+
self.device = sadtalker_cfg.MODEL.DEVICE
|
493 |
self.preprocesser = Preprocesser(sadtalker_cfg, self.device)
|
494 |
self.kp_extractor = KeyPointExtractor(sadtalker_cfg, self.device)
|
495 |
self.audio_to_coeff = Audio2Coeff(sadtalker_cfg, self.device)
|
496 |
self.animate_from_coeff = AnimateFromCoeff(sadtalker_cfg, self.device)
|
497 |
+
self.face_enhancer = FaceEnhancer(sadtalker_cfg, self.device) if sadtalker_cfg.MODEL.USE_ENHANCER else None
|
|
|
498 |
self.generator = Generator(sadtalker_cfg, self.device)
|
499 |
self.mapping = Mapping(sadtalker_cfg, self.device)
|
500 |
self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, self.device)
|
|
|
505 |
def __init__(self, sadtalker_cfg, device):
|
506 |
self.cfg = sadtalker_cfg
|
507 |
self.device = device
|
508 |
+
if self.cfg.INPUT_IMAGE.get('OLD_VERSION', False):
|
509 |
+
self.face3d_helper = Face3DHelperOld(self.cfg.INPUT_IMAGE.get('LOCAL_PCA_PATH', ''), device)
|
510 |
else:
|
511 |
+
self.face3d_helper = Face3DHelper(self.cfg.INPUT_IMAGE.get('LOCAL_PCA_PATH', ''), device)
|
512 |
self.mouth_detector = MouthDetector()
|
513 |
|
514 |
def crop(self, source_image_pil, preprocess_type, size=256):
|
|
|
542 |
cropped_image_pil = cropped_image_pil.resize((size, size), Image.Resampling.LANCZOS)
|
543 |
source_image_tensor = self.img2tensor(cropped_image_pil)
|
544 |
return source_image_tensor, [[y_min, y_max], [x_min, x_max], old_size, cropped_image_pil.size], os.path.basename(
|
545 |
+
self.cfg.INPUT_IMAGE.get('SOURCE_IMAGE', ''))
|
546 |
|
547 |
def img2tensor(self, img):
|
548 |
img = np.array(img).astype(np.float32) / 255.0
|
|
|
576 |
return ref_expression_coeff
|
577 |
|
578 |
def generate_idles_pose(self, length_of_audio, pose_style):
|
579 |
+
num_frames = int(length_of_audio * self.cfg.MODEL.VIDEO_FPS)
|
580 |
ref_pose_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device)
|
581 |
start_pose = self.generate_still_pose(pose_style)
|
582 |
end_pose = self.generate_still_pose(pose_style)
|
|
|
586 |
return ref_pose_coeff
|
587 |
|
588 |
def generate_idles_expression(self, length_of_audio):
|
589 |
+
num_frames = int(length_of_audio * self.cfg.MODEL.VIDEO_FPS)
|
590 |
ref_expression_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device)
|
591 |
start_exp = self.generate_still_expression(1.0)
|
592 |
end_exp = self.generate_still_expression(1.0)
|
|
|
600 |
|
601 |
def __init__(self, sadtalker_cfg, device):
|
602 |
super(KeyPointExtractor, self).__init__()
|
603 |
+
self.kp_extractor = OcclusionAwareKPDetector(kp_channels=sadtalker_cfg.MODEL.NUM_MOTION_FRAMES,
|
604 |
num_kp=10,
|
605 |
num_dilation_blocks=2,
|
606 |
dropout_rate=0.1).to(device)
|
607 |
+
checkpoint_path = os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR, 'kp_detector.safetensors')
|
608 |
self.load_kp_detector(checkpoint_path, device)
|
609 |
|
610 |
def load_kp_detector(self, checkpoint_path, device):
|
|
|
627 |
def __init__(self, sadtalker_cfg, device):
|
628 |
super(Audio2Coeff, self).__init__()
|
629 |
self.audio_model = Wav2Vec2Model().to(device)
|
630 |
+
checkpoint_path = os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR, 'wav2vec2.pth')
|
631 |
self.load_audio_model(checkpoint_path, device)
|
632 |
self.pose_mapper = AudioCoeffsPredictor(2048, 64).to(device)
|
633 |
self.exp_mapper = AudioCoeffsPredictor(2048, 64).to(device)
|
634 |
self.blink_mapper = AudioCoeffsPredictor(2048, 1).to(device)
|
635 |
+
mapping_checkpoint = os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR, 'audio2pose_00140-model.pth')
|
636 |
self.load_mapping_model(mapping_checkpoint, device)
|
637 |
|
638 |
def load_audio_model(self, checkpoint_path, device):
|
|
|
752 |
|
753 |
def __init__(self, sadtalker_cfg, device):
|
754 |
super(Generator, self).__init__()
|
755 |
+
self.generator = Hourglass(block_expansion=sadtalker_cfg.MODEL.SCALE,
|
756 |
+
num_blocks=sadtalker_cfg.MODEL.NUM_VOXEL_FRAMES,
|
757 |
+
max_features=sadtalker_cfg.MODEL.MAX_FEATURES,
|
758 |
num_channels=3,
|
759 |
kp_size=10,
|
760 |
+
num_deform_blocks=sadtalker_cfg.MODEL.NUM_MOTION_FRAMES).to(device)
|
761 |
+
checkpoint_path = os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR, 'generator.pth')
|
762 |
self.load_generator(checkpoint_path, device)
|
763 |
|
764 |
def load_generator(self, checkpoint_path, device):
|
|
|
785 |
def __init__(self, sadtalker_cfg, device):
|
786 |
super(Mapping, self).__init__()
|
787 |
self.mapping_net = MappingNet(num_coeffs=64, num_layers=3, hidden_dim=128).to(device)
|
788 |
+
checkpoint_path = os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR, 'mapping.pth')
|
789 |
self.load_mapping_net(checkpoint_path, device)
|
790 |
self.f_3d_mean = torch.zeros(1, 64, device=device)
|
791 |
|
|
|
813 |
super(OcclusionAwareDenseMotion, self).__init__()
|
814 |
self.dense_motion_network = DenseMotionNetwork(num_kp=10,
|
815 |
num_channels=3,
|
816 |
+
block_expansion=sadtalker_cfg.MODEL.SCALE,
|
817 |
+
num_blocks=sadtalker_cfg.MODEL.NUM_MOTION_FRAMES - 1,
|
818 |
+
max_features=sadtalker_cfg.MODEL.MAX_FEATURES).to(device)
|
819 |
+
checkpoint_path = os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR, 'dense_motion.pth')
|
820 |
self.load_dense_motion_network(checkpoint_path, device)
|
821 |
|
822 |
def load_dense_motion_network(self, checkpoint_path, device):
|
|
|
838 |
|
839 |
def __init__(self, sadtalker_cfg, device):
|
840 |
super(FaceEnhancer, self).__init__()
|
841 |
+
enhancer_name = sadtalker_cfg.MODEL.ENHANCER_NAME
|
842 |
+
bg_upsampler = sadtalker_cfg.MODEL.BG_UPSAMPLER
|
843 |
if enhancer_name == 'gfpgan':
|
844 |
from gfpgan import GFPGANer
|
845 |
+
self.face_enhancer = GFPGANer(model_path=os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR, 'GFPGANv1.4.pth'),
|
846 |
upscale=1,
|
847 |
arch='clean',
|
848 |
channel_multiplier=2,
|
849 |
bg_upsampler=bg_upsampler)
|
850 |
elif enhancer_name == 'realesrgan':
|
851 |
from realesrgan import RealESRGANer
|
852 |
+
half = False if device == 'cpu' else sadtalker_cfg.MODEL.IS_HALF
|
853 |
self.face_enhancer = RealESRGANer(scale=2,
|
854 |
+
model_path=os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR,
|
855 |
'RealESRGAN_x2plus.pth'),
|
856 |
tile=0,
|
857 |
tile_pad=10,
|
sentiment_api.py
CHANGED
@@ -2,25 +2,28 @@ from flask import jsonify
|
|
2 |
from main import *
|
3 |
import torch
|
4 |
|
5 |
-
def analyze_sentiment(text
|
6 |
if sentiment_model is None:
|
7 |
-
return "Sentiment model not initialized."
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
-
input_tokens = sentiment_model.tokenizer(text, return_tensors="pt", padding=True).to(device)
|
10 |
with torch.no_grad():
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
probability = torch.softmax(sentiment_logits, dim=-1)[0][predicted_class_id].item()
|
15 |
|
16 |
-
return {"sentiment": sentiment_label
|
17 |
|
18 |
def sentiment_api():
|
19 |
data = request.get_json()
|
20 |
text = data.get('text')
|
21 |
if not text:
|
22 |
return jsonify({"error": "Text is required"}), 400
|
23 |
-
|
24 |
-
if
|
25 |
-
return jsonify({"error": "
|
26 |
-
return jsonify(
|
|
|
2 |
from main import *
|
3 |
import torch
|
4 |
|
5 |
+
def analyze_sentiment(text):
|
6 |
if sentiment_model is None:
|
7 |
+
return {"error": "Sentiment model not initialized."}
|
8 |
+
|
9 |
+
features = [ord(c) for c in text[:10]]
|
10 |
+
while len(features) < 10:
|
11 |
+
features.append(0)
|
12 |
+
features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(device)
|
13 |
|
|
|
14 |
with torch.no_grad():
|
15 |
+
output = sentiment_model(features_tensor)
|
16 |
+
sentiment_idx = torch.argmax(output, dim=1).item()
|
17 |
+
sentiment_label = "positive" if sentiment_idx == 1 else "negative"
|
|
|
18 |
|
19 |
+
return {"sentiment": sentiment_label}
|
20 |
|
21 |
def sentiment_api():
|
22 |
data = request.get_json()
|
23 |
text = data.get('text')
|
24 |
if not text:
|
25 |
return jsonify({"error": "Text is required"}), 400
|
26 |
+
output = analyze_sentiment(text)
|
27 |
+
if "error" in output:
|
28 |
+
return jsonify({"error": output["error"]}), 500
|
29 |
+
return jsonify(output)
|
stt_api.py
CHANGED
@@ -5,9 +5,9 @@ from main import *
|
|
5 |
import torch
|
6 |
import torchaudio
|
7 |
|
8 |
-
def speech_to_text_func(audio_path
|
9 |
if stt_model is None:
|
10 |
-
return "STT model not initialized."
|
11 |
|
12 |
waveform, sample_rate = torchaudio.load(audio_path)
|
13 |
if waveform.ndim > 1:
|
@@ -18,9 +18,7 @@ def speech_to_text_func(audio_path, output_path="output_stt.txt"):
|
|
18 |
predicted_ids = torch.argmax(logits, dim=-1)
|
19 |
transcription = stt_model.tokenizer.decode(predicted_ids[0].cpu().tolist())
|
20 |
|
21 |
-
|
22 |
-
file.write(transcription)
|
23 |
-
return output_path
|
24 |
|
25 |
def stt_api():
|
26 |
if 'audio' not in request.files:
|
@@ -28,8 +26,8 @@ def stt_api():
|
|
28 |
audio_file = request.files['audio']
|
29 |
temp_audio_path = f"temp_audio_{uuid.uuid4()}.wav"
|
30 |
audio_file.save(temp_audio_path)
|
31 |
-
|
32 |
os.remove(temp_audio_path)
|
33 |
-
if
|
34 |
-
return jsonify({"error": "
|
35 |
-
return
|
|
|
5 |
import torch
|
6 |
import torchaudio
|
7 |
|
8 |
+
def speech_to_text_func(audio_path):
|
9 |
if stt_model is None:
|
10 |
+
return {"error": "STT model not initialized."}
|
11 |
|
12 |
waveform, sample_rate = torchaudio.load(audio_path)
|
13 |
if waveform.ndim > 1:
|
|
|
18 |
predicted_ids = torch.argmax(logits, dim=-1)
|
19 |
transcription = stt_model.tokenizer.decode(predicted_ids[0].cpu().tolist())
|
20 |
|
21 |
+
return {"text": transcription}
|
|
|
|
|
22 |
|
23 |
def stt_api():
|
24 |
if 'audio' not in request.files:
|
|
|
26 |
audio_file = request.files['audio']
|
27 |
temp_audio_path = f"temp_audio_{uuid.uuid4()}.wav"
|
28 |
audio_file.save(temp_audio_path)
|
29 |
+
output = speech_to_text_func(temp_audio_path)
|
30 |
os.remove(temp_audio_path)
|
31 |
+
if "error" in output:
|
32 |
+
return jsonify({"error": output["error"]}), 500
|
33 |
+
return jsonify(output)
|
summarization_api.py
CHANGED
@@ -3,15 +3,14 @@ from main import *
|
|
3 |
import torch
|
4 |
|
5 |
def summarize_text(text, output_path="output_summary.txt"):
|
6 |
-
if summarization_model is None:
|
7 |
-
return "Summarization model not initialized."
|
8 |
|
9 |
-
|
10 |
-
input_tensor = torch.tensor([input_tokens], dtype=torch.long).to(device)
|
11 |
|
12 |
with torch.no_grad():
|
13 |
-
summary_ids = summarization_model.generate(
|
14 |
-
summary_text =
|
15 |
|
16 |
with open(output_path, "w") as file:
|
17 |
file.write(summary_text)
|
@@ -23,6 +22,6 @@ def summarization_api():
|
|
23 |
if not text:
|
24 |
return jsonify({"error": "Text is required"}), 400
|
25 |
output_file = summarize_text(text)
|
26 |
-
if output_file == "Summarization model not initialized.":
|
27 |
return jsonify({"error": "Summarization failed"}), 500
|
28 |
return send_file(output_file, mimetype="text/plain", as_attachment=True, download_name="output_summary.txt")
|
|
|
3 |
import torch
|
4 |
|
5 |
def summarize_text(text, output_path="output_summary.txt"):
|
6 |
+
if summarization_model is None or summarization_tokenizer is None:
|
7 |
+
return "Summarization model or tokenizer not initialized."
|
8 |
|
9 |
+
input_ids = summarization_tokenizer.encode(text, return_tensors="pt").to(device)
|
|
|
10 |
|
11 |
with torch.no_grad():
|
12 |
+
summary_ids = summarization_model.generate(input_ids, num_beams=4, max_length=100, early_stopping=True)
|
13 |
+
summary_text = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
14 |
|
15 |
with open(output_path, "w") as file:
|
16 |
file.write(summary_text)
|
|
|
22 |
if not text:
|
23 |
return jsonify({"error": "Text is required"}), 400
|
24 |
output_file = summarize_text(text)
|
25 |
+
if output_file == "Summarization model or tokenizer not initialized.":
|
26 |
return jsonify({"error": "Summarization failed"}), 500
|
27 |
return send_file(output_file, mimetype="text/plain", as_attachment=True, download_name="output_summary.txt")
|
text_generation.py
CHANGED
@@ -22,124 +22,114 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
|
|
22 |
return logits
|
23 |
|
24 |
def sample_sequence(prompt, model, enc, length, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
|
25 |
-
start_time = time.time()
|
26 |
context_tokens = enc.encode(prompt)
|
27 |
context_tokens_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device)
|
28 |
generated = context_tokens
|
29 |
-
|
30 |
-
text_generated_count = 0
|
31 |
-
past_key_values = past if past is not None else None
|
32 |
|
33 |
with torch.no_grad():
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
|
52 |
def sample_sequence_codegen(prompt, model, tokenizer, length, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
|
53 |
-
start_time = time.time()
|
54 |
context_tokens = tokenizer.encode(prompt)
|
55 |
context_tokens_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device).unsqueeze(0)
|
56 |
generated = context_tokens
|
57 |
-
|
58 |
-
text_generated_count = 0
|
59 |
with torch.no_grad():
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
|
78 |
def perform_reasoning_stream(text_input, temperature, top_k, top_p, repetition_penalty):
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
reasoning_prompt = prompt_text
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
|
92 |
-
|
93 |
-
|
94 |
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
except Exception as e:
|
143 |
-
print(f"Reasoning Error: {e}")
|
144 |
-
yield "Error during reasoning. Please try again."
|
145 |
-
yield "<END_STREAM>"
|
|
|
22 |
return logits
|
23 |
|
24 |
def sample_sequence(prompt, model, enc, length, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
|
|
|
25 |
context_tokens = enc.encode(prompt)
|
26 |
context_tokens_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device)
|
27 |
generated = context_tokens
|
28 |
+
past_key_values = None
|
|
|
|
|
29 |
|
30 |
with torch.no_grad():
|
31 |
+
for _ in range(length):
|
32 |
+
outputs = model(context_tokens_tensor, past_key_values=past_key_values)
|
33 |
+
next_token_logits = outputs[0][:, -1, :] / temperature
|
34 |
+
past_key_values = outputs[1]
|
35 |
+
for token_index in set(generated):
|
36 |
+
next_token_logits[0, token_index] /= repetition_penalty
|
37 |
+
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
38 |
+
if temperature == 0:
|
39 |
+
next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(0)
|
40 |
+
else:
|
41 |
+
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
42 |
+
generated += next_token.tolist()[0]
|
43 |
+
token = next_token.tolist()[0][0]
|
44 |
+
yield enc.decode([token])
|
45 |
+
if token == enc.encoder[END_OF_TEXT_TOKEN]:
|
46 |
+
yield "<END_STREAM>"
|
47 |
+
return
|
48 |
|
49 |
def sample_sequence_codegen(prompt, model, tokenizer, length, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
|
|
|
50 |
context_tokens = tokenizer.encode(prompt)
|
51 |
context_tokens_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device).unsqueeze(0)
|
52 |
generated = context_tokens
|
53 |
+
past_key_values = None
|
|
|
54 |
with torch.no_grad():
|
55 |
+
for _ in range(length):
|
56 |
+
outputs = model(input_ids=context_tokens_tensor, past_key_values=past_key_values, labels=None)
|
57 |
+
next_token_logits = outputs[0][:, -1, :] / temperature
|
58 |
+
past_key_values = outputs[1]
|
59 |
+
for token_index in set(generated):
|
60 |
+
next_token_logits[0, token_index] /= repetition_penalty
|
61 |
+
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
|
62 |
+
if temperature == 0:
|
63 |
+
next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(0)
|
64 |
+
else:
|
65 |
+
next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
|
66 |
+
generated.append(next_token.tolist()[0][0])
|
67 |
+
token = next_token.tolist()[0][0]
|
68 |
+
yield tokenizer.decode([token])
|
69 |
+
if token == 50256:
|
70 |
+
yield "<END_STREAM>"
|
71 |
+
return
|
72 |
|
73 |
def perform_reasoning_stream(text_input, temperature, top_k, top_p, repetition_penalty):
|
74 |
+
prompt_text = SYSTEM_PROMPT + "\n\n"
|
75 |
+
prompt_text += "User: " + text_input + "\nAssistant:"
|
76 |
+
reasoning_prompt = prompt_text
|
|
|
77 |
|
78 |
+
ddgs = DDGS()
|
79 |
+
search_results = [r for r in ddgs.text(text_input, max_results=MAX_XDD)]
|
80 |
+
if search_results:
|
81 |
+
prompt_text += "\nWeb Search Results:\n"
|
82 |
+
for result in search_results:
|
83 |
+
prompt_text += f"- {result['body']}\n"
|
84 |
+
prompt_text += "\n"
|
85 |
|
86 |
+
generated_text_stream = []
|
87 |
+
stream_type = "text"
|
88 |
|
89 |
+
if "code" in text_input.lower() or "program" in text_input.lower():
|
90 |
+
if codegen_model and codegen_tokenizer:
|
91 |
+
generated_text_stream = sample_sequence_codegen(
|
92 |
+
prompt=reasoning_prompt,
|
93 |
+
model=codegen_model,
|
94 |
+
tokenizer=codegen_tokenizer,
|
95 |
+
length=999999999,
|
96 |
+
temperature=temperature,
|
97 |
+
top_k=top_k,
|
98 |
+
top_p=top_p,
|
99 |
+
repetition_penalty=repetition_penalty,
|
100 |
+
device=device
|
101 |
+
)
|
102 |
+
stream_type = "text"
|
103 |
+
elif "summarize" in text_input.lower() or "summary" in text_input.lower():
|
104 |
+
if summarization_model:
|
105 |
+
summary = summarize_text(text_input)
|
106 |
+
yield f"SUMMARY_TEXT:{summary}"
|
107 |
+
yield "<END_STREAM>"
|
108 |
+
stream_type = "summary"
|
109 |
+
else:
|
110 |
+
if model_gpt2 and enc:
|
111 |
+
generated_text_stream = sample_sequence(
|
112 |
+
prompt=reasoning_prompt,
|
113 |
+
model=model_gpt2,
|
114 |
+
enc=enc,
|
115 |
+
length=999999999,
|
116 |
+
temperature=temperature,
|
117 |
+
top_k=top_k,
|
118 |
+
top_p=top_p,
|
119 |
+
repetition_penalty=repetition_penalty,
|
120 |
+
device=device
|
121 |
+
)
|
122 |
+
stream_type = "text"
|
123 |
|
124 |
+
accumulated_text = ""
|
125 |
+
if stream_type == "text":
|
126 |
+
for token in generated_text_stream:
|
127 |
+
if token == "<END_STREAM>":
|
128 |
+
yield accumulated_text
|
129 |
+
yield "<END_STREAM>"
|
130 |
+
return
|
131 |
+
if token == END_OF_TEXT_TOKEN:
|
132 |
+
accumulated_text += END_OF_TEXT_TOKEN
|
133 |
+
continue
|
134 |
+
if token:
|
135 |
+
accumulated_text += token
|
|
|
|
|
|
|
|
tokenxxx.py
CHANGED
@@ -139,4 +139,4 @@ def codegen_tokenize(text, tokenizer):
|
|
139 |
return tokenizer.encode(text)
|
140 |
|
141 |
def codegen_decode(tokens, tokenizer):
|
142 |
-
return tokenizer.decode(tokens)
|
|
|
139 |
return tokenizer.encode(text)
|
140 |
|
141 |
def codegen_decode(tokens, tokenizer):
|
142 |
+
return tokenizer.decode(tokens)
|
translation_api.py
CHANGED
@@ -1,17 +1,15 @@
|
|
1 |
from flask import jsonify, send_file, request
|
2 |
from main import *
|
3 |
|
4 |
-
def perform_translation(text, target_language_code='es_XX', source_language_code='en_XX'
|
5 |
if translation_model is None:
|
6 |
-
return "Translation model not initialized."
|
7 |
|
8 |
encoded_text = translation_model.tokenizer(text, return_tensors="pt", padding=True).to(device)
|
9 |
generated_tokens = translation_model.generate(input_ids=encoded_text['input_ids'], attention_mask=encoded_text['attention_mask'], forced_bos_token_id=translation_model.config.lang_code_to_id[target_language_code])
|
10 |
translation = translation_model.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
11 |
|
12 |
-
|
13 |
-
file.write(translation)
|
14 |
-
return output_path
|
15 |
|
16 |
def translation_api():
|
17 |
data = request.get_json()
|
@@ -20,7 +18,7 @@ def translation_api():
|
|
20 |
source_lang = data.get('source_lang', 'en')
|
21 |
if not text:
|
22 |
return jsonify({"error": "Text is required"}), 400
|
23 |
-
|
24 |
-
if
|
25 |
-
return jsonify({"error": "
|
26 |
-
return
|
|
|
1 |
from flask import jsonify, send_file, request
|
2 |
from main import *
|
3 |
|
4 |
+
def perform_translation(text, target_language_code='es_XX', source_language_code='en_XX'):
|
5 |
if translation_model is None:
|
6 |
+
return {"error": "Translation model not initialized."}
|
7 |
|
8 |
encoded_text = translation_model.tokenizer(text, return_tensors="pt", padding=True).to(device)
|
9 |
generated_tokens = translation_model.generate(input_ids=encoded_text['input_ids'], attention_mask=encoded_text['attention_mask'], forced_bos_token_id=translation_model.config.lang_code_to_id[target_language_code])
|
10 |
translation = translation_model.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
11 |
|
12 |
+
return {"translated_text": translation}
|
|
|
|
|
13 |
|
14 |
def translation_api():
|
15 |
data = request.get_json()
|
|
|
18 |
source_lang = data.get('source_lang', 'en')
|
19 |
if not text:
|
20 |
return jsonify({"error": "Text is required"}), 400
|
21 |
+
output = perform_translation(text, target_language_code=f'{target_lang}_XX', source_language_code=f'{source_lang}_XX')
|
22 |
+
if "error" in output:
|
23 |
+
return jsonify({"error": output["error"]}), 500
|
24 |
+
return jsonify(output)
|
tts_api.py
CHANGED
@@ -1,15 +1,19 @@
|
|
1 |
import os
|
2 |
from flask import jsonify, send_file, request
|
3 |
from main import *
|
|
|
|
|
|
|
4 |
|
5 |
-
def text_to_speech_func(text
|
6 |
if tts_model is None:
|
7 |
-
return "TTS model not initialized."
|
8 |
input_tokens = tts_model.tokenizer(text, return_tensors="pt", padding=True).to(device)
|
9 |
with torch.no_grad():
|
10 |
audio_output = tts_model(input_tokens['input_ids'])
|
11 |
-
|
12 |
-
|
|
|
13 |
|
14 |
def tts_api():
|
15 |
data = request.get_json()
|
@@ -17,6 +21,6 @@ def tts_api():
|
|
17 |
if not text:
|
18 |
return jsonify({"error": "Text is required"}), 400
|
19 |
output_file = text_to_speech_func(text)
|
20 |
-
if
|
21 |
-
return jsonify({"error": "
|
22 |
return send_file(output_file, mimetype="audio/wav", as_attachment=True, download_name="output.wav")
|
|
|
1 |
import os
|
2 |
from flask import jsonify, send_file, request
|
3 |
from main import *
|
4 |
+
import torch
|
5 |
+
import torchaudio
|
6 |
+
import uuid
|
7 |
|
8 |
+
def text_to_speech_func(text):
|
9 |
if tts_model is None:
|
10 |
+
return {"error": "TTS model not initialized."}
|
11 |
input_tokens = tts_model.tokenizer(text, return_tensors="pt", padding=True).to(device)
|
12 |
with torch.no_grad():
|
13 |
audio_output = tts_model(input_tokens['input_ids'])
|
14 |
+
temp_audio_path = f"temp_audio_{uuid.uuid4()}.wav"
|
15 |
+
torchaudio.save(temp_audio_path, audio_output.cpu(), 16000)
|
16 |
+
return temp_audio_path
|
17 |
|
18 |
def tts_api():
|
19 |
data = request.get_json()
|
|
|
21 |
if not text:
|
22 |
return jsonify({"error": "Text is required"}), 400
|
23 |
output_file = text_to_speech_func(text)
|
24 |
+
if "error" in output:
|
25 |
+
return jsonify({"error": output["error"]}), 500
|
26 |
return send_file(output_file, mimetype="audio/wav", as_attachment=True, download_name="output.wav")
|