Hjgugugjhuhjggg commited on
Commit
e83e49f
·
verified ·
1 Parent(s): e3e53f6

Upload 28 files

Browse files
api.py CHANGED
@@ -1,18 +1,17 @@
1
  from main import *
2
- from tts_api import *
3
- from stt_api import *
4
- from sentiment_api import *
5
- from imagegen_api import *
6
- from musicgen_api import *
7
- from translation_api import *
8
- from codegen_api import *
9
- from text_to_video_api import *
10
- from summarization_api import *
11
- from image_to_3d_api import *
12
  from flask import Flask, request, jsonify, Response, send_file, stream_with_context
13
  from flask_cors import CORS
14
  import torch
15
- import torch.nn as nn
16
  import torch.nn.functional as F
17
  import torchaudio
18
  import numpy as np
@@ -22,9 +21,12 @@ import tempfile
22
  import queue
23
  import json
24
  import base64
 
 
25
 
26
  app = Flask(__name__)
27
  CORS(app)
 
28
  html_code = """<!DOCTYPE html>
29
  <html lang="en">
30
  <head>
@@ -225,59 +227,6 @@ html_code = """<!DOCTYPE html>
225
  """
226
  feedback_queue = queue.Queue()
227
 
228
- class TextGenerationModel(nn.Module):
229
- def __init__(self, vocab_size, embed_dim, hidden_dim):
230
- super(TextGenerationModel, self).__init__()
231
- self.embedding = nn.Embedding(vocab_size, embed_dim)
232
- self.rnn = nn.GRU(embed_dim, hidden_dim, batch_first=True)
233
- self.fc = nn.Linear(hidden_dim, vocab_size)
234
- def forward(self, x, hidden=None):
235
- x = self.embedding(x)
236
- out, hidden = self.rnn(x, hidden)
237
- out = self.fc(out)
238
- return out, hidden
239
-
240
- vocab = ["hola", "mundo", "este", "es", "un", "ejemplo", "de", "texto", "generado", "con", "torch"]
241
- vocab_size = len(vocab)
242
- embed_dim = 16
243
- hidden_dim = 32
244
- text_model = TextGenerationModel(vocab_size, embed_dim, hidden_dim)
245
- text_model.eval()
246
-
247
- def tokenize(text):
248
- tokens = text.lower().split()
249
- indices = [vocab.index(token) if token in vocab else 0 for token in tokens]
250
- return torch.tensor(indices, dtype=torch.long).unsqueeze(0)
251
-
252
- def perform_reasoning_stream(text, temperature, top_k, top_p, repetition_penalty):
253
- input_tensor = tokenize(text)
254
- hidden = None
255
- while True:
256
- outputs, hidden = text_model(input_tensor, hidden)
257
- logits = outputs[:, -1, :] / temperature
258
- probs = F.softmax(logits, dim=-1)
259
- topk_probs, topk_indices = torch.topk(probs, min(top_k, logits.shape[-1]))
260
- chosen_index = topk_indices[0, torch.multinomial(topk_probs[0], 1).item()].item()
261
- token_str = vocab[chosen_index]
262
- yield token_str
263
- input_tensor = torch.cat([input_tensor, torch.tensor([[chosen_index]], dtype=torch.long)], dim=1)
264
- if token_str == "mundo":
265
- yield "<END_STREAM>"
266
- break
267
-
268
-
269
- class SentimentModel(nn.Module):
270
- def __init__(self, input_dim, hidden_dim, output_dim):
271
- super(SentimentModel, self).__init__()
272
- self.fc1 = nn.Linear(input_dim, hidden_dim)
273
- self.fc2 = nn.Linear(hidden_dim, output_dim)
274
- def forward(self, x):
275
- x = F.relu(self.fc1(x))
276
- x = self.fc2(x)
277
- return x
278
-
279
- sentiment_model = SentimentModel(10, 16, 2)
280
- sentiment_model.eval()
281
 
282
  @app.route("/")
283
  def index():
@@ -290,16 +239,30 @@ def generate_stream():
290
  top_k = int(request.args.get("top_k", 40))
291
  top_p = float(request.args.get("top_p", 0.0))
292
  reppenalty = float(request.args.get("reppenalty", 1.2))
 
 
 
 
 
 
 
 
 
293
  @stream_with_context
294
  def event_stream():
295
- try:
296
- for token in perform_reasoning_stream(text, temperature=temp, top_k=top_k, top_p=top_p, repetition_penalty=reppenalty):
297
- if token == "<END_STREAM>":
298
- yield "data: <END_STREAM>\n\n"
299
- break
300
- yield "data: " + token + "\n\n"
301
- except Exception as e:
302
- yield "data: <ERROR>\n\n"
 
 
 
 
 
303
  return Response(event_stream(), mimetype="text/event-stream")
304
 
305
  @app.route("/api/v1/generate", methods=["POST"])
@@ -310,15 +273,20 @@ def generate():
310
  top_k = int(data.get("top_k", 40))
311
  top_p = float(data.get("top_p", 0.0))
312
  reppenalty = float(data.get("reppenalty", 1.2))
313
- result = ""
314
- try:
315
- for token in perform_reasoning_stream(text, temperature=temp, top_k=top_k, top_p=top_p, repetition_penalty=reppenalty):
316
- if token == "<END_STREAM>":
317
- break
318
- result += token + " "
319
- except Exception as e:
320
- return jsonify({"error": str(e)}), 500
321
- return jsonify({"solidity": result.strip()})
 
 
 
 
 
322
 
323
  @app.route("/api/v1/feedback", methods=["POST"])
324
  def feedback():
@@ -332,116 +300,48 @@ def feedback():
332
 
333
  @app.route("/api/v1/tts", methods=["POST"])
334
  def tts_api():
335
- data = request.get_json()
336
- text = data.get("text", "")
337
- sr = 22050
338
- duration = 3.0
339
- t = torch.linspace(0, duration, int(sr * duration))
340
- frequency = 440.0
341
- audio = 0.5 * torch.sin(2 * torch.pi * frequency * t)
342
- audio = audio.unsqueeze(0)
343
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
344
- torchaudio.save(tmp.name, audio, sr)
345
- tmp_path = tmp.name
346
- return send_file(tmp_path, mimetype="audio/wav", as_attachment=True, download_name="output.wav")
347
 
348
  @app.route("/api/v1/stt", methods=["POST"])
349
  def stt_api():
350
- data = request.get_json()
351
- audio_b64 = data.get("audio", "")
352
- if audio_b64:
353
- audio_bytes = base64.b64decode(audio_b64)
354
- buf = io.BytesIO(audio_bytes)
355
- waveform, sr = torchaudio.load(buf)
356
- mean_amp = waveform.abs().mean().item()
357
- recognized_text = f"Audio processed with mean amplitude {mean_amp:.3f}"
358
- return jsonify({"text": recognized_text})
359
- return jsonify({"text": ""})
360
 
361
  @app.route("/api/v1/sentiment", methods=["POST"])
362
  def sentiment_api():
363
- data = request.get_json()
364
- text = data.get("text", "")
365
- if not text:
366
- return jsonify({"sentiment": "neutral"})
367
- ascii_vals = [ord(c) for c in text[:10]]
368
- while len(ascii_vals) < 10:
369
- ascii_vals.append(0)
370
- features = torch.tensor(ascii_vals, dtype=torch.float32).unsqueeze(0)
371
- output = sentiment_model(features)
372
- sentiment_idx = torch.argmax(output, dim=1).item()
373
- sentiment = "positivo" if sentiment_idx == 1 else "negativo"
374
- return jsonify({"sentiment": sentiment})
375
 
376
  @app.route("/api/v1/imagegen", methods=["POST"])
377
  def imagegen_api():
378
- data = request.get_json()
379
- prompt = data.get("prompt", "")
380
- image_tensor = torch.rand(3, 256, 256)
381
- np_image = image_tensor.mul(255).clamp(0, 255).byte().numpy().transpose(1, 2, 0)
382
- img = Image.fromarray(np_image)
383
- buf = io.BytesIO()
384
- img.save(buf, format="PNG")
385
- buf.seek(0)
386
- return send_file(buf, mimetype="image/png", as_attachment=True, download_name="image.png")
387
 
388
  @app.route("/api/v1/musicgen", methods=["POST"])
389
  def musicgen_api():
390
- data = request.get_json()
391
- prompt = data.get("prompt", "")
392
- sr = 22050
393
- duration = 5.0
394
- t = torch.linspace(0, duration, int(sr * duration))
395
- frequency = 440.0
396
- audio = 0.5 * torch.sin(2 * torch.pi * frequency * t)
397
- audio = audio.unsqueeze(0)
398
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
399
- torchaudio.save(tmp.name, tmp.name, sr)
400
- tmp_path = tmp.name
401
- return send_file(tmp_path, mimetype="audio/wav", as_attachment=True, download_name="music.wav")
402
 
403
  @app.route("/api/v1/translation", methods=["POST"])
404
  def translation_api():
405
- data = request.get_json()
406
- text = data.get("text", "")
407
- translated = " ".join(text.split()[::-1])
408
- return jsonify({"translated_text": translated})
409
 
410
  @app.route("/api/v1/codegen", methods=["POST"])
411
  def codegen_api():
412
- data = request.get_json()
413
- prompt = data.get("prompt", "")
414
- generated_code = f"# Generated code based on prompt: {prompt}\nprint('Hello from Torch-generated code')"
415
- return jsonify({"code": generated_code})
416
 
417
  @app.route("/api/v1/text_to_video", methods=["POST"])
418
  def text_to_video_api():
419
- data = request.get_json()
420
- prompt = data.get("prompt", "")
421
- video_tensor = torch.randint(0, 255, (10, 3, 64, 64), dtype=torch.uint8)
422
- video_bytes = video_tensor.numpy().tobytes()
423
- buf = io.BytesIO(video_bytes)
424
- return send_file(buf, mimetype="video/mp4", as_attachment=True, download_name="video.mp4")
425
 
426
  @app.route("/api/v1/summarization", methods=["POST"])
427
  def summarization_api():
428
- data = request.get_json()
429
- text = data.get("text", "")
430
- sentences = text.split('.')
431
- summary = sentences[0] if sentences[0] else text
432
- return jsonify({"summary": summary})
433
 
434
  @app.route("/api/v1/image_to_3d", methods=["POST"])
435
  def image_to_3d_api():
436
- data = request.get_json()
437
- prompt = data.get("prompt", "")
438
- obj_data = "o Cube\nv 0 0 0\nv 1 0 0\nv 1 1 0\nv 0 1 0\nf 1 2 3 4"
439
- buf = io.BytesIO(obj_data.encode("utf-8"))
440
- return send_file(buf, mimetype="text/plain", as_attachment=True, download_name="model.obj")
441
 
442
- @app.route("/api/v1/sadtalker", methods=["GET"])
443
  def sadtalker():
444
- return jsonify({"message": "Respuesta de sadtalker"})
 
445
 
446
  if __name__ == "__main__":
447
  app.run(host="0.0.0.0", port=7860)
 
1
  from main import *
2
+ from tts_api import tts_api as tts_route
3
+ from stt_api import stt_api as stt_route
4
+ from sentiment_api import sentiment_api as sentiment_route
5
+ from imagegen_api import imagegen_api as imagegen_route
6
+ from musicgen_api import musicgen_api as musicgen_route
7
+ from translation_api import translation_api as translation_route
8
+ from codegen_api import codegen_api as codegen_route
9
+ from text_to_video_api import text_to_video_api as text_to_video_route
10
+ from summarization_api import summarization_api as summarization_route
11
+ from image_to_3d_api import image_to_3d_api as image_to_3d_route
12
  from flask import Flask, request, jsonify, Response, send_file, stream_with_context
13
  from flask_cors import CORS
14
  import torch
 
15
  import torch.nn.functional as F
16
  import torchaudio
17
  import numpy as np
 
21
  import queue
22
  import json
23
  import base64
24
+ from markupsafe import Markup
25
+ from markupsafe import escape
26
 
27
  app = Flask(__name__)
28
  CORS(app)
29
+
30
  html_code = """<!DOCTYPE html>
31
  <html lang="en">
32
  <head>
 
227
  """
228
  feedback_queue = queue.Queue()
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
  @app.route("/")
232
  def index():
 
239
  top_k = int(request.args.get("top_k", 40))
240
  top_p = float(request.args.get("top_p", 0.0))
241
  reppenalty = float(request.args.get("reppenalty", 1.2))
242
+ response_queue = queue.Queue()
243
+ reasoning_queue.put({
244
+ 'text_input': text,
245
+ 'temperature': temp,
246
+ 'top_k': top_k,
247
+ 'top_p': top_p,
248
+ 'repetition_penalty': reppenalty,
249
+ 'response_queue': response_queue
250
+ })
251
  @stream_with_context
252
  def event_stream():
253
+ while True:
254
+ output = response_queue.get()
255
+ if "error" in output:
256
+ yield "data: <ERROR>\n\n"
257
+ break
258
+ text_chunk = output.get("text")
259
+ if text_chunk:
260
+ for word in text_chunk.split(' '):
261
+ clean_word = word.strip()
262
+ if clean_word:
263
+ yield "data: " + clean_word + "\n\n"
264
+ yield "data: <END_STREAM>\n\n"
265
+ break
266
  return Response(event_stream(), mimetype="text/event-stream")
267
 
268
  @app.route("/api/v1/generate", methods=["POST"])
 
273
  top_k = int(data.get("top_k", 40))
274
  top_p = float(data.get("top_p", 0.0))
275
  reppenalty = float(data.get("reppenalty", 1.2))
276
+ response_queue = queue.Queue()
277
+ reasoning_queue.put({
278
+ 'text_input': text,
279
+ 'temperature': temp,
280
+ 'top_k': top_k,
281
+ 'top_p': top_p,
282
+ 'repetition_penalty': reppenalty,
283
+ 'response_queue': response_queue
284
+ })
285
+ output = response_queue.get()
286
+ if "error" in output:
287
+ return jsonify({"error": output["error"]}), 500
288
+ result_text = output.get("text", "").strip()
289
+ return jsonify({"response": result_text})
290
 
291
  @app.route("/api/v1/feedback", methods=["POST"])
292
  def feedback():
 
300
 
301
  @app.route("/api/v1/tts", methods=["POST"])
302
  def tts_api():
303
+ return tts_route()
 
 
 
 
 
 
 
 
 
 
 
304
 
305
  @app.route("/api/v1/stt", methods=["POST"])
306
  def stt_api():
307
+ return stt_route()
 
 
 
 
 
 
 
 
 
308
 
309
  @app.route("/api/v1/sentiment", methods=["POST"])
310
  def sentiment_api():
311
+ return sentiment_route()
 
 
 
 
 
 
 
 
 
 
 
312
 
313
  @app.route("/api/v1/imagegen", methods=["POST"])
314
  def imagegen_api():
315
+ return imagegen_route()
 
 
 
 
 
 
 
 
316
 
317
  @app.route("/api/v1/musicgen", methods=["POST"])
318
  def musicgen_api():
319
+ return musicgen_route()
 
 
 
 
 
 
 
 
 
 
 
320
 
321
  @app.route("/api/v1/translation", methods=["POST"])
322
  def translation_api():
323
+ return translation_route()
 
 
 
324
 
325
  @app.route("/api/v1/codegen", methods=["POST"])
326
  def codegen_api():
327
+ return codegen_route()
 
 
 
328
 
329
  @app.route("/api/v1/text_to_video", methods=["POST"])
330
  def text_to_video_api():
331
+ return text_to_video_route()
 
 
 
 
 
332
 
333
  @app.route("/api/v1/summarization", methods=["POST"])
334
  def summarization_api():
335
+ return summarization_route()
 
 
 
 
336
 
337
  @app.route("/api/v1/image_to_3d", methods=["POST"])
338
  def image_to_3d_api():
339
+ return image_to_3d_route()
 
 
 
 
340
 
341
+ @app.route("/api/v1/sadtalker", methods=["POST"])
342
  def sadtalker():
343
+ from sadtalker_api import router as sadtalker_router
344
+ return sadtalker_router.create_video()
345
 
346
  if __name__ == "__main__":
347
  app.run(host="0.0.0.0", port=7860)
background_tasks.py CHANGED
@@ -114,46 +114,12 @@ def background_training():
114
  except Exception:
115
  time.sleep(5)
116
 
117
- class ReasoningModel(nn.Module):
118
- def __init__(self, vocab_size, embed_dim=128, hidden_dim=128):
119
- super(ReasoningModel, self).__init__()
120
- self.embedding = nn.Embedding(vocab_size, embed_dim)
121
- self.rnn = nn.LSTM(embed_dim, hidden_dim, batch_first=True)
122
- self.fc = nn.Linear(hidden_dim, vocab_size)
123
- def forward(self, x, hidden=None):
124
- emb = self.embedding(x)
125
- output, hidden = self.rnn(emb, hidden)
126
- logits = self.fc(output)
127
- return logits, hidden
128
- def generate(self, input_seq, max_length=999999999, temperature=1.0):
129
- self.eval()
130
- tokens = input_seq.copy()
131
- hidden = None
132
- generated = []
133
- while True:
134
- input_tensor = torch.tensor([tokens], dtype=torch.long)
135
- logits, hidden = self.forward(input_tensor, hidden)
136
- next_token_logits = logits[0, -1, :] / temperature
137
- probabilities = torch.softmax(next_token_logits, dim=0)
138
- next_token = torch.multinomial(probabilities, 1).item()
139
- tokens.append(next_token)
140
- generated.append(next_token)
141
- if next_token == word_to_index.get("<EOS>"):
142
- break
143
- if len(generated) > max_length:
144
- break
145
- return generated
146
-
147
- reasoning_model = ReasoningModel(len(vocabulary))
148
-
149
  def perform_reasoning_stream(text_input, temperature=0.7, top_k=40, top_p=0.0, repetition_penalty=1.2):
150
- tokens = tokenize_text(text_input)
151
- update_vocabulary(tokens)
152
- tokens_indices = [word_to_index.get(token, 0) for token in tokens]
153
- generated_indices = reasoning_model.generate(tokens_indices, max_length=999999999, temperature=temperature)
154
- for idx in generated_indices:
155
- yield vocabulary[idx] + " "
156
- yield "<END_STREAM>"
157
 
158
  def background_reasoning_queue():
159
  global reasoning_queue, seen_responses
@@ -179,7 +145,7 @@ def background_reasoning_queue():
179
  if chunk == "<END_STREAM>":
180
  break
181
  full_response += chunk
182
- cleaned_response = re.sub(r'\s+(?=[.,,。])', '', full_response.replace("<|endoftext|>", "")).strip()
183
  if cleaned_response in seen_responses:
184
  final_response = "**Response is repetitive. Please try again or rephrase your query.**";
185
  resp_queue.put({"text": final_response})
 
114
  except Exception:
115
  time.sleep(5)
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  def perform_reasoning_stream(text_input, temperature=0.7, top_k=40, top_p=0.0, repetition_penalty=1.2):
118
+ for token in sample_sequence(text_input, model_gpt2, enc, length=999999999, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, device=device):
119
+ if token == "<END_STREAM>":
120
+ yield "<END_STREAM>"
121
+ break
122
+ yield token + " "
 
 
123
 
124
  def background_reasoning_queue():
125
  global reasoning_queue, seen_responses
 
145
  if chunk == "<END_STREAM>":
146
  break
147
  full_response += chunk
148
+ cleaned_response = re.sub(r'\s+(?=[.,,。])', '', full_response.replace("<|endoftext|>", "").strip())
149
  if cleaned_response in seen_responses:
150
  final_response = "**Response is repetitive. Please try again or rephrase your query.**";
151
  resp_queue.put({"text": final_response})
codegen_api.py CHANGED
@@ -2,10 +2,10 @@ from flask import jsonify, send_file, request
2
  from main import *
3
 
4
  def generate_code(prompt, output_path="output_code.py"):
5
- if codegen_model is None:
6
- return "Code generation model not initialized."
7
- input_ids = codegen_tokenizer.encode(prompt, return_tensors='pt').to(device)
8
- output = codegen_model.generate(input_ids, max_length=999999999, temperature=0.7, top_p=0.9)
9
  code = codegen_tokenizer.decode(output[0], skip_special_tokens=True)
10
  with open(output_path, "w") as file:
11
  file.write(code)
@@ -17,6 +17,6 @@ def codegen_api():
17
  if not prompt:
18
  return jsonify({"error": "Prompt is required"}), 400
19
  output_file = generate_code(prompt)
20
- if output_file == "Code generation model not initialized.":
21
  return jsonify({"error": "Code generation failed"}), 500
22
  return send_file(output_file, mimetype="text/x-python", as_attachment=True, download_name="output.py")
 
2
  from main import *
3
 
4
  def generate_code(prompt, output_path="output_code.py"):
5
+ if codegen_model is None or codegen_tokenizer is None:
6
+ return "Code generation model or tokenizer not initialized."
7
+ input_ids = codegen_tokenizer(prompt, return_tensors='pt').to(device)
8
+ output = codegen_model.generate(input_ids, max_length=2048, temperature=0.7, top_p=0.9)
9
  code = codegen_tokenizer.decode(output[0], skip_special_tokens=True)
10
  with open(output_path, "w") as file:
11
  file.write(code)
 
17
  if not prompt:
18
  return jsonify({"error": "Prompt is required"}), 400
19
  output_file = generate_code(prompt)
20
+ if output_file == "Code generation model or tokenizer not initialized.":
21
  return jsonify({"error": "Code generation failed"}), 500
22
  return send_file(output_file, mimetype="text/x-python", as_attachment=True, download_name="output.py")
configs.py CHANGED
@@ -58,11 +58,33 @@ class CodeGenConfig:
58
 
59
  class SummarizationConfig:
60
  def __init__(self):
61
- self.vocab_size = 10000
62
- self.embedding_dim = 256
63
- self.hidden_dim = 512
64
- self.num_layers = 2
65
- self.max_seq_len = 512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  class Clip4ClipConfig:
68
  def __init__(self, vocab_size=30522, hidden_size=512, num_hidden_layers=6, num_attention_heads=8, intermediate_size=2048, hidden_act="gelu", hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, bos_token_id=1, eos_token_id=2, **kwargs):
 
58
 
59
  class SummarizationConfig:
60
  def __init__(self):
61
+ self.vocab_size = 50265
62
+ self.max_position_embeddings = 1024
63
+ self.encoder_layers = 12
64
+ self.encoder_ffn_dim = 4096
65
+ self.encoder_attention_heads = 16
66
+ self.decoder_layers = 12
67
+ self.decoder_ffn_dim = 4096
68
+ self.decoder_attention_heads = 16
69
+ self.encoder_layerdrop = 0.0
70
+ self.decoder_layerdrop = 0.0
71
+ self.activation_function = "gelu"
72
+ self.d_model = 1024
73
+ self.dropout = 0.1
74
+ self.attention_dropout = 0.0
75
+ self.activation_dropout = 0.0
76
+ self.init_std = 0.02
77
+ self.classifier_dropout = 0.0
78
+ self.num_labels = 3
79
+ self.pad_token_id = 1
80
+ self.bos_token_id = 0
81
+ self.eos_token_id = 2
82
+ self.layer_norm_eps = 1e-05
83
+ self.num_beams = 4
84
+ self.early_stopping = True
85
+ self.max_length = 100
86
+ self.min_length = 30
87
+ self.scale_embedding = False
88
 
89
  class Clip4ClipConfig:
90
  def __init__(self, vocab_size=30522, hidden_size=512, num_hidden_layers=6, num_attention_heads=8, intermediate_size=2048, hidden_act="gelu", hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, pad_token_id=0, bos_token_id=1, eos_token_id=2, **kwargs):
constants.py CHANGED
@@ -158,13 +158,7 @@ html_code = """<!DOCTYPE html>
158
  top_p: top_p_val,
159
  reppenalty: repetition_penalty_val
160
  };
161
- eventSource = new EventSource('/generate_stream', {
162
- headers: {
163
- 'Content-Type': 'application/json'
164
- },
165
- method: 'POST',
166
- body: JSON.stringify(requestData)
167
- });
168
  eventSource.onmessage = function(event) {
169
  if (event.data === "<END_STREAM>") {
170
  eventSource.close();
 
158
  top_p: top_p_val,
159
  reppenalty: repetition_penalty_val
160
  };
161
+ eventSource = new EventSource('/api/v1/generate_stream?' + new URLSearchParams(requestData).toString());
 
 
 
 
 
 
162
  eventSource.onmessage = function(event) {
163
  if (event.data === "<END_STREAM>") {
164
  eventSource.close();
extensions.py CHANGED
@@ -159,6 +159,86 @@ class RealESRGANer():
159
  output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB)
160
  return [output_img, None]
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  def save_video_with_watermark(video_frames, audio_path, output_path, watermark_path='./assets/sadtalker_logo.png'):
163
  try:
164
  watermark = imageio.imread(watermark_path)
@@ -249,4 +329,4 @@ def get_prior_from_bfm(bfm_path):
249
  'u_tex': u_tex,
250
  'u_exp': u_exp
251
  }
252
- return prior_coeff
 
159
  output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB)
160
  return [output_img, None]
161
 
162
+ def enhance(self, img, outscale=None, tile=None, tile_pad=None, pre_pad=None, half=None):
163
+ h_input, w_input = img.shape[0:2]
164
+ if outscale is None:
165
+ outscale = self.scale
166
+ if tile is None:
167
+ tile = self.tile
168
+ if tile_pad is None:
169
+ tile_pad = self.tile_pad
170
+ if pre_pad is None:
171
+ pre_pad = self.pre_pad
172
+ if half is None:
173
+ half = self.half
174
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
175
+ img_tensor = img2tensor(img)
176
+ img_tensor = img_tensor.unsqueeze(0).to(self.device)
177
+ if half:
178
+ img_tensor = img_tensor.half()
179
+ mod_scale = self.mod_scale
180
+ h_pad, w_pad = 0, 0
181
+ if mod_scale is not None:
182
+ h_pad, w_pad = int(np.ceil(h_input / mod_scale) * mod_scale - h_input), int(np.ceil(w_input / mod_scale) * mod_scale - w_input)
183
+ img_tensor = F.pad(img_tensor, (0, w_pad, 0, h_pad), 'reflect')
184
+ window_size = 256
185
+ scale = self.scale
186
+ overlap_ratio = 0.5
187
+ if w_input * h_input < window_size**2:
188
+ tile = None
189
+ if tile is not None and tile > 0:
190
+ tile_overlap = tile * overlap_ratio
191
+ sf = scale
192
+ stride_w = math.ceil(tile - tile_overlap)
193
+ stride_h = math.ceil(tile - tile_overlap)
194
+ numW = math.ceil((w_input + tile_overlap) / stride_w)
195
+ numH = math.ceil((h_input + tile_overlap) / stride_h)
196
+ paddingW = (numW - 1) * stride_w + tile - w_input
197
+ paddingH = (numH - 1) * stride_h + tile - h_input
198
+ padding_bottom = int(max(paddingH, 0))
199
+ padding_right = int(max(paddingW, 0))
200
+ padding_left, padding_top = 0, 0
201
+ img_tensor = F.pad(img_tensor, (padding_left, padding_right, padding_top, padding_bottom), mode='reflect')
202
+ output_h, output_w = padding_top + h_input * scale + padding_bottom, padding_left + w_input * scale + padding_right
203
+ output_tensor = torch.zeros([1, 3, output_h, output_w], dtype=img_tensor.dtype, device=self.device)
204
+ windows = []
205
+ for row in range(numH):
206
+ for col in range(numW):
207
+ start_x = col * stride_w
208
+ start_y = row * stride_h
209
+ end_x = min(start_x + tile, img_tensor.shape[3])
210
+ end_y = min(start_y + tile, img_tensor.shape[2])
211
+ windows.append(img_tensor[:, :, start_y:end_y, start_x:end_x])
212
+ results = []
213
+ batch_size = 8
214
+ for i in range(0, len(windows), batch_size):
215
+ batch_windows = torch.stack(windows[i:min(i + batch_size, len(windows))], dim=0)
216
+ with torch.no_grad():
217
+ results.append(self.model(batch_windows))
218
+ results = torch.cat(results, dim=0)
219
+ count = 0
220
+ for row in range(numH):
221
+ for col in range(numW):
222
+ start_x = col * stride_w
223
+ start_y = row * stride_h
224
+ end_x = min(start_x + tile, img_tensor.shape[3])
225
+ end_y = min(start_y + tile, img_tensor.shape[2])
226
+ out_start_x, out_start_y = start_x * sf, start_y * sf
227
+ out_end_x, out_end_y = end_x * sf, end_y * sf
228
+ output_tensor[:, :, out_start_y:out_end_y, out_start_x:out_end_x] += results[count][:, :, :end_y * sf - out_start_y, :end_x * sf - out_start_x]
229
+ count += 1
230
+ forward_img = output_tensor[:, :, :h_input * sf, :w_input * sf]
231
+ else:
232
+ with torch.no_grad():
233
+ forward_img = self.model(img_tensor)
234
+ if half:
235
+ forward_img = forward_img.float()
236
+ output_img = tensor2img(forward_img.squeeze(0).clamp_(0, 1))
237
+ if mod_scale is not None:
238
+ output_img = output_img[:h_input * self.scale, :w_input * self.scale, ...]
239
+ output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB)
240
+ return [output_img, None]
241
+
242
  def save_video_with_watermark(video_frames, audio_path, output_path, watermark_path='./assets/sadtalker_logo.png'):
243
  try:
244
  watermark = imageio.imread(watermark_path)
 
329
  'u_tex': u_tex,
330
  'u_exp': u_exp
331
  }
332
+ return prior_coeff
imagegen_api.py CHANGED
@@ -10,10 +10,11 @@ def generate_image(prompt, output_path="output_image.png"):
10
  return "Image generation model not initialized."
11
 
12
  generator = torch.Generator(device=device).manual_seed(0)
13
- image = imagegen_model(
14
- prompt,
15
- generator=generator,
16
- ).images[0]
 
17
  image.save(output_path)
18
  return output_path
19
 
 
10
  return "Image generation model not initialized."
11
 
12
  generator = torch.Generator(device=device).manual_seed(0)
13
+ with torch.no_grad():
14
+ image = imagegen_model(
15
+ prompt,
16
+ generator=generator,
17
+ ).images[0]
18
  image.save(output_path)
19
  return output_path
20
 
main.py CHANGED
@@ -7,7 +7,7 @@ import re
7
  import json
8
  from flask import Flask
9
  from flask_cors import CORS
10
- from api import *
11
  from extensions import *
12
  from constants import *
13
  from configs import *
@@ -17,8 +17,7 @@ from model_loader import *
17
  from utils import *
18
  from background_tasks import generate_and_queue_text, background_training, background_reasoning_queue
19
  from text_generation import *
20
- from sadtalker_utils import *
21
- import torch
22
 
23
  state_dict = None
24
  enc = None
@@ -59,7 +58,7 @@ tts_model = None
59
  musicgen_model = None
60
 
61
  def load_models():
62
- global model_gpt2, enc, translation_model, codegen_model, codegen_tokenizer, codegen_vocabulary, codegen_index_to_word, codegen_word_to_index, summarization_model, imagegen_model, image_to_3d_model, text_to_video_model, sadtalker_instance, sentiment_model, stt_model, tts_model, musicgen_model, checkpoint_path, gfpgan_model_file, restoreformer_model_file, codeformer_model_file, realesrgan_model_file, kp_file, aud_file, wav_file, gen_file, mapx_file, den_file
63
  model_gpt2, enc = initialize_gpt2_model(GPT2_FOLDER, {MODEL_FILE: MODEL_URL, ENCODER_FILE: ENCODER_URL, VOCAB_FILE: VOCAB_URL, CONFIG_FILE: GPT2CONFHG})
64
  translation_model = initialize_translation_model(TRANSLATION_FOLDER, TRANSLATION_MODEL_FILES_URLS)
65
  codegen_model, codegen_tokenizer, codegen_vocabulary, codegen_index_to_word, codegen_word_to_index = initialize_codegen_model(CODEGEN_FOLDER, CODEGEN_FILES_URLS)
@@ -71,35 +70,7 @@ def load_models():
71
  stt_model = initialize_stt_model(STT_FOLDER, STT_FILES_URLS)
72
  tts_model = initialize_tts_model(TTS_FOLDER, TTS_FILES_URLS)
73
  musicgen_model = initialize_musicgen_model(MUSICGEN_FOLDER, MUSICGEN_FILES_URLS)
74
-
75
- class SimpleClassifier(torch.nn.Module):
76
- def __init__(self, vocab_size, num_classes):
77
- super(SimpleClassifier, self).__init__()
78
- self.embedding = torch.nn.Embedding(vocab_size, 128)
79
- self.linear = torch.nn.Linear(128, num_classes)
80
- def forward(self, x):
81
- embedded = self.embedding(x)
82
- pooled = torch.mean(embedded, dim=1)
83
- return self.linear(pooled)
84
-
85
- def tokenize_text(text):
86
- global vocabulary, word_to_index, index_to_word
87
- tokens = text.lower().split()
88
- for token in tokens:
89
- if token not in vocabulary:
90
- vocabulary.add(token)
91
- word_to_index[token] = len(index_to_word)
92
- index_to_word.append(token)
93
- return tokens
94
-
95
- def text_to_vector(text):
96
- global vocabulary, word_to_index
97
- tokens = tokenize_text(text)
98
- vector = torch.zeros(len(vocabulary))
99
- for token in tokens:
100
- if token in word_to_index:
101
- vector[word_to_index[token]] += 1
102
- return vector
103
 
104
  if __name__ == "__main__":
105
  nltk.download('punkt')
@@ -115,4 +86,4 @@ if __name__ == "__main__":
115
  background_threads.append(threading.Thread(target=background_reasoning_queue, daemon=True))
116
  for thread in background_threads:
117
  thread.start()
118
- app.run(host='0.0.0.0', port=7860)
 
7
  import json
8
  from flask import Flask
9
  from flask_cors import CORS
10
+ from api import app
11
  from extensions import *
12
  from constants import *
13
  from configs import *
 
17
  from utils import *
18
  from background_tasks import generate_and_queue_text, background_training, background_reasoning_queue
19
  from text_generation import *
20
+ from sadtalker_utils import SadTalker
 
21
 
22
  state_dict = None
23
  enc = None
 
58
  musicgen_model = None
59
 
60
  def load_models():
61
+ global model_gpt2, enc, translation_model, codegen_model, codegen_tokenizer, codegen_vocabulary, codegen_index_to_word, codegen_word_to_index, summarization_model, imagegen_model, image_to_3d_model, text_to_video_model, sadtalker_instance, sentiment_model, stt_model, tts_model, musicgen_model
62
  model_gpt2, enc = initialize_gpt2_model(GPT2_FOLDER, {MODEL_FILE: MODEL_URL, ENCODER_FILE: ENCODER_URL, VOCAB_FILE: VOCAB_URL, CONFIG_FILE: GPT2CONFHG})
63
  translation_model = initialize_translation_model(TRANSLATION_FOLDER, TRANSLATION_MODEL_FILES_URLS)
64
  codegen_model, codegen_tokenizer, codegen_vocabulary, codegen_index_to_word, codegen_word_to_index = initialize_codegen_model(CODEGEN_FOLDER, CODEGEN_FILES_URLS)
 
70
  stt_model = initialize_stt_model(STT_FOLDER, STT_FILES_URLS)
71
  tts_model = initialize_tts_model(TTS_FOLDER, TTS_FILES_URLS)
72
  musicgen_model = initialize_musicgen_model(MUSICGEN_FOLDER, MUSICGEN_FILES_URLS)
73
+ sadtalker_instance = SadTalker(checkpoint_path='./checkpoints', config_path='./src/config')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  if __name__ == "__main__":
76
  nltk.download('punkt')
 
86
  background_threads.append(threading.Thread(target=background_reasoning_queue, daemon=True))
87
  for thread in background_threads:
88
  thread.start()
89
+ app.run(host='0.0.0.0', port=7860)
model_loader.py CHANGED
@@ -265,7 +265,7 @@ class ResnetBlock(nn.Module):
265
  sc = self.conv_shortcut(x)
266
  h = F.silu(self.norm1(x))
267
  h = self.conv1(h)
268
- h = F.silu(self.norm2(h))
269
  h = self.conv2(h)
270
  return h + sc
271
 
 
265
  sc = self.conv_shortcut(x)
266
  h = F.silu(self.norm1(x))
267
  h = self.conv1(h)
268
+ h = F.silu(self.norm2(x))
269
  h = self.conv2(h)
270
  return h + sc
271
 
models.py CHANGED
@@ -91,4 +91,4 @@ class MusicGenModel(nn.Module):
91
  audio_output.append(predicted_token.cpu())
92
  input_tokens = torch.cat((input_tokens, predicted_token), dim=1)
93
  audio_output = torch.cat(audio_output, dim=1).float()
94
- return audio_output
 
91
  audio_output.append(predicted_token.cpu())
92
  input_tokens = torch.cat((input_tokens, predicted_token), dim=1)
93
  audio_output = torch.cat(audio_output, dim=1).float()
94
+ return audio_output
musicgen_api.py CHANGED
@@ -11,7 +11,7 @@ def generate_music(prompt, output_path="output_music.wav"):
11
 
12
  attributes = [prompt]
13
  sample_rate = 32000
14
- duration = 60
15
  audio_values = musicgen_model.sample(
16
  attributes=attributes,
17
  sample_rate=sample_rate,
 
11
 
12
  attributes = [prompt]
13
  sample_rate = 32000
14
+ duration = 10
15
  audio_values = musicgen_model.sample(
16
  attributes=attributes,
17
  sample_rate=sample_rate,
sadtalker_api.py CHANGED
@@ -157,7 +157,7 @@ async def websocket_endpoint(websocket: WebSocket):
157
  transcription_text_file = speech_to_text_func(tmp_audio_file.name)
158
  with open(transcription_text_file, 'r') as f:
159
  transcription_text = f.read()
160
- response_stream = perform_reasoning_stream(f"respond to this sentence in 10 words or less {transcription_text}", 0.7, 40, 0.0, 1.2)
161
  response_text = ""
162
  for chunk in response_stream:
163
  if chunk == "<END_STREAM>":
@@ -198,3 +198,7 @@ async def websocket_endpoint(websocket: WebSocket):
198
  except Exception as e:
199
  print(e)
200
  await websocket.send_json({"error":str(e)})
 
 
 
 
 
157
  transcription_text_file = speech_to_text_func(tmp_audio_file.name)
158
  with open(transcription_text_file, 'r') as f:
159
  transcription_text = f.read()
160
+ response_stream = perform_reasoning_stream(transcription_text, 0.7, 40, 0.0, 1.2)
161
  response_text = ""
162
  for chunk in response_stream:
163
  if chunk == "<END_STREAM>":
 
198
  except Exception as e:
199
  print(e)
200
  await websocket.send_json({"error":str(e)})
201
+
202
+ router = APIRouter()
203
+ router.add_api_route("/sadtalker", create_video, methods=["POST"])
204
+ router.add_api_websocket_route("/ws", websocket_endpoint)
sadtalker_utils.py CHANGED
@@ -269,32 +269,33 @@ class SadTalker:
269
  self.sadtalker_model = SadTalkerModel(self.cfg, device_id=[0])
270
 
271
  def get_cfg_defaults(self):
272
- return {
273
- 'MODEL': {
274
- 'CHECKPOINTS_DIR': '',
275
- 'CONFIG_DIR': '',
276
- 'DEVICE': self.device,
277
- 'SCALE': 64,
278
- 'NUM_VOXEL_FRAMES': 8,
279
- 'NUM_MOTION_FRAMES': 10,
280
- 'MAX_FEATURES': 256,
281
- 'DRIVEN_AUDIO_SAMPLE_RATE': 16000,
282
- 'VIDEO_FPS': 25,
283
- 'OUTPUT_VIDEO_FPS': None,
284
- 'OUTPUT_AUDIO_SAMPLE_RATE': None,
285
- 'USE_ENHANCER': False,
286
- 'ENHANCER_NAME': '',
287
- 'BG_UPSAMPLER': None,
288
- 'IS_HALF': False
289
- },
290
- 'INPUT_IMAGE': {}
291
- }
292
 
293
  def merge_from_file(self, filepath):
294
  if os.path.exists(filepath):
295
  with open(filepath, 'r') as f:
296
  cfg_from_file = yaml.safe_load(f)
297
- self.cfg.update(cfg_from_file)
 
298
 
299
  def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False,
300
  batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None,
@@ -310,7 +311,7 @@ class SadTalkerModel:
310
 
311
  def __init__(self, sadtalker_cfg, device_id=[0]):
312
  self.cfg = sadtalker_cfg
313
- self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu')
314
  self.sadtalker = SadTalkerInnerModel(sadtalker_cfg, device_id)
315
  self.preprocesser = self.sadtalker.preprocesser
316
  self.kp_extractor = self.sadtalker.kp_extractor
@@ -389,7 +390,7 @@ class SadTalkerInner:
389
  ref_pose_coeff = None
390
  ref_expression_coeff = None
391
  audio_tensor, audio_sample_rate = proc.process_audio(self.driven_audio,
392
- self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE'])
393
  batch = {
394
  'source_image': source_image_tensor.unsqueeze(0).to(self.device),
395
  'audio': audio_tensor.unsqueeze(0).to(self.device),
@@ -455,12 +456,11 @@ class SadTalkerInner:
455
  audio_name = os.path.splitext(os.path.basename(self.driven_audio))[0]
456
  output_video_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '.mp4')
457
  self.output_path = output_video_path
458
- video_fps = self.sadtalker_model.cfg['MODEL']['VIDEO_FPS'] if self.sadtalker_model.cfg['MODEL'][
459
- 'OUTPUT_VIDEO_FPS'] is None else \
460
- self.sadtalker_model.cfg['MODEL']['OUTPUT_VIDEO_FPS']
461
- audio_output_sample_rate = self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE'] if \
462
- self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE'] is None else \
463
- self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE']
464
  if self.use_enhancer:
465
  enhanced_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '_enhanced.mp4')
466
  save_video_with_watermark(output_video, self.driven_audio, enhanced_path)
@@ -489,13 +489,12 @@ class SadTalkerInnerModel:
489
 
490
  def __init__(self, sadtalker_cfg, device_id=[0]):
491
  self.cfg = sadtalker_cfg
492
- self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu')
493
  self.preprocesser = Preprocesser(sadtalker_cfg, self.device)
494
  self.kp_extractor = KeyPointExtractor(sadtalker_cfg, self.device)
495
  self.audio_to_coeff = Audio2Coeff(sadtalker_cfg, self.device)
496
  self.animate_from_coeff = AnimateFromCoeff(sadtalker_cfg, self.device)
497
- self.face_enhancer = FaceEnhancer(sadtalker_cfg, self.device) if sadtalker_cfg['MODEL'][
498
- 'USE_ENHANCER'] else None
499
  self.generator = Generator(sadtalker_cfg, self.device)
500
  self.mapping = Mapping(sadtalker_cfg, self.device)
501
  self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, self.device)
@@ -506,10 +505,10 @@ class Preprocesser:
506
  def __init__(self, sadtalker_cfg, device):
507
  self.cfg = sadtalker_cfg
508
  self.device = device
509
- if self.cfg['INPUT_IMAGE'].get('OLD_VERSION', False):
510
- self.face3d_helper = Face3DHelperOld(self.cfg['INPUT_IMAGE'].get('LOCAL_PCA_PATH', ''), device)
511
  else:
512
- self.face3d_helper = Face3DHelper(self.cfg['INPUT_IMAGE'].get('LOCAL_PCA_PATH', ''), device)
513
  self.mouth_detector = MouthDetector()
514
 
515
  def crop(self, source_image_pil, preprocess_type, size=256):
@@ -543,7 +542,7 @@ class Preprocesser:
543
  cropped_image_pil = cropped_image_pil.resize((size, size), Image.Resampling.LANCZOS)
544
  source_image_tensor = self.img2tensor(cropped_image_pil)
545
  return source_image_tensor, [[y_min, y_max], [x_min, x_max], old_size, cropped_image_pil.size], os.path.basename(
546
- self.cfg['INPUT_IMAGE'].get('SOURCE_IMAGE', ''))
547
 
548
  def img2tensor(self, img):
549
  img = np.array(img).astype(np.float32) / 255.0
@@ -577,7 +576,7 @@ class Preprocesser:
577
  return ref_expression_coeff
578
 
579
  def generate_idles_pose(self, length_of_audio, pose_style):
580
- num_frames = int(length_of_audio * self.cfg['MODEL']['VIDEO_FPS'])
581
  ref_pose_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device)
582
  start_pose = self.generate_still_pose(pose_style)
583
  end_pose = self.generate_still_pose(pose_style)
@@ -587,7 +586,7 @@ class Preprocesser:
587
  return ref_pose_coeff
588
 
589
  def generate_idles_expression(self, length_of_audio):
590
- num_frames = int(length_of_audio * self.cfg['MODEL']['VIDEO_FPS'])
591
  ref_expression_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device)
592
  start_exp = self.generate_still_expression(1.0)
593
  end_exp = self.generate_still_expression(1.0)
@@ -601,11 +600,11 @@ class KeyPointExtractor(nn.Module):
601
 
602
  def __init__(self, sadtalker_cfg, device):
603
  super(KeyPointExtractor, self).__init__()
604
- self.kp_extractor = OcclusionAwareKPDetector(kp_channels=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'],
605
  num_kp=10,
606
  num_dilation_blocks=2,
607
  dropout_rate=0.1).to(device)
608
- checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'kp_detector.safetensors')
609
  self.load_kp_detector(checkpoint_path, device)
610
 
611
  def load_kp_detector(self, checkpoint_path, device):
@@ -628,12 +627,12 @@ class Audio2Coeff(nn.Module):
628
  def __init__(self, sadtalker_cfg, device):
629
  super(Audio2Coeff, self).__init__()
630
  self.audio_model = Wav2Vec2Model().to(device)
631
- checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'wav2vec2.pth')
632
  self.load_audio_model(checkpoint_path, device)
633
  self.pose_mapper = AudioCoeffsPredictor(2048, 64).to(device)
634
  self.exp_mapper = AudioCoeffsPredictor(2048, 64).to(device)
635
  self.blink_mapper = AudioCoeffsPredictor(2048, 1).to(device)
636
- mapping_checkpoint = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'audio2pose_00140-model.pth')
637
  self.load_mapping_model(mapping_checkpoint, device)
638
 
639
  def load_audio_model(self, checkpoint_path, device):
@@ -753,13 +752,13 @@ class Generator(nn.Module):
753
 
754
  def __init__(self, sadtalker_cfg, device):
755
  super(Generator, self).__init__()
756
- self.generator = Hourglass(block_expansion=sadtalker_cfg['MODEL']['SCALE'],
757
- num_blocks=sadtalker_cfg['MODEL']['NUM_VOXEL_FRAMES'],
758
- max_features=sadtalker_cfg['MODEL']['MAX_FEATURES'],
759
  num_channels=3,
760
  kp_size=10,
761
- num_deform_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES']).to(device)
762
- checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'generator.pth')
763
  self.load_generator(checkpoint_path, device)
764
 
765
  def load_generator(self, checkpoint_path, device):
@@ -786,7 +785,7 @@ class Mapping(nn.Module):
786
  def __init__(self, sadtalker_cfg, device):
787
  super(Mapping, self).__init__()
788
  self.mapping_net = MappingNet(num_coeffs=64, num_layers=3, hidden_dim=128).to(device)
789
- checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'mapping.pth')
790
  self.load_mapping_net(checkpoint_path, device)
791
  self.f_3d_mean = torch.zeros(1, 64, device=device)
792
 
@@ -814,10 +813,10 @@ class OcclusionAwareDenseMotion(nn.Module):
814
  super(OcclusionAwareDenseMotion, self).__init__()
815
  self.dense_motion_network = DenseMotionNetwork(num_kp=10,
816
  num_channels=3,
817
- block_expansion=sadtalker_cfg['MODEL']['SCALE'],
818
- num_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'] - 1,
819
- max_features=sadtalker_cfg['MODEL']['MAX_FEATURES']).to(device)
820
- checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'dense_motion.pth')
821
  self.load_dense_motion_network(checkpoint_path, device)
822
 
823
  def load_dense_motion_network(self, checkpoint_path, device):
@@ -839,20 +838,20 @@ class FaceEnhancer(nn.Module):
839
 
840
  def __init__(self, sadtalker_cfg, device):
841
  super(FaceEnhancer, self).__init__()
842
- enhancer_name = sadtalker_cfg['MODEL']['ENHANCER_NAME']
843
- bg_upsampler = sadtalker_cfg['MODEL']['BG_UPSAMPLER']
844
  if enhancer_name == 'gfpgan':
845
  from gfpgan import GFPGANer
846
- self.face_enhancer = GFPGANer(model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'GFPGANv1.4.pth'),
847
  upscale=1,
848
  arch='clean',
849
  channel_multiplier=2,
850
  bg_upsampler=bg_upsampler)
851
  elif enhancer_name == 'realesrgan':
852
  from realesrgan import RealESRGANer
853
- half = False if device == 'cpu' else sadtalker_cfg['MODEL']['IS_HALF']
854
  self.face_enhancer = RealESRGANer(scale=2,
855
- model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'],
856
  'RealESRGAN_x2plus.pth'),
857
  tile=0,
858
  tile_pad=10,
 
269
  self.sadtalker_model = SadTalkerModel(self.cfg, device_id=[0])
270
 
271
  def get_cfg_defaults(self):
272
+ return CN(
273
+ MODEL=CN(
274
+ CHECKPOINTS_DIR='',
275
+ CONFIG_DIR='',
276
+ DEVICE=self.device,
277
+ SCALE=64,
278
+ NUM_VOXEL_FRAMES=8,
279
+ NUM_MOTION_FRAMES=10,
280
+ MAX_FEATURES=256,
281
+ DRIVEN_AUDIO_SAMPLE_RATE=16000,
282
+ VIDEO_FPS=25,
283
+ OUTPUT_VIDEO_FPS=None,
284
+ OUTPUT_AUDIO_SAMPLE_RATE=None,
285
+ USE_ENHANCER=False,
286
+ ENHANCER_NAME='',
287
+ BG_UPSAMPLER=None,
288
+ IS_HALF=False
289
+ ),
290
+ INPUT_IMAGE=CN()
291
+ )
292
 
293
  def merge_from_file(self, filepath):
294
  if os.path.exists(filepath):
295
  with open(filepath, 'r') as f:
296
  cfg_from_file = yaml.safe_load(f)
297
+ self.cfg.MODEL.update(CN(cfg_from_file['MODEL']))
298
+ self.cfg.INPUT_IMAGE.update(CN(cfg_from_file['INPUT_IMAGE']))
299
 
300
  def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False,
301
  batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None,
 
311
 
312
  def __init__(self, sadtalker_cfg, device_id=[0]):
313
  self.cfg = sadtalker_cfg
314
+ self.device = sadtalker_cfg.MODEL.get('DEVICE', 'cpu')
315
  self.sadtalker = SadTalkerInnerModel(sadtalker_cfg, device_id)
316
  self.preprocesser = self.sadtalker.preprocesser
317
  self.kp_extractor = self.sadtalker.kp_extractor
 
390
  ref_pose_coeff = None
391
  ref_expression_coeff = None
392
  audio_tensor, audio_sample_rate = proc.process_audio(self.driven_audio,
393
+ self.sadtalker_model.cfg.MODEL.DRIVEN_AUDIO_SAMPLE_RATE)
394
  batch = {
395
  'source_image': source_image_tensor.unsqueeze(0).to(self.device),
396
  'audio': audio_tensor.unsqueeze(0).to(self.device),
 
456
  audio_name = os.path.splitext(os.path.basename(self.driven_audio))[0]
457
  output_video_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '.mp4')
458
  self.output_path = output_video_path
459
+ video_fps = self.sadtalker_model.cfg.MODEL.VIDEO_FPS if self.sadtalker_model.cfg.MODEL.OUTPUT_VIDEO_FPS is None else \
460
+ self.sadtalker_model.cfg.MODEL.OUTPUT_VIDEO_FPS
461
+ audio_output_sample_rate = self.sadtalker_model.cfg.MODEL.DRIVEN_AUDIO_SAMPLE_RATE if \
462
+ self.sadtalker_model.cfg.MODEL.OUTPUT_AUDIO_SAMPLE_RATE is None else \
463
+ self.sadtalker_model.cfg.MODEL.OUTPUT_AUDIO_SAMPLE_RATE
 
464
  if self.use_enhancer:
465
  enhanced_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '_enhanced.mp4')
466
  save_video_with_watermark(output_video, self.driven_audio, enhanced_path)
 
489
 
490
  def __init__(self, sadtalker_cfg, device_id=[0]):
491
  self.cfg = sadtalker_cfg
492
+ self.device = sadtalker_cfg.MODEL.DEVICE
493
  self.preprocesser = Preprocesser(sadtalker_cfg, self.device)
494
  self.kp_extractor = KeyPointExtractor(sadtalker_cfg, self.device)
495
  self.audio_to_coeff = Audio2Coeff(sadtalker_cfg, self.device)
496
  self.animate_from_coeff = AnimateFromCoeff(sadtalker_cfg, self.device)
497
+ self.face_enhancer = FaceEnhancer(sadtalker_cfg, self.device) if sadtalker_cfg.MODEL.USE_ENHANCER else None
 
498
  self.generator = Generator(sadtalker_cfg, self.device)
499
  self.mapping = Mapping(sadtalker_cfg, self.device)
500
  self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, self.device)
 
505
  def __init__(self, sadtalker_cfg, device):
506
  self.cfg = sadtalker_cfg
507
  self.device = device
508
+ if self.cfg.INPUT_IMAGE.get('OLD_VERSION', False):
509
+ self.face3d_helper = Face3DHelperOld(self.cfg.INPUT_IMAGE.get('LOCAL_PCA_PATH', ''), device)
510
  else:
511
+ self.face3d_helper = Face3DHelper(self.cfg.INPUT_IMAGE.get('LOCAL_PCA_PATH', ''), device)
512
  self.mouth_detector = MouthDetector()
513
 
514
  def crop(self, source_image_pil, preprocess_type, size=256):
 
542
  cropped_image_pil = cropped_image_pil.resize((size, size), Image.Resampling.LANCZOS)
543
  source_image_tensor = self.img2tensor(cropped_image_pil)
544
  return source_image_tensor, [[y_min, y_max], [x_min, x_max], old_size, cropped_image_pil.size], os.path.basename(
545
+ self.cfg.INPUT_IMAGE.get('SOURCE_IMAGE', ''))
546
 
547
  def img2tensor(self, img):
548
  img = np.array(img).astype(np.float32) / 255.0
 
576
  return ref_expression_coeff
577
 
578
  def generate_idles_pose(self, length_of_audio, pose_style):
579
+ num_frames = int(length_of_audio * self.cfg.MODEL.VIDEO_FPS)
580
  ref_pose_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device)
581
  start_pose = self.generate_still_pose(pose_style)
582
  end_pose = self.generate_still_pose(pose_style)
 
586
  return ref_pose_coeff
587
 
588
  def generate_idles_expression(self, length_of_audio):
589
+ num_frames = int(length_of_audio * self.cfg.MODEL.VIDEO_FPS)
590
  ref_expression_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device)
591
  start_exp = self.generate_still_expression(1.0)
592
  end_exp = self.generate_still_expression(1.0)
 
600
 
601
  def __init__(self, sadtalker_cfg, device):
602
  super(KeyPointExtractor, self).__init__()
603
+ self.kp_extractor = OcclusionAwareKPDetector(kp_channels=sadtalker_cfg.MODEL.NUM_MOTION_FRAMES,
604
  num_kp=10,
605
  num_dilation_blocks=2,
606
  dropout_rate=0.1).to(device)
607
+ checkpoint_path = os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR, 'kp_detector.safetensors')
608
  self.load_kp_detector(checkpoint_path, device)
609
 
610
  def load_kp_detector(self, checkpoint_path, device):
 
627
  def __init__(self, sadtalker_cfg, device):
628
  super(Audio2Coeff, self).__init__()
629
  self.audio_model = Wav2Vec2Model().to(device)
630
+ checkpoint_path = os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR, 'wav2vec2.pth')
631
  self.load_audio_model(checkpoint_path, device)
632
  self.pose_mapper = AudioCoeffsPredictor(2048, 64).to(device)
633
  self.exp_mapper = AudioCoeffsPredictor(2048, 64).to(device)
634
  self.blink_mapper = AudioCoeffsPredictor(2048, 1).to(device)
635
+ mapping_checkpoint = os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR, 'audio2pose_00140-model.pth')
636
  self.load_mapping_model(mapping_checkpoint, device)
637
 
638
  def load_audio_model(self, checkpoint_path, device):
 
752
 
753
  def __init__(self, sadtalker_cfg, device):
754
  super(Generator, self).__init__()
755
+ self.generator = Hourglass(block_expansion=sadtalker_cfg.MODEL.SCALE,
756
+ num_blocks=sadtalker_cfg.MODEL.NUM_VOXEL_FRAMES,
757
+ max_features=sadtalker_cfg.MODEL.MAX_FEATURES,
758
  num_channels=3,
759
  kp_size=10,
760
+ num_deform_blocks=sadtalker_cfg.MODEL.NUM_MOTION_FRAMES).to(device)
761
+ checkpoint_path = os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR, 'generator.pth')
762
  self.load_generator(checkpoint_path, device)
763
 
764
  def load_generator(self, checkpoint_path, device):
 
785
  def __init__(self, sadtalker_cfg, device):
786
  super(Mapping, self).__init__()
787
  self.mapping_net = MappingNet(num_coeffs=64, num_layers=3, hidden_dim=128).to(device)
788
+ checkpoint_path = os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR, 'mapping.pth')
789
  self.load_mapping_net(checkpoint_path, device)
790
  self.f_3d_mean = torch.zeros(1, 64, device=device)
791
 
 
813
  super(OcclusionAwareDenseMotion, self).__init__()
814
  self.dense_motion_network = DenseMotionNetwork(num_kp=10,
815
  num_channels=3,
816
+ block_expansion=sadtalker_cfg.MODEL.SCALE,
817
+ num_blocks=sadtalker_cfg.MODEL.NUM_MOTION_FRAMES - 1,
818
+ max_features=sadtalker_cfg.MODEL.MAX_FEATURES).to(device)
819
+ checkpoint_path = os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR, 'dense_motion.pth')
820
  self.load_dense_motion_network(checkpoint_path, device)
821
 
822
  def load_dense_motion_network(self, checkpoint_path, device):
 
838
 
839
  def __init__(self, sadtalker_cfg, device):
840
  super(FaceEnhancer, self).__init__()
841
+ enhancer_name = sadtalker_cfg.MODEL.ENHANCER_NAME
842
+ bg_upsampler = sadtalker_cfg.MODEL.BG_UPSAMPLER
843
  if enhancer_name == 'gfpgan':
844
  from gfpgan import GFPGANer
845
+ self.face_enhancer = GFPGANer(model_path=os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR, 'GFPGANv1.4.pth'),
846
  upscale=1,
847
  arch='clean',
848
  channel_multiplier=2,
849
  bg_upsampler=bg_upsampler)
850
  elif enhancer_name == 'realesrgan':
851
  from realesrgan import RealESRGANer
852
+ half = False if device == 'cpu' else sadtalker_cfg.MODEL.IS_HALF
853
  self.face_enhancer = RealESRGANer(scale=2,
854
+ model_path=os.path.join(sadtalker_cfg.MODEL.CHECKPOINTS_DIR,
855
  'RealESRGAN_x2plus.pth'),
856
  tile=0,
857
  tile_pad=10,
sentiment_api.py CHANGED
@@ -2,25 +2,28 @@ from flask import jsonify
2
  from main import *
3
  import torch
4
 
5
- def analyze_sentiment(text, output_path="output_sentiment.json"):
6
  if sentiment_model is None:
7
- return "Sentiment model not initialized."
 
 
 
 
 
8
 
9
- input_tokens = sentiment_model.tokenizer(text, return_tensors="pt", padding=True).to(device)
10
  with torch.no_grad():
11
- sentiment_logits = sentiment_model(input_tokens['input_ids'])
12
- predicted_class_id = torch.argmax(sentiment_logits, dim=-1).item()
13
- sentiment_label = sentiment_model.config.id2label[predicted_class_id]
14
- probability = torch.softmax(sentiment_logits, dim=-1)[0][predicted_class_id].item()
15
 
16
- return {"sentiment": sentiment_label, "probability": probability}
17
 
18
  def sentiment_api():
19
  data = request.get_json()
20
  text = data.get('text')
21
  if not text:
22
  return jsonify({"error": "Text is required"}), 400
23
- output_file = analyze_sentiment(text)
24
- if output_file == "Sentiment model not initialized.":
25
- return jsonify({"error": "Sentiment analysis failed"}), 500
26
- return jsonify(output_file)
 
2
  from main import *
3
  import torch
4
 
5
+ def analyze_sentiment(text):
6
  if sentiment_model is None:
7
+ return {"error": "Sentiment model not initialized."}
8
+
9
+ features = [ord(c) for c in text[:10]]
10
+ while len(features) < 10:
11
+ features.append(0)
12
+ features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(device)
13
 
 
14
  with torch.no_grad():
15
+ output = sentiment_model(features_tensor)
16
+ sentiment_idx = torch.argmax(output, dim=1).item()
17
+ sentiment_label = "positive" if sentiment_idx == 1 else "negative"
 
18
 
19
+ return {"sentiment": sentiment_label}
20
 
21
  def sentiment_api():
22
  data = request.get_json()
23
  text = data.get('text')
24
  if not text:
25
  return jsonify({"error": "Text is required"}), 400
26
+ output = analyze_sentiment(text)
27
+ if "error" in output:
28
+ return jsonify({"error": output["error"]}), 500
29
+ return jsonify(output)
stt_api.py CHANGED
@@ -5,9 +5,9 @@ from main import *
5
  import torch
6
  import torchaudio
7
 
8
- def speech_to_text_func(audio_path, output_path="output_stt.txt"):
9
  if stt_model is None:
10
- return "STT model not initialized."
11
 
12
  waveform, sample_rate = torchaudio.load(audio_path)
13
  if waveform.ndim > 1:
@@ -18,9 +18,7 @@ def speech_to_text_func(audio_path, output_path="output_stt.txt"):
18
  predicted_ids = torch.argmax(logits, dim=-1)
19
  transcription = stt_model.tokenizer.decode(predicted_ids[0].cpu().tolist())
20
 
21
- with open(output_path, "w") as file:
22
- file.write(transcription)
23
- return output_path
24
 
25
  def stt_api():
26
  if 'audio' not in request.files:
@@ -28,8 +26,8 @@ def stt_api():
28
  audio_file = request.files['audio']
29
  temp_audio_path = f"temp_audio_{uuid.uuid4()}.wav"
30
  audio_file.save(temp_audio_path)
31
- output_file = speech_to_text_func(temp_audio_path)
32
  os.remove(temp_audio_path)
33
- if output_file == "STT model not initialized.":
34
- return jsonify({"error": "STT failed"}), 500
35
- return send_file(output_file, mimetype="text/plain", as_attachment=True, download_name="output.txt")
 
5
  import torch
6
  import torchaudio
7
 
8
+ def speech_to_text_func(audio_path):
9
  if stt_model is None:
10
+ return {"error": "STT model not initialized."}
11
 
12
  waveform, sample_rate = torchaudio.load(audio_path)
13
  if waveform.ndim > 1:
 
18
  predicted_ids = torch.argmax(logits, dim=-1)
19
  transcription = stt_model.tokenizer.decode(predicted_ids[0].cpu().tolist())
20
 
21
+ return {"text": transcription}
 
 
22
 
23
  def stt_api():
24
  if 'audio' not in request.files:
 
26
  audio_file = request.files['audio']
27
  temp_audio_path = f"temp_audio_{uuid.uuid4()}.wav"
28
  audio_file.save(temp_audio_path)
29
+ output = speech_to_text_func(temp_audio_path)
30
  os.remove(temp_audio_path)
31
+ if "error" in output:
32
+ return jsonify({"error": output["error"]}), 500
33
+ return jsonify(output)
summarization_api.py CHANGED
@@ -3,15 +3,14 @@ from main import *
3
  import torch
4
 
5
  def summarize_text(text, output_path="output_summary.txt"):
6
- if summarization_model is None:
7
- return "Summarization model not initialized."
8
 
9
- input_tokens = [summarization_word_to_index.get(word.lower(), 1) for word in text.split()]
10
- input_tensor = torch.tensor([input_tokens], dtype=torch.long).to(device)
11
 
12
  with torch.no_grad():
13
- summary_ids = summarization_model.generate(input_tensor, num_beams=4, max_length=999999999, early_stopping=True)
14
- summary_text = summarization_model.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
15
 
16
  with open(output_path, "w") as file:
17
  file.write(summary_text)
@@ -23,6 +22,6 @@ def summarization_api():
23
  if not text:
24
  return jsonify({"error": "Text is required"}), 400
25
  output_file = summarize_text(text)
26
- if output_file == "Summarization model not initialized.":
27
  return jsonify({"error": "Summarization failed"}), 500
28
  return send_file(output_file, mimetype="text/plain", as_attachment=True, download_name="output_summary.txt")
 
3
  import torch
4
 
5
  def summarize_text(text, output_path="output_summary.txt"):
6
+ if summarization_model is None or summarization_tokenizer is None:
7
+ return "Summarization model or tokenizer not initialized."
8
 
9
+ input_ids = summarization_tokenizer.encode(text, return_tensors="pt").to(device)
 
10
 
11
  with torch.no_grad():
12
+ summary_ids = summarization_model.generate(input_ids, num_beams=4, max_length=100, early_stopping=True)
13
+ summary_text = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
14
 
15
  with open(output_path, "w") as file:
16
  file.write(summary_text)
 
22
  if not text:
23
  return jsonify({"error": "Text is required"}), 400
24
  output_file = summarize_text(text)
25
+ if output_file == "Summarization model or tokenizer not initialized.":
26
  return jsonify({"error": "Summarization failed"}), 500
27
  return send_file(output_file, mimetype="text/plain", as_attachment=True, download_name="output_summary.txt")
text_generation.py CHANGED
@@ -22,124 +22,114 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')
22
  return logits
23
 
24
  def sample_sequence(prompt, model, enc, length, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
25
- start_time = time.time()
26
  context_tokens = enc.encode(prompt)
27
  context_tokens_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device)
28
  generated = context_tokens
29
- past = None
30
- text_generated_count = 0
31
- past_key_values = past if past is not None else None
32
 
33
  with torch.no_grad():
34
- outputs = model(context_tokens_tensor, past_key_values=past_key_values)
35
- next_token_logits = outputs[0][:, -1, :] / temperature
36
- past = outputs[1]
37
- for token_index in set(generated):
38
- next_token_logits[0, token_index] /= repetition_penalty
39
- filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
40
- if temperature == 0:
41
- next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(0)
42
- else:
43
- next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
44
- generated += next_token.tolist()[0]
45
- text_generated_count += 1
46
- token = next_token.tolist()[0][0]
47
- yield enc.decode([token])
48
- if token == enc.encoder[END_OF_TEXT_TOKEN]:
49
- yield "<END_STREAM>"
50
-
51
 
52
  def sample_sequence_codegen(prompt, model, tokenizer, length, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
53
- start_time = time.time()
54
  context_tokens = tokenizer.encode(prompt)
55
  context_tokens_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device).unsqueeze(0)
56
  generated = context_tokens
57
- past = None
58
- text_generated_count = 0
59
  with torch.no_grad():
60
- outputs = model(input_ids=context_tokens_tensor, past_key_values=past, labels=None)
61
- next_token_logits = outputs[0][:, -1, :] / temperature
62
- past = outputs[1]
63
- for token_index in set(generated):
64
- next_token_logits[0, token_index] /= repetition_penalty
65
- filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
66
- if temperature == 0:
67
- next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(0)
68
- else:
69
- next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
70
- generated.append(next_token.tolist()[0][0])
71
- text_generated_count += 1
72
- token = next_token.tolist()[0][0]
73
- yield tokenizer.decode([token])
74
- if token == 50256:
75
- yield "<END_STREAM>"
76
-
77
 
78
  def perform_reasoning_stream(text_input, temperature, top_k, top_p, repetition_penalty):
79
- try:
80
- prompt_text = system_prompt + "\n\n"
81
- prompt_text += "User: " + text_input + "\nCyrah: "
82
- reasoning_prompt = prompt_text
83
 
84
- ddgs = DDGS()
85
- search_results = [r for r in ddgs.text(text_input, max_results=MAX_XDD)]
86
- if search_results:
87
- prompt_text += "\nWeb Search Results:\n"
88
- for result in search_results:
89
- prompt_text += f"- {result['body']}\n"
90
- prompt_text += "\n"
91
 
92
- generated_text_stream = []
93
- stream_type = "text"
94
 
95
- if "code" in text_input.lower() or "program" in text_input.lower():
96
- if codegen_model and codegen_tokenizer:
97
- generated_text_stream = sample_sequence_codegen(
98
- prompt=reasoning_prompt,
99
- model=codegen_model,
100
- tokenizer=codegen_tokenizer,
101
- length=999999999,
102
- temperature=temperature,
103
- top_k=top_k,
104
- top_p=top_p,
105
- repetition_penalty=repetition_penalty,
106
- device=device
107
- )
108
- stream_type = "text"
109
- elif "summarize" in text_input.lower() or "summary" in text_input.lower():
110
- if summarization_model:
111
- summary = summarize_func(text_input)
112
- yield f"SUMMARY_TEXT:{summary}"
113
- yield "<END_STREAM>"
114
- stream_type = "summary"
115
- else:
116
- if model_gpt2 and enc:
117
- generated_text_stream = sample_sequence(
118
- prompt=reasoning_prompt,
119
- model=model_gpt2,
120
- enc=enc,
121
- length=999999999,
122
- temperature=temperature,
123
- top_k=top_k,
124
- top_p=top_p,
125
- repetition_penalty=repetition_penalty,
126
- device=device
127
- )
128
- stream_type = "text"
129
 
130
- accumulated_text = ""
131
- if stream_type == "text":
132
- for token in generated_text_stream:
133
- if token == "<END_STREAM>":
134
- yield accumulated_text
135
- yield "<END_STREAM>"
136
- return
137
- if token == END_OF_TEXT_TOKEN:
138
- accumulated_text += END_OF_TEXT_TOKEN
139
- continue
140
- if token:
141
- accumulated_text += token
142
- except Exception as e:
143
- print(f"Reasoning Error: {e}")
144
- yield "Error during reasoning. Please try again."
145
- yield "<END_STREAM>"
 
22
  return logits
23
 
24
  def sample_sequence(prompt, model, enc, length, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
 
25
  context_tokens = enc.encode(prompt)
26
  context_tokens_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device)
27
  generated = context_tokens
28
+ past_key_values = None
 
 
29
 
30
  with torch.no_grad():
31
+ for _ in range(length):
32
+ outputs = model(context_tokens_tensor, past_key_values=past_key_values)
33
+ next_token_logits = outputs[0][:, -1, :] / temperature
34
+ past_key_values = outputs[1]
35
+ for token_index in set(generated):
36
+ next_token_logits[0, token_index] /= repetition_penalty
37
+ filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
38
+ if temperature == 0:
39
+ next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(0)
40
+ else:
41
+ next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
42
+ generated += next_token.tolist()[0]
43
+ token = next_token.tolist()[0][0]
44
+ yield enc.decode([token])
45
+ if token == enc.encoder[END_OF_TEXT_TOKEN]:
46
+ yield "<END_STREAM>"
47
+ return
48
 
49
  def sample_sequence_codegen(prompt, model, tokenizer, length, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
 
50
  context_tokens = tokenizer.encode(prompt)
51
  context_tokens_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device).unsqueeze(0)
52
  generated = context_tokens
53
+ past_key_values = None
 
54
  with torch.no_grad():
55
+ for _ in range(length):
56
+ outputs = model(input_ids=context_tokens_tensor, past_key_values=past_key_values, labels=None)
57
+ next_token_logits = outputs[0][:, -1, :] / temperature
58
+ past_key_values = outputs[1]
59
+ for token_index in set(generated):
60
+ next_token_logits[0, token_index] /= repetition_penalty
61
+ filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
62
+ if temperature == 0:
63
+ next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(0)
64
+ else:
65
+ next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
66
+ generated.append(next_token.tolist()[0][0])
67
+ token = next_token.tolist()[0][0]
68
+ yield tokenizer.decode([token])
69
+ if token == 50256:
70
+ yield "<END_STREAM>"
71
+ return
72
 
73
  def perform_reasoning_stream(text_input, temperature, top_k, top_p, repetition_penalty):
74
+ prompt_text = SYSTEM_PROMPT + "\n\n"
75
+ prompt_text += "User: " + text_input + "\nAssistant:"
76
+ reasoning_prompt = prompt_text
 
77
 
78
+ ddgs = DDGS()
79
+ search_results = [r for r in ddgs.text(text_input, max_results=MAX_XDD)]
80
+ if search_results:
81
+ prompt_text += "\nWeb Search Results:\n"
82
+ for result in search_results:
83
+ prompt_text += f"- {result['body']}\n"
84
+ prompt_text += "\n"
85
 
86
+ generated_text_stream = []
87
+ stream_type = "text"
88
 
89
+ if "code" in text_input.lower() or "program" in text_input.lower():
90
+ if codegen_model and codegen_tokenizer:
91
+ generated_text_stream = sample_sequence_codegen(
92
+ prompt=reasoning_prompt,
93
+ model=codegen_model,
94
+ tokenizer=codegen_tokenizer,
95
+ length=999999999,
96
+ temperature=temperature,
97
+ top_k=top_k,
98
+ top_p=top_p,
99
+ repetition_penalty=repetition_penalty,
100
+ device=device
101
+ )
102
+ stream_type = "text"
103
+ elif "summarize" in text_input.lower() or "summary" in text_input.lower():
104
+ if summarization_model:
105
+ summary = summarize_text(text_input)
106
+ yield f"SUMMARY_TEXT:{summary}"
107
+ yield "<END_STREAM>"
108
+ stream_type = "summary"
109
+ else:
110
+ if model_gpt2 and enc:
111
+ generated_text_stream = sample_sequence(
112
+ prompt=reasoning_prompt,
113
+ model=model_gpt2,
114
+ enc=enc,
115
+ length=999999999,
116
+ temperature=temperature,
117
+ top_k=top_k,
118
+ top_p=top_p,
119
+ repetition_penalty=repetition_penalty,
120
+ device=device
121
+ )
122
+ stream_type = "text"
123
 
124
+ accumulated_text = ""
125
+ if stream_type == "text":
126
+ for token in generated_text_stream:
127
+ if token == "<END_STREAM>":
128
+ yield accumulated_text
129
+ yield "<END_STREAM>"
130
+ return
131
+ if token == END_OF_TEXT_TOKEN:
132
+ accumulated_text += END_OF_TEXT_TOKEN
133
+ continue
134
+ if token:
135
+ accumulated_text += token
 
 
 
 
tokenxxx.py CHANGED
@@ -139,4 +139,4 @@ def codegen_tokenize(text, tokenizer):
139
  return tokenizer.encode(text)
140
 
141
  def codegen_decode(tokens, tokenizer):
142
- return tokenizer.decode(tokens)
 
139
  return tokenizer.encode(text)
140
 
141
  def codegen_decode(tokens, tokenizer):
142
+ return tokenizer.decode(tokens)
translation_api.py CHANGED
@@ -1,17 +1,15 @@
1
  from flask import jsonify, send_file, request
2
  from main import *
3
 
4
- def perform_translation(text, target_language_code='es_XX', source_language_code='en_XX', output_path="output_translation.txt"):
5
  if translation_model is None:
6
- return "Translation model not initialized."
7
 
8
  encoded_text = translation_model.tokenizer(text, return_tensors="pt", padding=True).to(device)
9
  generated_tokens = translation_model.generate(input_ids=encoded_text['input_ids'], attention_mask=encoded_text['attention_mask'], forced_bos_token_id=translation_model.config.lang_code_to_id[target_language_code])
10
  translation = translation_model.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
11
 
12
- with open(output_path, "w") as file:
13
- file.write(translation)
14
- return output_path
15
 
16
  def translation_api():
17
  data = request.get_json()
@@ -20,7 +18,7 @@ def translation_api():
20
  source_lang = data.get('source_lang', 'en')
21
  if not text:
22
  return jsonify({"error": "Text is required"}), 400
23
- output_file = perform_translation(text, target_language_code=f'{target_lang}_XX', source_language_code=f'{source_lang}_XX')
24
- if output_file == "Translation model not initialized.":
25
- return jsonify({"error": "Translation failed"}), 500
26
- return send_file(output_file, mimetype="text/plain", as_attachment=True, download_name="output_translation.txt")
 
1
  from flask import jsonify, send_file, request
2
  from main import *
3
 
4
+ def perform_translation(text, target_language_code='es_XX', source_language_code='en_XX'):
5
  if translation_model is None:
6
+ return {"error": "Translation model not initialized."}
7
 
8
  encoded_text = translation_model.tokenizer(text, return_tensors="pt", padding=True).to(device)
9
  generated_tokens = translation_model.generate(input_ids=encoded_text['input_ids'], attention_mask=encoded_text['attention_mask'], forced_bos_token_id=translation_model.config.lang_code_to_id[target_language_code])
10
  translation = translation_model.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
11
 
12
+ return {"translated_text": translation}
 
 
13
 
14
  def translation_api():
15
  data = request.get_json()
 
18
  source_lang = data.get('source_lang', 'en')
19
  if not text:
20
  return jsonify({"error": "Text is required"}), 400
21
+ output = perform_translation(text, target_language_code=f'{target_lang}_XX', source_language_code=f'{source_lang}_XX')
22
+ if "error" in output:
23
+ return jsonify({"error": output["error"]}), 500
24
+ return jsonify(output)
tts_api.py CHANGED
@@ -1,15 +1,19 @@
1
  import os
2
  from flask import jsonify, send_file, request
3
  from main import *
 
 
 
4
 
5
- def text_to_speech_func(text, output_path="output_tts.wav"):
6
  if tts_model is None:
7
- return "TTS model not initialized."
8
  input_tokens = tts_model.tokenizer(text, return_tensors="pt", padding=True).to(device)
9
  with torch.no_grad():
10
  audio_output = tts_model(input_tokens['input_ids'])
11
- torchaudio.save(output_path, audio_output.cpu(), 16000)
12
- return output_path
 
13
 
14
  def tts_api():
15
  data = request.get_json()
@@ -17,6 +21,6 @@ def tts_api():
17
  if not text:
18
  return jsonify({"error": "Text is required"}), 400
19
  output_file = text_to_speech_func(text)
20
- if output_file == "TTS model not initialized.":
21
- return jsonify({"error": "TTS generation failed"}), 500
22
  return send_file(output_file, mimetype="audio/wav", as_attachment=True, download_name="output.wav")
 
1
  import os
2
  from flask import jsonify, send_file, request
3
  from main import *
4
+ import torch
5
+ import torchaudio
6
+ import uuid
7
 
8
+ def text_to_speech_func(text):
9
  if tts_model is None:
10
+ return {"error": "TTS model not initialized."}
11
  input_tokens = tts_model.tokenizer(text, return_tensors="pt", padding=True).to(device)
12
  with torch.no_grad():
13
  audio_output = tts_model(input_tokens['input_ids'])
14
+ temp_audio_path = f"temp_audio_{uuid.uuid4()}.wav"
15
+ torchaudio.save(temp_audio_path, audio_output.cpu(), 16000)
16
+ return temp_audio_path
17
 
18
  def tts_api():
19
  data = request.get_json()
 
21
  if not text:
22
  return jsonify({"error": "Text is required"}), 400
23
  output_file = text_to_speech_func(text)
24
+ if "error" in output:
25
+ return jsonify({"error": output["error"]}), 500
26
  return send_file(output_file, mimetype="audio/wav", as_attachment=True, download_name="output.wav")