Kfjjdjdjdhdhd commited on
Commit
7b74407
·
verified ·
1 Parent(s): caa5001

Upload 26 files

Browse files
Files changed (22) hide show
  1. api.py +422 -509
  2. background_tasks.py +37 -110
  3. codegen_api.py +8 -17
  4. coder.py +139 -0
  5. image_to_3d_api.py +13 -25
  6. imagegen_api.py +7 -23
  7. main.py +9 -56
  8. model_loader.py +229 -725
  9. models.py +53 -42
  10. musicgen_api.py +9 -28
  11. sadtalker_api.py +16 -183
  12. sadtalker_utils.py +209 -820
  13. sentiment_api.py +7 -20
  14. stt_api.py +11 -27
  15. summarization_api.py +8 -21
  16. text_generation.py +93 -194
  17. text_to_video_api.py +13 -25
  18. tokenxxx.py +44 -114
  19. translation_api.py +7 -16
  20. tts_api.py +9 -20
  21. xtts_api.py +21 -0
  22. xxx.py +43 -114
api.py CHANGED
@@ -1,509 +1,422 @@
1
- from main import *
2
- from tts_api import *
3
- from stt_api import *
4
- from sentiment_api import *
5
- from imagegen_api import *
6
- from musicgen_api import *
7
- from translation_api import *
8
- from codegen_api import *
9
- from text_to_video_api import *
10
- from summarization_api import *
11
- from image_to_3d_api import *
12
- from flask import Flask, request, jsonify, Response, send_file, stream_with_context
13
- from flask_cors import CORS
14
- import torch
15
- import torch.nn.functional as F
16
- import torchaudio
17
- import numpy as np
18
- from PIL import Image
19
- import io
20
- import tempfile
21
- import queue
22
- import json
23
- import base64
24
- from markupsafe import Markup
25
- from markupsafe import escape
26
-
27
- app = Flask(__name__)
28
- CORS(app)
29
-
30
- html_code = """<!DOCTYPE html>
31
- <html lang="en">
32
- <head>
33
- <meta charset="UTF-8">
34
- <meta name="viewport" content="width=device-width, initial-scale=1.0">
35
- <title>AI Conversational Avatar</title>
36
- <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/animate.css/4.1.1/animate.min.css"/>
37
- <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css"/>
38
- <script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
39
- <style>
40
- body {
41
- font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
42
- background: #f0f0f0;
43
- color: #333;
44
- margin: 0;
45
- padding: 0;
46
- display: flex;
47
- flex-direction: column;
48
- align-items: center;
49
- min-height: 100vh;
50
- }
51
- .container {
52
- width: 95%;
53
- max-width: 900px;
54
- padding: 20px;
55
- background-color: #fff;
56
- box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
57
- border-radius: 8px;
58
- margin-top: 20px;
59
- margin-bottom: 20px;
60
- display: flex;
61
- flex-direction: column;
62
- }
63
- .header {
64
- text-align: center;
65
- margin-bottom: 20px;
66
- }
67
- .header h1 {
68
- font-size: 2em;
69
- color: #333;
70
- }
71
- .form-group {
72
- margin-bottom: 15px;
73
- }
74
- .form-group textarea, .form-group input[type="text"] {
75
- width: 100%;
76
- padding: 10px;
77
- border: 1px solid #ccc;
78
- border-radius: 5px;
79
- font-size: 16px;
80
- box-sizing: border-box;
81
- }
82
- button, #recordButton, #stopButton {
83
- padding: 10px 15px;
84
- border: none;
85
- border-radius: 5px;
86
- background-color: #007bff;
87
- color: white;
88
- font-size: 18px;
89
- cursor: pointer;
90
- transition: background-color 0.3s ease;
91
- margin-right: 5px;
92
- }
93
- button:hover, #recordButton:hover, #stopButton:hover {
94
- background-color: #0056b3;
95
- }
96
- #output {
97
- margin-top: 20px;
98
- padding: 15px;
99
- border: 1px solid #ddd;
100
- border-radius: 5px;
101
- background-color: #f9f9f9;
102
- white-space: pre-wrap;
103
- word-break: break-word;
104
- overflow-y: auto;
105
- max-height: 300px;
106
- }
107
- #videoOutput {
108
- margin-top: 20px;
109
- border: 1px solid #ddd;
110
- border-radius: 5px;
111
- overflow: hidden;
112
- }
113
- #videoOutput video {
114
- width: 100%;
115
- display: block;
116
- }
117
- #animatedText {
118
- position: fixed;
119
- top: 20px;
120
- left: 20px;
121
- font-size: 1.5em;
122
- color: rgba(0, 0, 0, 0.1);
123
- pointer-events: none;
124
- z-index: -1;
125
- }
126
- #transcriptionOutput {
127
- margin-top: 10px;
128
- padding: 10px;
129
- border: 1px solid #ddd;
130
- border-radius: 5px;
131
- background-color: #f9f9f9;
132
- font-size: 14px;
133
- word-break: break-word;
134
- }
135
- @media (max-width: 768px) {
136
- .container {
137
- width: 98%;
138
- margin-top: 10px;
139
- margin-bottom: 10px;
140
- padding: 15px;
141
- }
142
- .header h1 {
143
- font-size: 1.8em;
144
- }
145
- .form-group textarea, .form-group input[type="text"] {
146
- font-size: 14px;
147
- padding: 8px;
148
- }
149
- button, #recordButton, #stopButton {
150
- font-size: 16px;
151
- padding: 8px 12px;
152
- }
153
- #output, #transcriptionOutput {
154
- font-size: 14px;
155
- padding: 10px;
156
- margin-top: 15px;
157
- }
158
- }
159
- </style>
160
- </head>
161
- <body>
162
- <div id="animatedText" class="animated-text animate__animated animate__fadeIn animate__infinite infinite">AI POWERED</div>
163
- <div class="container">
164
- <div class="header animate__animated animate__fadeInDown">
165
- <h1>Conversational Avatar</h1>
166
- </div>
167
- <div class="form-group animate__animated animate__fadeInLeft">
168
- <textarea id="textInput" rows="3" placeholder="Or type your request here"></textarea>
169
- </div>
170
- <div class="form-group animate__animated animate__fadeInRight" style="text-align: center;">
171
- <button onclick="generateResponse()" class="animate__animated animate__fadeInUp">Generate Avatar Response</button>
172
- </div>
173
-
174
- <div style="text-align: center; margin-bottom: 15px;">
175
- <button id="recordButton" class="animate__animated animate__fadeInUp"><i class="fas fa-microphone"></i> Start Recording</button>
176
- <button id="stopButton" class="animate__animated animate__fadeInUp" disabled><i class="fas fa-stop-circle"></i> Stop Recording</button>
177
- </div>
178
-
179
- <div id="transcriptionOutput" class="animate__animated animate__fadeIn">
180
- <strong>Transcription:</strong>
181
- <span id="transcriptionText"></span>
182
- </div>
183
-
184
- <div id="output" class="animate__animated animate__fadeIn">
185
- <strong>Response:</strong><br>
186
- <span id="responseText"></span>
187
- </div>
188
-
189
- <div id="videoOutput" class="animate__animated animate__fadeIn">
190
- <video id="avatarVideo" controls></video>
191
- </div>
192
- </div>
193
-
194
- <script>
195
- let mediaRecorder;
196
- let audioChunks = [];
197
- let lastResponse = "";
198
- let accumulatedText = "";
199
- let eventSource = null;
200
- let audioURL;
201
-
202
- const recordButton = document.getElementById('recordButton');
203
- const stopButton = document.getElementById('stopButton');
204
- const transcriptionTextSpan = document.getElementById('transcriptionText');
205
- const responseTextSpan = document.getElementById('responseText');
206
- const avatarVideoPlayer = document.getElementById('avatarVideo');
207
- const textInputField = document.getElementById('textInput');
208
-
209
- recordButton.onclick = async () => {
210
- try {
211
- const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
212
- mediaRecorder = new MediaRecorder(stream);
213
- audioChunks = [];
214
-
215
- mediaRecorder.ondataavailable = event => {
216
- audioChunks.push(event.data);
217
- };
218
-
219
- mediaRecorder.onstop = async () => {
220
- const audioBlob = new Blob(audioChunks, { type: 'audio/wav' });
221
- const formData = new FormData();
222
- formData.append('audio', audioBlob, 'recording.wav');
223
-
224
- transcriptionTextSpan.innerText = "Transcribing...";
225
- responseTextSpan.innerText = "";
226
- avatarVideoPlayer.src = "";
227
-
228
- try {
229
- const sttResponse = await fetch('/api/v1/stt', {
230
- method: 'POST',
231
- body: formData
232
- });
233
-
234
- if (!sttResponse.ok) {
235
- throw new Error(\`HTTP error! status: ${sttResponse.status}\`);
236
- }
237
-
238
- const sttData = await sttResponse.json();
239
- const transcribedText = sttData.text;
240
- transcriptionTextSpan.innerText = transcribedText || "Transcription failed.";
241
-
242
- if (transcribedText) {
243
- await generateAvatarVideoResponse(transcribedText);
244
- }
245
-
246
- } catch (error) {
247
- console.error("STT or subsequent error:", error);
248
- transcriptionTextSpan.innerText = "Transcription error.";
249
- responseTextSpan.innerText = "Error processing audio.";
250
- } finally {
251
- recordButton.disabled = false;
252
- stopButton.disabled = true;
253
- }
254
- };
255
-
256
- recordButton.disabled = true;
257
- stopButton.disabled = false;
258
- transcriptionTextSpan.innerText = "Recording...";
259
- mediaRecorder.start();
260
-
261
- } catch (error) {
262
- console.error("Error accessing microphone:", error);
263
- transcriptionTextSpan.innerText = "Microphone access denied or error.";
264
- recordButton.disabled = false;
265
- stopButton.disabled = true;
266
- }
267
- };
268
-
269
- stopButton.onclick = () => {
270
- if (mediaRecorder && mediaRecorder.state === "recording") {
271
- transcriptionTextSpan.innerText = "Processing...";
272
- mediaRecorder.stop();
273
- recordButton.disabled = true;
274
- stopButton.disabled = true;
275
- }
276
- };
277
-
278
- async function generateResponse() {
279
- const inputText = textInputField.value;
280
- if (!inputText.trim()) {
281
- alert("Please enter text or record audio.");
282
- return;
283
- }
284
- transcriptionTextSpan.innerText = inputText;
285
- await generateAvatarVideoResponse(inputText);
286
- }
287
-
288
-
289
- async function generateAvatarVideoResponse(inputText) {
290
- responseTextSpan.innerText = "Generating response...";
291
- avatarVideoPlayer.src = "";
292
- accumulatedText = "";
293
- lastResponse = "";
294
-
295
- const temp = 0.7;
296
- const top_k_val = 40;
297
- const top_p_val = 0.0;
298
- const repetition_penalty_val = 1.2;
299
-
300
- const requestData = {
301
- text: inputText,
302
- temp: temp,
303
- top_k: top_k_val,
304
- top_p: top_p_val,
305
- reppenalty: repetition_penalty_val
306
- };
307
-
308
- if (eventSource) {
309
- eventSource.close();
310
- }
311
-
312
- eventSource = new EventSource('/api/v1/generate_stream?' + new URLSearchParams(requestData).toString());
313
-
314
- eventSource.onmessage = async function(event) {
315
- if (event.data === "<END_STREAM>") {
316
- eventSource.close();
317
- const currentResponse = accumulatedText.replace("<|endoftext|>", "").replace(/\s+(?=[.,,。])/g, '').trim();
318
- if (currentResponse === lastResponse.trim()) {
319
- accumulatedText = "**Response is repetitive. Please try again or rephrase your query.**";
320
- } else {
321
- lastResponse = currentResponse;
322
- }
323
- responseTextSpan.innerHTML = marked.parse(accumulatedText);
324
-
325
- try {
326
- const ttsResponse = await fetch('/api/v1/tts', {
327
- method: 'POST',
328
- headers: {
329
- 'Content-Type': 'application/json'
330
- },
331
- body: JSON.stringify({ text: currentResponse })
332
- });
333
-
334
- if (!ttsResponse.ok) {
335
- throw new Error(\`TTS HTTP error! status: ${ttsResponse.status}\`);
336
- }
337
-
338
- const ttsBlob = await ttsResponse.blob();
339
- audioURL = URL.createObjectURL(ttsBlob);
340
-
341
-
342
- const sadTalkerResponse = await fetch('/api/v1/sadtalker', {
343
- method: 'POST',
344
- body: new URLSearchParams({
345
- 'source_image': './examples/source_image/full_body_female.png',
346
- 'driven_audio': audioURL,
347
- 'preprocess': 'full',
348
- 'still_mode': false,
349
- 'use_enhancer': true
350
- })
351
- });
352
-
353
- if (!sadTalkerResponse.ok) {
354
- throw new Error(\`SadTalker HTTP error! status: ${sadTalkerResponse.status}\`);
355
- }
356
-
357
- const sadTalkerData = await sadTalkerResponse.json();
358
- const videoURL = sadTalkerData.video_url;
359
- avatarVideoPlayer.src = videoURL;
360
-
361
-
362
- } catch (ttsError) {
363
- console.error("TTS or SadTalker error:", ttsError);
364
- responseTextSpan.innerHTML += "<br><br>Error generating audio or video avatar.";
365
- }
366
-
367
- return;
368
- }
369
- accumulatedText += event.data;
370
- let partialText = accumulatedText.replace("<|endoftext|>", "").replace(/\s+(?=[.,,。])/g, '').trim();
371
- responseTextSpan.innerHTML = marked.parse(partialText);
372
- };
373
-
374
- eventSource.onerror = function(error) {
375
- console.error("SSE error", error);
376
- eventSource.close();
377
- responseTextSpan.innerText = "Error generating response stream.";
378
- };
379
-
380
- const outputDiv = document.getElementById("output");
381
- outputDiv.classList.add("show");
382
- }
383
-
384
-
385
- </script>
386
- </body>
387
- </html>
388
- """
389
-
390
- feedback_queue = queue.Queue()
391
-
392
-
393
- @app.route("/")
394
- def index():
395
- return html_code
396
-
397
- @app.route("/api/v1/generate_stream", methods=["GET"])
398
- def generate_stream():
399
- text = request.args.get("text", "")
400
- temp = float(request.args.get("temp", 0.7))
401
- top_k = int(request.args.get("top_k", 40))
402
- top_p = float(request.args.get("top_p", 0.0))
403
- reppenalty = float(request.args.get("reppenalty", 1.2))
404
- response_queue = queue.Queue()
405
- reasoning_queue.put({
406
- 'text_input': text,
407
- 'temperature': temp,
408
- 'top_k': top_k,
409
- 'top_p': top_p,
410
- 'repetition_penalty': reppenalty,
411
- 'response_queue': response_queue
412
- })
413
- @stream_with_context
414
- def event_stream():
415
- while True:
416
- output = response_queue.get()
417
- if "error" in output:
418
- yield "data: <ERROR>\n\n"
419
- break
420
- text_chunk = output.get("text")
421
- if text_chunk:
422
- for word in text_chunk.split(' '):
423
- clean_word = word.strip()
424
- if clean_word:
425
- yield "data: " + clean_word + "\n\n"
426
- yield "data: <END_STREAM>\n\n"
427
- break
428
- return Response(event_stream(), mimetype="text/event-stream")
429
-
430
- @app.route("/api/v1/generate", methods=["POST"])
431
- def generate():
432
- data = request.get_json()
433
- text = data.get("text", "")
434
- temp = float(data.get("temp", 0.7))
435
- top_k = int(data.get("top_k", 40))
436
- top_p = float(data.get("top_p", 0.0))
437
- reppenalty = float(data.get("reppenalty", 1.2))
438
- response_queue = queue.Queue()
439
- reasoning_queue.put({
440
- 'text_input': text,
441
- 'temperature': temp,
442
- 'top_k': top_k,
443
- 'top_p': top_p,
444
- 'repetition_penalty': reppenalty,
445
- 'response_queue': response_queue
446
- })
447
- output = response_queue.get()
448
- if "error" in output:
449
- return jsonify({"error": output["error"]}), 500
450
- result_text = output.get("text", "").strip()
451
- return jsonify({"response": result_text})
452
-
453
- @app.route("/api/v1/feedback", methods=["POST"])
454
- def feedback():
455
- data = request.get_json()
456
- feedback_text = data.get("feedback_text")
457
- correct_category = data.get("correct_category")
458
- if feedback_text and correct_category:
459
- feedback_queue.put((feedback_text, correct_category))
460
- return jsonify({"status": "feedback received"})
461
- return jsonify({"status": "feedback failed"}), 400
462
-
463
- @app.route("/api/v1/tts", methods=["POST"])
464
- def tts_api():
465
- return tts_route()
466
-
467
- @app.route("/api/v1/stt", methods=["POST"])
468
- def stt_api():
469
- return stt_route()
470
-
471
- @app.route("/api/v1/sentiment", methods=["POST"])
472
- def sentiment_api():
473
- return sentiment_route()
474
-
475
- @app.route("/api/v1/imagegen", methods=["POST"])
476
- def imagegen_api():
477
- return imagegen_route()
478
-
479
- @app.route("/api/v1/musicgen", methods=["POST"])
480
- def musicgen_api():
481
- return musicgen_route()
482
-
483
- @app.route("/api/v1/translation", methods=["POST"])
484
- def translation_api():
485
- return translation_route()
486
-
487
- @app.route("/api/v1/codegen", methods=["POST"])
488
- def codegen_api():
489
- return codegen_route()
490
-
491
- @app.route("/api/v1/text_to_video", methods=["POST"])
492
- def text_to_video_api():
493
- return text_to_video_route()
494
-
495
- @app.route("/api/v1/summarization", methods=["POST"])
496
- def summarization_api():
497
- return summarization_route()
498
-
499
- @app.route("/api/v1/image_to_3d", methods=["POST"])
500
- def image_to_3d_api():
501
- return image_to_3d_route()
502
-
503
- @app.route("/api/v1/sadtalker", methods=["POST"])
504
- def sadtalker():
505
- from sadtalker_api import router as sadtalker_router
506
- return sadtalker_router.create_video()
507
-
508
- if __name__ == "__main__":
509
- app.run(host="0.0.0.0", port=7860)
 
1
+ from main import *
2
+ from tts_api import tts_api as tts_module_api
3
+ from stt_api import stt_api as stt_module_api
4
+ from sentiment_api import sentiment_api as sentiment_module_api
5
+ from imagegen_api import imagegen_api as imagegen_module_api
6
+ from musicgen_api import musicgen_api as musicgen_module_api
7
+ from translation_api import translation_api as translation_module_api
8
+ from codegen_api import codegen_api as codegen_module_api
9
+ from text_to_video_api import text_to_video_api as text_to_video_module_api
10
+ from summarization_api import summarization_api as summarization_module_api
11
+ from image_to_3d_api import image_to_3d_api as image_to_3d_module_api
12
+ from xtts_api import xtts_api as xtts_module_api
13
+ from flask import Flask, request, jsonify, Response, send_file, stream_with_context
14
+ from flask_cors import CORS
15
+ import io
16
+ import queue
17
+ import base64
18
+ import gradio as gr
19
+
20
+ app = Flask(__name__)
21
+ CORS(app)
22
+
23
+ html_code = """<!DOCTYPE html>
24
+ <html lang="en">
25
+ <head>
26
+ <meta charset="UTF-8">
27
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
28
+ <title>AI Text Generation</title>
29
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/animate.css/4.1.1/animate.min.css"/>
30
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css" integrity="sha512-9usAa10IRO0HhonpyAIVpjrylPvoDwiPUiKdWk5t3PyolY1cOd4DSE0Ga+ri4AuTroPR5aQvXU9xC6qOPnzFeg==" crossorigin="anonymous" referrerpolicy="no-referrer" />
31
+ <script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
32
+ <style>
33
+ body {
34
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
35
+ background: #f0f0f0;
36
+ color: #333;
37
+ margin: 0;
38
+ padding: 0;
39
+ display: flex;
40
+ flex-direction: column;
41
+ align-items: center;
42
+ min-height: 100vh;
43
+ }
44
+ .container {
45
+ width: 95%;
46
+ max-width: 900px;
47
+ padding: 20px;
48
+ background-color: #fff;
49
+ box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
50
+ border-radius: 8px;
51
+ margin-top: 20px;
52
+ margin-bottom: 20px;
53
+ display: flex;
54
+ flex-direction: column;
55
+ }
56
+ .header {
57
+ text-align: center;
58
+ margin-bottom: 20px;
59
+ }
60
+ .header h1 {
61
+ font-size: 2em;
62
+ color: #333;
63
+ }
64
+ .form-group {
65
+ margin-bottom: 15px;
66
+ }
67
+ .form-group textarea {
68
+ width: 100%;
69
+ padding: 10px;
70
+ border: 1px solid #ccc;
71
+ border-radius: 5px;
72
+ font-size: 16px;
73
+ box-sizing: border-box;
74
+ resize: vertical;
75
+ }
76
+ button {
77
+ padding: 10px 15px;
78
+ border: none;
79
+ border-radius: 5px;
80
+ background-color: #007bff;
81
+ color: white;
82
+ font-size: 18px;
83
+ cursor: pointer;
84
+ transition: background-color 0.3s ease;
85
+ }
86
+ button:hover {
87
+ background-color: #0056b3;
88
+ }
89
+ #output {
90
+ margin-top: 20px;
91
+ padding: 15px;
92
+ border: 1px solid #ddd;
93
+ border-radius: 5px;
94
+ background-color: #f9f9f9;
95
+ white-space: pre-wrap;
96
+ word-break: break-word;
97
+ overflow-y: auto;
98
+ max-height: 100vh;
99
+ }
100
+ #output strong {
101
+ font-weight: bold;
102
+ }
103
+ .animated-text {
104
+ position: fixed;
105
+ top: 20px;
106
+ left: 20px;
107
+ font-size: 1.5em;
108
+ color: rgba(0, 0, 0, 0.1);
109
+ pointer-events: none;
110
+ z-index: -1;
111
+ }
112
+ @media (max-width: 768px) {
113
+ .container {
114
+ width: 98%;
115
+ margin-top: 10px;
116
+ margin-bottom: 10px;
117
+ padding: 15px;
118
+ }
119
+ .header h1 {
120
+ font-size: 1.8em;
121
+ }
122
+ .form-group textarea, .form-group input[type="text"] {
123
+ font-size: 14px;
124
+ padding: 8px;
125
+ }
126
+ button {
127
+ font-size: 16px;
128
+ padding: 8px 12px;
129
+ }
130
+ #output {
131
+ font-size: 14px;
132
+ padding: 10px;
133
+ margin-top: 15px;
134
+ }
135
+ }
136
+ </style>
137
+ </head>
138
+ <body>
139
+ <div class="animated-text animate__animated animate__fadeIn animate__infinite infinite">AI POWERED</div>
140
+ <div class="container">
141
+ <div class="header animate__animated animate__fadeInDown">
142
+ </div>
143
+ <div class="form-group animate__animated animate__fadeInLeft">
144
+ <textarea id="text" rows="5" placeholder="Enter text"></textarea>
145
+ </div>
146
+ <button onclick="generateText()" class="animate__animated animate__fadeInUp">Generate Reasoning</button>
147
+ <div id="output" class="animate__animated">
148
+ <strong>Response:</strong><br>
149
+ <span id="generatedText"></span>
150
+ </div>
151
+ </div>
152
+ <script>
153
+ let eventSource = null;
154
+ let accumulatedText = "";
155
+ let lastResponse = "";
156
+ async function generateText() {
157
+ const inputText = document.getElementById("text").value;
158
+ document.getElementById("generatedText").innerText = "";
159
+ accumulatedText = "";
160
+ if (eventSource) {
161
+ eventSource.close();
162
+ }
163
+ const temp = 0.7;
164
+ const top_k_val = 40;
165
+ const top_p_val = 0.0;
166
+ const repetition_penalty_val = 1.2;
167
+ const requestData = {
168
+ text: inputText,
169
+ temp: temp,
170
+ top_k: top_k_val,
171
+ top_p: top_p_val,
172
+ reppenalty: repetition_penalty_val
173
+ };
174
+ const params = new URLSearchParams(requestData).toString();
175
+ eventSource = new EventSource('/api/v1/generate_stream?' + params);
176
+ eventSource.onmessage = function(event) {
177
+ if (event.data === "<END_STREAM>") {
178
+ eventSource.close();
179
+ const currentResponse = accumulatedText.replace("<|endoftext|>", "").replace(/\s+(?=[.,,。])/g, '').trim();
180
+ if (currentResponse === lastResponse.trim()) {
181
+ accumulatedText = "**Response is repetitive. Please try again or rephrase your query.**";
182
+ } else {
183
+ lastResponse = currentResponse;
184
+ }
185
+ document.getElementById("generatedText").innerHTML = marked.parse(accumulatedText);
186
+ return;
187
+ }
188
+ accumulatedText += event.data;
189
+ let partialText = accumulatedText.replace("<|endoftext|>", "").replace(/\s+(?=[.,,。])/g, '').trim();
190
+ document.getElementById("generatedText").innerHTML = marked.parse(partialText);
191
+ };
192
+ eventSource.onerror = function(error) {
193
+ console.error("SSE error", error);
194
+ eventSource.close();
195
+ };
196
+ const outputDiv = document.getElementById("output");
197
+ outputDiv.classList.add("show");
198
+ }
199
+ function base64ToBlob(base64Data, contentType) {
200
+ contentType = contentType || '';
201
+ const sliceSize = 1024;
202
+ const byteCharacters = atob(base64Data);
203
+ const bytesLength = byteCharacters.length;
204
+ const slicesCount = Math.ceil(bytesLength / sliceSize);
205
+ const byteArrays = new Array(slicesCount);
206
+ for (let sliceIndex = sliceIndex < slicesCount; ++sliceIndex) {
207
+ const begin = sliceIndex * sliceSize;
208
+ const end = Math.min(begin + sliceSize, bytesLength);
209
+ const bytes = new Array(end - begin);
210
+ for (let offset = begin, i = 0; offset < end; ++i, ++offset) {
211
+ bytes[i] = byteCharacters[offset].charCodeAt(0);
212
+ }
213
+ byteArrays[sliceIndex] = new Uint8Array(bytes);
214
+ }
215
+ return new Blob(byteArrays, { type: contentType });
216
+ }
217
+ </script>
218
+ </body>
219
+ </html>
220
+ """
221
+ feedback_queue = queue.Queue()
222
+
223
+
224
+ @app.route("/")
225
+ def index():
226
+ return html_code
227
+
228
+ @app.route("/api/v1/generate_stream", methods=["GET"])
229
+ def generate_stream():
230
+ text = request.args.get("text", "")
231
+ temp = float(request.args.get("temp", 0.7))
232
+ top_k = int(request.args.get("top_k", 40))
233
+ top_p = float(request.args.get("top_p", 0.0))
234
+ reppenalty = float(request.args.get("reppenalty", 1.2))
235
+ response_queue = queue.Queue()
236
+ reasoning_queue.put({
237
+ 'text_input': text,
238
+ 'temperature': temp,
239
+ 'top_k': top_k,
240
+ 'top_p': top_p,
241
+ 'repetition_penalty': reppenalty,
242
+ 'response_queue': response_queue
243
+ })
244
+ @stream_with_context
245
+ def event_stream():
246
+ while True:
247
+ output = response_queue.get()
248
+ if "error" in output:
249
+ yield "data: <ERROR>\n\n"
250
+ break
251
+ text_chunk = output.get("text")
252
+ if text_chunk:
253
+ for word in text_chunk.split(' '):
254
+ clean_word = word.strip()
255
+ if clean_word:
256
+ yield "data: " + clean_word + "\n\n"
257
+ yield "data: <END_STREAM>\n\n"
258
+ break
259
+ return Response(event_stream(), mimetype="text/event-stream")
260
+
261
+ @app.route("/api/v1/generate", methods=["POST"])
262
+ def generate():
263
+ data = request.get_json()
264
+ text = data.get("text", "")
265
+ temp = float(data.get("temp", 0.7))
266
+ top_k = int(data.get("top_k", 40))
267
+ top_p = float(data.get("top_p", 0.0))
268
+ reppenalty = float(data.get("reppenalty", 1.2))
269
+ response_queue = queue.Queue()
270
+ reasoning_queue.put({
271
+ 'text_input': text,
272
+ 'temperature': temp,
273
+ 'top_k': top_k,
274
+ 'top_p': top_p,
275
+ 'repetition_penalty': reppenalty,
276
+ 'response_queue': response_queue
277
+ })
278
+ output = response_queue.get()
279
+ if "error" in output:
280
+ return jsonify({"error": output["error"]}), 500
281
+ result_text = output.get("text", "").strip()
282
+ return jsonify({"response": result_text})
283
+
284
+ @app.route("/api/v1/feedback", methods=["POST"])
285
+ def feedback():
286
+ data = request.get_json()
287
+ feedback_text = data.get("feedback_text")
288
+ correct_category = data.get("correct_category")
289
+ if feedback_text and correct_category:
290
+ feedback_queue.put((feedback_text, correct_category))
291
+ return jsonify({"status": "feedback received"})
292
+ return jsonify({"status": "feedback failed"}), 400
293
+
294
+ @app.route("/api/v1/tts", methods=["POST"])
295
+ def tts_api():
296
+ return tts_module_api()
297
+
298
+ @app.route("/api/v1/stt", methods=["POST"])
299
+ def stt_api():
300
+ return stt_module_api()
301
+
302
+ @app.route("/api/v1/sentiment", methods=["POST"])
303
+ def sentiment_api():
304
+ return sentiment_module_api()
305
+
306
+ @app.route("/api/v1/imagegen", methods=["POST"])
307
+ def imagegen_api():
308
+ return imagegen_module_api()
309
+
310
+ @app.route("/api/v1/musicgen", methods=["POST"])
311
+ def musicgen_api():
312
+ return musicgen_module_api()
313
+
314
+ @app.route("/api/v1/translation", methods=["POST"])
315
+ def translation_api():
316
+ return translation_module_api()
317
+
318
+ @app.route("/api/v1/codegen", methods=["POST"])
319
+ def codegen_api():
320
+ return codegen_module_api()
321
+
322
+ @app.route("/api/v1/text_to_video", methods=["POST"])
323
+ def text_to_video_api():
324
+ return text_to_video_module_api()
325
+
326
+ @app.route("/api/v1/summarization", methods=["POST"])
327
+ def summarization_api():
328
+ return summarization_module_api()
329
+
330
+ @app.route("/api/v1/image_to_3d", methods=["POST"])
331
+ def image_to_3d_api():
332
+ return image_to_3d_module_api()
333
+
334
+ @app.route("/api/v1/xtts_clone", methods=["POST"])
335
+ def xtts_clone_api():
336
+ return xtts_module_api()
337
+
338
+ @app.route("/api/v1/sadtalker", methods=["POST"])
339
+ def sadtalker():
340
+ from sadtalker_api import router as sadtalker_router
341
+ return sadtalker_router.create_video()
342
+
343
+ if __name__ == "__main__":
344
+ with gr.Blocks() as demo:
345
+ gr.Markdown("## AI Powerhouse")
346
+ with gr.Tab("Text Generation"):
347
+ text_input = gr.Textbox(lines=5, placeholder="Enter text")
348
+ text_output = gr.Markdown()
349
+ text_button = gr.Button("Generate Text")
350
+ text_button.click(generate, inputs=text_input, outputs=text_output)
351
+
352
+ with gr.Tab("Image Generation"):
353
+ image_text_input = gr.Textbox(lines=3, placeholder="Enter prompt for image")
354
+ image_output = gr.Image()
355
+ image_button = gr.Button("Generate Image")
356
+ image_button.click(imagegen_api, inputs=image_text_input, outputs=image_output)
357
+
358
+ with gr.Tab("Music Generation"):
359
+ music_text_input = gr.Textbox(lines=3, placeholder="Enter prompt for music")
360
+ music_output = gr.Audio()
361
+ music_button = gr.Button("Generate Music")
362
+ music_button.click(musicgen_api, inputs=music_text_input, outputs=music_output)
363
+
364
+ with gr.Tab("Code Generation"):
365
+ code_text_input = gr.Textbox(lines=3, placeholder="Enter prompt for code")
366
+ code_output = gr.File()
367
+ code_button = gr.Button("Generate Code")
368
+ code_button.click(codegen_api, inputs=code_text_input, outputs=code_output)
369
+
370
+ with gr.Tab("Text to Video"):
371
+ video_text_input = gr.Textbox(lines=3, placeholder="Enter prompt for video")
372
+ video_output = gr.Video()
373
+ video_button = gr.Button("Generate Video")
374
+ video_button.click(text_to_video_api, inputs=video_text_input, outputs=video_output)
375
+
376
+ with gr.Tab("Summarization"):
377
+ summary_text_input = gr.Textbox(lines=5, placeholder="Enter text to summarize")
378
+ summary_output = gr.Textbox()
379
+ summary_button = gr.Button("Summarize")
380
+ summary_button.click(summarization_api, inputs=summary_text_input, outputs=summary_output)
381
+
382
+ with gr.Tab("Translation"):
383
+ translate_text_input = gr.Textbox(lines=3, placeholder="Enter text to translate")
384
+ translate_lang_dropdown = gr.Dropdown(['es', 'en', 'fr', 'de'], value='es', label="Target Language")
385
+ translation_output = gr.Textbox()
386
+ translate_button = gr.Button("Translate")
387
+ translate_button.click(translation_api, inputs=[translate_text_input, translate_lang_dropdown], outputs=translation_output)
388
+
389
+ with gr.Tab("Sentiment Analysis"):
390
+ sentiment_text_input = gr.Textbox(lines=3, placeholder="Enter text for sentiment analysis")
391
+ sentiment_output = gr.Textbox()
392
+ sentiment_button = gr.Button("Analyze Sentiment")
393
+ sentiment_button.click(sentiment_api, inputs=sentiment_text_input, outputs=sentiment_output)
394
+
395
+ with gr.Tab("Text to Speech"):
396
+ tts_text_input = gr.Textbox(lines=3, placeholder="Enter text for speech")
397
+ tts_output = gr.Audio()
398
+ tts_button = gr.Button("Generate Speech")
399
+ tts_button.click(tts_api, inputs=tts_text_input, outputs=tts_output)
400
+
401
+ with gr.Tab("Voice Cloning (XTTS)"):
402
+ xtts_text_input = gr.Textbox(lines=3, placeholder="Enter text for voice cloning")
403
+ xtts_audio_input = gr.Audio(source="upload", type="filepath", label="Reference Audio for Voice Cloning")
404
+ xtts_output = gr.Audio()
405
+ xtts_button = gr.Button("Clone Voice")
406
+ xtts_button.click(xtts_module_api, inputs=[xtts_text_input, xtts_audio_input], outputs=xtts_output)
407
+
408
+ with gr.Tab("Speech to Text"):
409
+ stt_audio_input = gr.Audio(source="microphone", type="filepath")
410
+ stt_output = gr.Textbox()
411
+ stt_button = gr.Button("Transcribe Speech")
412
+ stt_button.click(stt_api, inputs=stt_audio_input, outputs=stt_output)
413
+
414
+ with gr.Tab("Image to 3D"):
415
+ image_3d_input = gr.Image(source="upload", type="filepath")
416
+ model_3d_output = gr.File()
417
+ image_3d_button = gr.Button("Generate 3D Model")
418
+ image_3d_button.click(image_to_3d_api, inputs=image_3d_input, outputs=model_3d_output)
419
+
420
+ app = gr.routes.App(demo)
421
+
422
+ app.run(host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
background_tasks.py CHANGED
@@ -1,18 +1,9 @@
1
- import time
2
- import threading
3
- import queue
4
- import uuid
5
- import unicodedata
6
- import re
7
  from deep_translator import GoogleTranslator
8
  from duckduckgo_search import DDGS
9
- import nltk
10
- import torch
11
- import torch.nn as nn
12
- import math
13
 
14
  nltk.download('punkt')
15
-
16
  categories = ['News', 'Sports', 'Entertainment']
17
  TEXT_GENERATION_RATE = 10
18
  text_queue = queue.Queue()
@@ -25,7 +16,7 @@ news_clf = None
25
 
26
  class SimpleClassifier(nn.Module):
27
  def __init__(self, vocab_size, num_classes, embedding_dim=128):
28
- super(SimpleClassifier, self).__init__()
29
  self.embedding = nn.Embedding(vocab_size, embedding_dim)
30
  self.fc = nn.Linear(embedding_dim, num_classes)
31
  def forward(self, x):
@@ -34,91 +25,46 @@ class SimpleClassifier(nn.Module):
34
  out = self.fc(pooled)
35
  return out
36
 
37
- def tokenize_text(text):
38
- return nltk.word_tokenize(text)
39
-
40
- def update_vocabulary(tokens):
41
- global vocabulary, word_to_index
42
- for token in tokens:
43
- if token not in word_to_index:
44
- word_to_index[token] = len(vocabulary)
45
- vocabulary.append(token)
46
-
47
- def text_to_vector(text):
48
- tokens = tokenize_text(text)
49
- update_vocabulary(tokens)
50
- indices = [word_to_index.get(token, 0) for token in tokens]
51
- return torch.tensor(indices, dtype=torch.long).unsqueeze(0)
52
 
53
  def generate_and_queue_text(language):
54
  global categories, text_queue
55
- num_categories = len(categories)
56
- num_texts_per_category = TEXT_GENERATION_RATE // (2 * num_categories)
57
  while True:
58
  for category in categories:
59
  for _ in range(num_texts_per_category):
60
- uid = uuid.uuid4()
61
- base_text = f"Category: {category}. ID:{uid}"
62
- try:
63
- translator = GoogleTranslator(source='auto', target=language)
64
- text = translator.translate(base_text)
65
- except Exception:
66
- text = base_text
67
- processed_text = ''.join(c for c in unicodedata.normalize('NFKC', text) if c.isprintable())
68
- text_queue.put((processed_text, category))
69
- time.sleep(0)
70
 
71
  def background_training():
72
  global categories, news_clf, feedback_queue, vocabulary
73
- if categories is None:
74
- categories = ['DefaultCategory']
75
- num_classes = len(categories)
76
- learning_rate = 0.01
77
- epochs = 1
78
- if news_clf is None:
79
- news_clf = SimpleClassifier(len(vocabulary), num_classes)
80
- optimizer = torch.optim.SGD(news_clf.parameters(), lr=learning_rate)
81
- criterion = nn.CrossEntropyLoss()
82
  while True:
83
  try:
84
  feedback_item = feedback_queue.get(timeout=10)
85
  if feedback_item:
86
- input_text, generated_text = feedback_item
87
- input_vector = text_to_vector(input_text)
88
- if len(vocabulary) == 0:
89
- vocabulary.extend(["<PAD>", "<EOS>"])
90
- news_clf = SimpleClassifier(len(vocabulary), num_classes)
91
- optimizer = torch.optim.SGD(news_clf.parameters(), lr=learning_rate)
92
- if input_vector.size(0) != len(vocabulary) and len(vocabulary) > 0:
93
- news_clf = SimpleClassifier(len(vocabulary), num_classes)
94
- optimizer = torch.optim.SGD(news_clf.parameters(), lr=learning_rate)
95
- input_vector = text_to_vector(input_text)
96
- tokens = tokenize_text(input_text)
97
- update_vocabulary(tokens)
98
- tokens_indices = [word_to_index.get(word, 0) for word in tokens]
99
- input_tensor = torch.tensor([tokens_indices], dtype=torch.long)
100
- target_index = categories.index(generated_text) if generated_text in categories else 0
101
  target_category_index = torch.tensor([target_index], dtype=torch.long)
102
- if num_classes <= 1:
103
- num_classes = 2
104
- news_clf.fc = nn.Linear(128, num_classes)
105
- for _ in range(epochs):
106
- optimizer.zero_grad()
107
- output = news_clf(input_tensor)
108
- loss = criterion(output, target_category_index)
109
- loss.backward()
110
- optimizer.step()
111
  feedback_queue.task_done()
112
- except queue.Empty:
113
- pass
114
- except Exception:
115
- time.sleep(5)
116
 
117
  def perform_reasoning_stream(text_input, temperature=0.7, top_k=40, top_p=0.0, repetition_penalty=1.2):
118
  for token in sample_sequence(text_input, model_gpt2, enc, length=999999999, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, device=device):
119
- if token == "<END_STREAM>":
120
- yield "<END_STREAM>"
121
- break
122
  yield token + " "
123
 
124
  def background_reasoning_queue():
@@ -126,40 +72,21 @@ def background_reasoning_queue():
126
  while True:
127
  try:
128
  item = reasoning_queue.get(timeout=1)
129
- if item is None:
130
- reasoning_queue.task_done()
131
- continue
132
- text_input = item.get('text_input')
133
- temperature = item.get('temperature', 0.7)
134
- top_k = item.get('top_k', 40)
135
- top_p = item.get('top_p', 0.0)
136
- repetition_penalty = item.get('repetition_penalty', 1.2)
137
  resp_queue = item.get('response_queue', queue.Queue())
138
- if not text_input:
139
- resp_queue.put({"error": "Empty text input received."})
140
- reasoning_queue.task_done()
141
- continue
142
  generated_text_stream = perform_reasoning_stream(text_input, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty)
143
- full_response = ""
144
  for chunk in generated_text_stream:
145
- if chunk == "<END_STREAM>":
146
- break
147
  full_response += chunk
148
  cleaned_response = re.sub(r'\s+(?=[.,,。])', '', full_response.replace("<|endoftext|>", "").strip())
149
- if cleaned_response in seen_responses:
150
- final_response = "**Response is repetitive. Please try again or rephrase your query.**";
151
- resp_queue.put({"text": final_response})
152
- else:
153
- seen_responses.add(cleaned_response)
154
- final_response = cleaned_response
155
- resp_queue.put({"text": final_response})
156
  reasoning_queue.task_done()
157
- except queue.Empty:
158
- pass
159
- except Exception as e:
160
- try:
161
- resp_queue.put({"error": str(e)})
162
- except Exception:
163
- pass
164
- if reasoning_queue and not reasoning_queue.empty():
165
- reasoning_queue.task_done()
 
1
+ import time, threading, queue, uuid, unicodedata, re
 
 
 
 
 
2
  from deep_translator import GoogleTranslator
3
  from duckduckgo_search import DDGS
4
+ import nltk, torch, torch.nn as nn
 
 
 
5
 
6
  nltk.download('punkt')
 
7
  categories = ['News', 'Sports', 'Entertainment']
8
  TEXT_GENERATION_RATE = 10
9
  text_queue = queue.Queue()
 
16
 
17
  class SimpleClassifier(nn.Module):
18
  def __init__(self, vocab_size, num_classes, embedding_dim=128):
19
+ super().__init__()
20
  self.embedding = nn.Embedding(vocab_size, embedding_dim)
21
  self.fc = nn.Linear(embedding_dim, num_classes)
22
  def forward(self, x):
 
25
  out = self.fc(pooled)
26
  return out
27
 
28
+ def tokenize_text(text): return nltk.word_tokenize(text)
29
+ def update_vocabulary(tokens): global vocabulary, word_to_index; for token in tokens: if token not in word_to_index: word_to_index[token] = len(vocabulary); vocabulary.append(token)
30
+ def text_to_vector(text): tokens = tokenize_text(text); update_vocabulary(tokens); indices = [word_to_index.get(token, 0) for token in tokens]; return torch.tensor(indices, dtype=torch.long).unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  def generate_and_queue_text(language):
33
  global categories, text_queue
34
+ num_categories = len(categories); num_texts_per_category = TEXT_GENERATION_RATE // (2 * num_categories)
 
35
  while True:
36
  for category in categories:
37
  for _ in range(num_texts_per_category):
38
+ uid = uuid.uuid4(); base_text = f"Category: {category}. ID:{uid}"
39
+ try: translator = GoogleTranslator(source='auto', target=language); text = translator.translate(base_text)
40
+ except: text = base_text
41
+ processed_text = ''.join(c for c in unicodedata.normalize('NFKC', text) if c.isprintable()); text_queue.put((processed_text, category)); time.sleep(0)
 
 
 
 
 
 
42
 
43
  def background_training():
44
  global categories, news_clf, feedback_queue, vocabulary
45
+ if categories is None: categories = ['DefaultCategory']
46
+ num_classes = len(categories); learning_rate = 0.01; epochs = 1
47
+ if news_clf is None: news_clf = SimpleClassifier(len(vocabulary), num_classes)
48
+ optimizer = torch.optim.SGD(news_clf.parameters(), lr=learning_rate); criterion = nn.CrossEntropyLoss()
 
 
 
 
 
49
  while True:
50
  try:
51
  feedback_item = feedback_queue.get(timeout=10)
52
  if feedback_item:
53
+ input_text, generated_text = feedback_item; input_vector = text_to_vector(input_text)
54
+ if len(vocabulary) == 0: vocabulary.extend(["<PAD>", "<EOS>"]); news_clf = SimpleClassifier(len(vocabulary), num_classes); optimizer = torch.optim.SGD(news_clf.parameters(), lr=learning_rate)
55
+ if input_vector.size(0) != len(vocabulary) and len(vocabulary) > 0: news_clf = SimpleClassifier(len(vocabulary), num_classes); optimizer = torch.optim.SGD(news_clf.parameters(), lr=learning_rate); input_vector = text_to_vector(input_text)
56
+ tokens = tokenize_text(input_text); update_vocabulary(tokens); tokens_indices = [word_to_index.get(word, 0) for word in tokens]
57
+ input_tensor = torch.tensor([tokens_indices], dtype=torch.long); target_index = categories.index(generated_text) if generated_text in categories else 0
 
 
 
 
 
 
 
 
 
 
58
  target_category_index = torch.tensor([target_index], dtype=torch.long)
59
+ if num_classes <= 1: num_classes = 2; news_clf.fc = nn.Linear(128, num_classes)
60
+ for _ in range(epochs): optimizer.zero_grad(); output = news_clf(input_tensor); loss = criterion(output, target_category_index); loss.backward(); optimizer.step()
 
 
 
 
 
 
 
61
  feedback_queue.task_done()
62
+ except queue.Empty: pass
63
+ except: time.sleep(5)
 
 
64
 
65
  def perform_reasoning_stream(text_input, temperature=0.7, top_k=40, top_p=0.0, repetition_penalty=1.2):
66
  for token in sample_sequence(text_input, model_gpt2, enc, length=999999999, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, device=device):
67
+ if token == "<END_STREAM>": yield "<END_STREAM>"; break
 
 
68
  yield token + " "
69
 
70
  def background_reasoning_queue():
 
72
  while True:
73
  try:
74
  item = reasoning_queue.get(timeout=1)
75
+ if item is None: reasoning_queue.task_done(); continue
76
+ text_input = item.get('text_input'); temperature = item.get('temperature', 0.7); top_k = item.get('top_k', 40); top_p = item.get('top_p', 0.0); repetition_penalty = item.get('repetition_penalty', 1.2)
 
 
 
 
 
 
77
  resp_queue = item.get('response_queue', queue.Queue())
78
+ if not text_input: resp_queue.put({"error": "Empty text input received."}); reasoning_queue.task_done(); continue
 
 
 
79
  generated_text_stream = perform_reasoning_stream(text_input, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty)
80
+ full_response = "";
81
  for chunk in generated_text_stream:
82
+ if chunk == "<END_STREAM>": break
 
83
  full_response += chunk
84
  cleaned_response = re.sub(r'\s+(?=[.,,。])', '', full_response.replace("<|endoftext|>", "").strip())
85
+ if cleaned_response in seen_responses: final_response = "**Response is repetitive. Please try again or rephrase your query.**"; resp_queue.put({"text": final_response})
86
+ else: seen_responses.add(cleaned_response); final_response = cleaned_response; resp_queue.put({"text": final_response})
 
 
 
 
 
87
  reasoning_queue.task_done()
88
+ except queue.Empty: pass
89
+ except Exception as e:
90
+ try: resp_queue.put({"error": str(e)})
91
+ except: pass
92
+ if reasoning_queue and not reasoning_queue.empty(): reasoning_queue.task_done()
 
 
 
 
codegen_api.py CHANGED
@@ -1,22 +1,13 @@
1
  from flask import jsonify, send_file, request
2
  from main import *
 
3
 
4
  def generate_code(prompt, output_path="output_code.py"):
5
- if codegen_model is None or codegen_tokenizer is None:
6
- return "Code generation model or tokenizer not initialized."
7
- input_ids = codegen_tokenizer(prompt, return_tensors='pt').to(device)
8
- output = codegen_model.generate(input_ids, max_length=2048, temperature=0.7, top_p=0.9)
9
- code = codegen_tokenizer.decode(output[0], skip_special_tokens=True)
10
- with open(output_path, "w") as file:
11
- file.write(code)
12
- return output_path
13
 
14
- def codegen_api():
15
- data = request.get_json()
16
- prompt = data.get('prompt')
17
- if not prompt:
18
- return jsonify({"error": "Prompt is required"}), 400
19
- output_file = generate_code(prompt)
20
- if output_file == "Code generation model or tokenizer not initialized.":
21
- return jsonify({"error": "Code generation failed"}), 500
22
- return send_file(output_file, mimetype="text/x-python", as_attachment=True, download_name="output.py")
 
1
  from flask import jsonify, send_file, request
2
  from main import *
3
+ import io, base64
4
 
5
  def generate_code(prompt, output_path="output_code.py"):
6
+ if codegen_model is None or codegen_tokenizer is None: return {"error": "Code generation model or tokenizer not initialized."}
7
+ input_ids = codegen_tokenizer(prompt, return_tensors='pt').to(device); output = codegen_model.generate(input_ids, max_length=2048, temperature=0.7, top_p=0.9)
8
+ code = codegen_tokenizer.decode(output[0], skip_special_tokens=True); return {"code": code}
 
 
 
 
 
9
 
10
+ def codegen_api(prompt):
11
+ output = generate_code(prompt)
12
+ if "error" in output: return {"error": output["error"]}
13
+ code_base64 = base64.b64encode(output['code'].encode('utf-8')).decode('utf-8'); return {"code_base64": code_base64, "mimetype": "text/x-python", "filename": "output.py"}
 
 
 
 
 
coder.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dash
2
+ from dash import Dash, html, dcc, callback, Output, Input, State
3
+ from dash.exceptions import PreventUpdate
4
+ import base64, uuid, requests, io, re
5
+
6
+ app_style = None
7
+ external_stylesheets = ['style.css']
8
+ app = dash.Dash(__name__, external_stylesheets=external_stylesheets)
9
+ app.layout = html.Div([dcc.Location(id='url', refresh=False), html.Div(id='page-content')])
10
+ index_page = html.Div([html.H1("AI Powerhouse", className='index-title animate__animated animate__fadeInDown'), html.Div([dcc.Link('Text Generation', href='/text-generation', className='index-link animate__animated animate__fadeInUp'), dcc.Link('Audio Video', href='/audio-video', className='index-link animate__animated animate__fadeInUp'), dcc.Link('Image Generation', href='/image-generation', className='index-link animate__animated animate__fadeInUp'), dcc.Link('Image to 3D', href='/image-to-3d', className='index-link animate__animated animate__fadeInUp'), dcc.Link('Text to Video', href='/text-to-video', className='index-link animate__animated animate__fadeInUp'), dcc.Link('Music Generation', href='/music-generation', className='index-link animate__animated animate__fadeInUp'), dcc.Link('Sentiment Analysis', href='/sentiment-analysis', className='index-link animate__animated animate__fadeInUp'), dcc.Link('Translation', href='/translation', className='index-link animate__animated animate__fadeInUp'), dcc.Link('Code Generation', href='/code-generation', className='index-link animate__animated animate__fadeInUp'), dcc.Link('Summarization', href='/summarization', className='index-link animate__animated animate__fadeInUp'), dcc.Link('SadTalker', href='/sad-talker', className='index-link animate__animated animate__fadeInUp')], className='index-links-container')], className='index-page page-layout')
11
+ text_generation_layout = html.Div([html.Div(className='animated-text animate__animated animate__fadeIn animate__infinite infinite', children='AI POWERED Text Generation'), html.Div(className='chat-container page-layout', children=[html.Div(className='chat-header animate__animated animate__fadeInDown', children="Text Generation Interface"), html.Div(className='chat-form animate__animated animate__fadeInLeft', children=[dcc.Textarea(id='text-input', placeholder='Enter text prompt...', rows=4, className='chat-text-area'), html.Button('Generate', id='generate-button', n_clicks=0, className='chat-generate-button')]), html.Div(id='output', className='chat-output animate__animated animate__fadeInUp', children=[html.Div(className='response-header', children="Response:"), dcc.Markdown(id='generated-text', className='response-text')]), dcc.Link('Back to Home', href='/', className='chat-back-link')])], className='page-layout text-gen-layout')
12
+ audio_video_layout = html.Div([html.Div(className='animated-text animate__animated animate__fadeIn animate__infinite infinite', children='AI POWERED Audio & Video Tools'), html.Div(className='av-container page-layout', children=[html.Div(className='av-header animate__animated animate__fadeInDown', children="Audio & Video Processing"), html.Div(className='av-upload-section animate__animated animate__fadeInLeft', children=[html.Div(className='upload-box', children=[dcc.Upload(id='upload-audio', children=html.Div(['Drag and Drop or ', html.A('Select Audio')]), className='upload-area'), html.Div(id='audio-output-text', className='upload-output', children="STT Output")]), html.Div(className='upload-box', children=[dcc.Upload(id='upload-image', children=html.Div(['Drag and Drop or ', html.A('Select Image')]), className='upload-area'), html.Div(id='image-output', className='upload-output', children="Image Uploaded")])]), html.Div(id='video-output', className='av-video-output animate__animated animate__fadeInUp', children="Video Output Area"), dcc.Link('Back to Home', href='/', className='av-back-link')])], className='page-layout av-layout')
13
+ image_generation_layout = html.Div([html.Div(className='animated-text animate__animated animate__fadeIn animate__infinite infinite', children='AI POWERED Image Generation'), html.Div(className='imagegen-container page-layout', children=[html.Div(className='imagegen-header animate__animated animate__fadeInDown', children="Image Generation Interface"), html.Div(className='imagegen-form animate__animated animate__fadeInLeft', children=[dcc.Textarea(id='imagegen-text-input', placeholder='Enter prompt for image...', rows=4, className='imagegen-text-area'), html.Button('Generate Image', id='generate-image-button', n_clicks=0, className='imagegen-generate-button')]), html.Div(id='image-output-display', className='imagegen-output animate__animated animate__fadeInUp'), dcc.Link('Back to Home', href='/', className='imagegen-back-link')])], className='page-layout imagegen-layout')
14
+ image_to_3d_layout = html.Div([html.Div(className='animated-text animate__animated animate__fadeIn animate__infinite infinite', children='AI POWERED Image to 3D Conversion'), html.Div(className='imagetod-container page-layout', children=[html.Div(className='imagetod-header animate__animated animate__fadeInDown', children="Image to 3D Model Conversion"), html.Div(className='imagetod-upload animate__animated animate__fadeInLeft', children=[dcc.Upload(id='upload-image-3d', children=html.Div(['Drag and Drop or ', html.A('Select Image')]), className='imagetod-upload-area')]), html.Div(id='3d-model-output', className='imagetod-output animate__animated animate__fadeInUp'), dcc.Link('Back to Home', href='/', className='imagetod-back-link')])], className='page-layout imagetod-layout')
15
+ text_to_video_layout = html.Div([html.Div(className='animated-text animate__animated animate__fadeIn animate__infinite infinite', children='AI POWERED Text to Video Generation'), html.Div(className='textvideo-container page-layout', children=[html.Div(className='textvideo-header animate__animated animate__fadeInDown', children="Text to Video Generation Interface"), html.Div(className='textvideo-form animate__animated animate__fadeInLeft', children=[dcc.Textarea(id='text-video-input', placeholder='Enter prompt for video...', rows=4, className='textvideo-text-area'), html.Button('Generate Video', id='generate-video-button', n_clicks=0, className='textvideo-generate-button')]), html.Div(id='video-gen-output', className='textvideo-output animate__animated animate__fadeInUp'), dcc.Link('Back to Home', href='/', className='textvideo-back-link')])], className='page-layout textvideo-layout')
16
+ music_generation_layout = html.Div([html.Div(className='animated-text animate__animated animate__fadeIn animate__infinite infinite', children='AI POWERED Music Generation'), html.Div(className='musicgen-container page-layout', children=[html.Div(className='musicgen-header animate__animated animate__fadeInDown', children="Music Generation Interface"), html.Div(className='musicgen-form animate__animated animate__fadeInLeft', children=[dcc.Textarea(id='musicgen-text-input', placeholder='Enter prompt for music...', rows=4, className='musicgen-text-area'), html.Button('Generate Music', id='generate-music-button', n_clicks=0, className='musicgen-generate-button')]), html.Div(id='music-output-display', className='musicgen-output animate__animated animate__fadeInUp'), dcc.Link('Back to Home', href='/', className='musicgen-back-link')])], className='page-layout musicgen-layout')
17
+ sentiment_analysis_layout = html.Div([html.Div(className='animated-text animate__animated animate__fadeIn animate__infinite infinite', children='AI POWERED Sentiment Analysis'), html.Div(className='sentiment-container page-layout', children=[html.Div(className='sentiment-header animate__animated animate__fadeInDown', children="Sentiment Analysis Tool"), html.Div(className='sentiment-form animate__animated animate__fadeInLeft', children=[dcc.Textarea(id='sentiment-text-input', placeholder='Enter text for analysis...', rows=4, className='sentiment-text-area'), html.Button('Analyze Sentiment', id='analyze-sentiment-button', n_clicks=0, className='sentiment-analyze-button')]), html.Div(id='sentiment-output-display', className='sentiment-output animate__animated animate__fadeInUp'), dcc.Link('Back to Home', href='/', className='sentiment-back-link')])], className='page-layout sentiment-layout')
18
+ translation_layout = html.Div([html.Div(className='animated-text animate__animated animate__fadeIn animate__infinite infinite', children='AI POWERED Translation Services'), html.Div(className='translation-container page-layout', children=[html.Div(className='translation-header animate__animated animate__fadeInDown', children="Translation Interface"), html.Div(className='translation-form animate__animated animate__fadeInLeft', children=[dcc.Textarea(id='translate-text-input', placeholder='Enter text to translate...', rows=4, className='translation-text-area'), dcc.Dropdown(id='target-language-dropdown', options=[{'label': 'Spanish', 'value': 'es'},{'label': 'English', 'value': 'en'},{'label': 'French', 'value': 'fr'},{'label': 'German', 'value': 'de'}], value='es', className='translation-dropdown'), html.Button('Translate', id='translate-button', n_clicks=0, className='translation-translate-button')]), html.Div(id='translation-output-display', className='translation-output animate__animated animate__fadeInUp'), dcc.Link('Back to Home', href='/', className='translation-back-link')])], className='page-layout translation-layout')
19
+ code_generation_layout = html.Div([html.Div(className='animated-text animate__animated animate__fadeIn animate__infinite infinite', children='AI POWERED Code Generation'), html.Div(className='codegen-container page-layout', children=[html.Div(className='codegen-header animate__animated animate__fadeInDown', children="Code Generation Interface"), html.Div(className='codegen-form animate__animated animate__fadeInLeft', children=[dcc.Textarea(id='codegen-text-input', placeholder='Enter prompt for code...', rows=4, className='codegen-text-area'), html.Button('Generate Code', id='generate-code-button', n_clicks=0, className='codegen-generate-button')]), html.Div(id='codegen-output-display', className='codegen-output animate__animated animate__fadeInUp'), dcc.Link('Back to Home', href='/', className='codegen-back-link')])], className='page-layout codegen-layout')
20
+ summarization_layout = html.Div([html.Div(className='animated-text animate__animated animate__fadeIn animate__infinite infinite', children='AI POWERED Text Summarization'), html.Div(className='summarization-container page-layout', children=[html.Div(className='summarization-header animate__animated animate__fadeInDown', children="Text Summarization Tool"), html.Div(className='summarization-form animate__animated animate__fadeInLeft', children=[dcc.Textarea(id='summarize-text-input', placeholder='Enter text to summarize...', rows=4, className='summarization-text-area'), html.Button('Summarize', id='summarize-button', n_clicks=0, className='summarization-summarize-button')]), html.Div(id='summarization-output-display', className='summarization-output animate__animated animate__fadeInUp'), dcc.Link('Back to Home', href='/', className='summarization-back-link')])], className='page-layout summarization-layout')
21
+ sadtalker_layout = html.Div([html.Div(className='animated-text animate__animated animate__fadeIn animate__infinite infinite', children='AI POWERED SadTalker'), html.Div(className='sadtalker-container page-layout', children=[html.Div(className='sadtalker-header animate__animated animate__fadeInDown', children="SadTalker Interface"), html.Div(className='sadtalker-upload animate__animated animate__fadeInLeft', children=[dcc.Upload(id='upload-sadtalker-image', children=html.Div(['Drag and Drop Image', html.Br(), html.I(className="fas fa-image upload-icon")]), className='sadtalker-upload-area', multiple=False), dcc.Upload(id='upload-sadtalker-audio', children=html.Div(['Drag and Drop Audio', html.Br(), html.I(className="fas fa-file-audio upload-icon")]), className='sadtalker-upload-area', multiple=False)]), html.Div(id='sadtalker-video-output', className='sadtalker-output animate__animated animate__fadeInUp'), dcc.Link('Back to Home', href='/', className='sadtalker-back-link')])], className='page-layout sadtalker-layout')
22
+
23
+ @app.callback(Output('page-content', 'children'), [Input('url', 'pathname')])
24
+ def display_page(pathname):
25
+ if pathname == '/text-generation': return text_generation_layout
26
+ elif pathname == '/audio-video': return audio_video_layout
27
+ elif pathname == '/image-generation': return image_generation_layout
28
+ elif pathname == '/image-to-3d': return image_to_3d_layout
29
+ elif pathname == '/text-to-video': return text_to_video_layout
30
+ elif pathname == '/music-generation': return music_generation_layout
31
+ elif pathname == '/sentiment-analysis': return sentiment_analysis_layout
32
+ elif pathname == '/translation': return translation_layout
33
+ elif pathname == '/code-generation': return code_generation_layout
34
+ elif pathname == '/summarization': return summarization_layout
35
+ elif pathname == '/sad-talker': return sadtalker_layout
36
+ else: return index_page
37
+
38
+ @app.callback(Output('generated-text', 'children'), Output('output', 'className'), Input('generate-button', 'n_clicks'), State('text-input', 'value'), prevent_initial_call=True)
39
+ def generate_reasoning_dash(n_clicks, text_input):
40
+ if not text_input: return "Please enter text.", 'chat-output animate__animated animate__fadeInUp error-output'
41
+ api_url = "/api/v1/generate"; payload = {"text": text_input}
42
+ try: response = requests.post("http://127.0.0.1:7860" + api_url, json=payload); response.raise_for_status(); data = response.json(); generated_text = data.get("response", "Error in backend response."); return generated_text, 'chat-output animate__animated animate__fadeInUp'
43
+ except requests.exceptions.RequestException as e: return f"Error communicating with backend: {e}", 'chat-output animate__animated animate__fadeInUp error-output'
44
+
45
+ @app.callback(Output('audio-output-text', 'children'), Output('video-output', 'children'), Output('audio-output-text', 'className'), Output('video-output', 'className'), Input('upload-audio', 'contents'), State('upload-audio', 'filename'), Input('upload-image', 'contents'), State('upload-image', 'filename'), prevent_initial_call=True)
46
+ def process_audio_video_dash(audio_contents, audio_filename, image_contents, image_filename):
47
+ stt_output_text = ""; video_display = ""; stt_class = 'upload-output'; video_class = 'av-video-output animate__animated animate__fadeInUp'
48
+ if audio_contents:
49
+ try: content_type, content_string = audio_contents.split(','); decoded_audio = base64.b64decode(content_string); audio_io = io.BytesIO(decoded_audio); files = {'audio': (audio_filename, audio_io, content_type)}; response = requests.post("http://127.0.0.1:7860/api/v1/stt", files=files); response.raise_for_status(); data = response.json(); stt_output_text = f"STT Output: {data.get('text', 'Transcription failed')}"; stt_class = 'upload-output success'
50
+ except requests.exceptions.RequestException as e: stt_output_text = f"STT Error: {e}"; stt_class = 'upload-output error'
51
+
52
+ if image_contents:
53
+ try: content_type, content_string = image_contents.split(','); decoded_image = base64.b64decode(content_string); image_io = io.BytesIO(decoded_image); files = {'image': (image_filename, image_io, content_type)}; response = requests.post("http://127.0.0.1:7860/api/v1/image_to_3d", files=files); response.raise_for_status(); video_display = "3D Model Feature Extracted (Check Backend Logs for Output)."; video_class = 'av-video-output animate__animated animate__fadeInUp success'
54
+ except requests.exceptions.RequestException as e: video_display = f"3D Error: {e}"; video_class = 'av-video-output animate__animated animate__fadeInUp error'
55
+
56
+ video_output_component = html.Div(video_display) if video_display else ""; return stt_output_text, video_output_component, stt_class, video_class
57
+
58
+ @app.callback(Output('image-output-display', 'children'), Output('image-output-display', 'className'), Input('generate-image-button', 'n_clicks'), State('imagegen-text-input', 'value'), prevent_initial_call=True)
59
+ def generate_image_dash(n_clicks, prompt):
60
+ if not prompt: return "Please enter a prompt for image generation.", 'imagegen-output animate__animated animate__fadeInUp error-output'
61
+ api_url = "/api/v1/imagegen"; payload = {"prompt": prompt}
62
+ try: response = requests.post("http://127.0.0.1:7860" + api_url, json=payload); response.raise_for_status(); image_base64 = base64.b64encode(response.content).decode('utf-8'); return html.Img(src=f'data:image/png;base64,{image_base64}', className='generated-image'), 'imagegen-output animate__animated animate__fadeInUp success-output'
63
+ except requests.exceptions.RequestException as e: return f"Image Generation Error: {e}", 'imagegen-output animate__animated animate__fadeInUp error-output'
64
+
65
+ @app.callback(Output('3d-model-output', 'children'), Output('3d-model-output', 'className'), Input('upload-image-3d', 'contents'), State('upload-image-3d', 'filename'), prevent_initial_call=True)
66
+ def process_image_to_3d_dash(contents, filename):
67
+ if contents is None: raise PreventUpdate
68
+ try:
69
+ content_type, content_string = contents.split(',')
70
+ decoded_image = base64.b64decode(content_string); image_io = io.BytesIO(decoded_image); files = {'image': (filename, image_io, content_type)}
71
+ response = requests.post("http://127.0.0.1:7860/api/v1/image_to_3d", files=files); response.raise_for_status()
72
+ content_disposition = response.headers.get('Content-Disposition'); download_filename = 'model_3d.obj'
73
+ if content_disposition: filenames = re.findall('filename="([^"]+)"', content_disposition);
74
+ if filenames: download_filename = filenames[0]
75
+ model_base64 = base64.b64encode(response.content).decode('utf-8'); href = f'data:model/obj;base64,{model_base64}'
76
+ download_link = html.A('Download 3D Model', href=href, download=download_filename, className='download-link'),; return download_link, 'imagetod-output animate__animated animate__fadeInUp success-output'
77
+ except requests.exceptions.RequestException as e: return f"3D Conversion Error: {e}", 'imagetod-output animate__animated animate__fadeInUp error-output'
78
+
79
+ @app.callback(Output('video-gen-output', 'children'), Output('video-gen-output', 'className'), Input('generate-video-button', 'n_clicks'), State('text-video-input', 'value'), prevent_initial_call=True)
80
+ def generate_video_dash(n_clicks, prompt):
81
+ if not prompt: return "Please enter a prompt for video generation.", 'textvideo-output animate__animated animate__fadeInUp error-output'
82
+ api_url = "/api/v1/text_to_video"; payload = {"prompt": prompt}
83
+ try: response = requests.post("http://127.0.0.1:7860" + api_url, json=payload); response.raise_for_status(); video_base64 = base64.b64encode(response.content).decode('utf-8'); return html.Video(src=f'data:video/mp4;base64,{video_base64}', controls=True, className='generated-video'), 'textvideo-output animate__animated animate__fadeInUp success-output'
84
+ except requests.exceptions.RequestException as e: return f"Video Generation Error: {e}", 'textvideo-output animate__animated animate__fadeInUp error-output'
85
+
86
+ @app.callback(Output('music-output-display', 'children'), Output('music-output-display', 'className'), Input('generate-music-button', 'n_clicks'), State('musicgen-text-input', 'value'), prevent_initial_call=True)
87
+ def generate_music_dash(n_clicks, prompt):
88
+ if not prompt: return "Please enter a prompt for music generation.", 'musicgen-output animate__animated animate__fadeInUp error-output'
89
+ api_url = "/api/v1/musicgen"; payload = {"prompt": prompt}
90
+ try: response = requests.post("http://127.0.0.1:7860" + api_url, json=payload); response.raise_for_status(); audio_base64 = base64.b64encode(response.content).decode('utf-8'); return html.Audio(src=f'data:audio/wav;base64,{audio_base64}', controls=True, className='generated-audio'), 'musicgen-output animate__animated animate__fadeInUp success-output'
91
+ except requests.exceptions.RequestException as e: return f"Music Generation Error: {e}", 'musicgen-output animate__animated animate__fadeInUp error-output'
92
+
93
+ @app.callback(Output('sentiment-output-display', 'children'), Output('sentiment-output-display', 'className'), Input('analyze-sentiment-button', 'n_clicks'), State('sentiment-text-input', 'value'), prevent_initial_call=True)
94
+ def analyze_sentiment_dash(n_clicks, text):
95
+ if not text: return "Please enter text for sentiment analysis.", 'sentiment-output animate__animated animate__fadeInUp error-output'
96
+ api_url = "/api/v1/sentiment"; payload = {"text": text}
97
+ try: response = requests.post("http://127.0.0.1:7860" + api_url, json=payload); response.raise_for_status(); data = response.json(); sentiment_label = data.get('sentiment', 'Analysis Failed'); return f"Sentiment: {sentiment_label}", 'sentiment-output animate__animated animate__fadeInUp success-output'
98
+ except requests.exceptions.RequestException as e: return f"Sentiment Analysis Error: {e}", 'sentiment-output animate__animated animate__fadeInUp error-output'
99
+
100
+ @app.callback(Output('translation-output-display', 'children'), Output('translation-output-display', 'className'), Input('translate-button', 'n_clicks'), State('translate-text-input', 'value'), State('target-language-dropdown', 'value'), prevent_initial_call=True)
101
+ def translate_text_dash(n_clicks, text, target_lang):
102
+ if not text: return "Please enter text for translation.", 'translation-output animate__animated animate__fadeInUp error-output'
103
+ api_url = "/api/v1/translation"; payload = {"text": text, "target_lang": target_lang}
104
+ try: response = requests.post("http://127.0.0.1:7860" + api_url, json=payload); response.raise_for_status(); data = response.json(); translation = data.get('translated_text', 'Translation Failed'); return f"Translation ({target_lang.upper()}): {translation}", 'translation-output animate__animated animate__fadeInUp success-output'
105
+ except requests.exceptions.RequestException as e: return f"Translation Error: {e}", 'translation-output animate__animated animate__fadeInUp error-output'
106
+
107
+ @app.callback(Output('codegen-output-display', 'children'), Output('codegen-output-display', 'className'), Input('generate-code-button', 'n_clicks'), State('codegen-text-input', 'value'), prevent_initial_call=True)
108
+ def generate_code_dash(n_clicks, prompt):
109
+ if not prompt: return "Please enter a prompt for code generation.", 'codegen-output animate__animated animate__fadeInUp error-output'
110
+ api_url = "/api/v1/codegen"; payload = {"prompt": prompt}
111
+ try: response = requests.post("http://127.0.0.1:7860" + api_url, json=payload); response.raise_for_status()
112
+ content_disposition = response.headers.get('Content-Disposition'); download_filename = 'code.py'
113
+ if content_disposition: filenames = re.findall('filename="([^"]+)"', content_disposition); if filenames: download_filename = filenames[0]
114
+ code_base64 = base64.b64encode(response.content).decode('utf-8'); download_link = html.A('Download Code', href=f'data:text/x-python;base64,{code_base64}', download=download_filename, className='download-link'); return download_link, 'codegen-output animate__animated animate__fadeInUp success-output'
115
+ except requests.exceptions.RequestException as e: return f"Code Generation Error: {e}", 'codegen-output animate__animated animate__fadeInUp error-output'
116
+
117
+ @app.callback(Output('summarization-output-display', 'children'), Output('summarization-output-display', 'className'), Input('summarize-button', 'n_clicks'), State('summarize-text-input', 'value'), prevent_initial_call=True)
118
+ def summarize_text_dash(n_clicks, text):
119
+ if not text: return "Please enter text for summarization.", 'summarization-output animate__animated animate__fadeInUp error-output'
120
+ api_url = "/api/v1/summarization"; payload = {"text": text}
121
+ try: response = requests.post("http://127.0.0.1:7860" + api_url, json=payload); response.raise_for_status()
122
+ content_disposition = response.headers.get('Content-Disposition'); download_filename = 'summary.txt'
123
+ if content_disposition: filenames = re.findall('filename="([^"]+)"', content_disposition); if filenames: download_filename = filenames[0]
124
+ summary_base64 = base64.b64encode(response.content).decode('utf-8'); download_link = html.A('Download Summary', href=f'data:text/plain;base64,{summary_base64}', download=download_filename, className='download-link'); return download_link, 'summarization-output animate__animated animate__fadeInUp success-output'
125
+ except requests.exceptions.RequestException as e: return f"Summarization Error: {e}", 'summarization-output animate__animated animate__fadeInUp error-output'
126
+
127
+ @app.callback(Output('sadtalker-video-output', 'children'), Output('sadtalker-video-output', 'className'), Input('upload-sadtalker-image', 'contents'), State('upload-sadtalker-image', 'filename'), Input('upload-sadtalker-audio', 'contents'), State('upload-sadtalker-audio', 'filename'), prevent_initial_call=True)
128
+ def process_sadtalker_dash(image_contents, image_filename, audio_contents, audio_filename):
129
+ if not image_contents or not audio_contents: return "Please upload both image and audio for SadTalker.", 'sadtalker-output animate__animated animate__fadeInUp error-output'
130
+ try:
131
+ image_content_type, image_content_string = image_contents.split(','); decoded_image = base64.b64decode(image_content_string); image_io = io.BytesIO(decoded_image)
132
+ audio_content_type, audio_content_string = audio_contents.split(','); decoded_audio = base64.b64decode(audio_content_string); audio_io = io.BytesIO(decoded_audio)
133
+ files = {'source_image_file': (image_filename, image_io, image_content_type), 'driven_audio_file': (audio_filename, audio_io, audio_content_type)}
134
+ response = requests.post("http://127.0.0.1:7860/api/v1/sadtalker", files=files); response.raise_for_status(); data = response.json(); video_url = data.get('video_url')
135
+ if video_url: video_base64 = base64.b64encode(requests.get(video_url).content).decode('utf-8'); return html.Video(src=f'data:video/mp4;base64,{video_base64}', controls=True, className='generated-video'), 'sadtalker-output animate__animated animate__fadeInUp success-output'
136
+ else: return "SadTalker video generation failed, check backend logs.", 'sadtalker-output animate__animated animate__fadeInUp error-output'
137
+ except requests.exceptions.RequestException as e: return f"SadTalker Error: {e}", 'sadtalker-output animate__animated animate__fadeInUp error-output'
138
+
139
+ if __name__ == '__main__': app.run_server(host='0.0.0.0', port=7861, debug=False)
image_to_3d_api.py CHANGED
@@ -1,31 +1,19 @@
1
- import os
2
- import uuid
 
 
 
3
  from flask import jsonify, send_file, request
4
  from main import *
5
  from PIL import Image
6
- import torch
7
- import numpy as np
8
 
9
  def image_to_3d_func(image_path, output_path="output_3d.obj"):
10
- if image_to_3d_model is None:
11
- return "Image-to-3D model not initialized."
12
- pil_image = Image.open(image_path).convert("RGB")
13
- image = torch.tensor(np.array(pil_image)).float().permute(2,0,1).unsqueeze(0) / 255.0
14
- image = image.to(device)
15
- with torch.no_grad():
16
- mesh_obj = image_to_3d_model(image)
17
- with open(output_path, 'w') as f:
18
- f.write(mesh_obj)
19
- return output_path
20
 
21
- def image_to_3d_api():
22
- if 'image' not in request.files:
23
- return jsonify({"error": "Image file is required"}), 400
24
- image_file = request.files['image']
25
- temp_image_path = f"temp_image_{uuid.uuid4()}.png"
26
- image_file.save(temp_image_path)
27
- output_file = image_to_3d_func(temp_image_path)
28
- os.remove(temp_image_path)
29
- if output_file == "Image-to-3D model not initialized.":
30
- return jsonify({"error": "Image to 3D failed"}), 500
31
- return send_file(output_file, mimetype="model/obj", as_attachment=True, download_name="output_3d.obj")
 
1
+ ```
2
+
3
+ ```python
4
+ --- START OF FILE image_to_3d_api.py ---
5
+ import os, uuid
6
  from flask import jsonify, send_file, request
7
  from main import *
8
  from PIL import Image
9
+ import torch, numpy as np, io, base64
 
10
 
11
  def image_to_3d_func(image_path, output_path="output_3d.obj"):
12
+ if image_to_3d_model is None: return {"error": "Image-to-3D model not initialized."}
13
+ pil_image = Image.open(image_path).convert("RGB"); image = torch.tensor(np.array(pil_image)).float().permute(2,0,1).unsqueeze(0) / 255.0; image = image.to(device)
14
+ with torch.no_grad(): mesh_obj = image_to_3d_model(image); return {"model_3d": mesh_obj}
 
 
 
 
 
 
 
15
 
16
+ def image_to_3d_api(image_path):
17
+ output = image_to_3d_func(image_path)
18
+ if "error" in output: return {"error": output["error"]}
19
+ model_3d_base64 = base64.b64encode(output['model_3d'].encode('utf-8')).decode('utf-8'); return {"model_3d_base64": model_3d_base64, "mimetype": "model/obj", "filename": "output_3d.obj"}
 
 
 
 
 
 
 
imagegen_api.py CHANGED
@@ -3,31 +3,15 @@ from flask import jsonify, send_file, request
3
  from io import BytesIO
4
  from PIL import Image
5
  from main import *
6
- import torch
7
 
8
  def generate_image(prompt, output_path="output_image.png"):
9
- if imagegen_model is None:
10
- return "Image generation model not initialized."
11
-
12
  generator = torch.Generator(device=device).manual_seed(0)
13
- with torch.no_grad():
14
- image = imagegen_model(
15
- prompt,
16
- generator=generator,
17
- ).images[0]
18
- image.save(output_path)
19
- return output_path
20
 
21
- def imagegen_api():
22
- data = request.get_json()
23
- prompt = data.get('prompt')
24
- if not prompt:
25
- return jsonify({"error": "Prompt is required"}), 400
26
  output_file = generate_image(prompt)
27
- if output_file == "Image generation model not initialized.":
28
- return jsonify({"error": "Image generation failed"}), 500
29
- image_io = BytesIO()
30
- pil_image = Image.open(output_file)
31
- pil_image.save(image_io, 'PNG')
32
- image_io.seek(0)
33
- return send_file(image_io, mimetype='image/png', as_attachment=True, download_name="output.png")
 
3
  from io import BytesIO
4
  from PIL import Image
5
  from main import *
6
+ import torch, base64
7
 
8
  def generate_image(prompt, output_path="output_image.png"):
9
+ if imagegen_model is None: return {"error": "Image generation model not initialized."}
 
 
10
  generator = torch.Generator(device=device).manual_seed(0)
11
+ with torch.no_grad(): image = imagegen_model(prompt, generator=generator,).images[0]; image.save(output_path); return output_path
 
 
 
 
 
 
12
 
13
+ def imagegen_api(prompt):
 
 
 
 
14
  output_file = generate_image(prompt)
15
+ if isinstance(output_file, dict) and "error" in output_file: return {"error": output_file["error"]}
16
+ image_io = BytesIO(); pil_image = Image.open(output_file); pil_image.save(image_io, 'PNG'); image_base64 = base64.b64encode(image_io.getvalue()).decode('utf-8')
17
+ os.remove(output_file); return {"image_base64": image_base64, "mimetype": "image/png"}
 
 
 
 
main.py CHANGED
@@ -1,10 +1,4 @@
1
- import threading
2
- import queue
3
- import time
4
- import os
5
- import nltk
6
- import re
7
- import json
8
  from flask import Flask
9
  from flask_cors import CORS
10
  from api import *
@@ -19,50 +13,13 @@ from background_tasks import *
19
  from text_generation import *
20
  from sadtalker_utils import *
21
 
22
-
23
- state_dict = None
24
- enc = None
25
- config = None
26
- model_gpt2 = None
27
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
- news_clf = None
29
- tfidf_vectorizer = None
30
- text_queue = queue.Queue()
31
- categories = None
32
- background_threads = []
33
- feedback_queue = queue.Queue()
34
- reasoning_queue = queue.Queue()
35
- seen_responses = set()
36
- dialogue_history = []
37
- vocabulary = set()
38
- word_to_index = {}
39
- index_to_word = []
40
- translation_model = None
41
- sp = None
42
- codegen_model = None
43
- codegen_tokenizer = None
44
- codegen_vocabulary = None
45
- codegen_index_to_word = None
46
- codegen_word_to_index = None
47
- summarization_model = None
48
- summarization_vocabulary = set()
49
- summarization_word_to_index = {}
50
- summarization_index_to_word = []
51
- sadtalker_instance = None
52
- imagegen_model = None
53
- image_to_3d_model = None
54
- text_to_video_model = None
55
- stream_type = "text"
56
- sentiment_model = None
57
- stt_model = None
58
- tts_model = None
59
- musicgen_model = None
60
 
61
  def load_models():
62
- global model_gpt2, enc, translation_model, codegen_model, codegen_tokenizer, codegen_vocabulary, codegen_index_to_word, codegen_word_to_index, summarization_model, imagegen_model, image_to_3d_model, text_to_video_model, sadtalker_instance, sentiment_model, stt_model, tts_model, musicgen_model
63
  model_gpt2, enc = initialize_gpt2_model(GPT2_FOLDER, {MODEL_FILE: MODEL_URL, ENCODER_FILE: ENCODER_URL, VOCAB_FILE: VOCAB_URL, CONFIG_FILE: GPT2CONFHG})
64
  translation_model = initialize_translation_model(TRANSLATION_FOLDER, TRANSLATION_MODEL_FILES_URLS)
65
- codegen_model, codegen_tokenizer, codegen_vocabulary, codegen_index_to_word, codegen_word_to_index = initialize_codegen_model(CODEGEN_FOLDER, CODEGEN_FILES_URLS)
66
  summarization_model, _, _, _ = initialize_summarization_model(SUMMARIZATION_FOLDER, SUMMARIZATION_FILES_URLS)
67
  imagegen_model = initialize_imagegen_model(IMAGEGEN_FOLDER, IMAGEGEN_FILES_URLS)
68
  image_to_3d_model = initialize_image_to_3d_model(IMAGE_TO_3D_FOLDER, IMAGE_TO_3D_FILES_URLS)
@@ -71,6 +28,7 @@ def load_models():
71
  stt_model = initialize_stt_model(STT_FOLDER, STT_FILES_URLS)
72
  tts_model = initialize_tts_model(TTS_FOLDER, TTS_FILES_URLS)
73
  musicgen_model = initialize_musicgen_model(MUSICGEN_FOLDER, MUSICGEN_FILES_URLS)
 
74
  sadtalker_instance = SadTalker(checkpoint_path='./checkpoints', config_path='./src/config')
75
 
76
  if __name__ == "__main__":
@@ -78,13 +36,8 @@ if __name__ == "__main__":
78
  load_models()
79
  categories = ['Category1', 'Category2', 'Category3', 'Category4', 'Category5']
80
  import background_tasks
81
- background_tasks.categories = categories
82
- background_tasks.text_queue = text_queue
83
- background_tasks.reasoning_queue = reasoning_queue
84
- background_threads.append(threading.Thread(target=generate_and_queue_text, args=('en',), daemon=True))
85
- background_threads.append(threading.Thread(target=generate_and_queue_text, args=('es',), daemon=True))
86
- background_threads.append(threading.Thread(target=background_training, daemon=True))
87
- background_threads.append(threading.Thread(target=background_reasoning_queue, daemon=True))
88
- for thread in background_threads:
89
- thread.start()
90
  app.run(host='0.0.0.0', port=7860)
 
1
+ import threading, queue, time, os, nltk, re, json
 
 
 
 
 
 
2
  from flask import Flask
3
  from flask_cors import CORS
4
  from api import *
 
13
  from text_generation import *
14
  from sadtalker_utils import *
15
 
16
+ state_dict, enc, config, model_gpt2, device, news_clf, tfidf_vectorizer, text_queue, categories, background_threads, feedback_queue, reasoning_queue, seen_responses, dialogue_history, vocabulary, word_to_index, index_to_word, translation_model, sp, codegen_model, codegen_tokenizer, codegen_vocabulary, codegen_index_to_word, codegen_word_to_index, summarization_model, summarization_vocabulary, summarization_word_to_index, summarization_index_to_word, sadtalker_instance, imagegen_model, image_to_3d_model, text_to_video_model, stream_type, sentiment_model, stt_model, tts_model, musicgen_model, xtts_model = None, None, None, None, torch.device("cuda" if torch.cuda.is_available() else "cpu"), None, None, queue.Queue(), None, [], queue.Queue(), queue.Queue(), set(), [], set(), {}, [], None, None, None, None, None, None, set(), {}, [], None, None, None, None, "text", None, None, None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  def load_models():
19
+ global model_gpt2, enc, translation_model, codegen_model, codegen_tokenizer, summarization_model, imagegen_model, image_to_3d_model, text_to_video_model, sadtalker_instance, sentiment_model, stt_model, tts_model, musicgen_model, xtts_model
20
  model_gpt2, enc = initialize_gpt2_model(GPT2_FOLDER, {MODEL_FILE: MODEL_URL, ENCODER_FILE: ENCODER_URL, VOCAB_FILE: VOCAB_URL, CONFIG_FILE: GPT2CONFHG})
21
  translation_model = initialize_translation_model(TRANSLATION_FOLDER, TRANSLATION_MODEL_FILES_URLS)
22
+ codegen_model, codegen_tokenizer, _, _, _ = initialize_codegen_model(CODEGEN_FOLDER, CODEGEN_FILES_URLS)
23
  summarization_model, _, _, _ = initialize_summarization_model(SUMMARIZATION_FOLDER, SUMMARIZATION_FILES_URLS)
24
  imagegen_model = initialize_imagegen_model(IMAGEGEN_FOLDER, IMAGEGEN_FILES_URLS)
25
  image_to_3d_model = initialize_image_to_3d_model(IMAGE_TO_3D_FOLDER, IMAGE_TO_3D_FILES_URLS)
 
28
  stt_model = initialize_stt_model(STT_FOLDER, STT_FILES_URLS)
29
  tts_model = initialize_tts_model(TTS_FOLDER, TTS_FILES_URLS)
30
  musicgen_model = initialize_musicgen_model(MUSICGEN_FOLDER, MUSICGEN_FILES_URLS)
31
+ xtts_model = initialize_xtts_model(XTTS_FOLDER, XTTS_FILES_URLS)
32
  sadtalker_instance = SadTalker(checkpoint_path='./checkpoints', config_path='./src/config')
33
 
34
  if __name__ == "__main__":
 
36
  load_models()
37
  categories = ['Category1', 'Category2', 'Category3', 'Category4', 'Category5']
38
  import background_tasks
39
+ background_tasks.categories = categories; background_tasks.text_queue = text_queue; background_tasks.reasoning_queue = reasoning_queue
40
+ background_threads.append(threading.Thread(target=generate_and_queue_text, args=('en',), daemon=True)); background_threads.append(threading.Thread(target=generate_and_queue_text, args=('es',), daemon=True))
41
+ background_threads.append(threading.Thread(target=background_training, daemon=True)); background_threads.append(threading.Thread(target=background_reasoning_queue, daemon=True))
42
+ for thread in background_threads: thread.start()
 
 
 
 
 
43
  app.run(host='0.0.0.0', port=7860)
model_loader.py CHANGED
@@ -1,725 +1,229 @@
1
- from tokenxxx import *
2
- from constants import *
3
- from utils import *
4
- import os
5
- import json
6
- import urllib.request
7
- import urllib.parse
8
- import torch
9
- import hashlib
10
- from tqdm import tqdm
11
- from skimage import img_as_ubyte
12
- from torch import nn
13
- import torch.nn.functional as F
14
- import inspect
15
-
16
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
-
18
- def filter_kwargs(cls, kwargs):
19
- sig = inspect.signature(cls.__init__)
20
- accepted = set(sig.parameters.keys()) - {"self"}
21
- return {k: v for k, v in kwargs.items() if k in accepted}
22
-
23
- def sanitize_filename(name, url=None):
24
- for c in '<>:"/\\|?*':
25
- name = name.replace(c, '')
26
- if not name and url is not None:
27
- name = hashlib.md5(url.encode()).hexdigest()
28
- return name
29
-
30
- def download_file(url, filepath):
31
- d = os.path.dirname(filepath)
32
- if d and not os.path.exists(d):
33
- os.makedirs(d, exist_ok=True)
34
- while not os.path.exists(filepath):
35
- try:
36
- def prog(t):
37
- last = [0]
38
- def inner(n, bs, ts):
39
- if ts > 0:
40
- t.total = ts
41
- t.update(n * bs - last[0])
42
- last[0] = n * bs
43
- return inner
44
- with tqdm(unit='B', unit_scale=True, unit_divisor=1024, desc=os.path.basename(filepath)) as t:
45
- urllib.request.urlretrieve(url, filepath, reporthook=prog(t))
46
- except Exception:
47
- continue
48
-
49
- def download_files(folder, files_spec):
50
- if isinstance(files_spec, dict):
51
- for fn, url in files_spec.items():
52
- fn = sanitize_filename(fn, url)
53
- fp = os.path.join(folder, fn)
54
- download_file(url, fp)
55
- elif isinstance(files_spec, list):
56
- for item in files_spec:
57
- if isinstance(item, str):
58
- url = item
59
- parsed = urllib.parse.urlparse(url)
60
- fn = os.path.basename(parsed.path)
61
- if not fn:
62
- fn = hashlib.md5(url.encode()).hexdigest()
63
- fn = sanitize_filename(fn, url)
64
- elif isinstance(item, (list, tuple)) and len(item) == 2:
65
- url, fn = item
66
- fn = sanitize_filename(fn, url)
67
- elif isinstance(item, dict) and "filename" in item and "url" in item:
68
- fn = sanitize_filename(item["filename"], item["url"])
69
- url = item["url"]
70
- else:
71
- raise ValueError("Invalid file specification")
72
- fp = os.path.join(folder, fn)
73
- download_file(url, fp)
74
- else:
75
- raise ValueError("files_spec must be dict or list")
76
-
77
- def read_json(fp):
78
- with open(fp, 'r', encoding='utf-8') as f:
79
- return json.load(f)
80
-
81
- def get_codegen_tokenizer(vocab_path, merges_path):
82
- with open(vocab_path, 'r', encoding='utf-8') as f:
83
- vocab = json.load(f)
84
- with open(merges_path, 'r', encoding='utf-8') as f:
85
- merges = f.read().splitlines()
86
- merge_ranks = {}
87
- for i, merge in enumerate(merges):
88
- parts = merge.strip().split()
89
- if len(parts) == 2:
90
- merge_ranks[tuple(parts)] = i
91
- def bpe(token):
92
- word = list(token)
93
- pairs = [(word[i], word[i+1]) for i in range(len(word)-1)]
94
- while True:
95
- candidate = None
96
- candidate_rank = None
97
- candidate_index = None
98
- for i, pair in enumerate(pairs):
99
- if pair in merge_ranks:
100
- rank = merge_ranks[pair]
101
- if candidate is None or rank < candidate_rank:
102
- candidate = pair
103
- candidate_rank = rank
104
- candidate_index = i
105
- if candidate is None:
106
- break
107
- first, second = candidate
108
- new_word = []
109
- i = 0
110
- while i < len(word):
111
- if i < len(word) - 1 and word[i] == first and word[i+1] == second:
112
- new_word.append(first + second)
113
- i += 2
114
- else:
115
- new_word.append(word[i])
116
- i += 1
117
- word = new_word
118
- if len(word) == 1:
119
- break
120
- pairs = [(word[i], word[i+1]) for i in range(len(word)-1)]
121
- return word
122
- def tokenizer(text):
123
- tokens = []
124
- for token in text.split():
125
- bpe_tokens = bpe(token)
126
- for subtoken in bpe_tokens:
127
- tokens.append(vocab.get(subtoken, 0))
128
- return tokens
129
- return tokenizer
130
-
131
- def simple_tokenizer(text, vocab, max_length=77):
132
- toks = text.split()
133
- ids = [vocab.get(t, 1) for t in toks]
134
- if len(ids) < max_length:
135
- ids = ids + [0] * (max_length - len(ids))
136
- else:
137
- ids = ids[:max_length]
138
- return torch.tensor(ids, dtype=torch.long).unsqueeze(0).to(device)
139
-
140
- def load_state_dict_safe(model, loaded_state_dict):
141
- model_state = model.state_dict()
142
- new_state = {}
143
- for key, value in model_state.items():
144
- if key in loaded_state_dict and loaded_state_dict[key].shape == value.shape:
145
- new_state[key] = loaded_state_dict[key]
146
- else:
147
- new_state[key] = value
148
- model.load_state_dict(new_state, strict=False)
149
-
150
- class GPT2Config:
151
- def __init__(self, vocab_size=50257, **kwargs):
152
- self.vocab_size = vocab_size
153
- self.__dict__.update(kwargs)
154
- @classmethod
155
- def from_dict(cls, d):
156
- return cls(**d)
157
-
158
- class MBartConfig:
159
- def __init__(self, vocab_size=50265, **kwargs):
160
- self.vocab_size = vocab_size
161
- self.__dict__.update(kwargs)
162
- @classmethod
163
- def from_dict(cls, d):
164
- return cls(**d)
165
-
166
- class CodeGenConfig:
167
- def __init__(self, vocab_size=50257, **kwargs):
168
- self.vocab_size = vocab_size
169
- self.__dict__.update(kwargs)
170
- @classmethod
171
- def from_dict(cls, d):
172
- return cls(**d)
173
-
174
- class BartConfig:
175
- def __init__(self, vocab_size=50265, **kwargs):
176
- self.vocab_size = vocab_size
177
- self.__dict__.update(kwargs)
178
- @classmethod
179
- def from_dict(cls, d):
180
- return cls(**d)
181
-
182
- class AutoencoderKLConfig:
183
- def __init__(self, **kwargs):
184
- self.__dict__.update(kwargs)
185
- @classmethod
186
- def from_dict(cls, d):
187
- return cls(**d)
188
-
189
- class OpenLRMConfig:
190
- def __init__(self, **kwargs):
191
- self.__dict__.update(kwargs)
192
- @classmethod
193
- def from_dict(cls, d):
194
- return cls(**d)
195
-
196
- class UNet2DConditionModelConfig:
197
- def __init__(self, **kwargs):
198
- self.__dict__.update(kwargs)
199
- @classmethod
200
- def from_dict(cls, d):
201
- return cls(**d)
202
-
203
- class MusicGenConfig:
204
- def __init__(self, **kwargs):
205
- self.__dict__.update(kwargs)
206
- @classmethod
207
- def from_dict(cls, d):
208
- return cls(**d)
209
-
210
- class GPT2LMHeadModel(nn.Module):
211
- def __init__(self, config):
212
- super().__init__()
213
- layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
214
- self.transformer = nn.TransformerEncoder(layer, num_layers=12)
215
- self.lm_head = nn.Linear(768, config.vocab_size)
216
- def forward(self, x):
217
- return self.lm_head(self.transformer(x))
218
-
219
- class MBartForConditionalGeneration(nn.Module):
220
- def __init__(self, config):
221
- super().__init__()
222
- self.config = config
223
- layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
224
- self.encoder = nn.TransformerEncoder(layer, num_layers=6)
225
- dlayer = nn.TransformerDecoderLayer(d_model=768, nhead=12)
226
- self.decoder = nn.TransformerDecoder(dlayer, num_layers=6)
227
- self.output_layer = nn.Linear(768, config.vocab_size)
228
- def forward(self, src, tgt):
229
- return self.output_layer(self.decoder(tgt, self.encoder(src)))
230
-
231
- class CodeGenForCausalLM(nn.Module):
232
- def __init__(self, config):
233
- super().__init__()
234
- d_model = getattr(config, "d_model", 1024)
235
- n_head = getattr(config, "n_head", 16)
236
- num_layers = getattr(config, "num_layers", 12)
237
- dlayer = nn.TransformerDecoderLayer(d_model=d_model, nhead=n_head)
238
- self.transformer_decoder = nn.TransformerDecoder(dlayer, num_layers=num_layers)
239
- self.lm_head = nn.Linear(d_model, config.vocab_size)
240
- def forward(self, tgt, memory=None):
241
- if memory is None:
242
- memory = torch.zeros_like(tgt)
243
- return self.lm_head(self.transformer_decoder(tgt, memory))
244
-
245
- class BartForConditionalGeneration(nn.Module):
246
- def __init__(self, config):
247
- super().__init__()
248
- layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
249
- self.encoder = nn.TransformerEncoder(layer, num_layers=6)
250
- dlayer = nn.TransformerDecoderLayer(d_model=768, nhead=12)
251
- self.decoder = nn.TransformerDecoder(dlayer, num_layers=6)
252
- self.output_layer = nn.Linear(768, config.vocab_size)
253
- def forward(self, src, tgt):
254
- return self.output_layer(self.decoder(tgt, self.encoder(src)))
255
-
256
- class ResnetBlock(nn.Module):
257
- def __init__(self, in_ch, out_ch):
258
- super().__init__()
259
- self.norm1 = nn.GroupNorm(32, in_ch)
260
- self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
261
- self.norm2 = nn.GroupNorm(32, out_ch)
262
- self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
263
- self.conv_shortcut = nn.Conv2d(in_ch, out_ch, 1)
264
- def forward(self, x):
265
- sc = self.conv_shortcut(x)
266
- h = F.silu(self.norm1(x))
267
- h = self.conv1(h)
268
- h = F.silu(self.norm2(x))
269
- h = self.conv2(h)
270
- return h + sc
271
-
272
- class Downsample(nn.Module):
273
- def __init__(self, in_ch, out_ch):
274
- super().__init__()
275
- self.conv = nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1)
276
- def forward(self, x):
277
- return self.conv(x)
278
-
279
- class DownBlock(nn.Module):
280
- def __init__(self, in_ch, out_ch, num_res):
281
- super().__init__()
282
- self.resnets = nn.ModuleList([ResnetBlock(in_ch if i == 0 else out_ch, out_ch) for i in range(num_res)])
283
- self.downsamplers = nn.ModuleList([Downsample(out_ch, out_ch)])
284
- def forward(self, x):
285
- for r in self.resnets:
286
- x = r(x)
287
- for ds in self.downsamplers:
288
- x = ds(x)
289
- return x
290
-
291
- class Upsample(nn.Module):
292
- def __init__(self, in_ch, out_ch):
293
- super().__init__()
294
- self.conv = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
295
- def forward(self, x):
296
- return self.conv(x)
297
-
298
- class UpBlock(nn.Module):
299
- def __init__(self, in_ch, out_ch, num_res):
300
- super().__init__()
301
- self.resnets = nn.ModuleList([ResnetBlock(in_ch if i == 0 else out_ch, out_ch) for i in range(num_res)])
302
- self.upsampler = Upsample(out_ch, out_ch)
303
- def forward(self, x):
304
- for r in self.resnets:
305
- x = r(x)
306
- return self.upsampler(x)
307
-
308
- class AttentionBlock(nn.Module):
309
- def __init__(self, ch):
310
- super().__init__()
311
- self.norm = nn.GroupNorm(32, ch)
312
- self.query = nn.Conv2d(ch, ch, 1)
313
- self.key = nn.Conv2d(ch, ch, 1)
314
- self.value = nn.Conv2d(ch, ch, 1)
315
- self.proj_attn = nn.Conv2d(ch, ch, 1)
316
- def forward(self, x):
317
- b, c, h, w = x.shape
318
- xn = self.norm(x)
319
- q = self.query(xn).view(b, c, -1).permute(0, 2, 1)
320
- k = self.key(xn).view(b, c, -1)
321
- v = self.value(xn).view(b, c, -1).permute(0, 2, 1)
322
- attn = torch.softmax(torch.bmm(q, k) / (c ** 0.5), dim=-1)
323
- out = torch.bmm(attn, v).permute(0, 2, 1).view(b, c, h, w)
324
- return x + self.proj_attn(out)
325
-
326
- class Encoder(nn.Module):
327
- def __init__(self, in_ch=3, base_ch=128, latent_ch=4):
328
- super().__init__()
329
- self.conv_in = nn.Conv2d(in_ch, base_ch, 3, padding=1)
330
- self.down_blocks = nn.ModuleList([
331
- DownBlock(base_ch, base_ch, 2),
332
- DownBlock(base_ch, base_ch * 2, 2),
333
- DownBlock(base_ch * 2, base_ch * 4, 2),
334
- DownBlock(base_ch * 4, base_ch * 4, 2)
335
- ])
336
- self.mid_block = nn.ModuleList([
337
- ResnetBlock(base_ch * 4, base_ch * 4),
338
- AttentionBlock(base_ch * 4),
339
- ResnetBlock(base_ch * 4, base_ch * 4)
340
- ])
341
- self.conv_norm_out = nn.GroupNorm(32, base_ch * 4)
342
- self.conv_out = nn.Conv2d(base_ch * 4, latent_ch * 2, 3, padding=1)
343
- self.quant_conv = nn.Conv2d(latent_ch * 2, latent_ch, 1)
344
- def forward(self, x):
345
- x = self.conv_in(x)
346
- for blk in self.down_blocks:
347
- x = blk(x)
348
- for m in self.mid_block:
349
- x = m(x)
350
- x = self.conv_norm_out(x)
351
- x = self.conv_out(x)
352
- return self.quant_conv(x)
353
-
354
- class Decoder(nn.Module):
355
- def __init__(self, out_ch=3, base_ch=128, latent_ch=4):
356
- super().__init__()
357
- self.post_quant_conv = nn.Conv2d(latent_ch, latent_ch * 2, 1)
358
- self.conv_in = nn.Conv2d(latent_ch, base_ch * 4, 3, padding=1)
359
- self.mid_block = nn.ModuleList([
360
- ResnetBlock(base_ch * 4, base_ch * 4),
361
- AttentionBlock(base_ch * 4),
362
- ResnetBlock(base_ch * 4, base_ch * 4)
363
- ])
364
- self.up_blocks = nn.ModuleList([
365
- UpBlock(base_ch * 4, base_ch * 4, 3),
366
- UpBlock(base_ch * 4, base_ch * 2, 3),
367
- UpBlock(base_ch * 2, base_ch, 3),
368
- UpBlock(base_ch, base_ch, 3)
369
- ])
370
- self.conv_norm_out = nn.GroupNorm(32, base_ch)
371
- self.conv_out = nn.Conv2d(base_ch, out_ch, 3, padding=1)
372
- def forward(self, x):
373
- x = self.post_quant_conv(x)
374
- x = self.conv_in(x)
375
- for m in self.mid_block:
376
- x = m(x)
377
- for up in self.up_blocks:
378
- x = up(x)
379
- x = self.conv_norm_out(x)
380
- return self.conv_out(x)
381
-
382
- class AutoencoderKL(nn.Module):
383
- def __init__(self, config):
384
- super().__init__()
385
- in_ch = config.get("in_channels", 3) if isinstance(config, dict) else config.__dict__.get("in_channels", 3)
386
- out_ch = config.get("out_channels", 3) if isinstance(config, dict) else config.__dict__.get("out_channels", 3)
387
- base_ch = config.get("base_channels", 128) if isinstance(config, dict) else config.__dict__.get("base_channels", 128)
388
- latent_ch = config.get("latent_channels", 4) if isinstance(config, dict) else config.__dict__.get("latent_channels", 4)
389
- self.encoder = Encoder(in_ch, base_ch, latent_ch)
390
- self.decoder = Decoder(out_ch, base_ch, latent_ch)
391
- def forward(self, x):
392
- return self.decoder(self.encoder(x))
393
- def decode(self, x):
394
- return self.decoder(x)
395
-
396
- class TransformerBlock(nn.Module):
397
- def __init__(self, embed_dim, num_heads):
398
- super().__init__()
399
- self.norm1 = nn.LayerNorm(embed_dim)
400
- self.attn = nn.MultiheadAttention(embed_dim, num_heads)
401
- self.norm2 = nn.LayerNorm(embed_dim)
402
- hidden_dim = embed_dim * 4
403
- self.mlp = nn.Sequential(
404
- nn.Linear(embed_dim, hidden_dim),
405
- nn.GELU(),
406
- nn.Linear(hidden_dim, embed_dim)
407
- )
408
- def forward(self, x):
409
- res = x
410
- x = self.norm1(x)
411
- x = x.transpose(0, 1)
412
- attn, _ = self.attn(x, x, x)
413
- x = attn.transpose(0, 1)
414
- x = res + x
415
- return x + self.mlp(self.norm2(x))
416
-
417
- class VisionTransformer(nn.Module):
418
- def __init__(self, config):
419
- super().__init__()
420
- if isinstance(config, dict):
421
- self.img_size = config.get("img_size", 592)
422
- self.patch_size = config.get("patch_size", 16)
423
- self.embed_dim = config.get("hidden_size", 768)
424
- depth = config.get("depth", 12)
425
- num_heads = config.get("num_heads", 12)
426
- else:
427
- self.img_size = config.__dict__.get("img_size", 592)
428
- self.patch_size = config.__dict__.get("patch_size", 16)
429
- self.embed_dim = config.__dict__.get("hidden_size", 768)
430
- depth = config.__dict__.get("depth", 12)
431
- num_heads = config.__dict__.get("num_heads", 12)
432
- num_patches = (self.img_size // self.patch_size) ** 2
433
- self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
434
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim))
435
- self.patch_embed = nn.Conv2d(3, self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size)
436
- self.blocks = nn.ModuleList([TransformerBlock(self.embed_dim, num_heads) for _ in range(depth)])
437
- self.norm = nn.LayerNorm(self.embed_dim)
438
- self.register_tokens = nn.Parameter(torch.zeros(1, 4, self.embed_dim))
439
- self._init_weights()
440
- def _init_weights(self):
441
- nn.init.normal_(self.cls_token, std=0.02)
442
- nn.init.normal_(self.pos_embed, std=0.02)
443
- def forward(self, x):
444
- x = self.patch_embed(x)
445
- x = x.flatten(2).transpose(1, 2)
446
- cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
447
- x = torch.cat((cls_tokens, x), dim=1)
448
- x = x + self.pos_embed
449
- for blk in self.blocks:
450
- x = blk(x)
451
- return self.norm(x)[:, 0]
452
-
453
- class OpenLRM(nn.Module):
454
- def __init__(self, config):
455
- super().__init__()
456
- self.encoder = nn.ModuleDict({"model": VisionTransformer(config)})
457
- hidden = config.get("hidden_size", 768) if isinstance(config, dict) else config.__dict__.get("hidden_size", 768)
458
- self.linear = nn.Linear(hidden, hidden)
459
- def forward(self, x):
460
- return self.linear(self.encoder["model"](x))
461
-
462
- class VideoUNet(nn.Module):
463
- def __init__(self, in_ch=4, out_ch=4, features=None):
464
- super().__init__()
465
- if features is None:
466
- features = [64, 128, 256]
467
- self.encoder = nn.ModuleList()
468
- self.pool = nn.MaxPool3d(2, 2)
469
- self.decoder = nn.ModuleList()
470
- for f in features:
471
- self.encoder.append(nn.Sequential(
472
- nn.Conv3d(in_ch, f, 3, padding=1),
473
- nn.ReLU(inplace=True),
474
- nn.Conv3d(f, f, 3, padding=1),
475
- nn.ReLU(inplace=True)
476
- ))
477
- in_ch = f
478
- for f in reversed(features):
479
- self.decoder.append(nn.Sequential(
480
- nn.Conv3d(f * 2, f, 3, padding=1),
481
- nn.ReLU(inplace=True),
482
- nn.Conv3d(f, f, 3, padding=1),
483
- nn.ReLU(inplace=True)
484
- ))
485
- self.final_conv = nn.Conv3d(features[0], out_ch, 1)
486
- def forward(self, x, t, encoder_hidden_states):
487
- skips = []
488
- for enc in self.encoder:
489
- x = enc(x)
490
- skips.append(x)
491
- x = self.pool(x)
492
- for dec in self.decoder:
493
- skip = skips.pop()
494
- x = F.interpolate(x, scale_factor=2, mode='trilinear', align_corners=False)
495
- x = torch.cat([x, skip], dim=1)
496
- x = dec(x)
497
- return self.final_conv(x)
498
-
499
- class SentimentClassifierModel(nn.Module):
500
- def __init__(self, config):
501
- super().__init__()
502
- self.classifier = nn.Sequential(
503
- nn.Linear(768, 256),
504
- nn.ReLU(),
505
- nn.Linear(256, 2)
506
- )
507
- def forward(self, x):
508
- return self.classifier(x)
509
-
510
- class STTModel(nn.Module):
511
- def __init__(self, config):
512
- super().__init__()
513
- self.net = nn.Sequential(
514
- nn.Linear(768, 512),
515
- nn.ReLU(),
516
- nn.Linear(512, 768)
517
- )
518
- def forward(self, x):
519
- return self.net(x)
520
-
521
- class TTSModel(nn.Module):
522
- def __init__(self, config):
523
- super().__init__()
524
- self.net = nn.Sequential(
525
- nn.Linear(768, 512),
526
- nn.ReLU(),
527
- nn.Linear(512, 768)
528
- )
529
- def forward(self, x):
530
- return self.net(x)
531
-
532
- class MusicGenModel(nn.Module):
533
- def __init__(self, config):
534
- super().__init__()
535
- layer = nn.TransformerEncoderLayer(d_model=768, nhead=12)
536
- self.transformer = nn.TransformerEncoder(layer, num_layers=12)
537
- self.linear = nn.Linear(768, 768)
538
- def forward(self, x):
539
- return self.linear(self.transformer(x))
540
-
541
- class SimpleTextEncoder(nn.Module):
542
- def __init__(self, vocab_size=10000, embed_dim=768, max_length=77):
543
- super().__init__()
544
- self.embedding = nn.Embedding(vocab_size, embed_dim)
545
- self.max_length = max_length
546
- def forward(self, text_tokens):
547
- return self.embedding(text_tokens)
548
-
549
- class DiffusionScheduler:
550
- def __init__(self, steps):
551
- self.steps = steps
552
- self.betas = torch.linspace(0.1, 0.001, steps=steps).to(device)
553
- self.alphas = 1 - self.betas
554
- self.alpha_bars = torch.cumprod(self.alphas, dim=0)
555
- def step(self, noise, t, sample):
556
- alpha_bar = self.alpha_bars[t]
557
- alpha_bar_prev = self.alpha_bars[t-1] if t > 0 else torch.tensor(1.0, device=sample.device)
558
- x0 = (sample - torch.sqrt(1 - alpha_bar) * noise) / torch.sqrt(alpha_bar)
559
- new_sample = torch.sqrt(alpha_bar_prev) * x0 + torch.sqrt(1 - alpha_bar_prev) * noise
560
- return new_sample
561
-
562
- class VideoOutput:
563
- def __init__(self, frames):
564
- self.frames = [img_as_ubyte(frame) for frame in frames[0]]
565
-
566
- class VideoPipeline(nn.Module):
567
- def __init__(self, unet, vae, text_encoder, vocab):
568
- super().__init__()
569
- self.unet = unet
570
- self.vae = vae
571
- self.text_encoder = text_encoder
572
- self.vocab = vocab
573
- def forward(self, prompt: str, steps: int = 25, num_frames: int = 24):
574
- token_ids = simple_tokenizer(prompt, self.vocab)
575
- text_emb = self.text_encoder(token_ids)
576
- latent = torch.randn((1, 4, num_frames, 64, 64), device=device).half()
577
- sched = DiffusionScheduler(steps)
578
- for t in range(steps):
579
- noise = self.unet(latent, t, text_emb)
580
- latent = sched.step(noise, t, latent)
581
- frames = self.vae.decode(latent / 0.18215)
582
- frames = frames.clamp(0, 1).float().cpu().permute(0, 2, 3, 4, 1).numpy()
583
- return VideoOutput(frames)
584
-
585
- def initialize_gpt2_model(folder, files):
586
- download_files(folder, files)
587
- config = GPT2Config()
588
- model = GPT2LMHeadModel(config).to(device)
589
- sd = torch.load(os.path.join(folder, sanitize_filename("gpt2-pytorch_model.bin")), map_location=device)
590
- load_state_dict_safe(model, sd)
591
- model.eval()
592
- enc = read_json(os.path.join(folder, sanitize_filename("encoder.json")))
593
- return model, enc
594
-
595
- def initialize_translation_model(folder, files):
596
- download_files(folder, files)
597
- config = MBartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
598
- model = MBartForConditionalGeneration(config).to(device)
599
- sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
600
- load_state_dict_safe(model, sd)
601
- model.eval()
602
- vp = os.path.join(folder, "vocab.json")
603
- if os.path.exists(vp):
604
- vocab = read_json(vp)
605
- model.tokenizer = lambda txt: [vocab.get(t, 0) for t in txt.split()]
606
- else:
607
- model.tokenizer = lambda txt: txt
608
- model.config.lang_code_to_id = {'en_XX': 0, 'es_XX': 1}
609
- return model
610
-
611
- def initialize_codegen_model(folder, files):
612
- download_files(folder, files)
613
- config = CodeGenConfig.from_dict(read_json(os.path.join(folder, "config.json")))
614
- model = CodeGenForCausalLM(config).to(device)
615
- sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
616
- load_state_dict_safe(model, sd)
617
- model.eval()
618
- tok = get_codegen_tokenizer(os.path.join(folder, "vocab.json"), os.path.join(folder, "merges.txt"))
619
- vocab = read_json(os.path.join(folder, "vocab.json"))
620
- idx2w = {v: k for k, v in vocab.items()}
621
- model.tokenizer = tok
622
- return model, tok, vocab, idx2w, vocab
623
-
624
- def initialize_summarization_model(folder, files):
625
- download_files(folder, files)
626
- config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
627
- model = BartForConditionalGeneration(config).to(device)
628
- sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
629
- load_state_dict_safe(model, sd)
630
- model.eval()
631
- vp = os.path.join(folder, "vocab.json")
632
- if os.path.exists(vp):
633
- vocab_json = read_json(vp)
634
- vocab = set(vocab_json.keys())
635
- return model, vocab, vocab_json, {v: k for k, v in vocab_json.items()}
636
- return model, None, None, None
637
-
638
- def initialize_imagegen_model(folder, files):
639
- download_files(folder, files)
640
- config = AutoencoderKLConfig.from_dict(read_json(os.path.join(folder, "config.json")))
641
- vae = AutoencoderKL(config).to(device)
642
- sd = torch.load(os.path.join(folder, "diffusion_pytorch_model.bin"), map_location=device)
643
- load_state_dict_safe(vae, sd)
644
- vae.eval()
645
- return vae
646
-
647
- def initialize_image_to_3d_model(folder, files):
648
- download_files(folder, files)
649
- config = OpenLRMConfig.from_dict(read_json(os.path.join(folder, "config.json")))
650
- model3d = OpenLRM(config).to(device)
651
- sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
652
- load_state_dict_safe(model3d, sd)
653
- model3d.eval()
654
- return model3d
655
-
656
- def initialize_text_to_video_model(folder, files):
657
- download_files(folder, files)
658
- unet_cfg = read_json(os.path.join(folder, "config.json"))
659
- unet_cfg = filter_kwargs(VideoUNet, unet_cfg)
660
- unet = VideoUNet(**unet_cfg).half().to(device)
661
- sd_unet = torch.load(os.path.join(folder, "diffusion_pytorch_model.fp16.bin"), map_location=device)
662
- load_state_dict_safe(unet, sd_unet)
663
- unet.eval()
664
- vae_cfg = read_json(os.path.join(folder, "config.json"))
665
- vae_cfg = filter_kwargs(AutoencoderKL, vae_cfg)
666
- vae = AutoencoderKL(vae_cfg).half().to(device)
667
- sd_vae = torch.load(os.path.join(folder, "diffusion_pytorch_model.bin"), map_location=device)
668
- load_state_dict_safe(vae, sd_vae)
669
- vae.eval()
670
- vp = os.path.join(folder, "vocab.json")
671
- text_vocab = read_json(vp) if os.path.exists(vp) else {}
672
- te_path = os.path.join(folder, "text_encoder.bin")
673
- if os.path.exists(te_path):
674
- text_encoder = SimpleTextEncoder(vocab_size=(max(text_vocab.values())+1) if text_vocab else 10000, embed_dim=768, max_length=77).to(device)
675
- sd_te = torch.load(te_path, map_location=device)
676
- load_state_dict_safe(text_encoder, sd_te)
677
- else:
678
- text_encoder = SimpleTextEncoder(vocab_size=(max(text_vocab.values())+1) if text_vocab else 10000, embed_dim=768, max_length=77).to(device)
679
- text_encoder.eval()
680
- return VideoPipeline(unet, vae, text_encoder, text_vocab)
681
-
682
- def initialize_sentiment_model(folder, files):
683
- download_files(folder, files)
684
- config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
685
- model = SentimentClassifierModel(config).to(device)
686
- sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
687
- load_state_dict_safe(model, sd)
688
- model.eval()
689
- vp = os.path.join(folder, "vocab.json")
690
- if os.path.exists(vp):
691
- read_json(vp)
692
- return model
693
-
694
- def initialize_stt_model(folder, files):
695
- download_files(folder, files)
696
- config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
697
- model = STTModel(config).to(device)
698
- sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
699
- load_state_dict_safe(model, sd)
700
- model.eval()
701
- vp = os.path.join(folder, "vocab.json")
702
- if os.path.exists(vp):
703
- read_json(vp)
704
- return model
705
-
706
- def initialize_tts_model(folder, files):
707
- download_files(folder, files)
708
- config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
709
- model = TTSModel(config).to(device)
710
- sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
711
- load_state_dict_safe(model, sd)
712
- model.eval()
713
- vp = os.path.join(folder, "vocab.json")
714
- if os.path.exists(vp):
715
- read_json(vp)
716
- return model
717
-
718
- def initialize_musicgen_model(folder, files):
719
- download_files(folder, files)
720
- config = MusicGenConfig.from_dict(read_json(os.path.join(folder, "config.json")))
721
- model = MusicGenModel(config).to(device)
722
- sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device)
723
- load_state_dict_safe(model, sd)
724
- model.eval()
725
- return model
 
1
+ from tokenxxx import *
2
+ from constants import *
3
+ from utils import *
4
+ import os, json, urllib.request, urllib.parse, torch, hashlib, inspect
5
+ from tqdm import tqdm
6
+ from TTS.config import load_config
7
+ from TTS.tts.models.xtts import Xtts
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ def filter_kwargs(cls, kwargs): sig = inspect.signature(cls.__init__); accepted = set(sig.parameters.keys()) - {"self"}; return {k: v for k, v in kwargs.items() if k in accepted}
11
+ def sanitize_filename(name, url=None): for c in '<>:"/\\|?*': name = name.replace(c, ''); if not name and url is not None: name = hashlib.md5(url.encode()).hexdigest(); return name
12
+ def download_file(url, filepath): d = os.path.dirname(filepath); if d and not os.path.exists(d): os.makedirs(d, exist_ok=True)
13
+ while not os.path.exists(filepath):
14
+ try:
15
+ def prog(t): last = [0]; def inner(n, bs, ts): if ts > 0: t.total = ts; t.update(n * bs - last[0]); last[0] = n * bs; return inner
16
+ with tqdm(unit='B', unit_scale=True, unit_divisor=1024, desc=os.path.basename(filepath)) as t: urllib.request.urlretrieve(url, filepath, reporthook=prog(t))
17
+ except: continue
18
+ def download_files(folder, files_spec):
19
+ if isinstance(files_spec, dict): for fn, url in files_spec.items(): fn = sanitize_filename(fn, url); fp = os.path.join(folder, fn); download_file(url, fp)
20
+ elif isinstance(files_spec, list):
21
+ for item in files_spec:
22
+ if isinstance(item, str): url = item; parsed = urllib.parse.urlparse(url); fn = os.path.basename(parsed.path); if not fn: fn = hashlib.md5(url.encode()).hexdigest(); fn = sanitize_filename(fn, url)
23
+ elif isinstance(item, (list, tuple)) and len(item) == 2: url, fn = item; fn = sanitize_filename(fn, url)
24
+ elif isinstance(item, dict) and "filename" in item and "url" in item: fn = sanitize_filename(item["filename"], item["url"]); url = item["url"]
25
+ else: raise ValueError("Invalid file specification")
26
+ fp = os.path.join(folder, fn); download_file(url, fp)
27
+ else: raise ValueError("files_spec must be dict or list")
28
+
29
+ def read_json(fp): with open(fp, 'r', encoding='utf-8') as f: return json.load(f)
30
+ def get_codegen_tokenizer(vocab_path, merges_path):
31
+ with open(vocab_path, 'r', encoding='utf-8') as f: vocab = json.load(f)
32
+ with open(merges_path, 'r', encoding='utf-8') as f: merges = f.read().splitlines()
33
+ merge_ranks = {};
34
+ for i, merge in enumerate(merges): parts = merge.strip().split(); if len(parts) == 2: merge_ranks[tuple(parts)] = i
35
+ def bpe(token):
36
+ word = list(token); pairs = [(word[i], word[i+1]) for i in range(len(word)-1)]
37
+ while True: candidate = None; candidate_rank = None; candidate_index = None
38
+ for i, pair in enumerate(pairs): if pair in merge_ranks: rank = merge_ranks[pair]; if candidate is None or rank < candidate_rank: candidate = pair; candidate_rank = rank; candidate_index = i
39
+ if candidate is None: break
40
+ first, second = candidate; new_word = []; i = 0
41
+ while i < len(word):
42
+ try: j = word.index(first, i); new_word.extend(word[i:j]); i = j
43
+ except ValueError: new_word.extend(word[i:]); break
44
+ if word[i] == first and i < len(word)-1 and word[i+1] == second: new_word.append(first+second); i += 2
45
+ else: new_word.append(word[i]); i += 1
46
+ word = new_word;
47
+ if len(word) == 1: break
48
+ pairs = [(word[i], word[i+1]) for i in range(len(word)-1)]
49
+ return word
50
+ def tokenizer(text): tokens = []; for token in text.split(): bpe_tokens = bpe(token); for subtoken in bpe_tokens: tokens.append(vocab.get(subtoken, 0)); return tokens
51
+ return tokenizer
52
+
53
+ def simple_tokenizer(text, vocab, max_length=77): toks = text.split(); ids = [vocab.get(t, 1) for t in toks]; if len(ids) < max_length: ids = ids + [0] * (max_length - len(ids))
54
+ else: ids = ids[:max_length]; return torch.tensor(ids, dtype=torch.long).unsqueeze(0).to(device)
55
+ def load_state_dict_safe(model, loaded_state_dict): model_state = model.state_dict(); new_state = {}; for key, value in model_state.items(): if key in loaded_state_dict and loaded_state_dict[key].shape == value.shape: new_state[key] = loaded_state_dict[key]
56
+ else: new_state[key] = value; model.load_state_dict(new_state, strict=False)
57
+
58
+ class GPT2Config: def __init__(self, vocab_size=50257, **kwargs): self.vocab_size = vocab_size; self.__dict__.update(kwargs)
59
+ @classmethod
60
+ def from_dict(cls, d): return cls(**d)
61
+ class MBartConfig: def __init__(self, vocab_size=50265, **kwargs): self.vocab_size = vocab_size; self.__dict__.update(kwargs)
62
+ @classmethod
63
+ def from_dict(cls, d): return cls(**d)
64
+ class CodeGenConfig: def __init__(self, vocab_size=50257, **kwargs): self.vocab_size = vocab_size; self.__dict__.update(kwargs)
65
+ @classmethod
66
+ def from_dict(cls, d): return cls(**d)
67
+ class BartConfig: def __init__(self, vocab_size=50265, **kwargs): self.vocab_size = vocab_size; self.__dict__.update(kwargs)
68
+ @classmethod
69
+ def from_dict(cls, d): return cls(**d)
70
+ class AutoencoderKLConfig:
71
+ def __init__(self, **kwargs): self.__dict__.update(kwargs)
72
+ @classmethod
73
+ def from_dict(cls, d): return cls(**d)
74
+ class OpenLRMConfig:
75
+ def __init__(self, **kwargs): self.__dict__.update(kwargs)
76
+ @classmethod
77
+ def from_dict(cls, d): return cls(**d)
78
+ class UNet2DConditionModelConfig:
79
+ def __init__(self, **kwargs): self.__dict__.update(kwargs)
80
+ @classmethod
81
+ def from_dict(cls, d): return cls(**d)
82
+ class MusicGenConfig:
83
+ def __init__(self, **kwargs): self.__dict__.update(kwargs)
84
+ @classmethod
85
+ def from_dict(cls, d): return cls(**d)
86
+ class XTTSConfig:
87
+ def __init__(self, **kwargs): self.__dict__.update(kwargs)
88
+ @classmethod
89
+ def from_dict(cls, d): return cls(**d)
90
+
91
+ class GPT2LMHeadModel(nn.Module): def __init__(self, config): super().__init__(); layer = nn.TransformerEncoderLayer(d_model=768, nhead=12); self.transformer = nn.TransformerEncoder(layer, num_layers=12); self.lm_head = nn.Linear(768, config.vocab_size)
92
+ def forward(self, x): return self.lm_head(self.transformer(x))
93
+ class MBartForConditionalGeneration(nn.Module): def __init__(self, config): super().__init__(); self.config = config; layer = nn.TransformerEncoderLayer(d_model=768, nhead=12); self.encoder = nn.TransformerEncoder(layer, num_layers=6)
94
+ dlayer = nn.TransformerDecoderLayer(d_model=768, nhead=12); self.decoder = nn.TransformerDecoder(dlayer, num_layers=6); self.output_layer = nn.Linear(768, config.vocab_size)
95
+ def forward(self, src, tgt): return self.output_layer(self.decoder(tgt, self.encoder(src)))
96
+ class CodeGenForCausalLM(nn.Module): def __init__(self, config): super().__init__(); d_model = getattr(config, "d_model", 1024); n_head = getattr(config, "n_head", 16)
97
+ num_layers = getattr(config, "num_layers", 12); dlayer = nn.TransformerDecoderLayer(d_model=d_model, nhead=n_head); self.transformer_decoder = nn.TransformerDecoder(dlayer, num_layers=num_layers); self.lm_head = nn.Linear(d_model, config.vocab_size)
98
+ def forward(self, tgt, memory=None): if memory is None: memory = torch.zeros_like(tgt); return self.lm_head(self.transformer_decoder(tgt, memory))
99
+ class BartForConditionalGeneration(nn.Module): def __init__(self, config): super().__init__(); layer = nn.TransformerEncoderLayer(d_model=768, nhead=12); self.encoder = nn.TransformerEncoder(layer, num_layers=6)
100
+ dlayer = nn.TransformerDecoderLayer(d_model=768, nhead=12); self.decoder = nn.TransformerDecoder(dlayer, num_layers=6); self.output_layer = nn.Linear(768, config.vocab_size)
101
+ def forward(self, src, tgt): return self.output_layer(self.decoder(tgt, self.encoder(src)))
102
+ class ResnetBlock(nn.Module): def __init__(self, in_ch, out_ch): super().__init__(); self.norm1 = nn.GroupNorm(32, in_ch); self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1); self.norm2 = nn.GroupNorm(32, out_ch); self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1); self.conv_shortcut = nn.Conv2d(in_ch, out_ch, 1)
103
+ def forward(self, x): sc = self.conv_shortcut(x); h = F.silu(self.norm1(x)); h = self.conv1(h); h = F.silu(self.norm2(x)); h = self.conv2(h); return h + sc
104
+ class Downsample(nn.Module): def __init__(self, in_ch, out_ch): super().__init__(); self.conv = nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1)
105
+ def forward(self, x): return self.conv(x)
106
+ class DownBlock(nn.Module): def __init__(self, in_ch, out_ch, num_res): super().__init__(); self.resnets = nn.ModuleList([ResnetBlock(in_ch if i == 0 else out_ch, out_ch) for i in range(num_res)]); self.downsamplers = nn.ModuleList([Downsample(out_ch, out_ch)])
107
+ def forward(self, x): for r in self.resnets: x = r(x); for ds in self.downsamplers: x = ds(x); return x
108
+ class Upsample(nn.Module): def __init__(self, in_ch, out_ch): super().__init__(); self.conv = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
109
+ def forward(self, x): return self.conv(x)
110
+ class UpBlock(nn.Module): def __init__(self, in_ch, out_ch, num_res): super().__init__(); self.resnets = nn.ModuleList([ResnetBlock(in_ch if i == 0 else out_ch, out_ch) for i in range(num_res)]); self.upsampler = Upsample(out_ch, out_ch)
111
+ def forward(self, x): for r in self.resnets: x = r(x); return self.upsampler(x)
112
+ class AttentionBlock(nn.Module): def __init__(self, ch): super().__init__(); self.norm = nn.GroupNorm(32, ch); self.query = nn.Conv2d(ch, ch, 1); self.key = nn.Conv2d(ch, ch, 1); self.value = nn.Conv2d(ch, ch, 1); self.proj_attn = nn.Conv2d(ch, ch, 1)
113
+ def forward(self, x): b, c, h, w = x.shape; xn = self.norm(x); q = self.query(xn).view(b, c, -1).permute(0, 2, 1); k = self.key(xn).view(b, c, -1); v = self.value(xn).view(b, c, -1).permute(0, 2, 1)
114
+ attn = torch.softmax(torch.bmm(q, k) / (c ** 0.5), dim=-1); out = torch.bmm(attn, v).permute(0, 2, 1).view(b, c, h, w); return x + self.proj_attn(out)
115
+ class Encoder(nn.Module): def __init__(self, in_ch=3, base_ch=128, latent_ch=4): super().__init__(); self.conv_in = nn.Conv2d(in_ch, base_ch, 3, padding=1); self.down_blocks = nn.ModuleList([DownBlock(base_ch, base_ch, 2), DownBlock(base_ch, base_ch * 2, 2), DownBlock(base_ch * 2, base_ch * 4, 2), DownBlock(base_ch * 4, base_ch * 4, 2)]); self.mid_block = nn.ModuleList([ResnetBlock(base_ch * 4, base_ch * 4), AttentionBlock(base_ch * 4), ResnetBlock(base_ch * 4, base_ch * 4)]); self.conv_norm_out = nn.GroupNorm(32, base_ch * 4); self.conv_out = nn.Conv2d(base_ch * 4, latent_ch * 2, 3, padding=1); self.quant_conv = nn.Conv2d(latent_ch * 2, latent_ch, 1)
116
+ def forward(self, x): x = self.conv_in(x); for blk in self.down_blocks: x = blk(x); for m in self.mid_block: x = m(x); x = self.conv_norm_out(x); x = self.conv_out(x); return self.quant_conv(x)
117
+ class Decoder(nn.Module): def __init__(self, out_ch=3, base_ch=128, latent_ch=4): super().__init__(); self.post_quant_conv = nn.Conv2d(latent_ch, latent_ch * 2, 1); self.conv_in = nn.Conv2d(latent_ch, base_ch * 4, 3, padding=1); self.mid_block = nn.ModuleList([ResnetBlock(base_ch * 4, base_ch * 4), AttentionBlock(base_ch * 4), ResnetBlock(base_ch * 4, base_ch * 4)]); self.up_blocks = nn.ModuleList([UpBlock(base_ch * 4, base_ch * 4, 3), UpBlock(base_ch * 4, base_ch * 2, 3), UpBlock(base_ch * 2, base_ch, 3), UpBlock(base_ch, base_ch, 3)]); self.conv_norm_out = nn.GroupNorm(32, base_ch); self.conv_out = nn.Conv2d(base_ch, out_ch, 3, padding=1)
118
+ def forward(self, x): x = self.post_quant_conv(x); x = self.conv_in(x); for m in self.mid_block: x = m(x); for up in self.up_blocks: x = up(x); x = self.conv_norm_out(x); return self.conv_out(x)
119
+ class AutoencoderKL(nn.Module): def __init__(self, config): super().__init__(); in_ch = config.get("in_channels", 3) if isinstance(config, dict) else config.__dict__.get("in_channels", 3)
120
+ out_ch = config.get("out_channels", 3) if isinstance(config, dict) else config.__dict__.get("out_channels", 3); base_ch = config.get("base_channels", 128) if isinstance(config, dict) else config.__dict__.get("base_channels", 128)
121
+ latent_ch = config.get("latent_channels", 4) if isinstance(config, dict) else config.__dict__.get("latent_channels", 4); self.encoder = Encoder(in_ch, base_ch, latent_ch); self.decoder = Decoder(out_ch, base_ch, latent_ch)
122
+ def forward(self, x): return self.decoder(self.encoder(x))
123
+ def decode(self, x): return self.decoder(x)
124
+ class TransformerBlock(nn.Module): def __init__(self, embed_dim, num_heads): super().__init__(); self.norm1 = nn.LayerNorm(embed_dim); self.attn = nn.MultiheadAttention(embed_dim, num_heads); self.norm2 = nn.LayerNorm(embed_dim)
125
+ hidden_dim = embed_dim * 4; self.mlp = nn.Sequential(nn.Linear(embed_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, embed_dim))
126
+ def forward(self, x): res = x; x = self.norm1(x); x = x.transpose(0, 1); attn, _ = self.attn(x, x, x); x = attn.transpose(0, 1); x = res + x; return x + self.mlp(self.norm2(x))
127
+ class VisionTransformer(nn.Module): def __init__(self, config): super().__init__(); self.img_size = config.get("img_size", 592)
128
+ self.patch_size = config.get("patch_size", 16); self.embed_dim = config.get("hidden_size", 768); depth = config.get("depth", 12); num_heads = config.get("num_heads", 12)
129
+ num_patches = (self.img_size // self.patch_size) ** 2; self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)); self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, self.embed_dim))
130
+ self.patch_embed = nn.Conv2d(3, self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size); self.blocks = nn.ModuleList([TransformerBlock(self.embed_dim, num_heads) for _ in range(depth)]); self.norm = nn.LayerNorm(self.embed_dim)
131
+ self.register_tokens = nn.Parameter(torch.zeros(1, 4, self.embed_dim)); self._init_weights()
132
+ def _init_weights(self): nn.init.normal_(self.cls_token, std=0.02); nn.init.normal_(self.pos_embed, std=0.02)
133
+ def forward(self, x): x = self.patch_embed(x); x = x.flatten(2).transpose(1, 2); cls_tokens = self.cls_token.expand(x.shape[0], -1, -1); x = torch.cat((cls_tokens, x), dim=1); x = x + self.pos_embed
134
+ for blk in self.blocks: x = blk(x); return self.norm(x)[:, 0]
135
+ class OpenLRM(nn.Module): def __init__(self, config): super().__init__(); self.encoder = nn.ModuleDict({"model": VisionTransformer(config)}); hidden = config.get("hidden_size", 768) if isinstance(config, dict) else config.__dict__.get("hidden_size", 768); self.linear = nn.Linear(hidden, hidden)
136
+ def forward(self, x): return self.linear(self.encoder["model"](x))
137
+ class VideoUNet(nn.Module): def __init__(self, in_ch=4, out_ch=4, features=None): super().__init__();
138
+ if features is None: features = [64, 128, 256]
139
+ self.encoder = nn.ModuleList(); self.pool = nn.MaxPool3d(2, 2); self.decoder = nn.ModuleList()
140
+ for f in features: self.encoder.append(nn.Sequential(nn.Conv3d(in_ch, f, 3, padding=1), nn.ReLU(inplace=True), nn.Conv3d(f, f, 3, padding=1), nn.ReLU(inplace=True))); in_ch = f
141
+ for f in reversed(features): self.decoder.append(nn.Sequential(nn.Conv3d(f * 2, f, 3, padding=1), nn.ReLU(inplace=True), nn.Conv3d(f, f, 3, padding=1), nn.ReLU(inplace=True)))
142
+ self.final_conv = nn.Conv3d(features[0], out_ch, 1)
143
+ def forward(self, x, t, encoder_hidden_states): skips = []; for enc in self.encoder: x = enc(x); skips.append(x); x = self.pool(x)
144
+ for dec in self.decoder: skip = skips.pop(); x = F.interpolate(x, scale_factor=2, mode='trilinear', align_corners=False); x = torch.cat([x, skip], dim=1); x = dec(x)
145
+ return self.final_conv(x)
146
+ class SentimentClassifierModel(nn.Module): def __init__(self, config): super().__init__(); self.classifier = nn.Sequential(nn.Linear(768, 256), nn.ReLU(), nn.Linear(256, 2))
147
+ def forward(self, x): return self.classifier(x)
148
+ class STTModel(nn.Module): def __init__(self, config): super().__init__(); self.net = nn.Sequential(nn.Linear(768, 512), nn.ReLU(), nn.Linear(512, 768))
149
+ def forward(self, x): return self.net(x)
150
+ class TTSModel(nn.Module): def __init__(self, config): super().__init__(); self.net = nn.Sequential(nn.Linear(768, 512), nn.ReLU(), nn.Linear(512, 768))
151
+ def forward(self, x): return self.net(x)
152
+ class MusicGenModel(nn.Module): def __init__(self, config): super().__init__(); layer = nn.TransformerEncoderLayer(d_model=768, nhead=12); self.transformer = nn.TransformerEncoder(layer, num_layers=12); self.linear = nn.Linear(768, 768)
153
+ def forward(self, x): return self.linear(self.transformer(x))
154
+ class SimpleTextEncoder(nn.Module): def __init__(self, vocab_size=10000, embed_dim=768, max_length=77): super().__init__(); self.embedding = nn.Embedding(vocab_size, embed_dim); self.max_length = max_length
155
+ def forward(self, text_tokens): return self.embedding(text_tokens)
156
+ class DiffusionScheduler: def __init__(self, steps): self.steps = steps; self.betas = torch.linspace(0.1, 0.001, steps=steps).to(device); self.alphas = 1 - self.betas; self.alpha_bars = torch.cumprod(self.alphas, dim=0)
157
+ def step(self, noise, t, sample): alpha_bar = self.alpha_bars[t]; alpha_bar_prev = self.alpha_bars[t-1] if t > 0 else torch.tensor(1.0, device=sample.device)
158
+ x0 = (sample - torch.sqrt(1 - alpha_bar) * noise) / torch.sqrt(alpha_bar); new_sample = torch.sqrt(alpha_bar_prev) * x0 + torch.sqrt(1 - alpha_bar_prev) * noise; return new_sample
159
+ class VideoOutput: def __init__(self, frames): self.frames = [img_as_ubyte(frame) for frame in frames[0]]
160
+ class VideoPipeline(nn.Module): def __init__(self, unet, vae, text_encoder, vocab): super().__init__(); self.unet = unet; self.vae = vae; self.text_encoder = text_encoder; self.vocab = vocab
161
+ def forward(self, prompt: str, steps: int = 25, num_frames: int = 24): token_ids = simple_tokenizer(prompt, self.vocab); text_emb = self.text_encoder(token_ids)
162
+ latent = torch.randn((1, 4, num_frames, 64, 64), device=device).half(); sched = DiffusionScheduler(steps)
163
+ for t in range(steps): noise = self.unet(latent, t, text_emb); latent = sched.step(noise, t, latent)
164
+ frames = self.vae.decode(latent / 0.18215); frames = frames.clamp(0, 1).float().cpu().permute(0, 2, 3, 4, 1).numpy(); return VideoOutput(frames)
165
+ class XTTSModelClass(nn.Module):
166
+ def __init__(self, config):
167
+ super().__init__()
168
+ self.xtts = XTTSModel(config, num_speakers=1024, num_languages=25) # Adjust num_speakers, num_languages as needed
169
+
170
+ def forward(self, text_tokens, text_lengths, speaker_ids, language_ids, voice_samples, voice_sample_lengths):
171
+ return self.xtts.forward(text_tokens, text_lengths, speaker_ids, language_ids, voice_samples, voice_sample_lengths)
172
+
173
+ def inference(self, text, language_id, speaker_id, voice_sample, temperature=0.7, length_penalty=1.0):
174
+ return self.xtts.inference(text, language_id, speaker_id, voice_sample, temperature, length_penalty)
175
+
176
+
177
+ def initialize_gpt2_model(folder, files): download_files(folder, files); config = GPT2Config(); model = GPT2LMHeadModel(config).to(device)
178
+ sd = torch.load(os.path.join(folder, sanitize_filename("gpt2-pytorch_model.bin")), map_location=device); load_state_dict_safe(model, sd); model.eval(); enc = read_json(os.path.join(folder, sanitize_filename("encoder.json"))); return model, enc
179
+ def initialize_translation_model(folder, files): download_files(folder, files); config = MBartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
180
+ model = MBartForConditionalGeneration(config).to(device); sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device); load_state_dict_safe(model, sd); model.eval()
181
+ vp = os.path.join(folder, "vocab.json");
182
+ if os.path.exists(vp): vocab = read_json(vp); model.tokenizer = lambda txt: [vocab.get(t, 0) for t in txt.split()]
183
+ else: model.tokenizer = lambda txt: txt
184
+ model.config.lang_code_to_id = {'en_XX': 0, 'es_XX': 1}; return model
185
+ def initialize_codegen_model(folder, files): download_files(folder, files); config = CodeGenConfig.from_dict(read_json(os.path.join(folder, "config.json")))
186
+ model = CodeGenForCausalLM(config).to(device); sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device); load_state_dict_safe(model, sd); model.eval()
187
+ tok = get_codegen_tokenizer(os.path.join(folder, "vocab.json"), os.path.join(folder, "merges.txt")); vocab = read_json(os.path.join(folder, "vocab.json")); idx2w = {v: k for k, v in vocab.items()}
188
+ model.tokenizer = tok; return model, tok, vocab, idx2w, vocab
189
+ def initialize_summarization_model(folder, files): download_files(folder, files); config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
190
+ model = BartForConditionalGeneration(config).to(device); sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device); load_state_dict_safe(model, sd); model.eval()
191
+ vp = os.path.join(folder, "vocab.json");
192
+ if os.path.exists(vp): vocab_json = read_json(vp); vocab = set(vocab_json.keys()); return model, vocab, vocab_json, {v: k for k, v in vocab_json.items()}
193
+ return model, None, None, None
194
+ def initialize_imagegen_model(folder, files): download_files(folder, files); config = AutoencoderKLConfig.from_dict(read_json(os.path.join(folder, "config.json")))
195
+ vae = AutoencoderKL(config).to(device); sd = torch.load(os.path.join(folder, "diffusion_pytorch_model.bin"), map_location=device); load_state_dict_safe(vae, sd); vae.eval(); return vae
196
+ def initialize_image_to_3d_model(folder, files): download_files(folder, files); config = OpenLRMConfig.from_dict(read_json(os.path.join(folder, "config.json")))
197
+ model3d = OpenLRM(config).to(device); sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device); load_state_dict_safe(model3d, sd); model3d.eval(); return model3d
198
+ def initialize_text_to_video_model(folder, files): download_files(folder, files); unet_cfg = read_json(os.path.join(folder, "config.json"))
199
+ unet_cfg = filter_kwargs(VideoUNet, unet_cfg); unet = VideoUNet(**unet_cfg).half().to(device)
200
+ sd_unet = torch.load(os.path.join(folder, "diffusion_pytorch_model.fp16.bin"), map_location=device); load_state_dict_safe(unet, sd_unet); unet.eval()
201
+ vae_cfg = read_json(os.path.join(folder, "config.json")); vae_cfg = filter_kwargs(AutoencoderKL, vae_cfg); vae = AutoencoderKL(vae_cfg).half().to(device)
202
+ sd_vae = torch.load(os.path.join(folder, "diffusion_pytorch_model.bin"), map_location=device); load_state_dict_safe(vae, sd_vae); vae.eval()
203
+ vp = os.path.join(folder, "vocab.json"); text_vocab = read_json(vp) if os.path.exists(vp) else {}; te_path = os.path.join(folder, "text_encoder.bin")
204
+ if os.path.exists(te_path): text_encoder = SimpleTextEncoder(vocab_size=(max(text_vocab.values())+1) if text_vocab else 10000, embed_dim=768, max_length=77).to(device); sd_te = torch.load(te_path, map_location=device); load_state_dict_safe(text_encoder, sd_te)
205
+ else: text_encoder = SimpleTextEncoder(vocab_size=(max(text_vocab.values())+1) if text_vocab else 10000, embed_dim=768, max_length=77).to(device)
206
+ text_encoder.eval(); return VideoPipeline(unet, vae, text_encoder, text_vocab)
207
+ def initialize_sentiment_model(folder, files): download_files(folder, files); config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
208
+ model = SentimentClassifierModel(config).to(device); sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device); load_state_dict_safe(model, sd); model.eval()
209
+ vp = os.path.join(folder, "vocab.json");
210
+ if os.path.exists(vp): read_json(vp); return model
211
+ def initialize_stt_model(folder, files): download_files(folder, files); config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
212
+ model = STTModel(config).to(device); sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device); load_state_dict_safe(model, sd); model.eval()
213
+ vp = os.path.join(folder, "vocab.json");
214
+ if os.path.exists(vp): read_json(vp); return model
215
+ def initialize_tts_model(folder, files): download_files(folder, files); config = BartConfig.from_dict(read_json(os.path.join(folder, "config.json")))
216
+ model = TTSModel(config).to(device); sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device); load_state_dict_safe(model, sd); model.eval()
217
+ vp = os.path.join(folder, "vocab.json");
218
+ if os.path.exists(vp): read_json(vp); return model
219
+ def initialize_musicgen_model(folder, files): download_files(folder, files); config = MusicGenConfig.from_dict(read_json(os.path.join(folder, "config.json")))
220
+ model = MusicGenModel(config).to(device); sd = torch.load(os.path.join(folder, "pytorch_model.bin"), map_location=device); load_state_dict_safe(model, sd); model.eval(); return model
221
+
222
+ def initialize_xtts_model(folder, files):
223
+ download_files(folder, files)
224
+ config_xtts = XTTSConfig.from_dict(read_json(os.path.join(folder, "config.json")))
225
+ model = XTTSModelClass(config_xtts).to(device)
226
+ checkpoint = torch.load(os.path.join(folder, "model.pth"), map_location=torch.device(device))
227
+ model.load_state_dict(checkpoint["model"], strict=False)
228
+ model.eval()
229
+ return model.xtts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models.py CHANGED
@@ -2,9 +2,11 @@ import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  import math
5
- import copy
6
- from configs import *
7
- from extensions import *
 
 
8
 
9
  class SentimentClassifierModel(nn.Module):
10
  def __init__(self, config):
@@ -13,15 +15,12 @@ class SentimentClassifierModel(nn.Module):
13
  self.embedding = nn.Embedding(config.vocab_size, config.d_model)
14
  self.lstm = nn.LSTM(config.d_model, config.d_model, batch_first=True, bidirectional=True)
15
  self.fc = nn.Linear(config.d_model * 2, 3)
16
-
17
  def forward(self, input_ids):
18
  embedded = self.embedding(input_ids)
19
  packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, lengths=[input_ids.size(1)]*input_ids.size(0), batch_first=True, enforce_sorted=False)
20
  packed_output, _ = self.lstm(packed_embedded)
21
  output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
22
- pooled = output[:, -1, :]
23
- logits = self.fc(pooled)
24
- return logits
25
 
26
  class STTModel(nn.Module):
27
  def __init__(self, config):
@@ -35,17 +34,11 @@ class STTModel(nn.Module):
35
  self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
36
  self.lstm = nn.LSTM(32 * (config.max_position_embeddings // 8), 128, batch_first=True, bidirectional=True)
37
  self.fc = nn.Linear(128 * 2, config.vocab_size)
38
-
39
  def forward(self, audio_data):
40
  x = self.pool1(self.relu1(self.conv1(audio_data.unsqueeze(1))))
41
- x = self.pool2(self.relu2(self.conv2(x)))
42
- x = x.transpose(1, 2).contiguous()
43
- x = x.view(x.size(0), -1, x.size(2))
44
- packed_output = nn.utils.rnn.pack_padded_sequence(x, lengths=[x.size(1)]*x.size(0), batch_first=True, enforce_sorted=False)
45
- packed_output, _ = self.lstm(packed_output)
46
- output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
47
- logits = self.fc(output)
48
- return logits
49
 
50
  class TTSModel(nn.Module):
51
  def __init__(self, config):
@@ -55,15 +48,9 @@ class TTSModel(nn.Module):
55
  self.lstm = nn.LSTM(config.d_model, config.d_model, batch_first=True, bidirectional=True)
56
  self.fc = nn.Linear(config.d_model * 2, 1)
57
  self.sigmoid = nn.Sigmoid()
58
-
59
  def forward(self, input_ids):
60
- embedded = self.embedding(input_ids)
61
- packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, lengths=[input_ids.size(1)]*input_ids.size(0), batch_first=True, enforce_sorted=False)
62
- packed_output, _ = self.lstm(packed_embedded)
63
- output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
64
- logits = self.fc(output)
65
- audio = self.sigmoid(logits)
66
- return audio
67
 
68
  class MusicGenModel(nn.Module):
69
  def __init__(self, config: MusicGenConfig):
@@ -72,23 +59,47 @@ class MusicGenModel(nn.Module):
72
  self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
73
  self.transformer_layers = nn.ModuleList([CodeGenBlock(config) for _ in range(config.num_hidden_layers)])
74
  self.fc_out = nn.Linear(config.hidden_size, config.vocab_size)
75
-
76
  def forward(self, input_ids):
77
- embedded_tokens = self.embedding(input_ids)
78
- hidden_states = embedded_tokens
79
- for layer in self.transformer_layers:
80
- hidden_states = layer(hidden_states)
81
- logits = self.fc_out(hidden_states)
82
- return logits
83
-
84
  def sample(self, attributes, sample_rate, duration):
85
- input_tokens = torch.randint(0, self.config.vocab_size, (1, 1), dtype=torch.long).to(device)
86
- audio_output = []
87
- num_steps = int(duration * sample_rate / 1024)
88
- for _ in tqdm(range(num_steps), desc="Generating music"):
89
- logits = self.forward(input_tokens)
90
- predicted_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
91
- audio_output.append(predicted_token.cpu())
92
- input_tokens = torch.cat((input_tokens, predicted_token), dim=1)
93
- audio_output = torch.cat(audio_output, dim=1).float()
94
- return audio_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  import math
5
+ from configs import MusicGenConfig
6
+ from extensions import CodeGenBlock
7
+ from TTS.tts.layers.xtts.transformer import XTransformerEncoder, XTransformerDecoder
8
+ from TTS.tts.layers.xtts.flow import VitsFlowModules
9
+ from TTS.tts.layers.xtts.tokenizer import VoiceBPE
10
 
11
  class SentimentClassifierModel(nn.Module):
12
  def __init__(self, config):
 
15
  self.embedding = nn.Embedding(config.vocab_size, config.d_model)
16
  self.lstm = nn.LSTM(config.d_model, config.d_model, batch_first=True, bidirectional=True)
17
  self.fc = nn.Linear(config.d_model * 2, 3)
 
18
  def forward(self, input_ids):
19
  embedded = self.embedding(input_ids)
20
  packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, lengths=[input_ids.size(1)]*input_ids.size(0), batch_first=True, enforce_sorted=False)
21
  packed_output, _ = self.lstm(packed_embedded)
22
  output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
23
+ pooled = output[:, -1, :]; logits = self.fc(pooled); return logits
 
 
24
 
25
  class STTModel(nn.Module):
26
  def __init__(self, config):
 
34
  self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)
35
  self.lstm = nn.LSTM(32 * (config.max_position_embeddings // 8), 128, batch_first=True, bidirectional=True)
36
  self.fc = nn.Linear(128 * 2, config.vocab_size)
 
37
  def forward(self, audio_data):
38
  x = self.pool1(self.relu1(self.conv1(audio_data.unsqueeze(1))))
39
+ x = self.pool2(self.relu2(self.conv2(x))); x = x.transpose(1, 2).contiguous(); x = x.view(x.size(0), -1, x.size(2))
40
+ packed_output = nn.utils.rnn.pack_padded_sequence(embedded, lengths=[x.size(1)]*x.size(0), batch_first=True, enforce_sorted=False); packed_output, _ = self.lstm(packed_output)
41
+ output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True); logits = self.fc(output); return logits
 
 
 
 
 
42
 
43
  class TTSModel(nn.Module):
44
  def __init__(self, config):
 
48
  self.lstm = nn.LSTM(config.d_model, config.d_model, batch_first=True, bidirectional=True)
49
  self.fc = nn.Linear(config.d_model * 2, 1)
50
  self.sigmoid = nn.Sigmoid()
 
51
  def forward(self, input_ids):
52
+ embedded = self.embedding(input_ids); packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, lengths=[input_ids.size(1)]*input_ids.size(0), batch_first=True, enforce_sorted=False)
53
+ packed_output, _ = self.lstm(packed_embedded); output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True); logits = self.fc(output); audio = self.sigmoid(logits); return audio
 
 
 
 
 
54
 
55
  class MusicGenModel(nn.Module):
56
  def __init__(self, config: MusicGenConfig):
 
59
  self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
60
  self.transformer_layers = nn.ModuleList([CodeGenBlock(config) for _ in range(config.num_hidden_layers)])
61
  self.fc_out = nn.Linear(config.hidden_size, config.vocab_size)
 
62
  def forward(self, input_ids):
63
+ embedded_tokens = self.embedding(input_ids); hidden_states = embedded_tokens
64
+ for layer in self.transformer_layers: hidden_states = layer(hidden_states)
65
+ logits = self.fc_out(hidden_states); return logits
 
 
 
 
66
  def sample(self, attributes, sample_rate, duration):
67
+ input_tokens = torch.randint(0, self.config.vocab_size, (1, 1), dtype=torch.long).to(device); audio_output = []; num_steps = int(duration * sample_rate / 1024)
68
+ for _ in tqdm(range(num_steps), desc="Generating music"): logits = self.forward(input_tokens); predicted_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True); audio_output.append(predicted_token.cpu()); input_tokens = torch.cat((input_tokens, predicted_token), dim=1)
69
+ audio_output = torch.cat(audio_output, dim=1).float(); return audio_output
70
+
71
+ class XTTSModelClass(nn.Module):
72
+ def __init__(self, config):
73
+ super().__init__()
74
+ self.xtts = XTTSModel(config, num_speakers=1024, num_languages=25)
75
+
76
+ def forward(self, text_tokens, text_lengths, speaker_ids, language_ids, voice_samples, voice_sample_lengths):
77
+ return self.xtts.forward(text_tokens, text_lengths, speaker_ids, language_ids, voice_samples, voice_sample_lengths)
78
+
79
+ def inference(self, text, language_id, speaker_id, voice_sample, temperature=0.7, length_penalty=1.0):
80
+ return self.xtts.inference(text, language_id, speaker_id, voice_sample, temperature, length_penalty)
81
+
82
+ class XTTSModel(nn.Module):
83
+ def __init__(self, config, num_speakers, num_languages):
84
+ super().__init__()
85
+ self.config = config
86
+ self.num_speakers = num_speakers
87
+ self.num_languages = num_languages
88
+ self.encoder = XTransformerEncoder(**config.encoder_config)
89
+ self.decoder = XTransformerDecoder(**config.decoder_config)
90
+ self.flow_modules = VitsFlowModules(**config.flow_config)
91
+ self.voice_tokenizer = VoiceBPE(vocab_path=config.voice_tokenizer_config.vocab_path, vocab_size=config.voice_tokenizer_config.vocab_size)
92
+ self.language_embedding = nn.Embedding(num_languages, config.embedding_dim)
93
+ self.speaker_embedding = nn.Embedding(num_speakers, config.embedding_dim)
94
+ self.text_embedding = nn.Embedding(config.num_chars, config.embedding_dim)
95
+
96
+ def forward(self, text_tokens, text_lengths, speaker_ids, language_ids, voice_samples, voice_sample_lengths):
97
+ lang_embed = self.language_embedding(language_ids); spk_embed = self.speaker_embedding(speaker_ids); text_embed = self.text_embedding(text_tokens)
98
+ encoder_outputs, _ = self.encoder(text_embed, text_lengths, lang_embed + spk_embed); mel_outputs, _ = self.decoder(encoder_outputs, lang_embed + spk_embed, voice_samples); return mel_outputs, None
99
+
100
+ def inference(self, text, language_id, speaker_id, voice_sample, temperature=0.7, length_penalty=1.0):
101
+ language_ids = torch.tensor([language_id], dtype=torch.long).to(device); speaker_ids = torch.tensor([speaker_id], dtype=torch.long).to(device)
102
+ text_tokens = self.voice_tokenizer.text_to_ids(text).to(device); text_lengths = torch.tensor([text_tokens.shape[0]], dtype=torch.long).to(device); voice_sample_lengths = torch.tensor([voice_sample.shape[0]], dtype=torch.long).to(device)
103
+ lang_embed = self.language_embedding(language_ids); spk_embed = self.speaker_embedding(speaker_ids); text_embed = self.text_embedding(text_tokens)
104
+ encoder_outputs, _ = self.encoder(text_embed, text_lengths, lang_embed + spk_embed); mel_outputs, _ = self.decoder.inference(encoder_outputs, lang_embed + spk_embed, voice_sample, temperature=temperature, length_penalty=length_penalty)
105
+ return mel_outputs
musicgen_api.py CHANGED
@@ -1,34 +1,15 @@
1
  from flask import jsonify, send_file, request
2
  from main import *
3
- import torch
4
- import soundfile as sf
5
- import numpy as np
6
- import io
7
 
8
  def generate_music(prompt, output_path="output_music.wav"):
9
- if musicgen_model is None:
10
- return "Music generation model not initialized."
 
 
11
 
12
- attributes = [prompt]
13
- sample_rate = 32000
14
- duration = 10
15
- audio_values = musicgen_model.sample(
16
- attributes=attributes,
17
- sample_rate=sample_rate,
18
- duration=duration,
19
- )
20
- output_audio = audio_values.cpu().numpy().squeeze()
21
- sf.write(output_path, output_audio, sample_rate)
22
- return output_path
23
-
24
- def musicgen_api():
25
- data = request.get_json()
26
- prompt = data.get('prompt')
27
- if not prompt:
28
- return jsonify({"error": "Prompt is required"}), 400
29
  output_file = generate_music(prompt)
30
- if output_file == "Music generation model not initialized.":
31
- return jsonify({"error": "Music generation failed"}), 500
32
- with open(output_file, 'rb') as f:
33
- audio_content = f.read()
34
- return send_file(io.BytesIO(audio_content), mimetype="audio/wav", as_attachment=True, download_name="output.wav")
 
1
  from flask import jsonify, send_file, request
2
  from main import *
3
+ import torch, soundfile as sf, numpy as np, io, base64
 
 
 
4
 
5
  def generate_music(prompt, output_path="output_music.wav"):
6
+ if musicgen_model is None: return {"error": "Music generation model not initialized."}
7
+ attributes = [prompt]; sample_rate = 32000; duration = 8
8
+ audio_values = musicgen_model.sample(attributes=attributes, sample_rate=sample_rate, duration=duration); output_audio = audio_values.cpu().numpy().squeeze()
9
+ sf.write(output_path, output_audio, sample_rate); return output_path
10
 
11
+ def musicgen_api(prompt):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  output_file = generate_music(prompt)
13
+ if isinstance(output_file, dict) and "error" in output_file: return {"error": output_file["error"]}
14
+ with open(output_file, 'rb') as f: audio_content = f.read()
15
+ audio_base64 = base64.b64encode(audio_content).decode('utf-8'); os.remove(output_file); return {"audio_base64": audio_base64, "mimetype": "audio/wav"}
 
 
sadtalker_api.py CHANGED
@@ -1,9 +1,4 @@
1
- import os
2
- import tempfile
3
- import uuid
4
- import asyncio
5
- import shutil
6
- import requests
7
  from urllib.parse import urlparse
8
  from fastapi import FastAPI, UploadFile, File, HTTPException, Form, WebSocket
9
  from fastapi.responses import JSONResponse
@@ -19,186 +14,24 @@ from text_generation import *
19
  router = APIRouter()
20
 
21
  @router.post("/sadtalker")
22
- async def create_video(
23
- source_image: str = Form(None),
24
- source_image_file: UploadFile = File(None),
25
- driven_audio: str = Form(None),
26
- driven_audio_file: UploadFile = File(None),
27
- preprocess: str = Form('crop'),
28
- still_mode: bool = Form(False),
29
- use_enhancer: bool = Form(False),
30
- batch_size: int = Form(1),
31
- size: int = Form(256),
32
- pose_style: int = Form(0),
33
- exp_scale: float = Form(1.0),
34
- use_ref_video: bool = Form(False),
35
- ref_video: str = Form(None),
36
- ref_video_file: UploadFile = File(None),
37
- ref_info: str = Form(None),
38
- use_idle_mode: bool = Form(False),
39
- length_of_audio: int = Form(0),
40
- use_blink: bool = Form(True),
41
- checkpoint_dir: str = Form('checkpoints'),
42
- config_dir: str = Form('src/config'),
43
- old_version: bool = Form(False),
44
- tts_text: str = Form(None),
45
- tts_lang: str = Form('en'),
46
- ):
47
- if source_image_file and source_image:
48
- raise HTTPException(status_code=400, detail="source_image and source_image_file cannot be both not None")
49
- if driven_audio and driven_audio_file:
50
- raise HTTPException(status_code=400, detail="driven_audio and driven_audio_file cannot be both not None")
51
- if ref_video and ref_video_file:
52
- raise HTTPException(status_code=400, detail="ref_video and ref_video_file cannot be both not None")
53
- tmp_source_image = None
54
- if source_image_file:
55
- tmp_source_image = tempfile.NamedTemporaryFile(suffix=os.path.splitext(source_image_file.filename)[1], delete=False)
56
- content = await source_image_file.read()
57
- tmp_source_image.write(content)
58
- source_image_path = tmp_source_image.name
59
- elif source_image:
60
- if urlparse(source_image).scheme in ["http", "https"]:
61
- response = requests.get(source_image, stream=True)
62
- response.raise_for_status()
63
- with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as tmp_source_image:
64
- for chunk in response.iter_content(chunk_size=8192):
65
- tmp_source_image.write(chunk)
66
- source_image_path = tmp_source_image.name
67
- else:
68
- source_image_path = source_image
69
- else:
70
- raise HTTPException(status_code=400, detail="source_image not provided")
71
- tmp_driven_audio = None
72
- if driven_audio_file:
73
- tmp_driven_audio = tempfile.NamedTemporaryFile(suffix=os.path.splitext(driven_audio_file.filename)[1], delete=False)
74
- content = await driven_audio_file.read()
75
- tmp_driven_audio.write(content)
76
- driven_audio_path = tmp_driven_audio.name
77
- elif driven_audio:
78
- if urlparse(driven_audio).scheme in ["http", "https"]:
79
- response = requests.get(driven_audio, stream=True)
80
- response.raise_for_status()
81
- with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_driven_audio:
82
- for chunk in response.iter_content(chunk_size=8192):
83
- tmp_driven_audio.write(chunk)
84
- driven_audio_path = tmp_driven_audio.name
85
- else:
86
- driven_audio_path = driven_audio
87
- else:
88
- driven_audio_path = None
89
- tmp_ref_video = None
90
- if ref_video_file:
91
- tmp_ref_video = tempfile.NamedTemporaryFile(suffix=os.path.splitext(ref_video_file.filename)[1], delete=False)
92
- content = await ref_video_file.read()
93
- tmp_ref_video.write(content)
94
- ref_video_path = tmp_ref_video.name
95
- elif ref_video:
96
- if urlparse(ref_video).scheme in ["http", "https"]:
97
- response = requests.get(ref_video, stream=True)
98
- response.raise_for_status()
99
- with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_ref_video:
100
- for chunk in response.iter_content(chunk_size=8192):
101
- tmp_ref_video.write(chunk)
102
- ref_video_path = tmp_ref_video.name
103
- else:
104
- ref_video_path = ref_video
105
- else:
106
- ref_video_path=None
107
- try:
108
- loop = asyncio.get_running_loop()
109
- output_path = await loop.run_in_executor(None, sadtalker_instance.test,
110
- source_image_path,
111
- driven_audio_path,
112
- preprocess,
113
- still_mode,
114
- use_enhancer,
115
- batch_size,
116
- size,
117
- pose_style,
118
- exp_scale,
119
- use_ref_video,
120
- ref_video_path,
121
- ref_info,
122
- use_idle_mode,
123
- length_of_audio,
124
- use_blink,
125
- './results/',
126
- tts_text=tts_text,
127
- tts_lang=tts_lang,
128
- )
129
- return {"video_url": output_path}
130
- except Exception as e:
131
- raise HTTPException(status_code=500, detail=str(e))
132
- finally:
133
- if tmp_source_image:
134
- os.remove(tmp_source_image.name)
135
- if tmp_driven_audio:
136
- os.remove(tmp_driven_audio.name)
137
- if tmp_ref_video:
138
- os.remove(tmp_ref_video.name)
139
 
140
- @router.websocket("/ws")
141
- async def websocket_endpoint(websocket: WebSocket):
142
- await websocket.accept()
143
- tts_model = TTSTalker()
144
  try:
145
- while True:
146
- data = await websocket.receive_json()
147
- text = data.get("text")
148
- audio_base64 = data.get("audio")
149
- if text:
150
- audio_path = await asyncio.get_running_loop().run_in_executor(None, tts_model.test, text)
151
- elif audio_base64:
152
- try:
153
- audio_bytes = base64.b64decode(audio_base64)
154
- tmp_audio_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
155
- tmp_audio_file.write(audio_bytes)
156
- audio_path = tmp_audio_file.name
157
- transcription_text_file = speech_to_text_func(tmp_audio_file.name)
158
- with open(transcription_text_file, 'r') as f:
159
- transcription_text = f.read()
160
- response_stream = perform_reasoning_stream(transcription_text, 0.7, 40, 0.0, 1.2)
161
- response_text = ""
162
- for chunk in response_stream:
163
- if chunk == "<END_STREAM>":
164
- break
165
- response_text += chunk
166
- audio_path = await asyncio.get_running_loop().run_in_executor(None, tts_model.test, response_text)
167
 
168
- except Exception as e:
169
- await websocket.send_json({"error":str(e)})
170
- continue
171
- finally:
172
- if 'tmp_audio_file' in locals() and tmp_audio_file:
173
- os.remove(tmp_audio_file.name)
174
- else:
175
- continue
176
- source_image_path = './examples/source_image/cyarh.png'
177
- ref_video_path='./examples/driven_video/vid_xdd.mp4'
178
- loop = asyncio.get_running_loop()
179
- output = await loop.run_in_executor(None, sadtalker_instance.test,
180
- source_image_path,
181
- audio_path,
182
- 'full',
183
- True,
184
- True,
185
- 1,
186
- 256,
187
- 0,
188
- 1,
189
- True,
190
- ref_video_path,
191
- "pose+blink",
192
- False,
193
- 0,
194
- True,
195
- './results/'
196
- )
197
- await websocket.send_json({"video_url": output})
198
- except Exception as e:
199
- print(e)
200
- await websocket.send_json({"error":str(e)})
201
 
202
  router = APIRouter()
203
  router.add_api_route("/sadtalker", create_video, methods=["POST"])
204
- router.add_api_websocket_route("/ws", websocket_endpoint)
 
1
+ import os, tempfile, uuid, asyncio, shutil, requests
 
 
 
 
 
2
  from urllib.parse import urlparse
3
  from fastapi import FastAPI, UploadFile, File, HTTPException, Form, WebSocket
4
  from fastapi.responses import JSONResponse
 
14
  router = APIRouter()
15
 
16
  @router.post("/sadtalker")
17
+ async def create_video(source_image_file: UploadFile = File(...), driven_audio_file: UploadFile = File(...)):
18
+ if not source_image_file: raise HTTPException(status_code=400, detail="Source image file is required")
19
+ if not driven_audio_file: raise HTTPException(status_code=400, detail="Driven audio file is required")
20
+
21
+ temp_source_image = tempfile.NamedTemporaryFile(suffix=os.path.splitext(source_image_file.filename)[1], delete=False)
22
+ content_image = await source_image_file.read(); temp_source_image.write(content_image); source_image_path = temp_source_image.name
23
+ temp_driven_audio = tempfile.NamedTemporaryFile(suffix=os.path.splitext(driven_audio_file.filename)[1], delete=False)
24
+ content_audio = await driven_audio_file.read(); temp_driven_audio.write(content_audio); driven_audio_path = temp_driven_audio.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
 
 
 
 
26
  try:
27
+ loop = asyncio.get_running_loop()
28
+ output_path = await loop.run_in_executor(None, sadtalker_instance.test, source_image_path, driven_audio_path)
29
+ video_base64 = None
30
+ with open(output_path, 'rb') as video_file: video_bytes = video_file.read(); video_base64 = base64.b64encode(video_bytes).decode('utf-8')
31
+ os.remove(output_path); return {"video_base64": video_base64, "mimetype": "video/mp4"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ except Exception as e: raise HTTPException(status_code=500, detail=str(e))
34
+ finally: os.remove(temp_source_image.name); os.remove(temp_driven_audio.name)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  router = APIRouter()
37
  router.add_api_route("/sadtalker", create_video, methods=["POST"])
 
sadtalker_utils.py CHANGED
@@ -1,820 +1,209 @@
1
- import os
2
- import shutil
3
- import uuid
4
- import cv2
5
- import numpy as np
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
- import yaml
10
- from PIL import Image
11
- from skimage import img_as_ubyte, transform
12
- import safetensors
13
- import librosa
14
- from pydub import AudioSegment
15
- import imageio
16
- from scipy import signal
17
- from scipy.io import loadmat, savemat, wavfile
18
- import glob
19
- import tempfile
20
- import tqdm
21
- import math
22
- import torchaudio
23
- import urllib.request
24
- from safetensors.torch import load_file, save_file
25
-
26
- REALESRGAN_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth"
27
- CODEFORMER_URL = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth"
28
- RESTOREFORMER_URL = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth"
29
- GFPGAN_URL = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
30
- kp_url = "https://huggingface.co/usyd-community/vitpose-base-simple/resolve/main/model.safetensors"
31
- kp_file = "kp_detector.safetensors"
32
- aud_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/auido2pose_00140-model.pth"
33
- aud_file = "auido2pose_00140-model.pth"
34
- wav_url = "https://huggingface.co/facebook/wav2vec2-base/resolve/main/pytorch_model.bin"
35
- wav_file = "wav2vec2.pth"
36
- gen_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/wav2lip.pth"
37
- gen_file = "generator.pth"
38
- mapx_url = "https://huggingface.co/vinthony/SadTalker/resolve/main/mapping_00229-model.pth.tar"
39
- mapx_file = "mapping.pth"
40
- den_url = "https://huggingface.co/KwaiVGI/LivePortrait/resolve/main/liveportrait/base_models/motion_extractor.pth"
41
- den_file = "dense_motion.pth"
42
-
43
-
44
- def download_model(url, filename, checkpoint_dir):
45
- if not os.path.exists(os.path.join(checkpoint_dir, filename)):
46
- print(f"Downloading {filename}...")
47
- os.makedirs(checkpoint_dir, exist_ok=True)
48
- urllib.request.urlretrieve(url, os.path.join(checkpoint_dir, filename))
49
- print(f"{filename} downloaded.")
50
- else:
51
- print(f"{filename} already exists.")
52
-
53
-
54
- def mp3_to_wav_util(mp3_filename, wav_filename, frame_rate):
55
- AudioSegment.from_file(mp3_filename).set_frame_rate(frame_rate).export(wav_filename, format="wav")
56
-
57
-
58
- def load_wav_util(path, sr):
59
- return librosa.core.load(path, sr=sr)[0]
60
-
61
-
62
- def save_wav_util(wav, path, sr):
63
- wav *= 32767 / max(0.01, np.max(np.abs(wav)))
64
- wavfile.write(path, sr, wav.astype(np.int16))
65
-
66
-
67
- def load_state_dict_robust(model, checkpoint_path, device, model_name="model"):
68
- if not os.path.exists(checkpoint_path):
69
- raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
70
- if checkpoint_path.endswith('safetensors'):
71
- checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
72
- else:
73
- checkpoint = torch.load(checkpoint_path, map_location=device)
74
-
75
- state_dict = checkpoint.get(model_name, checkpoint)
76
- try:
77
- model.load_state_dict(state_dict)
78
- except RuntimeError as e:
79
- print(f"Error loading {model_name} state_dict: {e}")
80
- print(f"Trying to load state_dict with key mapping for {model_name}.")
81
- model_state_dict = model.state_dict()
82
- mapped_state_dict = {}
83
- for key, value in state_dict.items():
84
- if key in model_state_dict and model_state_dict[key].shape == value.shape:
85
- mapped_state_dict[key] = value
86
- else:
87
- print(f"Skipping key {key} due to shape mismatch or missing in model.")
88
- missing_keys, unexpected_keys = model.load_state_dict(mapped_state_dict, strict=False)
89
- if missing_keys or unexpected_keys:
90
- print(f"Missing keys: {missing_keys}")
91
- print(f"Unexpected keys: {unexpected_keys}")
92
- print(f"Successfully loaded {model_name} state_dict with key mapping.")
93
-
94
-
95
- class OcclusionAwareKPDetector(nn.Module):
96
-
97
- def __init__(self, kp_channels, num_kp, num_dilation_blocks, dropout_rate):
98
- super(OcclusionAwareKPDetector, self).__init__()
99
- self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
100
- self.bn1 = nn.BatchNorm2d(64)
101
- self.relu = nn.ReLU()
102
- self.conv2 = nn.Conv2d(64, num_kp, kernel_size=3, padding=1)
103
-
104
- def forward(self, x):
105
- x = self.relu(self.bn1(self.conv1(x)))
106
- x = self.conv2(x)
107
- kp = {'value': x.view(x.size(0), -1)}
108
- return kp
109
-
110
-
111
- class Wav2Vec2Model(nn.Module):
112
-
113
- def __init__(self):
114
- super(Wav2Vec2Model, self).__init__()
115
- self.conv = nn.Conv1d(1, 64, kernel_size=10, stride=5, padding=5)
116
- self.bn = nn.BatchNorm1d(64)
117
- self.relu = nn.ReLU()
118
- self.fc = nn.Linear(64, 2048)
119
-
120
- def forward(self, audio):
121
- x = audio.unsqueeze(1)
122
- x = self.relu(self.bn(self.conv(x)))
123
- x = torch.mean(x, dim=-1)
124
- x = self.fc(x)
125
- return x
126
-
127
-
128
- class AudioCoeffsPredictor(nn.Module):
129
-
130
- def __init__(self, input_dim, output_dim):
131
- super(AudioCoeffsPredictor, self).__init__()
132
- self.linear = nn.Linear(input_dim, output_dim)
133
-
134
- def forward(self, audio_embedding):
135
- return self.linear(audio_embedding)
136
-
137
-
138
- class MappingNet(nn.Module):
139
-
140
- def __init__(self, num_coeffs, num_layers, hidden_dim):
141
- super(MappingNet, self).__init__()
142
- layers = []
143
- input_dim = num_coeffs * 2
144
- for _ in range(num_layers):
145
- layers.append(nn.Linear(input_dim, hidden_dim))
146
- layers.append(nn.ReLU())
147
- input_dim = hidden_dim
148
- layers.append(nn.Linear(hidden_dim, num_coeffs))
149
- self.net = nn.Sequential(*layers)
150
-
151
- def forward(self, x):
152
- return self.net(x)
153
-
154
-
155
- class DenseMotionNetwork(nn.Module):
156
-
157
- def __init__(self, num_kp, num_channels, block_expansion, num_blocks, max_features):
158
- super(DenseMotionNetwork, self).__init__()
159
- self.conv1 = nn.Conv2d(num_channels, max_features, kernel_size=3, padding=1)
160
- self.relu = nn.ReLU()
161
- self.conv2 = nn.Conv2d(max_features, num_channels, kernel_size=3, padding=1)
162
-
163
- def forward(self, kp_source, kp_driving, jacobian):
164
- x = self.relu(self.conv1(kp_source))
165
- x = self.conv2(x)
166
- sparse_motion = {'dense_motion': x}
167
- return sparse_motion
168
-
169
-
170
- class Hourglass(nn.Module):
171
-
172
- def __init__(self, block_expansion, num_blocks, max_features, num_channels, kp_size, num_deform_blocks):
173
- super(Hourglass, self).__init__()
174
- self.encoder = nn.Sequential(nn.Conv2d(num_channels, max_features, kernel_size=7, stride=2, padding=3),
175
- nn.BatchNorm2d(max_features), nn.ReLU())
176
- self.decoder = nn.Sequential(
177
- nn.ConvTranspose2d(max_features, num_channels, kernel_size=4, stride=2, padding=1), nn.Tanh())
178
-
179
- def forward(self, source_image, kp_driving, **kwargs):
180
- x = self.encoder(source_image)
181
- x = self.decoder(x)
182
- B, C, H, W = x.size()
183
- video = []
184
- for _ in range(10):
185
- frame = (x[0].cpu().detach().numpy().transpose(1, 2, 0) * 127.5 + 127.5).clip(0, 255).astype(
186
- np.uint8)
187
- video.append(frame)
188
- return video
189
-
190
-
191
- class Face3DHelper:
192
-
193
- def __init__(self, local_pca_path, device):
194
- self.local_pca_path = local_pca_path
195
- self.device = device
196
-
197
- def run(self, source_image):
198
- h, w, _ = source_image.shape
199
- x_min = w // 4
200
- y_min = h // 4
201
- x_max = x_min + w // 2
202
- y_max = y_min + h // 2
203
- return [x_min, y_min, x_max, y_max]
204
-
205
-
206
- class MouthDetector:
207
-
208
- def __init__(self):
209
- pass
210
-
211
- def detect(self, image):
212
- h, w = image.shape[:2]
213
- return (w // 2, h // 2)
214
-
215
-
216
- class KeypointNorm(nn.Module):
217
-
218
- def __init__(self, device):
219
- super(KeypointNorm, self).__init__()
220
- self.device = device
221
-
222
- def forward(self, kp_driving):
223
- return kp_driving
224
-
225
-
226
- def save_video_with_watermark(video_frames, audio_path, output_path):
227
- H, W, _ = video_frames[0].shape
228
- out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (W, H))
229
- for frame in video_frames:
230
- out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
231
- out.release()
232
-
233
-
234
- def paste_pic(video_path, source_image_crop, crop_info, audio_path, output_path):
235
- shutil.copy(video_path, output_path)
236
-
237
-
238
- class TTSTalker:
239
-
240
- def __init__(self):
241
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
242
- self.tts_model = None
243
-
244
- def load_model(self):
245
- self.tts_model = self
246
-
247
- def tokenizer(self, text):
248
- return [ord(c) for c in text]
249
-
250
- def __call__(self, input_tokens):
251
- return torch.zeros(1, 16000, device=self.device)
252
-
253
- def test(self, text, lang='en'):
254
- if self.tts_model is None:
255
- self.load_model()
256
- output_path = os.path.join('./results', str(uuid.uuid4()) + '.wav')
257
- os.makedirs('./results', exist_ok=True)
258
- tokens = self.tokenizer(text)
259
- input_tokens = torch.tensor([tokens], dtype=torch.long).to(self.device)
260
- with torch.no_grad():
261
- audio_output = self(input_tokens)
262
- torchaudio.save(output_path, audio_output.cpu(), 16000)
263
- return output_path
264
-
265
-
266
- class SadTalker:
267
-
268
- def __init__(self, checkpoint_path='checkpoints', config_path='src/config', size=256, preprocess='crop',
269
- old_version=False):
270
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
271
- self.cfg = self.get_cfg_defaults()
272
- self.merge_from_file(os.path.join(config_path, 'sadtalker_config.yaml'))
273
- self.cfg['MODEL']['CHECKPOINTS_DIR'] = checkpoint_path
274
- self.cfg['MODEL']['CONFIG_DIR'] = config_path
275
- self.cfg['MODEL']['DEVICE'] = self.device
276
- self.cfg['INPUT_IMAGE'] = {}
277
- self.cfg['INPUT_IMAGE']['SOURCE_IMAGE'] = 'None'
278
- self.cfg['INPUT_IMAGE']['DRIVEN_AUDIO'] = 'None'
279
- self.cfg['INPUT_IMAGE']['PREPROCESS'] = preprocess
280
- self.cfg['INPUT_IMAGE']['SIZE'] = size
281
- self.cfg['INPUT_IMAGE']['OLD_VERSION'] = old_version
282
-
283
- for filename, url in [
284
- (kp_file, kp_url), (aud_file, aud_url), (wav_file, wav_url), (gen_file, gen_url),
285
- (mapx_file, mapx_url), (den_file, den_url), ('GFPGANv1.4.pth', GFPGAN_URL),
286
- ('RealESRGAN_x2plus.pth', REALESRGAN_URL)
287
- ]:
288
- download_model(url, filename, checkpoint_path)
289
-
290
- self.sadtalker_model = SadTalkerModel(self.cfg, device_id=[0])
291
-
292
- def get_cfg_defaults(self):
293
- return {
294
- 'MODEL': {
295
- 'CHECKPOINTS_DIR': '',
296
- 'CONFIG_DIR': '',
297
- 'DEVICE': self.device,
298
- 'SCALE': 64,
299
- 'NUM_VOXEL_FRAMES': 8,
300
- 'NUM_MOTION_FRAMES': 10,
301
- 'MAX_FEATURES': 256,
302
- 'DRIVEN_AUDIO_SAMPLE_RATE': 16000,
303
- 'VIDEO_FPS': 25,
304
- 'OUTPUT_VIDEO_FPS': None,
305
- 'OUTPUT_AUDIO_SAMPLE_RATE': None,
306
- 'USE_ENHANCER': False,
307
- 'ENHANCER_NAME': '',
308
- 'BG_UPSAMPLER': None,
309
- 'IS_HALF': False
310
- },
311
- 'INPUT_IMAGE': {}
312
- }
313
-
314
- def merge_from_file(self, filepath):
315
- if os.path.exists(filepath):
316
- with open(filepath, 'r') as f:
317
- cfg_from_file = yaml.safe_load(f)
318
- self.cfg.update(cfg_from_file)
319
-
320
- def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False,
321
- batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None,
322
- ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/',
323
- tts_text=None, tts_lang='en'):
324
- self.sadtalker_model.test(source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size,
325
- pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode,
326
- length_of_audio, use_blink, result_dir, tts_text, tts_lang)
327
- return self.sadtalker_model.save_result()
328
-
329
-
330
- class SadTalkerModel:
331
-
332
- def __init__(self, sadtalker_cfg, device_id=[0]):
333
- self.cfg = sadtalker_cfg
334
- self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu')
335
- self.sadtalker = SadTalkerInnerModel(sadtalker_cfg, device_id)
336
- self.preprocesser = self.sadtalker.preprocesser
337
- self.kp_extractor = self.sadtalker.kp_extractor
338
- self.generator = self.sadtalker.generator
339
- self.mapping = self.sadtalker.mapping
340
- self.he_estimator = self.sadtalker.he_estimator
341
- self.audio_to_coeff = self.sadtalker.audio_to_coeff
342
- self.animate_from_coeff = self.sadtalker.animate_from_coeff
343
- self.face_enhancer = self.sadtalker.face_enhancer
344
-
345
- def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False,
346
- batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None,
347
- ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/',
348
- tts_text=None, tts_lang='en', jitter_amount=10, jitter_source_image=False):
349
- self.inner_test = SadTalkerInner(self, source_image, driven_audio, preprocess, still_mode, use_enhancer,
350
- batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info,
351
- use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang,
352
- jitter_amount, jitter_source_image)
353
- return self.inner_test.test()
354
-
355
- def save_result(self):
356
- return self.inner_test.save_result()
357
-
358
-
359
- class SadTalkerInner:
360
-
361
- def __init__(self, sadtalker_model, source_image, driven_audio, preprocess, still_mode, use_enhancer,
362
- batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode,
363
- length_of_audio, use_blink, result_dir, tts_text, tts_lang, jitter_amount, jitter_source_image):
364
- self.sadtalker_model = sadtalker_model
365
- self.source_image = source_image
366
- self.driven_audio = driven_audio
367
- self.preprocess = preprocess
368
- self.still_mode = still_mode
369
- self.use_enhancer = use_enhancer
370
- self.batch_size = batch_size
371
- self.size = size
372
- self.pose_style = pose_style
373
- self.exp_scale = exp_scale
374
- self.use_ref_video = use_ref_video
375
- self.ref_video = ref_video
376
- self.ref_info = ref_info
377
- self.use_idle_mode = use_idle_mode
378
- self.length_of_audio = length_of_audio
379
- self.use_blink = use_blink
380
- self.result_dir = result_dir
381
- self.tts_text = tts_text
382
- self.tts_lang = tts_lang
383
- self.jitter_amount = jitter_amount
384
- self.jitter_source_image = jitter_source_image
385
- self.device = self.sadtalker_model.device
386
- self.output_path = None
387
-
388
- def get_test_data(self):
389
- proc = self.sadtalker_model.preprocesser
390
- if self.tts_text is not None:
391
- temp_dir = tempfile.mkdtemp()
392
- audio_path = os.path.join(temp_dir, 'audio.wav')
393
- tts = TTSTalker()
394
- tts.test(self.tts_text, self.tts_lang)
395
- self.driven_audio = audio_path
396
- source_image_pil = Image.open(self.source_image).convert('RGB')
397
- if self.jitter_source_image:
398
- jitter_dx = np.random.randint(-self.jitter_amount, self.jitter_amount + 1)
399
- jitter_dy = np.random.randint(-self.jitter_amount, self.jitter_amount + 1)
400
- source_image_pil = Image.fromarray(
401
- np.roll(np.roll(np.array(source_image_pil), jitter_dx, axis=1), jitter_dy, axis=0))
402
- source_image_tensor, crop_info, cropped_image = proc.crop(source_image_pil, self.preprocess, self.size)
403
- if self.still_mode or self.use_idle_mode:
404
- ref_pose_coeff = proc.generate_still_pose(self.pose_style)
405
- ref_expression_coeff = proc.generate_still_expression(self.exp_scale)
406
- else:
407
- ref_pose_coeff = None
408
- ref_expression_coeff = None
409
- audio_tensor, audio_sample_rate = proc.process_audio(self.driven_audio,
410
- self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE'])
411
- batch = {
412
- 'source_image': source_image_tensor.unsqueeze(0).to(self.device),
413
- 'audio': audio_tensor.unsqueeze(0).to(self.device),
414
- 'ref_pose_coeff': ref_pose_coeff,
415
- 'ref_expression_coeff': ref_expression_coeff,
416
- 'source_image_crop': cropped_image,
417
- 'crop_info': crop_info,
418
- 'use_blink': self.use_blink,
419
- 'pose_style': self.pose_style,
420
- 'exp_scale': self.exp_scale,
421
- 'ref_video': self.ref_video,
422
- 'use_ref_video': self.use_ref_video,
423
- 'ref_info': self.ref_info,
424
- }
425
- return batch, audio_sample_rate
426
-
427
- def run_inference(self, batch):
428
- kp_extractor = self.sadtalker_model.kp_extractor
429
- generator = self.sadtalker_model.generator
430
- mapping = self.sadtalker_model.mapping
431
- he_estimator = self.sadtalker_model.he_estimator
432
- audio_to_coeff = self.sadtalker_model.audio_to_coeff
433
- animate_from_coeff = self.sadtalker_model.animate_from_coeff
434
- face_enhancer = self.sadtalker_model.face_enhancer if self.use_enhancer else None
435
- with torch.no_grad():
436
- kp_source = kp_extractor(batch['source_image'])
437
- if self.still_mode or self.use_idle_mode:
438
- ref_pose_coeff = batch['ref_pose_coeff']
439
- ref_expression_coeff = batch['ref_expression_coeff']
440
- pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], ref_pose_coeff)
441
- expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], ref_expression_coeff)
442
- elif self.use_idle_mode:
443
- ref_pose_coeff = batch['ref_pose_coeff']
444
- ref_expression_coeff = batch['ref_expression_coeff']
445
- pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], ref_pose_coeff)
446
- expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], ref_expression_coeff)
447
- else:
448
- if self.use_ref_video:
449
- kp_ref = kp_extractor(batch['source_image'])
450
- pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], kp_ref=kp_ref,
451
- use_ref_info=batch['ref_info'])
452
- else:
453
- pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'])
454
- expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'])
455
- coeff = {'pose_coeff': pose_coeff, 'expression_coeff': expression_coeff}
456
- if self.use_blink:
457
- coeff['blink_coeff'] = audio_to_coeff.get_blink_coeff(batch['audio'])
458
- else:
459
- coeff['blink_coeff'] = None
460
- kp_driving = audio_to_coeff(batch['audio'])[0]
461
- kp_norm = animate_from_coeff.normalize_kp(kp_driving)
462
- coeff['kp_driving'] = kp_norm
463
- coeff['jacobian'] = [torch.eye(2).unsqueeze(0).unsqueeze(0).to(self.device)] * 4
464
- output_video = animate_from_coeff.generate(batch['source_image'], kp_source, coeff, generator, mapping,
465
- he_estimator, batch['audio'], batch['source_image_crop'],
466
- face_enhancer=face_enhancer)
467
- return output_video
468
-
469
- def post_processing(self, output_video, audio_sample_rate, batch):
470
- proc = self.sadtalker_model.preprocesser
471
- base_name = os.path.splitext(os.path.basename(batch['source_image_crop']))[0]
472
- audio_name = os.path.splitext(os.path.basename(self.driven_audio))[0]
473
- output_video_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '.mp4')
474
- self.output_path = output_video_path
475
- video_fps = self.sadtalker_model.cfg['MODEL']['VIDEO_FPS'] if self.sadtalker_model.cfg['MODEL'][
476
- 'OUTPUT_VIDEO_FPS'] is None else \
477
- self.sadtalker_model.cfg['MODEL']['OUTPUT_VIDEO_FPS']
478
- audio_output_sample_rate = self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE'] if \
479
- self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE'] is None else \
480
- self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE']
481
- if self.use_enhancer:
482
- enhanced_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '_enhanced.mp4')
483
- save_video_with_watermark(output_video, self.driven_audio, enhanced_path)
484
- paste_pic(enhanced_path, batch['source_image_crop'], batch['crop_info'], self.driven_audio,
485
- output_video_path)
486
- os.remove(enhanced_path)
487
- else:
488
- save_video_with_watermark(output_video, self.driven_audio, output_video_path)
489
- if self.tts_text is not None:
490
- shutil.rmtree(os.path.dirname(self.driven_audio))
491
-
492
- def save_result(self):
493
- return self.output_path
494
-
495
- def __call__(self):
496
- return self.output_path
497
-
498
- def test(self):
499
- batch, audio_sample_rate = self.get_test_data()
500
- output_video = self.run_inference(batch)
501
- self.post_processing(output_video, audio_sample_rate, batch)
502
- return self.save_result()
503
-
504
-
505
- class SadTalkerInnerModel:
506
-
507
- def __init__(self, sadtalker_cfg, device_id=[0]):
508
- self.cfg = sadtalker_cfg
509
- self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu')
510
- self.preprocesser = Preprocesser(sadtalker_cfg, self.device)
511
- self.kp_extractor = KeyPointExtractor(sadtalker_cfg, self.device)
512
- self.audio_to_coeff = Audio2Coeff(sadtalker_cfg, self.device)
513
- self.animate_from_coeff = AnimateFromCoeff(sadtalker_cfg, self.device)
514
- self.face_enhancer = FaceEnhancer(sadtalker_cfg, self.device) if sadtalker_cfg['MODEL'][
515
- 'USE_ENHANCER'] else None
516
- self.generator = Generator(sadtalker_cfg, self.device)
517
- self.mapping = Mapping(sadtalker_cfg, self.device)
518
- self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, self.device)
519
-
520
-
521
- class Preprocesser:
522
-
523
- def __init__(self, sadtalker_cfg, device):
524
- self.cfg = sadtalker_cfg
525
- self.device = device
526
- self.face3d_helper = Face3DHelper(self.cfg['INPUT_IMAGE'].get('LOCAL_PCA_PATH', ''), device)
527
- self.mouth_detector = MouthDetector()
528
-
529
- def crop(self, source_image_pil, preprocess_type, size=256):
530
- source_image = np.array(source_image_pil)
531
- face_info = self.face3d_helper.run(source_image)
532
- if face_info is None:
533
- raise Exception("No face detected")
534
- x_min, y_min, x_max, y_max = face_info[:4]
535
- old_size = (x_max - x_min, y_max - y_min)
536
- x_center = (x_max + x_min) / 2
537
- y_center = (y_max + y_min) / 2
538
- if preprocess_type == 'crop':
539
- face_size = max(x_max - x_min, y_max - y_min)
540
- x_min = int(x_center - face_size / 2)
541
- y_min = int(y_center - face_size / 2)
542
- x_max = int(x_center + face_size / 2)
543
- y_max = int(y_center + face_size / 2)
544
- else:
545
- x_min -= int((x_max - x_min) * 0.1)
546
- y_min -= int((y_max - y_min) * 0.1)
547
- x_max += int((x_max - x_min) * 0.1)
548
- y_max += int((y_max - y_min) * 0.1)
549
- h, w = source_image.shape[:2]
550
- x_min = max(0, x_min)
551
- y_min = max(0, y_min)
552
- x_max = min(w, x_max)
553
- y_max = min(h, y_max)
554
- cropped_image = source_image[y_min:y_max, x_min:x_max]
555
- cropped_image_pil = Image.fromarray(cropped_image)
556
- if size is not None and size != 0:
557
- cropped_image_pil = cropped_image_pil.resize((size, size), Image.Resampling.LANCZOS)
558
- source_image_tensor = self.img2tensor(cropped_image_pil)
559
- return source_image_tensor, [[y_min, y_max], [x_min, x_max], old_size, cropped_image_pil.size], os.path.basename(
560
- self.cfg['INPUT_IMAGE'].get('SOURCE_IMAGE', ''))
561
-
562
- def img2tensor(self, img):
563
- img = np.array(img).astype(np.float32) / 255.0
564
- img = np.transpose(img, (2, 0, 1))
565
- return torch.FloatTensor(img)
566
-
567
- def video_to_tensor(self, video, device):
568
- video_tensor_list = []
569
- import torchvision.transforms as transforms
570
- transform_func = transforms.ToTensor()
571
- for frame in video:
572
- frame_pil = Image.fromarray(frame)
573
- frame_tensor = transform_func(frame_pil).unsqueeze(0).to(device)
574
- video_tensor_list.append(frame_tensor)
575
- video_tensor = torch.cat(video_tensor_list, dim=0)
576
- return video_tensor
577
-
578
- def process_audio(self, audio_path, sample_rate):
579
- wav = load_wav_util(audio_path, sample_rate)
580
- wav_tensor = torch.FloatTensor(wav).unsqueeze(0)
581
- return wav_tensor, sample_rate
582
-
583
- def generate_still_pose(self, pose_style):
584
- ref_pose_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device)
585
- ref_pose_coeff[:, :3] = torch.tensor([0, 0, pose_style * 0.3], dtype=torch.float32)
586
- return ref_pose_coeff
587
-
588
- def generate_still_expression(self, exp_scale):
589
- ref_expression_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device)
590
- ref_expression_coeff[:, :3] = torch.tensor([0, 0, exp_scale * 0.3], dtype=torch.float32)
591
- return ref_expression_coeff
592
-
593
- def generate_idles_pose(self, length_of_audio, pose_style):
594
- num_frames = int(length_of_audio * self.cfg['MODEL']['VIDEO_FPS'])
595
- ref_pose_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device)
596
- start_pose = self.generate_still_pose(pose_style)
597
- end_pose = self.generate_still_pose(pose_style)
598
- for frame_idx in range(num_frames):
599
- alpha = frame_idx / num_frames
600
- ref_pose_coeff[frame_idx] = (1 - alpha) * start_pose + alpha * end_pose
601
- return ref_pose_coeff
602
-
603
- def generate_idles_expression(self, length_of_audio):
604
- num_frames = int(length_of_audio * self.cfg['MODEL']['VIDEO_FPS'])
605
- ref_expression_coeff = torch.zeros((num_frames, 64), dtype=torch.float32).to(self.device)
606
- start_exp = self.generate_still_expression(1.0)
607
- end_exp = self.generate_still_expression(1.0)
608
- for frame_idx in range(num_frames):
609
- alpha = frame_idx / num_frames
610
- ref_expression_coeff[frame_idx] = (1 - alpha) * start_exp + alpha * end_exp
611
- return ref_expression_coeff
612
-
613
-
614
- class KeyPointExtractor(nn.Module):
615
-
616
- def __init__(self, sadtalker_cfg, device):
617
- super(KeyPointExtractor, self).__init__()
618
- self.kp_extractor = OcclusionAwareKPDetector(kp_channels=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'],
619
- num_kp=10,
620
- num_dilation_blocks=2,
621
- dropout_rate=0.1).to(device)
622
- checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'kp_detector.safetensors')
623
- load_state_dict_robust(self.kp_extractor, checkpoint_path, device, model_name='kp_detector')
624
-
625
- def forward(self, x):
626
- kp = self.kp_extractor(x)
627
- return kp
628
-
629
-
630
- class Audio2Coeff(nn.Module):
631
-
632
- def __init__(self, sadtalker_cfg, device):
633
- super(Audio2Coeff, self).__init__()
634
- self.audio_model = Wav2Vec2Model().to(device)
635
- checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'wav2vec2.pth')
636
- load_state_dict_robust(self.audio_model, checkpoint_path, device, model_name='wav2vec2')
637
- self.pose_mapper = AudioCoeffsPredictor(2048, 64).to(device)
638
- self.exp_mapper = AudioCoeffsPredictor(2048, 64).to(device)
639
- self.blink_mapper = AudioCoeffsPredictor(2048, 1).to(device)
640
- mapping_checkpoint = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'auido2pose_00140-model.pth')
641
- load_state_dict_robust(self, mapping_checkpoint, device)
642
-
643
- def get_pose_coeff(self, audio_tensor, ref_pose_coeff=None, kp_ref=None, use_ref_info=''):
644
- audio_embedding = self.audio_model(audio_tensor)
645
- pose_coeff = self.pose_mapper(audio_embedding)
646
- if ref_pose_coeff is not None:
647
- pose_coeff = ref_pose_coeff
648
- if kp_ref is not None and use_ref_info == 'pose':
649
- ref_pose_6d = kp_ref['value'][:, :6]
650
- pose_coeff[:, :6] = self.mean_std_normalize(ref_pose_6d).mean(dim=1)
651
- return pose_coeff
652
-
653
- def get_exp_coeff(self, audio_tensor, ref_expression_coeff=None):
654
- audio_embedding = self.audio_model(audio_tensor)
655
- expression_coeff = self.exp_mapper(audio_embedding)
656
- if ref_expression_coeff is not None:
657
- expression_coeff = ref_expression_coeff
658
- return expression_coeff
659
-
660
- def get_blink_coeff(self, audio_tensor):
661
- audio_embedding = self.audio_model(audio_tensor)
662
- blink_coeff = self.blink_mapper(audio_embedding)
663
- return blink_coeff
664
-
665
- def forward(self, audio):
666
- audio_embedding = self.audio_model(audio)
667
- pose_coeff, expression_coeff, blink_coeff = self.pose_mapper(audio_embedding), self.exp_mapper(
668
- audio_embedding), self.blink_mapper(audio_embedding)
669
- return pose_coeff, expression_coeff, blink_coeff
670
-
671
- def mean_std_normalize(self, coeff):
672
- mean = coeff.mean(dim=1, keepdim=True)
673
- std = coeff.std(dim=1, keepdim=True)
674
- return (coeff - mean) / std
675
-
676
-
677
- class AnimateFromCoeff(nn.Module):
678
-
679
- def __init__(self, sadtalker_cfg, device):
680
- super(AnimateFromCoeff, self).__init__()
681
- self.generator = Generator(sadtalker_cfg, device)
682
- self.mapping = Mapping(sadtalker_cfg, device)
683
- self.kp_norm = KeypointNorm(device=device)
684
- self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, device)
685
-
686
- def normalize_kp(self, kp_driving):
687
- return self.kp_norm(kp_driving)
688
-
689
- def generate(self, source_image, kp_source, coeff, generator, mapping, he_estimator, audio, source_image_crop,
690
- face_enhancer=None):
691
- kp_driving = coeff['kp_driving']
692
- jacobian = coeff['jacobian']
693
- pose_coeff = coeff['pose_coeff']
694
- expression_coeff = coeff['expression_coeff']
695
- blink_coeff = coeff['blink_coeff']
696
- face_3d = mapping(expression_coeff, pose_coeff, blink_coeff) if blink_coeff is not None else mapping(expression_coeff, pose_coeff)
697
- sparse_motion = he_estimator(kp_source, kp_driving, jacobian)
698
- dense_motion = sparse_motion['dense_motion']
699
- video_deocclusion = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None})
700
- video_3d = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None}, face_3d_param=face_3d)
701
- video_output = video_deocclusion['video_no_reocclusion'] + video_3d['video_3d']
702
- if face_enhancer is not None:
703
- video_output_enhanced = []
704
- for frame in tqdm(video_output, 'Face enhancer running'):
705
- pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
706
- enhanced_image = face_enhancer.forward(np.array(pil_image))
707
- video_output_enhanced.append(cv2.cvtColor(enhanced_image, cv2.COLOR_BGR2RGB))
708
- video_output = video_output_enhanced
709
- return video_output
710
-
711
- def make_animation(self, video_array):
712
- H, W, _ = video_array[0].shape
713
- out = cv2.VideoWriter('./tmp.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 25, (W, H))
714
- for img in video_array:
715
- out.write(cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
716
- out.release()
717
- video = imageio.mimread('./tmp.mp4')
718
- os.remove('./tmp.mp4')
719
- return video
720
-
721
-
722
- class Generator(nn.Module):
723
-
724
- def __init__(self, sadtalker_cfg, device):
725
- super(Generator, self).__init__()
726
- self.generator = Hourglass(block_expansion=sadtalker_cfg['MODEL']['SCALE'],
727
- num_blocks=sadtalker_cfg['MODEL']['NUM_VOXEL_FRAMES'],
728
- max_features=sadtalker_cfg['MODEL']['MAX_FEATURES'],
729
- num_channels=3,
730
- kp_size=10,
731
- num_deform_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES']).to(device)
732
- checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'generator.pth')
733
- load_state_dict_robust(self.generator, checkpoint_path, device, model_name='generator')
734
-
735
- def forward(self, source_image, dense_motion, bg_param, face_3d_param=None):
736
- video_3d = self.generator(source_image, kp_driving=dense_motion, bg_param=bg_param, face_3d_param=face_3d_param)
737
- return {'video_3d': video_3d, 'video_no_reocclusion': video_3d}
738
-
739
-
740
- class Mapping(nn.Module):
741
-
742
- def __init__(self, sadtalker_cfg, device):
743
- super(Mapping, self).__init__()
744
- self.mapping_net = MappingNet(num_coeffs=64, num_layers=3, hidden_dim=128).to(device)
745
- checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'mapping.pth')
746
- load_state_dict_robust(self.mapping_net, checkpoint_path, device, model_name='mapping')
747
- self.f_3d_mean = torch.zeros(1, 64, device=device)
748
-
749
- def forward(self, expression_coeff, pose_coeff, blink_coeff=None):
750
- coeff = torch.cat([expression_coeff, pose_coeff], dim=1)
751
- face_3d = self.mapping_net(coeff) + self.f_3d_mean
752
- if blink_coeff is not None:
753
- face_3d[:, -1:] = blink_coeff
754
- return face_3d
755
-
756
-
757
- class OcclusionAwareDenseMotion(nn.Module):
758
-
759
- def __init__(self, sadtalker_cfg, device):
760
- super(OcclusionAwareDenseMotion, self).__init__()
761
- self.dense_motion_network = DenseMotionNetwork(num_kp=10,
762
- num_channels=3,
763
- block_expansion=sadtalker_cfg['MODEL']['SCALE'],
764
- num_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'] - 1,
765
- max_features=sadtalker_cfg['MODEL']['MAX_FEATURES']).to(device)
766
- checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'dense_motion.pth')
767
- load_state_dict_robust(self.dense_motion_network, checkpoint_path, device, model_name='dense_motion')
768
-
769
- def forward(self, kp_source, kp_driving, jacobian):
770
- sparse_motion = self.dense_motion_network(kp_source, kp_driving, jacobian)
771
- return sparse_motion
772
-
773
-
774
- class FaceEnhancer(nn.Module):
775
-
776
- def __init__(self, sadtalker_cfg, device):
777
- super(FaceEnhancer, self).__init__()
778
- enhancer_name = sadtalker_cfg['MODEL']['ENHANCER_NAME']
779
- bg_upsampler = sadtalker_cfg['MODEL']['BG_UPSAMPLER']
780
- if enhancer_name == 'gfpgan':
781
- from gfpgan import GFPGANer
782
- self.face_enhancer = GFPGANer(model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'GFPGANv1.4.pth'),
783
- upscale=1,
784
- arch='clean',
785
- channel_multiplier=2,
786
- bg_upsampler=bg_upsampler)
787
- elif enhancer_name == 'realesrgan':
788
- from realesrgan import RealESRGANer
789
- half = False if device == 'cpu' else sadtalker_cfg['MODEL']['IS_HALF']
790
- self.face_enhancer = RealESRGANer(scale=2,
791
- model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'],
792
- 'RealESRGAN_x2plus.pth'),
793
- tile=0,
794
- tile_pad=10,
795
- pre_pad=0,
796
- half=half,
797
- device=device)
798
- else:
799
- self.face_enhancer = None
800
-
801
- def forward(self, x):
802
- if self.face_enhancer:
803
- return self.face_enhancer.enhance(x, outscale=1)[0]
804
- return x
805
-
806
-
807
- def load_models():
808
- checkpoint_path = './checkpoints'
809
- config_path = './src/config'
810
- size = 256
811
- preprocess = 'crop'
812
- old_version = False
813
-
814
- sadtalker_instance = SadTalker(checkpoint_path, config_path, size, preprocess, old_version)
815
- print("SadTalker models loaded successfully!")
816
- return sadtalker_instance
817
-
818
-
819
- if __name__ == '__main__':
820
- sadtalker_instance = load_models()
 
1
+ import os, shutil, uuid, cv2, numpy as np, torch, torch.nn as nn, torch.nn.functional as F, yaml, safetensors, librosa, imageio
2
+ from PIL import Image
3
+ from skimage import img_as_ubyte, transform
4
+ from scipy.io import loadmat, wavfile
5
+
6
+ class SadTalker():
7
+ def __init__(self, checkpoint_path='checkpoints', config_path='src/config', size=256, preprocess='crop', old_version=False):
8
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ self.cfg = self.get_cfg_defaults()
10
+ self.merge_from_file(os.path.join(config_path, 'sadtalker_config.yaml'))
11
+ self.cfg['MODEL']['CHECKPOINTS_DIR'] = checkpoint_path
12
+ self.cfg['MODEL']['CONFIG_DIR'] = config_path
13
+ self.cfg['MODEL']['DEVICE'] = self.device
14
+ self.cfg['INPUT_IMAGE'] = {}
15
+ self.cfg['INPUT_IMAGE']['SOURCE_IMAGE'] = 'None'
16
+ self.cfg['INPUT_IMAGE']['DRIVEN_AUDIO'] = 'None'
17
+ self.cfg['INPUT_IMAGE']['PREPROCESS'] = preprocess
18
+ self.cfg['INPUT_IMAGE']['SIZE'] = size
19
+ self.cfg['INPUT_IMAGE']['OLD_VERSION'] = old_version
20
+ for filename, url in [(kp_file, kp_url), (aud_file, aud_url), (wav_file, wav_url), (gen_file, gen_url), (mapx_file, mapx_url), (den_file, den_url), ('GFPGANv1.4.pth', GFPGAN_URL), ('RealESRGAN_x2plus.pth', REALESRGAN_URL)]: download_model(url, filename, checkpoint_dir)
21
+ self.sadtalker_model = SadTalkerModel(self.cfg, device_id=[0])
22
+
23
+ def get_cfg_defaults(self):
24
+ return {'MODEL': {'CHECKPOINTS_DIR': '', 'CONFIG_DIR': '', 'DEVICE': self.device, 'SCALE': 64, 'NUM_VOXEL_FRAMES': 8, 'NUM_MOTION_FRAMES': 10, 'MAX_FEATURES': 256, 'DRIVEN_AUDIO_SAMPLE_RATE': 16000, 'VIDEO_FPS': 25, 'OUTPUT_VIDEO_FPS': None, 'OUTPUT_AUDIO_SAMPLE_RATE': None, 'USE_ENHANCER': False, 'ENHANCER_NAME': '', 'BG_UPSAMPLER': None, 'IS_HALF': False}, 'INPUT_IMAGE': {}}
25
+
26
+ def merge_from_file(self, filepath):
27
+ if os.path.exists(filepath):
28
+ with open(filepath, 'r') as f: cfg_from_file = yaml.safe_load(f); self.cfg.update(cfg_from_file)
29
+
30
+ def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False, batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None, ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/', tts_text=None, tts_lang='en'):
31
+ self.sadtalker_model.test(source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang); return self.sadtalker_model.save_result()
32
+
33
+ class SadTalkerModel():
34
+ def __init__(self, sadtalker_cfg, device_id=[0]):
35
+ self.cfg = sadtalker_cfg; self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu')
36
+ self.sadtalker = SadTalkerInnerModel(sadtalker_cfg, device_id)
37
+ self.preprocesser = self.sadtalker.preprocesser
38
+ self.kp_extractor = self.sadtalker.kp_extractor; self.generator = self.sadtalker.generator
39
+ self.mapping = self.sadtalker.mapping; self.he_estimator = self.sadtalker.he_estimator
40
+ self.audio_to_coeff = self.sadtalker.audio_to_coeff; self.animate_from_coeff = self.sadtalker.animate_from_coeff; self.face_enhancer = self.sadtalker.face_enhancer
41
+
42
+ def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False, batch_size=1, size=256, pose_style=0, exp_scale=1.0, use_ref_video=False, ref_video=None, ref_info=None, use_idle_mode=False, length_of_audio=0, use_blink=True, result_dir='./results/', tts_text=None, tts_lang='en', jitter_amount=10, jitter_source_image=False):
43
+ self.inner_test = SadTalkerInner(self, source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang, jitter_amount, jitter_source_image); return self.inner_test.test()
44
+
45
+ def save_result(self):
46
+ return self.inner_test.save_result()
47
+
48
+ class SadTalkerInner():
49
+ def __init__(self, sadtalker_model, source_image, driven_audio, preprocess, still_mode, use_enhancer, batch_size, size, pose_style, exp_scale, use_ref_video, ref_video, ref_info, use_idle_mode, length_of_audio, use_blink, result_dir, tts_text, tts_lang, jitter_amount, jitter_source_image):
50
+ self.sadtalker_model = sadtalker_model; self.source_image = source_image; self.driven_audio = driven_audio
51
+ self.preprocess = preprocess; self.still_mode = still_mode; self.use_enhancer = use_enhancer
52
+ self.batch_size = batch_size; self.size = size; self.pose_style = pose_style; self.exp_scale = exp_scale
53
+ self.use_ref_video = use_ref_video; self.ref_video = ref_video; self.ref_info = ref_info
54
+ self.use_idle_mode = use_idle_mode; self.length_of_audio = length_of_audio; self.use_blink = use_blink
55
+ self.result_dir = result_dir; self.tts_text = tts_text; self.tts_lang = tts_lang
56
+ self.jitter_amount = jitter_amount; self.jitter_source_image = jitter_source_image; self.device = self.sadtalker_model.device; self.output_path = None
57
+
58
+ def get_test_data(self):
59
+ proc = self.sadtalker_model.preprocesser
60
+ if self.tts_text is not None: temp_dir = tempfile.mkdtemp(); audio_path = os.path.join(temp_dir, 'audio.wav'); tts = TTSTalker(); tts.test(self.tts_text, self.tts_lang); self.driven_audio = audio_path
61
+ source_image_pil = Image.open(self.source_image).convert('RGB')
62
+ if self.jitter_source_image: jitter_dx = np.random.randint(-self.jitter_amount, self.jitter_amount + 1); jitter_dy = np.random.randint(-self.jitter_amount, self.jitter_amount + 1); source_image_pil = Image.fromarray(np.roll(np.roll(np.array(source_image_pil), jitter_dx, axis=1), jitter_dy, axis=0))
63
+ source_image_tensor, crop_info, cropped_image = proc.crop(source_image_pil, self.preprocess, self.size)
64
+ if self.still_mode or self.use_idle_mode: ref_pose_coeff = proc.generate_still_pose(self.pose_style); ref_expression_coeff = proc.generate_still_expression(self.exp_scale)
65
+ else: ref_pose_coeff = None; ref_expression_coeff = None
66
+ audio_tensor, audio_sample_rate = proc.process_audio(self.driven_audio, self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE'])
67
+ batch = {'source_image': source_image_tensor.unsqueeze(0).to(self.device), 'audio': audio_tensor.unsqueeze(0).to(self.device), 'ref_pose_coeff': ref_pose_coeff, 'ref_expression_coeff': ref_expression_coeff, 'source_image_crop': cropped_image, 'crop_info': crop_info, 'use_blink': self.use_blink, 'pose_style': self.pose_style, 'exp_scale': self.exp_scale, 'ref_video': self.ref_video, 'use_ref_video': self.use_ref_video, 'ref_info': self.ref_info}
68
+ return batch, audio_sample_rate
69
+
70
+ def run_inference(self, batch):
71
+ kp_extractor, generator, mapping, he_estimator, audio_to_coeff, animate_from_coeff, face_enhancer = self.sadtalker_model.kp_extractor, self.sadtalker_model.generator, self.sadtalker_model.mapping, self.sadtalker_model.he_estimator, self.sadtalker_model.audio_to_coeff, self.sadtalker_model.animate_from_coeff, self.sadtalker_model.face_enhancer
72
+ with torch.no_grad():
73
+ kp_source = kp_extractor(batch['source_image'])
74
+ if self.still_mode or self.use_idle_mode: pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], batch['ref_pose_coeff']); expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], batch['ref_expression_coeff'])
75
+ elif self.use_idle_mode: pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], batch['ref_pose_coeff']); expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'], batch['ref_expression_coeff'])
76
+ else:
77
+ if self.use_ref_video: kp_ref = kp_extractor(batch['source_image']); pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'], kp_ref=kp_ref, use_ref_info=batch['ref_info'])
78
+ else: pose_coeff = audio_to_coeff.get_pose_coeff(batch['audio'])
79
+ expression_coeff = audio_to_coeff.get_exp_coeff(batch['audio'])
80
+ coeff = {'pose_coeff': pose_coeff, 'expression_coeff': expression_coeff}
81
+ if self.use_blink: coeff['blink_coeff'] = audio_to_coeff.get_blink_coeff(batch['audio'])
82
+ else: coeff['blink_coeff'] = None
83
+ kp_driving = audio_to_coeff(batch['audio'])[0]; kp_norm = animate_from_coeff.normalize_kp(kp_driving); coeff['kp_driving'] = kp_norm; coeff['jacobian'] = [torch.eye(2).unsqueeze(0).unsqueeze(0).to(self.device)] * 4
84
+ output_video = animate_from_coeff.generate(batch['source_image'], kp_source, coeff, generator, mapping, he_estimator, batch['audio'], batch['source_image_crop'], face_enhancer=face_enhancer)
85
+ return output_video
86
+
87
+ def post_processing(self, output_video, audio_sample_rate, batch):
88
+ proc = self.sadtalker_model.preprocesser; base_name = os.path.splitext(os.path.basename(batch['source_image_crop']))[0]; audio_name = os.path.splitext(os.path.basename(self.driven_audio))[0]
89
+ output_video_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '.mp4'); self.output_path = output_video_path
90
+ video_fps = self.sadtalker_model.cfg['MODEL']['VIDEO_FPS'] if self.sadtalker_model.cfg['MODEL']['OUTPUT_VIDEO_FPS'] is None else self.sadtalker_model.cfg['MODEL']['OUTPUT_VIDEO_FPS']
91
+ audio_output_sample_rate = self.sadtalker_model.cfg['MODEL']['DRIVEN_AUDIO_SAMPLE_RATE'] if self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE'] is None else self.sadtalker_model.cfg['MODEL']['OUTPUT_AUDIO_SAMPLE_RATE']
92
+ if self.use_enhancer: enhanced_path = os.path.join(self.result_dir, base_name + '_' + audio_name + '_enhanced.mp4'); save_video_with_watermark(output_video, self.driven_audio, enhanced_path); paste_pic(enhanced_path, batch['source_image_crop'], batch['crop_info'], self.driven_audio, output_video_path); os.remove(enhanced_path)
93
+ else: save_video_with_watermark(output_video, self.driven_audio, output_video_path)
94
+ if self.tts_text is not None: shutil.rmtree(os.path.dirname(self.driven_audio))
95
+
96
+ def save_result(self):
97
+ return self.output_path
98
+
99
+ def __call__(self):
100
+ return self.output_path
101
+
102
+ def test(self):
103
+ batch, audio_sample_rate = self.get_test_data(); output_video = self.run_inference(batch); self.post_processing(output_video, audio_sample_rate, batch); return self.save_result()
104
+
105
+ class SadTalkerInnerModel():
106
+ def __init__(self, sadtalker_cfg, device_id=[0]):
107
+ self.cfg = sadtalker_cfg; self.device = sadtalker_cfg['MODEL'].get('DEVICE', 'cpu')
108
+ self.sadtalker = SadTalkerInnerModel(sadtalker_cfg, device_id)
109
+ self.preprocesser = Preprocesser(sadtalker_cfg, self.device); self.kp_extractor = KeyPointExtractor(sadtalker_cfg, self.device)
110
+ self.audio_to_coeff = Audio2Coeff(sadtalker_cfg, self.device); self.animate_from_coeff = AnimateFromCoeff(sadtalker_cfg, self.device)
111
+ self.face_enhancer = FaceEnhancer(sadtalker_cfg, self.device) if sadtalker_cfg['MODEL']['USE_ENHANCER'] else None
112
+ self.generator = Generator(sadtalker_cfg, self.device); self.mapping = Mapping(sadtalker_cfg, self.device); self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, self.device)
113
+
114
+ class Preprocesser():
115
+ def __init__(self, sadtalker_cfg, device):
116
+ self.cfg = sadtalker_cfg; self.device = device
117
+ self.face3d_helper = Face3DHelper(self.cfg['INPUT_IMAGE'].get('LOCAL_PCA_PATH', ''), device); self.mouth_detector = MouthDetector()
118
+
119
+ def crop(self, source_image_pil, preprocess_type, size=256):
120
+ source_image = np.array(source_image_pil); face_info = self.face3d_helper.run(source_image)
121
+ if face_info is None: raise Exception("No face detected")
122
+ x_min, y_min, x_max, y_max = face_info[:4]; old_size = (x_max - x_min, y_max - y_min); x_center = (x_max + x_min) / 2; y_center = (y_max + y_min) / 2
123
+ if preprocess_type == 'crop': face_size = max(x_max - x_min, y_max - y_min); x_min = int(x_center - face_size / 2); y_min = int(y_center - face_size / 2); x_max = int(x_center + face_size / 2); y_max = int(y_center + face_size / 2)
124
+ else: x_min -= int((x_max - x_min) * 0.1); y_min -= int((y_max - y_min) * 0.1); x_max += int((x_max - x_min) * 0.1); y_max += int((y_max - y_min) * 0.1)
125
+ h, w = source_image.shape[:2]; x_min = max(0, x_min); y_min = max(0, y_min); x_max = min(w, x_max); y_max = min(h, y_max)
126
+ cropped_image = source_image[y_min:y_max, x_min:x_max]; cropped_image_pil = Image.fromarray(cropped_image)
127
+ if size is not None and size != 0: cropped_image_pil = cropped_image_pil.resize((size, size), Image.Resampling.LANCZOS)
128
+ source_image_tensor = self.img2tensor(cropped_image_pil); return source_image_tensor, [[y_min, y_max], [x_min, x_max], old_size, cropped_image_pil.size], os.path.basename(self.cfg['INPUT_IMAGE'].get('SOURCE_IMAGE', ''))
129
+
130
+ def img2tensor(self, img):
131
+ img = np.array(img).astype(np.float32) / 255.0; img = np.transpose(img, (2, 0, 1)); return torch.FloatTensor(img)
132
+ def video_to_tensor(self, video, device): return 0
133
+ def process_audio(self, audio_path, sample_rate): wav = load_wav_util(audio_path, sample_rate); wav_tensor = torch.FloatTensor(wav).unsqueeze(0); return wav_tensor, sample_rate
134
+ def generate_still_pose(self, pose_style): ref_pose_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device); ref_pose_coeff[:, :3] = torch.tensor([0, 0, pose_style * 0.3], dtype=torch.float32); return ref_pose_coeff
135
+ def generate_still_expression(self, exp_scale): ref_expression_coeff = torch.zeros((1, 64), dtype=torch.float32).to(self.device); ref_expression_coeff[:, :3] = torch.tensor([0, 0, exp_scale * 0.3], dtype=torch.float32); return ref_expression_coeff
136
+ def generate_idles_pose(self, length_of_audio, pose_style): return 0
137
+ def generate_idles_expression(self, length_of_audio): return 0
138
+
139
+ class KeyPointExtractor(nn.Module):
140
+ def __init__(self, sadtalker_cfg, device):
141
+ super(KeyPointExtractor, self).__init__(); self.kp_extractor = OcclusionAwareKPDetector(kp_channels=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'], num_kp=10, num_dilation_blocks=2, dropout_rate=0.1).to(device)
142
+ checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'kp_detector.safetensors'); load_state_dict_robust(self.kp_extractor, checkpoint_path, device, model_name='kp_detector')
143
+ def forward(self, x): kp = self.kp_extractor(x); return kp
144
+
145
+ class Audio2Coeff(nn.Module):
146
+ def __init__(self, sadtalker_cfg, device):
147
+ super(Audio2Coeff, self).__init__(); self.audio_model = Wav2Vec2Model().to(device)
148
+ checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'wav2vec2.pth'); load_state_dict_robust(self.audio_model, checkpoint_path, device, model_name='wav2vec2')
149
+ self.pose_mapper = AudioCoeffsPredictor(2048, 64).to(device); self.exp_mapper = AudioCoeffsPredictor(2048, 64).to(device); self.blink_mapper = AudioCoeffsPredictor(2048, 1).to(device)
150
+ mapping_checkpoint = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'auido2pose_00140-model.pth'); load_state_dict_robust(self, mapping_checkpoint, device)
151
+ def get_pose_coeff(self, audio_tensor, ref_pose_coeff=None, kp_ref=None, use_ref_info=''): audio_embedding = self.audio_model(audio_tensor); pose_coeff = self.pose_mapper(audio_embedding)
152
+ if ref_pose_coeff is not None: pose_coeff = ref_pose_coeff
153
+ if kp_ref is not None and use_ref_info == 'pose': ref_pose_6d = kp_ref['value'][:, :6]; pose_coeff[:, :6] = self.mean_std_normalize(ref_pose_6d).mean(dim=1)
154
+ return pose_coeff
155
+ def get_exp_coeff(self, audio_tensor, ref_expression_coeff=None): audio_embedding = self.audio_model(audio_tensor); expression_coeff = self.exp_mapper(audio_embedding)
156
+ if ref_expression_coeff is not None: expression_coeff = ref_expression_coeff; return expression_coeff
157
+ def get_blink_coeff(self, audio_tensor): audio_embedding = self.audio_model(audio_tensor); blink_coeff = self.blink_mapper(audio_embedding); return blink_coeff
158
+ def forward(self, audio): audio_embedding = self.audio_model(audio); pose_coeff, expression_coeff, blink_coeff = self.pose_mapper(audio_embedding), self.exp_mapper(audio_embedding), self.blink_mapper(audio_embedding); return pose_coeff, expression_coeff, blink_coeff
159
+ def mean_std_normalize(self, coeff): mean = coeff.mean(dim=1, keepdim=True); std = coeff.std(dim=1, keepdim=True); return (coeff - mean) / std
160
+
161
+ class AnimateFromCoeff(nn.Module):
162
+ def __init__(self, sadtalker_cfg, device):
163
+ super(AnimateFromCoeff, self).__init__(); self.generator = Generator(sadtalker_cfg, device); self.mapping = Mapping(sadtalker_cfg, device); self.kp_norm = KeypointNorm(device=device); self.he_estimator = OcclusionAwareDenseMotion(sadtalker_cfg, device)
164
+ def normalize_kp(self, kp_driving): return self.kp_norm(kp_driving)
165
+ def generate(self, source_image, kp_source, coeff, generator, mapping, he_estimator, audio, source_image_crop, face_enhancer=None):
166
+ kp_driving, jacobian, pose_coeff, expression_coeff, blink_coeff = coeff['kp_driving'], coeff['jacobian'], coeff['pose_coeff'], coeff['expression_coeff'], coeff['blink_coeff']
167
+ face_3d = mapping(expression_coeff, pose_coeff, blink_coeff) if blink_coeff is not None else mapping(expression_coeff, pose_coeff); sparse_motion = he_estimator(kp_source, kp_driving, jacobian)
168
+ dense_motion = sparse_motion['dense_motion']; video_deocclusion = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None})
169
+ video_3d = generator(source_image, dense_motion, bg_param={'mask': None, 'color': None}, face_3d_param=face_3d); video_output = video_deocclusion['video_no_reocclusion'] + video_3d['video_3d']
170
+ if face_enhancer is not None: video_output_enhanced = []; for frame in tqdm(video_output, 'Face enhancer running'): pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)); enhanced_image = face_enhancer.forward(np.array(pil_image)); video_output_enhanced.append(cv2.cvtColor(enhanced_image, cv2.COLOR_BGR2RGB)); video_output = video_output_enhanced
171
+ return video_output
172
+ def make_animation(self, video_array): H, W, _ = video_array[0].shape; out = cv2.VideoWriter('./tmp.mp4', cv2.VideoWriter_fourcc(*'mp4v'), 25, (W, H)); for img in video_array: out.write(cv2.cvtColor(img, cv2.COLOR_RGB2BGR)); out.release(); video = imageio.mimread('./tmp.mp4'); os.remove('./tmp.mp4'); return video
173
+
174
+ class Generator(nn.Module):
175
+ def __init__(self, sadtalker_cfg, device):
176
+ super(Generator, self).__init__(); self.generator = Hourglass(block_expansion=sadtalker_cfg['MODEL']['SCALE'], num_blocks=sadtalker_cfg['MODEL']['NUM_VOXEL_FRAMES'], max_features=sadtalker_cfg['MODEL']['MAX_FEATURES'], num_channels=3, kp_size=10, num_deform_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES']).to(device)
177
+ checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'generator.pth'); load_state_dict_robust(self.generator, checkpoint_path, device, model_name='generator')
178
+ def forward(self, source_image, dense_motion, bg_param, face_3d_param=None): video_3d = self.generator(source_image, kp_driving=dense_motion, bg_param=bg_param, face_3d_param=face_3d_param); return {'video_3d': video_3d, 'video_no_reocclusion': video_3d}
179
+
180
+ class Mapping(nn.Module):
181
+ def __init__(self, sadtalker_cfg, device):
182
+ super(Mapping, self).__init__(); self.mapping_net = MappingNet(num_coeffs=64, num_layers=3, hidden_dim=128).to(device)
183
+ checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'mapping.pth'); load_state_dict_robust(self.mapping_net, checkpoint_path, device, model_name='mapping')
184
+ self.f_3d_mean = torch.zeros(1, 64, device=device)
185
+ def forward(self, expression_coeff, pose_coeff, blink_coeff=None): coeff = torch.cat([expression_coeff, pose_coeff], dim=1); face_3d = self.mapping_net(coeff) + self.f_3d_mean; if blink_coeff is not None: face_3d[:, -1:] = blink_coeff; return face_3d
186
+
187
+ class OcclusionAwareDenseMotion(nn.Module):
188
+ def __init__(self, sadtalker_cfg, device):
189
+ super(OcclusionAwareDenseMotion, self).__init__(); self.dense_motion_network = DenseMotionNetwork(num_kp=10, num_channels=3, block_expansion=sadtalker_cfg['MODEL']['SCALE'], num_blocks=sadtalker_cfg['MODEL']['NUM_MOTION_FRAMES'] - 1, max_features=sadtalker_cfg['MODEL']['MAX_FEATURES']).to(device)
190
+ checkpoint_path = os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'dense_motion.pth'); load_state_dict_robust(self.dense_motion_network, checkpoint_path, device, model_name='dense_motion')
191
+ def forward(self, kp_source, kp_driving, jacobian): sparse_motion = self.dense_motion_network(kp_source, kp_driving, jacobian); return sparse_motion
192
+
193
+ class FaceEnhancer(nn.Module):
194
+ def __init__(self, sadtalker_cfg, device):
195
+ super(FaceEnhancer, self).__init__(); enhancer_name = sadtalker_cfg['MODEL']['ENHANCER_NAME']; bg_upsampler = sadtalker_cfg['MODEL']['BG_UPSAMPLER']
196
+ if enhancer_name == 'gfpgan': from gfpgan import GFPGANer; self.face_enhancer = GFPGANer(model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'GFPGANv1.4.pth'), upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=bg_upsampler)
197
+ elif enhancer_name == 'realesrgan': from realesrgan import RealESRGANer; half = False if device == 'cpu' else sadtalker_cfg['MODEL']['IS_HALF']; self.face_enhancer = RealESRGANer(scale=2, model_path=os.path.join(sadtalker_cfg['MODEL']['CHECKPOINTS_DIR'], 'RealESRGAN_x2plus.pth'), tile=0, tile_pad=10, pre_pad=0, half=half, device=device)
198
+ else: self.face_enhancer = None
199
+ def forward(self, x): return self.face_enhancer.enhance(x, outscale=1)[0] if self.face_enhancer else x
200
+
201
+ def download_model(url, filename, checkpoint_dir):
202
+ if not os.path.exists(os.path.join(checkpoint_dir, filename)): print(f"Downloading {filename}..."); os.makedirs(checkpoint_dir, exist_ok=True); urllib.request.urlretrieve(url, os.path.join(checkpoint_dir, filename)); print(f"{filename} downloaded.")
203
+ else: print(f"{filename} already exists.")
204
+
205
+ def load_models():
206
+ checkpoint_path = './checkpoints'; config_path = './src/config'; size = 256; preprocess = 'crop'; old_version = False
207
+ sadtalker_instance = SadTalker(checkpoint_path, config_path, size, preprocess, old_version); print("SadTalker models loaded successfully!"); return sadtalker_instance
208
+
209
+ if __name__ == '__main__': sadtalker_instance = load_models()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
sentiment_api.py CHANGED
@@ -3,27 +3,14 @@ from main import *
3
  import torch
4
 
5
  def analyze_sentiment(text):
6
- if sentiment_model is None:
7
- return {"error": "Sentiment model not initialized."}
8
-
9
- features = [ord(c) for c in text[:10]]
10
- while len(features) < 10:
11
- features.append(0)
12
  features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(device)
13
-
14
- with torch.no_grad():
15
- output = sentiment_model(features_tensor)
16
- sentiment_idx = torch.argmax(output, dim=1).item()
17
- sentiment_label = "positive" if sentiment_idx == 1 else "negative"
18
-
19
  return {"sentiment": sentiment_label}
20
 
21
- def sentiment_api():
22
- data = request.get_json()
23
- text = data.get('text')
24
- if not text:
25
- return jsonify({"error": "Text is required"}), 400
26
  output = analyze_sentiment(text)
27
- if "error" in output:
28
- return jsonify({"error": output["error"]}), 500
29
- return jsonify(output)
 
3
  import torch
4
 
5
  def analyze_sentiment(text):
6
+ if sentiment_model is None: return {"error": "Sentiment model not initialized."}
7
+ features = [ord(c) for c in text[:10]];
8
+ while len(features) < 10: features.append(0)
 
 
 
9
  features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(device)
10
+ with torch.no_grad(): output = sentiment_model(features_tensor); sentiment_idx = torch.argmax(output, dim=1).item(); sentiment_label = "positive" if sentiment_idx == 1 else "negative"
 
 
 
 
 
11
  return {"sentiment": sentiment_label}
12
 
13
+ def sentiment_api(text):
 
 
 
 
14
  output = analyze_sentiment(text)
15
+ if "error" in output: return {"error": output["error"]}
16
+ return output
 
stt_api.py CHANGED
@@ -1,33 +1,17 @@
1
- import os
2
- import uuid
3
  from flask import jsonify, send_file, request
4
  from main import *
5
- import torch
6
- import torchaudio
7
 
8
  def speech_to_text_func(audio_path):
9
- if stt_model is None:
10
- return {"error": "STT model not initialized."}
11
-
12
- waveform, sample_rate = torchaudio.load(audio_path)
13
- if waveform.ndim > 1:
14
- waveform = torch.mean(waveform, dim=0, keepdim=True)
15
  waveform = waveform.to(device)
16
- with torch.no_grad():
17
- logits = stt_model(waveform)
18
- predicted_ids = torch.argmax(logits, dim=-1)
19
- transcription = stt_model.tokenizer.decode(predicted_ids[0].cpu().tolist())
20
-
21
- return {"text": transcription}
22
 
23
- def stt_api():
24
- if 'audio' not in request.files:
25
- return jsonify({"error": "Audio file is required"}), 400
26
- audio_file = request.files['audio']
27
- temp_audio_path = f"temp_audio_{uuid.uuid4()}.wav"
28
- audio_file.save(temp_audio_path)
29
- output = speech_to_text_func(temp_audio_path)
30
- os.remove(temp_audio_path)
31
- if "error" in output:
32
- return jsonify({"error": output["error"]}), 500
33
- return jsonify(output)
 
1
+ import os, uuid
 
2
  from flask import jsonify, send_file, request
3
  from main import *
4
+ import torch, torchaudio
 
5
 
6
  def speech_to_text_func(audio_path):
7
+ if stt_model is None: return {"error": "STT model not initialized."}
8
+ waveform, sample_rate = torchaudio.load(audio_path);
9
+ if waveform.ndim > 1: waveform = torch.mean(waveform, dim=0, keepdim=True)
 
 
 
10
  waveform = waveform.to(device)
11
+ with torch.no_grad(): logits = stt_model(waveform)
12
+ predicted_ids = torch.argmax(logits, dim=-1); transcription = stt_model.tokenizer.decode(predicted_ids[0].cpu().tolist()); return {"text": transcription}
 
 
 
 
13
 
14
+ def stt_api(audio_filepath):
15
+ output = speech_to_text_func(audio_filepath)
16
+ if "error" in output: return {"error": output["error"]}
17
+ return output
 
 
 
 
 
 
 
summarization_api.py CHANGED
@@ -1,27 +1,14 @@
1
  from flask import jsonify, send_file, request
2
  from main import *
3
- import torch
4
 
5
  def summarize_text(text, output_path="output_summary.txt"):
6
- if summarization_model is None or summarization_tokenizer is None:
7
- return "Summarization model or tokenizer not initialized."
8
-
9
  input_ids = summarization_tokenizer.encode(text, return_tensors="pt").to(device)
 
 
10
 
11
- with torch.no_grad():
12
- summary_ids = summarization_model.generate(input_ids, num_beams=4, max_length=100, early_stopping=True)
13
- summary_text = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
14
-
15
- with open(output_path, "w") as file:
16
- file.write(summary_text)
17
- return output_path
18
-
19
- def summarization_api():
20
- data = request.get_json()
21
- text = data.get('text')
22
- if not text:
23
- return jsonify({"error": "Text is required"}), 400
24
- output_file = summarize_text(text)
25
- if output_file == "Summarization model or tokenizer not initialized.":
26
- return jsonify({"error": "Summarization failed"}), 500
27
- return send_file(output_file, mimetype="text/plain", as_attachment=True, download_name="output_summary.txt")
 
1
  from flask import jsonify, send_file, request
2
  from main import *
3
+ import torch, io, base64
4
 
5
  def summarize_text(text, output_path="output_summary.txt"):
6
+ if summarization_model is None or summarization_tokenizer is None: return {"error": "Summarization model or tokenizer not initialized."}
 
 
7
  input_ids = summarization_tokenizer.encode(text, return_tensors="pt").to(device)
8
+ with torch.no_grad(): summary_ids = summarization_model.generate(input_ids, num_beams=4, max_length=100, early_stopping=True); summary_text = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
9
+ return {"summary_text": summary_text}
10
 
11
+ def summarization_api(text):
12
+ output = summarize_text(text)
13
+ if "error" in output: return {"error": output["error"]}
14
+ return output
 
 
 
 
 
 
 
 
 
 
 
 
 
text_generation.py CHANGED
@@ -1,194 +1,93 @@
1
- import torch
2
- import torch.nn.functional as F
3
- from tqdm import trange
4
- import time
5
- from tokenxxx import *
6
- from main import *
7
- from duckduckgo_search import DDGS
8
-
9
- try:
10
- END_OF_TEXT_TOKEN
11
- except NameError:
12
- END_OF_TEXT_TOKEN = ""
13
- try:
14
- SYSTEM_PROMPT
15
- except NameError:
16
- SYSTEM_PROMPT = "Sistema: Proporcione respuestas ultra rápidas, coherentes, similares, precisas y con sentido, con razonamiento lógico y profundo."
17
- try:
18
- MAX_XDD
19
- except NameError:
20
- MAX_XDD = 5
21
- try:
22
- codegen_model
23
- except NameError:
24
- codegen_model = None
25
- try:
26
- codegen_tokenizer
27
- except NameError:
28
- codegen_tokenizer = None
29
- try:
30
- summarization_model
31
- except NameError:
32
- summarization_model = None
33
- try:
34
- summarization_tokenizer
35
- except NameError:
36
- summarization_tokenizer = None
37
- try:
38
- model_gpt2
39
- except NameError:
40
- model_gpt2 = None
41
- try:
42
- enc
43
- except NameError:
44
- enc = None
45
- try:
46
- device
47
- except NameError:
48
- device = "cpu"
49
-
50
- if torch.device(device).type == "cuda":
51
- torch.backends.cudnn.benchmark = True
52
-
53
- MAX_GENERATION_LENGTH = 512
54
-
55
- def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
56
- top_k = min(top_k, logits.size(-1))
57
- if top_k > 0:
58
- indices_to_remove = logits < torch.topk(logits, top_k)[0][..., [-1]]
59
- logits[indices_to_remove] = filter_value
60
- if top_p > 0.0:
61
- sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
62
- cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
63
- sorted_indices_to_remove = cumulative_probs > top_p
64
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
65
- sorted_indices_to_remove[..., 0] = 0
66
- indices_to_remove = sorted_indices[sorted_indices_to_remove]
67
- logits[indices_to_remove] = filter_value
68
- return logits
69
-
70
- def _generate_sequence(model_call, context_tensor, generated, decode_fn, end_token_condition, temperature, top_k, top_p, repetition_penalty, max_length):
71
- past_key_values = None
72
- last_token = None
73
- repetition_count = 0
74
- for _ in range(max_length):
75
- try:
76
- outputs = model_call(context_tensor, past_key_values)
77
- except Exception as e:
78
- yield "<ERROR:" + str(e) + ">"
79
- yield "<END_STREAM>"
80
- return
81
- next_token_logits = outputs[0][:, -1, :] / temperature
82
- past_key_values = outputs[1]
83
- for token_index in set(generated):
84
- next_token_logits[0, token_index] /= repetition_penalty
85
- filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
86
- if temperature == 0:
87
- next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(0)
88
- else:
89
- next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
90
- token_id = next_token.tolist()[0][0]
91
- if token_id == last_token:
92
- repetition_count += 1
93
- else:
94
- repetition_count = 0
95
- last_token = token_id
96
- if repetition_count >= 10:
97
- yield "<END_STREAM>"
98
- return
99
- generated.append(token_id)
100
- token_decoded = decode_fn(token_id)
101
- yield token_decoded
102
- if end_token_condition(token_id):
103
- yield "<END_STREAM>"
104
- return
105
-
106
- def sample_sequence(prompt, model, enc, max_length=MAX_GENERATION_LENGTH, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
107
- context_tokens = enc.encode(prompt)
108
- context_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device)
109
- return _generate_sequence(
110
- lambda ct, past: model(ct, past_key_values=past),
111
- context_tensor,
112
- list(context_tokens),
113
- lambda token: enc.decode([token]),
114
- lambda token: token == enc.encoder[END_OF_TEXT_TOKEN],
115
- temperature, top_k, top_p, repetition_penalty, max_length
116
- )
117
-
118
- def sample_sequence_codegen(prompt, model, tokenizer, max_length=MAX_GENERATION_LENGTH, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
119
- context_tokens = tokenizer.encode(prompt)
120
- context_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device)
121
- return _generate_sequence(
122
- lambda ct, past: model(input_ids=ct, past_key_values=past, labels=None),
123
- context_tensor,
124
- list(context_tokens),
125
- lambda token: tokenizer.decode([token]),
126
- lambda token: token == 50256,
127
- temperature, top_k, top_p, repetition_penalty, max_length
128
- )
129
-
130
- def summarize_text(text):
131
- if summarization_model and summarization_tokenizer:
132
- input_ids = summarization_tokenizer.encode(text, return_tensors="pt", truncation=True, max_length=1024).to(device)
133
- summary_ids = summarization_model.generate(
134
- input_ids,
135
- max_length=150,
136
- min_length=40,
137
- length_penalty=2.0,
138
- num_beams=4,
139
- early_stopping=True
140
- )
141
- return summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
142
- return text[:300] + "..." if len(text) > 300 else text
143
-
144
- def perform_reasoning_stream(text_input, temperature, top_k, top_p, repetition_penalty, prev_context=""):
145
- initial_prompt = SYSTEM_PROMPT + "\n\nUser: " + text_input + "\nAssistant:"
146
- reasoning_prompt = prev_context if prev_context else initial_prompt
147
- ddgs = DDGS()
148
- search_results = [r for r in ddgs.text(text_input, max_results=MAX_XDD)]
149
- if search_results:
150
- reasoning_prompt += "\nWeb Search Results:\n"
151
- for result in search_results:
152
- reasoning_prompt += "- " + result['body'] + "\n"
153
- reasoning_prompt += "\n"
154
- if "code" in text_input.lower() or "program" in text_input.lower():
155
- model_type = "code"
156
- elif "summarize" in text_input.lower() or "summary" in text_input.lower():
157
- model_type = "summarize"
158
- elif model_gpt2 and enc:
159
- model_type = "gpt2"
160
- else:
161
- yield "<ERROR: No se encontró un modelo adecuado>"
162
- yield "<END_STREAM>"
163
- return
164
- if model_type == "summarize":
165
- if summarization_model:
166
- summary = summarize_text(text_input)
167
- yield "SUMMARY_TEXT:" + summary
168
- yield "<END_STREAM>"
169
- return
170
- accumulated_text = ""
171
- current_context = reasoning_prompt
172
- overlap = 256
173
- while True:
174
- if model_type == "code":
175
- generator = sample_sequence_codegen(current_context, codegen_model, codegen_tokenizer, MAX_GENERATION_LENGTH, temperature, top_k, top_p, repetition_penalty, device)
176
- elif model_type == "gpt2":
177
- generator = sample_sequence(current_context, model_gpt2, enc, MAX_GENERATION_LENGTH, temperature, top_k, top_p, repetition_penalty, device)
178
- chunk_text = ""
179
- finished = False
180
- for token in generator:
181
- if token == "<END_STREAM>":
182
- finished = True
183
- break
184
- chunk_text += token
185
- if accumulated_text:
186
- overlap_text = accumulated_text[-overlap:]
187
- if chunk_text.startswith(overlap_text):
188
- chunk_text = chunk_text[len(overlap_text):]
189
- accumulated_text += chunk_text
190
- yield chunk_text
191
- if finished:
192
- yield "<END_STREAM>"
193
- break
194
- current_context = accumulated_text[-overlap:] if len(accumulated_text) > overlap else accumulated_text
 
1
+ import torch, torch.nn.functional as F
2
+ from tqdm import trange
3
+ import time
4
+ from tokenxxx import *
5
+ from main import *
6
+ from duckduckgo_search import DDGS
7
+
8
+ try: END_OF_TEXT_TOKEN
9
+ except NameError: END_OF_TEXT_TOKEN = ""
10
+ try: SYSTEM_PROMPT
11
+ except NameError: SYSTEM_PROMPT = "Sistema: Proporcione respuestas ultra rápidas, coherentes, similares, precisas y con sentido, con razonamiento lógico y profundo."
12
+ try: MAX_XDD
13
+ except NameError: MAX_XDD = 5
14
+ try: codegen_model
15
+ except NameError: codegen_model = None
16
+ try: codegen_tokenizer
17
+ except NameError: codegen_tokenizer = None
18
+ try: summarization_model
19
+ except NameError: summarization_model = None
20
+ try: summarization_tokenizer
21
+ except NameError: summarization_tokenizer = None
22
+ try: model_gpt2
23
+ except NameError: model_gpt2 = None
24
+ try: enc
25
+ except NameError: enc = None
26
+ try: device
27
+ except NameError: device = "cpu"
28
+
29
+ if torch.device(device).type == "cuda": torch.backends.cudnn.benchmark = True
30
+ MAX_GENERATION_LENGTH = 512
31
+
32
+ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
33
+ top_k = min(top_k, logits.size(-1));
34
+ if top_k > 0: indices_to_remove = logits < torch.topk(logits, top_k)[0][..., [-1]]; logits[indices_to_remove] = filter_value
35
+ if top_p > 0.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1); cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
36
+ sorted_indices_to_remove = cumulative_probs > top_p; sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone(); sorted_indices_to_remove[..., 0] = 0; indices_to_remove = sorted_indices[sorted_indices_to_remove]; logits[indices_to_remove] = filter_value; return logits
37
+
38
+ def _generate_sequence(model_call, context_tensor, generated, decode_fn, end_token_condition, temperature, top_k, top_p, repetition_penalty, max_length):
39
+ past_key_values = None; last_token = None; repetition_count = 0
40
+ for _ in range(max_length):
41
+ try: outputs = model_call(context_tensor, past_key_values)
42
+ except Exception as e: yield "<ERROR:" + str(e) + ">"; yield "<END_STREAM>"; return
43
+ next_token_logits = outputs[0][:, -1, :] / temperature; past_key_values = outputs[1]
44
+ for token_index in set(generated): next_token_logits[0, token_index] /= repetition_penalty
45
+ filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
46
+ if temperature == 0: next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(0)
47
+ else: next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
48
+ token_id = next_token.tolist()[0][0]
49
+ if token_id == last_token: repetition_count += 1
50
+ else: repetition_count = 0; last_token = token_id
51
+ if repetition_count >= 10: yield "<END_STREAM>"; return
52
+ generated.append(token_id); token_decoded = decode_fn(token_id); yield token_decoded
53
+ if end_token_condition(token_id): yield "<END_STREAM>"; return
54
+
55
+ def sample_sequence(prompt, model, enc, max_length=MAX_GENERATION_LENGTH, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
56
+ context_tokens = enc.encode(prompt); context_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device)
57
+ return _generate_sequence(lambda ct, past: model(ct, past_key_values=past), context_tensor, list(context_tokens), lambda token: enc.decode([token]), lambda token: token == enc.encoder[END_OF_TEXT_TOKEN], temperature, top_k, top_p, repetition_penalty, max_length)
58
+
59
+ def sample_sequence_codegen(prompt, model, tokenizer, max_length=MAX_GENERATION_LENGTH, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device="cpu"):
60
+ context_tokens = tokenizer.encode(prompt); context_tensor = torch.tensor([context_tokens], dtype=torch.long, device=device)
61
+ return _generate_sequence(lambda ct, past: model(input_ids=ct, past_key_values=past, labels=None), context_tensor, list(context_tokens), lambda token: tokenizer.decode([token]), lambda token: token == 50256, temperature, top_k, top_p, repetition_penalty, max_length)
62
+
63
+ def summarize_text(text):
64
+ if summarization_model and summarization_tokenizer:
65
+ input_ids = summarization_tokenizer.encode(text, return_tensors="pt", truncation=True, max_length=1024).to(device); summary_ids = summarization_model.generate(input_ids, max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
66
+ return summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
67
+ return text[:300] + "..." if len(text) > 300 else text
68
+
69
+ def perform_reasoning_stream(text_input, temperature, top_k, top_p, repetition_penalty, prev_context=""):
70
+ initial_prompt = SYSTEM_PROMPT + "\n\nUser: " + text_input + "\nAssistant:"; reasoning_prompt = prev_context if prev_context else initial_prompt; ddgs = DDGS()
71
+ search_results = [r for r in ddgs.text(text_input, max_results=MAX_XDD)];
72
+ if search_results: reasoning_prompt += "\nWeb Search Results:\n";
73
+ for result in search_results: reasoning_prompt += "- " + result['body'] + "\n"
74
+ reasoning_prompt += "\n"
75
+ if "code" in text_input.lower() or "program" in text_input.lower(): model_type = "code"
76
+ elif "summarize" in text_input.lower() or "summary" in text_input.lower(): model_type = "summarize"
77
+ elif model_gpt2 and enc: model_type = "gpt2"
78
+ else: yield "<ERROR: No se encontró un modelo adecuado>"; yield "<END_STREAM>"; return
79
+ if model_type == "summarize":
80
+ if summarization_model: summary = summarize_text(text_input); yield "SUMMARY_TEXT:" + summary; yield "<END_STREAM>"; return
81
+ accumulated_text = ""; current_context = reasoning_prompt; overlap = 256
82
+ while True:
83
+ if model_type == "code": generator = sample_sequence_codegen(current_context, codegen_model, codegen_tokenizer, MAX_GENERATION_LENGTH, temperature, top_k, top_p, repetition_penalty, device)
84
+ elif model_type == "gpt2": generator = sample_sequence(current_context, model_gpt2, enc, MAX_GENERATION_LENGTH, temperature, top_k, top_p, repetition_penalty, device)
85
+ chunk_text = ""; finished = False
86
+ for token in generator:
87
+ if token == "<END_STREAM>": finished = True; break
88
+ chunk_text += token
89
+ if accumulated_text: overlap_text = accumulated_text[-overlap:];
90
+ if chunk_text.startswith(overlap_text): chunk_text = chunk_text[len(overlap_text):]
91
+ accumulated_text += chunk_text; yield chunk_text
92
+ if finished: yield "<END_STREAM>"; break
93
+ current_context = accumulated_text[-overlap:] if len(accumulated_text) > overlap else accumulated_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
text_to_video_api.py CHANGED
@@ -1,36 +1,24 @@
1
- import os
2
- import uuid
3
  from flask import jsonify, send_file, request
4
  from main import *
5
- import torch
6
- import io
7
  from skimage import img_as_ubyte
8
- import imageio
9
 
10
  def text_to_video_func(prompt, output_path="output_video.mp4"):
11
- if text_to_video_model is None:
12
- return "Text-to-Video model not initialized."
13
  video_frames_list = text_to_video_model(prompt)
14
- if video_frames_list and hasattr(video_frames_list, 'frames'):
15
- video_frames = video_frames_list.frames
16
- export_to_video_pure(video_frames, output_video=output_path)
17
- return output_path
18
- return "Video generation failed."
19
 
20
  def export_to_video_pure(video_frames, output_video="output_video.mp4", fps=25):
21
  writer = imageio.get_writer(output_video, fps=fps)
22
- for frame in video_frames:
23
- writer.append_data(img_as_ubyte(frame))
24
  writer.close()
25
 
26
- def text_to_video_api():
27
- data = request.get_json()
28
- prompt = data.get('prompt')
29
- if not prompt:
30
- return jsonify({"error": "Prompt is required"}), 400
31
- output_file = text_to_video_func(prompt)
32
- if output_file == "Text-to-Video model not initialized." or output_file == "Video generation failed.":
33
- return jsonify({"error": "Text to video failed"}), 500
34
- with open(output_file, 'rb') as f:
35
- video_content = f.read()
36
- return send_file(io.BytesIO(video_content), mimetype='video/mp4', as_attachment=True, download_name="output_video.mp4")
 
1
+ import os, uuid
 
2
  from flask import jsonify, send_file, request
3
  from main import *
4
+ import torch, io
 
5
  from skimage import img_as_ubyte
6
+ import imageio, base64
7
 
8
  def text_to_video_func(prompt, output_path="output_video.mp4"):
9
+ if text_to_video_model is None: return {"error": "Text-to-Video model not initialized."}
 
10
  video_frames_list = text_to_video_model(prompt)
11
+ if video_frames_list and hasattr(video_frames_list, 'frames'): export_to_video_pure(video_frames_list.frames, output_video=output_path); return output_path
12
+ return {"error": "Video generation failed."}
 
 
 
13
 
14
  def export_to_video_pure(video_frames, output_video="output_video.mp4", fps=25):
15
  writer = imageio.get_writer(output_video, fps=fps)
16
+ for frame in video_frames: writer.append_data(img_as_ubyte(frame))
 
17
  writer.close()
18
 
19
+ def text_to_video_api(prompt):
20
+ output_data = text_to_video_func(prompt)
21
+ if "error" in output_data: return {"error": output_data["error"]}
22
+ output_file = output_data;
23
+ with open(output_file, 'rb') as f: video_content = f.read()
24
+ video_base64 = base64.b64encode(video_content).decode('utf-8'); os.remove(output_file); return {"video_base64": video_base64, "mimetype": "video/mp4"}
 
 
 
 
 
tokenxxx.py CHANGED
@@ -1,142 +1,72 @@
1
- import json
2
- import re
3
- import unicodedata
4
  from functools import lru_cache
5
- import wget
6
- import os
7
- from constants import *
8
  import nltk
9
 
10
  @lru_cache()
11
  def bytes_to_unicode():
12
  bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
13
- cs = bs[:]
14
- n = 0
15
- for b in range(2**8):
16
- if b not in bs:
17
- bs.append(b)
18
- cs.append(2**8 + n)
19
- n += 1
20
- cs = [chr(n) for n in cs]
21
- return dict(zip(bs, cs))
22
 
23
  def get_pairs(word):
24
- pairs = set()
25
- prev_char = word[0]
26
- for char in word[1:]:
27
- pairs.add((prev_char, char))
28
- prev_char = char
29
- return pairs
30
 
31
  class Encoder:
32
  def __init__(self, encoder, bpe_merges, errors='replace', tokenize=None):
33
- self.encoder = encoder
34
- self.decoder = {v:k for k,v in self.encoder.items()}
35
- self.errors = errors
36
- self.byte_encoder = bytes_to_unicode()
37
- self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
38
  self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
39
- self.cache = {}
40
- if tokenize is None:
41
- self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\w+| ?[^\s\w]+|\s+(?!\S)|\s+""", re.UNICODE)
42
- self.tokenize = lambda text: re.findall(self.pat, text)
43
- else:
44
- self.tokenize = tokenize
45
 
46
  def bpe(self, token):
47
- if token in self.cache:
48
- return self.cache[token]
49
- word = tuple(token)
50
- pairs = get_pairs(word)
51
- if not pairs:
52
- return token
53
  while True:
54
  bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
55
- if bigram not in self.bpe_ranks:
56
- break
57
- first, second = bigram
58
- new_word = []
59
- i = 0
60
  while i < len(word):
61
- try:
62
- j = word.index(first, i)
63
- new_word.extend(word[i:j])
64
- i = j
65
- except ValueError:
66
- new_word.extend(word[i:])
67
- break
68
- if word[i] == first and i < len(word)-1 and word[i+1] == second:
69
- new_word.append(first+second)
70
- i += 2
71
- else:
72
- new_word.append(word[i])
73
- i += 1
74
- new_word = tuple(new_word)
75
- word = new_word
76
- if len(word) == 1:
77
- break
78
- else:
79
- pairs = get_pairs(word)
80
- word = ' '.join(word)
81
- self.cache[token] = word
82
- return word
83
 
84
  def encode(self, text):
85
- bpe_tokens = []
86
- normalized_text = unicodedata.normalize('NFKC', text)
87
- normalized_text = ''.join(c for c in normalized_text if c.isascii() and c != '\t')
88
- normalized_text = ''.join(c for c in normalized_text if not unicodedata.category(c).startswith('C'))
89
- for token in self.tokenize(normalized_text):
90
- token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8', errors='ignore'))
91
- bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
92
  return bpe_tokens
93
 
94
  def decode(self, tokens):
95
- text = ''.join([self.decoder[token] for token in tokens])
96
- text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors='replace')
97
  decoded_text = text.replace(" .", ".").replace(" ,", ",").replace(" '", "'").replace(" ?", "?").replace(" !", "!").replace(" :", ":").replace('\n', '<br>')
98
- sentences = nltk.sent_tokenize(decoded_text)
99
- return ' '.join(sentences).replace("<br>", "<br>\n")
100
 
101
  def get_encoder_gpt2():
102
- encoder_path = os.path.join(GPT2_FOLDER, ENCODER_FILE)
103
- vocab_path = os.path.join(GPT2_FOLDER, VOCAB_FILE)
104
- if not os.path.exists(GPT2_FOLDER):
105
- os.makedirs(GPT2_FOLDER)
106
- if not os.path.exists(encoder_path):
107
- wget.download(ENCODER_URL, out=encoder_path)
108
- if not os.path.exists(vocab_path):
109
- wget.download(VOCAB_URL, out=vocab_path)
110
-
111
- with open(encoder_path, 'r') as f:
112
- encoder = json.load(f)
113
- with open(vocab_path, 'r', encoding="utf-8") as f:
114
- bpe_data = f.read()
115
- bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
116
- encoder_obj = Encoder(encoder=encoder, bpe_merges=bpe_merges)
117
- encoder_obj.encoder[END_OF_TEXT_TOKEN] = len(encoder_obj.encoder)
118
- encoder_obj.decoder[len(encoder_obj.decoder)] = END_OF_TEXT_TOKEN
119
- return encoder_obj
120
 
121
  def get_codegen_tokenizer_pure(vocab_file, merges_file):
122
- vocab = json.load(open(vocab_file))
123
- merges = open(merges_file, 'r', encoding="utf-8").read().split('\n')[1:-1]
124
- bpe_merges = [tuple(m.split()) for m in merges]
125
- byte_encoder = bytes_to_unicode()
126
- byte_decoder = {v: k for k, v in byte_encoder.items()}
127
- tokenizer_regex = re.compile(r'''<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+''')
128
- tokenize = lambda text: re.findall(tokenizer_regex, text)
129
- encoder_obj = Encoder(
130
- encoder=vocab,
131
- bpe_merges=bpe_merges,
132
- byte_encoder=byte_encoder,
133
- byte_decoder=byte_decoder,
134
- tokenize=tokenize
135
- )
136
- return encoder_obj
137
-
138
- def codegen_tokenize(text, tokenizer):
139
- return tokenizer.encode(text)
140
 
141
- def codegen_decode(tokens, tokenizer):
142
- return tokenizer.decode(tokens)
 
1
+ import json, re, unicodedata
 
 
2
  from functools import lru_cache
3
+ import wget, os
4
+ from constants import GPT2_FOLDER, ENCODER_FILE, VOCAB_FILE, END_OF_TEXT_TOKEN
 
5
  import nltk
6
 
7
  @lru_cache()
8
  def bytes_to_unicode():
9
  bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
10
+ cs = bs[:]; n = 0
11
+ for b in range(2**8): if b not in bs: bs.append(b); cs.append(2**8 + n); n += 1
12
+ cs = [chr(n) for n in cs]; return dict(zip(bs, cs))
 
 
 
 
 
 
13
 
14
  def get_pairs(word):
15
+ pairs = set(); prev_char = word[0]
16
+ for char in word[1:]: pairs.add((prev_char, char)); prev_char = char; return pairs
 
 
 
 
17
 
18
  class Encoder:
19
  def __init__(self, encoder, bpe_merges, errors='replace', tokenize=None):
20
+ self.encoder = encoder; self.decoder = {v:k for k,v in self.encoder.items()}; self.errors = errors
21
+ self.byte_encoder = bytes_to_unicode(); self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
 
 
 
22
  self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
23
+ self.cache = {};
24
+ if tokenize is None: self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\w+| ?[^\s\w]+|\s+(?!\S)|\s+""", re.UNICODE); self.tokenize = lambda text: re.findall(self.pat, text)
25
+ else: self.tokenize = tokenize
 
 
 
26
 
27
  def bpe(self, token):
28
+ if token in self.cache: return self.cache[token]
29
+ word = tuple(token); pairs = get_pairs(word)
30
+ if not pairs: return token
 
 
 
31
  while True:
32
  bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
33
+ if bigram not in self.bpe_ranks: break
34
+ first, second = bigram; new_word = []; i = 0
 
 
 
35
  while i < len(word):
36
+ try: j = word.index(first, i); new_word.extend(word[i:j]); i = j
37
+ except ValueError: new_word.extend(word[i:]); break
38
+ if word[i] == first and i < len(word)-1 and word[i+1] == second: new_word.append(first+second); i += 2
39
+ else: new_word.append(word[i]); i += 1
40
+ new_word = tuple(new_word); word = new_word
41
+ if len(word) == 1: break
42
+ else: pairs = get_pairs(word)
43
+ word = ' '.join(word); self.cache[token] = word; return word
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  def encode(self, text):
46
+ bpe_tokens = []; normalized_text = unicodedata.normalize('NFKC', text); normalized_text = ''.join(c for c in normalized_text if c.isascii() and c != '\t'); normalized_text = ''.join(c for c in normalized_text if not unicodedata.category(c).startswith('C'))
47
+ for token in self.tokenize(normalized_text): token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8', errors='ignore')); bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
 
 
 
 
 
48
  return bpe_tokens
49
 
50
  def decode(self, tokens):
51
+ text = ''.join([self.decoder[token] for token in tokens]); text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors='replace')
 
52
  decoded_text = text.replace(" .", ".").replace(" ,", ",").replace(" '", "'").replace(" ?", "?").replace(" !", "!").replace(" :", ":").replace('\n', '<br>')
53
+ sentences = nltk.sent_tokenize(decoded_text); return ' '.join(sentences).replace("<br>", "<br>\n")
 
54
 
55
  def get_encoder_gpt2():
56
+ encoder_path = os.path.join(GPT2_FOLDER, ENCODER_FILE); vocab_path = os.path.join(GPT2_FOLDER, VOCAB_FILE)
57
+ if not os.path.exists(GPT2_FOLDER): os.makedirs(GPT2_FOLDER)
58
+ if not os.path.exists(encoder_path): wget.download(ENCODER_URL, out=encoder_path)
59
+ if not os.path.exists(vocab_path): wget.download(VOCAB_URL, out=vocab_path)
60
+ with open(encoder_path, 'r') as f: encoder = json.load(f)
61
+ with open(vocab_path, 'r', encoding="utf-8") as f: bpe_data = f.read()
62
+ bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]; encoder_obj = Encoder(encoder=encoder, bpe_merges=bpe_merges)
63
+ encoder_obj.encoder[END_OF_TEXT_TOKEN] = len(encoder_obj.encoder); encoder_obj.decoder[len(encoder_obj.decoder)] = END_OF_TEXT_TOKEN; return encoder_obj
 
 
 
 
 
 
 
 
 
 
64
 
65
  def get_codegen_tokenizer_pure(vocab_file, merges_file):
66
+ vocab = json.load(open(vocab_file)); merges = open(merges_file, 'r', encoding="utf-8").read().split('\n')[1:-1]; bpe_merges = [tuple(m.split()) for m in merges]
67
+ byte_encoder = bytes_to_unicode(); byte_decoder = {v: k for k, v in byte_encoder.items()}
68
+ tokenizer_regex = re.compile(r'''<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+'''); tokenize = lambda text: re.findall(tokenizer_regex, text)
69
+ encoder_obj = Encoder(encoder=vocab, bpe_merges=bpe_merges, byte_encoder=byte_encoder, byte_decoder=byte_decoder, tokenize=tokenize); return encoder_obj
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ def codegen_tokenize(text, tokenizer): return tokenizer.encode(text)
72
+ def codegen_decode(tokens, tokenizer): return tokenizer.decode(tokens)
translation_api.py CHANGED
@@ -2,23 +2,14 @@ from flask import jsonify, send_file, request
2
  from main import *
3
 
4
  def perform_translation(text, target_language_code='es_XX', source_language_code='en_XX'):
5
- if translation_model is None:
6
- return {"error": "Translation model not initialized."}
7
-
8
  encoded_text = translation_model.tokenizer(text, return_tensors="pt", padding=True).to(device)
9
  generated_tokens = translation_model.generate(input_ids=encoded_text['input_ids'], attention_mask=encoded_text['attention_mask'], forced_bos_token_id=translation_model.config.lang_code_to_id[target_language_code])
10
- translation = translation_model.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
11
-
12
- return {"translated_text": translation}
13
 
14
- def translation_api():
15
- data = request.get_json()
16
- text = data.get('text')
17
- target_lang = data.get('target_lang', 'es')
18
- source_lang = data.get('source_lang', 'en')
19
- if not text:
20
- return jsonify({"error": "Text is required"}), 400
21
  output = perform_translation(text, target_language_code=f'{target_lang}_XX', source_language_code=f'{source_lang}_XX')
22
- if "error" in output:
23
- return jsonify({"error": output["error"]}), 500
24
- return jsonify(output)
 
2
  from main import *
3
 
4
  def perform_translation(text, target_language_code='es_XX', source_language_code='en_XX'):
5
+ if translation_model is None: return {"error": "Translation model not initialized."}
 
 
6
  encoded_text = translation_model.tokenizer(text, return_tensors="pt", padding=True).to(device)
7
  generated_tokens = translation_model.generate(input_ids=encoded_text['input_ids'], attention_mask=encoded_text['attention_mask'], forced_bos_token_id=translation_model.config.lang_code_to_id[target_language_code])
8
+ translation = translation_model.tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]; return {"translated_text": translation}
 
 
9
 
10
+ def translation_api(text):
11
+ data = request.get_json(); text = data.get('text'); target_lang = data.get('target_lang', 'es'); source_lang = data.get('source_lang', 'en')
12
+ if not text: return jsonify({"error": "Text is required"}), 400
 
 
 
 
13
  output = perform_translation(text, target_language_code=f'{target_lang}_XX', source_language_code=f'{source_lang}_XX')
14
+ if "error" in output: return jsonify({"error": output["error"]}), 500
15
+ return output
 
tts_api.py CHANGED
@@ -1,26 +1,15 @@
1
- import os
2
  from flask import jsonify, send_file, request
3
  from main import *
4
- import torch
5
- import torchaudio
6
- import uuid
7
 
8
- def text_to_speech_func(text):
9
- if tts_model is None:
10
- return {"error": "TTS model not initialized."}
11
  input_tokens = tts_model.tokenizer(text, return_tensors="pt", padding=True).to(device)
12
- with torch.no_grad():
13
- audio_output = tts_model(input_tokens['input_ids'])
14
- temp_audio_path = f"temp_audio_{uuid.uuid4()}.wav"
15
- torchaudio.save(temp_audio_path, audio_output.cpu(), 16000)
16
- return temp_audio_path
17
 
18
- def tts_api():
19
- data = request.get_json()
20
- text = data.get('text')
21
- if not text:
22
- return jsonify({"error": "Text is required"}), 400
23
  output_file = text_to_speech_func(text)
24
- if "error" in output:
25
- return jsonify({"error": output["error"]}), 500
26
- return send_file(output_file, mimetype="audio/wav", as_attachment=True, download_name="output.wav")
 
 
1
  from flask import jsonify, send_file, request
2
  from main import *
3
+ import torch, torchaudio, uuid, io, base64
 
 
4
 
5
+ def text_to_speech_func(text, output_path="output_audio.wav"):
6
+ if tts_model is None: return {"error": "TTS model not initialized."}
 
7
  input_tokens = tts_model.tokenizer(text, return_tensors="pt", padding=True).to(device)
8
+ with torch.no_grad(): audio_output = tts_model(input_tokens['input_ids'])
9
+ torchaudio.save(output_path, audio_output.cpu(), 16000); return output_path
 
 
 
10
 
11
+ def tts_api(text):
 
 
 
 
12
  output_file = text_to_speech_func(text)
13
+ if isinstance(output_file, dict) and "error" in output_file: return {"error": output_file["error"]}
14
+ with open(output_file, 'rb') as f: audio_content = f.read()
15
+ audio_base64 = base64.b64encode(audio_content).decode('utf-8'); os.remove(output_file); return {"audio_base64": audio_base64, "mimetype": "audio/wav"}
xtts_api.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import jsonify, send_file, request
2
+ from main import *
3
+ import torch, torchaudio, io, base64, uuid, os
4
+
5
+ def xtts_clone_func(text, audio_sample_path, output_path="output_xtts_audio.wav"):
6
+ if xtts_model is None: return {"error": "XTTS model not initialized."}
7
+ language = "en"; speaker_id = 0
8
+ try:
9
+ with torch.no_grad(): wav = xtts_model.inference(text=text, language_id=language, speaker_id=speaker_id, voice_sample=audio_sample_path, temperature=0.7, length_penalty=1.0)
10
+ except Exception as e: return {"error": f"XTTS inference failed: {e}"}
11
+ torchaudio.save(output_path, wav, 24000); return output_path
12
+
13
+ def xtts_api(inputs):
14
+ text = inputs[0]; audio_sample_filepath = inputs[1]
15
+ temp_audio_path = f"temp_audio_{uuid.uuid4()}.wav"; os.rename(audio_sample_filepath, temp_audio_path)
16
+ output = xtts_clone_func(text, temp_audio_path); os.remove(temp_audio_path)
17
+ if isinstance(output, dict) and "error" in output: return {"error": output["error"]}
18
+ output_file = output
19
+ with open(output_file, 'rb') as f: audio_content = f.read()
20
+ audio_base64 = base64.b64encode(audio_content).decode('utf-8'); os.remove(output_file); return {"audio_base64": audio_base64, "mimetype": "audio/wav"}
21
+ --- END OF FILE xtts_api.py ---
xxx.py CHANGED
@@ -1,142 +1,71 @@
1
- import json
2
- import re
3
- import unicodedata
4
  from functools import lru_cache
5
- import wget
6
- import os
7
  from constants import GPT2_FOLDER, ENCODER_FILE, VOCAB_FILE, END_OF_TEXT_TOKEN
8
- import nltk
9
 
10
  @lru_cache()
11
  def bytes_to_unicode():
12
  bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
13
- cs = bs[:]
14
- n = 0
15
- for b in range(2**8):
16
- if b not in bs:
17
- bs.append(b)
18
- cs.append(2**8 + n)
19
- n += 1
20
- cs = [chr(n) for n in cs]
21
- return dict(zip(bs, cs))
22
 
23
  def get_pairs(word):
24
- pairs = set()
25
- prev_char = word[0]
26
- for char in word[1:]:
27
- pairs.add((prev_char, char))
28
- prev_char = char
29
- return pairs
30
 
31
  class Encoder:
32
  def __init__(self, encoder, bpe_merges, errors='replace', tokenize=None):
33
- self.encoder = encoder
34
- self.decoder = {v:k for k,v in self.encoder.items()}
35
- self.errors = errors
36
- self.byte_encoder = bytes_to_unicode()
37
- self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
38
  self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
39
- self.cache = {}
40
- if tokenize is None:
41
- self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\w+| ?[^\s\w]+|\s+(?!\S)|\s+""", re.UNICODE)
42
- self.tokenize = lambda text: re.findall(self.pat, text)
43
- else:
44
- self.tokenize = tokenize
45
 
46
  def bpe(self, token):
47
- if token in self.cache:
48
- return self.cache[token]
49
- word = tuple(token)
50
- pairs = get_pairs(word)
51
- if not pairs:
52
- return token
53
  while True:
54
  bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
55
- if bigram not in self.bpe_ranks:
56
- break
57
- first, second = bigram
58
- new_word = []
59
- i = 0
60
  while i < len(word):
61
- try:
62
- j = word.index(first, i)
63
- new_word.extend(word[i:j])
64
- i = j
65
- except ValueError:
66
- new_word.extend(word[i:])
67
- break
68
- if word[i] == first and i < len(word)-1 and word[i+1] == second:
69
- new_word.append(first+second)
70
- i += 2
71
- else:
72
- new_word.append(word[i])
73
- i += 1
74
- new_word = tuple(new_word)
75
- word = new_word
76
- if len(word) == 1:
77
- break
78
- else:
79
- pairs = get_pairs(word)
80
- word = ' '.join(word)
81
- self.cache[token] = word
82
- return word
83
 
84
  def encode(self, text):
85
- bpe_tokens = []
86
- normalized_text = unicodedata.normalize('NFKC', text)
87
- normalized_text = ''.join(c for c in normalized_text if c.isascii() and c != '\t')
88
- normalized_text = ''.join(c for c in normalized_text if not unicodedata.category(c).startswith('C'))
89
- for token in self.tokenize(normalized_text):
90
- token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8', errors='ignore'))
91
- bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
92
  return bpe_tokens
93
 
94
  def decode(self, tokens):
95
- text = ''.join([self.decoder[token] for token in tokens])
96
- text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors='replace')
97
  decoded_text = text.replace(" .", ".").replace(" ,", ",").replace(" '", "'").replace(" ?", "?").replace(" !", "!").replace(" :", ":").replace('\n', '<br>')
98
- sentences = nltk.sent_tokenize(decoded_text)
99
- return ' '.join(sentences).replace("<br>", "<br>\n")
100
 
101
  def get_encoder_gpt2():
102
- encoder_path = os.path.join(GPT2_FOLDER, ENCODER_FILE)
103
- vocab_path = os.path.join(GPT2_FOLDER, VOCAB_FILE)
104
- if not os.path.exists(GPT2_FOLDER):
105
- os.makedirs(GPT2_FOLDER)
106
- if not os.path.exists(encoder_path):
107
- wget.download(ENCODER_URL, out=encoder_path)
108
- if not os.path.exists(vocab_path):
109
- wget.download(VOCAB_URL, out=vocab_path)
110
-
111
- with open(encoder_path, 'r') as f:
112
- encoder = json.load(f)
113
- with open(vocab_path, 'r', encoding="utf-8") as f:
114
- bpe_data = f.read()
115
- bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
116
- encoder_obj = Encoder(encoder=encoder, bpe_merges=bpe_merges)
117
- encoder_obj.encoder[END_OF_TEXT_TOKEN] = len(encoder_obj.encoder)
118
- encoder_obj.decoder[len(encoder_obj.decoder)] = END_OF_TEXT_TOKEN
119
- return encoder_obj
120
 
121
  def get_codegen_tokenizer_pure(vocab_file, merges_file):
122
- vocab = json.load(open(vocab_file))
123
- merges = open(merges_file, 'r', encoding="utf-8").read().split('\n')[1:-1]
124
- bpe_merges = [tuple(m.split()) for m in merges]
125
- byte_encoder = bytes_to_unicode()
126
- byte_decoder = {v: k for k, v in byte_encoder.items()}
127
- tokenizer_regex = re.compile(r'''<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+''')
128
- tokenize = lambda text: re.findall(tokenizer_regex, text)
129
- encoder_obj = Encoder(
130
- encoder=vocab,
131
- bpe_merges=bpe_merges,
132
- byte_encoder=byte_encoder,
133
- byte_decoder=byte_decoder,
134
- tokenize=tokenize
135
- )
136
- return encoder_obj
137
-
138
- def codegen_tokenize(text, tokenizer):
139
- return tokenizer.encode(text)
140
 
141
- def codegen_decode(tokens, tokenizer):
142
- return tokenizer.decode(tokens)
 
1
+ import json, re, unicodedata
 
 
2
  from functools import lru_cache
3
+ import wget, os
 
4
  from constants import GPT2_FOLDER, ENCODER_FILE, VOCAB_FILE, END_OF_TEXT_TOKEN
 
5
 
6
  @lru_cache()
7
  def bytes_to_unicode():
8
  bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
9
+ cs = bs[:]; n = 0
10
+ for b in range(2**8): if b not in bs: bs.append(b); cs.append(2**8 + n); n += 1
11
+ cs = [chr(n) for n in cs]; return dict(zip(bs, cs))
 
 
 
 
 
 
12
 
13
  def get_pairs(word):
14
+ pairs = set(); prev_char = word[0]
15
+ for char in word[1:]: pairs.add((prev_char, char)); prev_char = char; return pairs
 
 
 
 
16
 
17
  class Encoder:
18
  def __init__(self, encoder, bpe_merges, errors='replace', tokenize=None):
19
+ self.encoder = encoder; self.decoder = {v:k for k,v in self.encoder.items()}; self.errors = errors
20
+ self.byte_encoder = bytes_to_unicode(); self.byte_decoder = {v:k for k, v in self.byte_encoder.items()}
 
 
 
21
  self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
22
+ self.cache = {};
23
+ if tokenize is None: self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\w+| ?[^\s\w]+|\s+(?!\S)|\s+""", re.UNICODE); self.tokenize = lambda text: re.findall(self.pat, text)
24
+ else: self.tokenize = tokenize
 
 
 
25
 
26
  def bpe(self, token):
27
+ if token in self.cache: return self.cache[token]
28
+ word = tuple(token); pairs = get_pairs(word)
29
+ if not pairs: return token
 
 
 
30
  while True:
31
  bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
32
+ if bigram not in self.bpe_ranks: break
33
+ first, second = bigram; new_word = []; i = 0
 
 
 
34
  while i < len(word):
35
+ try: j = word.index(first, i); new_word.extend(word[i:j]); i = j
36
+ except ValueError: new_word.extend(word[i:]); break
37
+ if word[i] == first and i < len(word)-1 and word[i+1] == second: new_word.append(first+second); i += 2
38
+ else: new_word.append(word[i]); i += 1
39
+ new_word = tuple(new_word); word = new_word
40
+ if len(word) == 1: break
41
+ else: pairs = get_pairs(word)
42
+ word = ' '.join(word); self.cache[token] = word; return word
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  def encode(self, text):
45
+ bpe_tokens = []; normalized_text = unicodedata.normalize('NFKC', text); normalized_text = ''.join(c for c in normalized_text if c.isascii() and c != '\t'); normalized_text = ''.join(c for c in normalized_text if not unicodedata.category(c).startswith('C'))
46
+ for token in self.tokenize(normalized_text): token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8', errors='ignore')); bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
 
 
 
 
 
47
  return bpe_tokens
48
 
49
  def decode(self, tokens):
50
+ text = ''.join([self.decoder[token] for token in tokens]); text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors='replace')
 
51
  decoded_text = text.replace(" .", ".").replace(" ,", ",").replace(" '", "'").replace(" ?", "?").replace(" !", "!").replace(" :", ":").replace('\n', '<br>')
52
+ sentences = nltk.sent_tokenize(decoded_text); return ' '.join(sentences).replace("<br>", "<br>\n")
 
53
 
54
  def get_encoder_gpt2():
55
+ encoder_path = os.path.join(GPT2_FOLDER, ENCODER_FILE); vocab_path = os.path.join(GPT2_FOLDER, VOCAB_FILE)
56
+ if not os.path.exists(GPT2_FOLDER): os.makedirs(GPT2_FOLDER)
57
+ if not os.path.exists(encoder_path): wget.download(ENCODER_URL, out=encoder_path)
58
+ if not os.path.exists(vocab_path): wget.download(VOCAB_URL, out=vocab_path)
59
+ with open(encoder_path, 'r') as f: encoder = json.load(f)
60
+ with open(vocab_path, 'r', encoding="utf-8") as f: bpe_data = f.read()
61
+ bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]; encoder_obj = Encoder(encoder=encoder, bpe_merges=bpe_merges)
62
+ encoder_obj.encoder[END_OF_TEXT_TOKEN] = len(encoder_obj.encoder); encoder_obj.decoder[len(encoder_obj.decoder)] = END_OF_TEXT_TOKEN; return encoder_obj
 
 
 
 
 
 
 
 
 
 
63
 
64
  def get_codegen_tokenizer_pure(vocab_file, merges_file):
65
+ vocab = json.load(open(vocab_file)); merges = open(merges_file, 'r', encoding="utf-8").read().split('\n')[1:-1]; bpe_merges = [tuple(m.split()) for m in merges]
66
+ byte_encoder = bytes_to_unicode(); byte_decoder = {v: k for k, v in byte_encoder.items()}
67
+ tokenizer_regex = re.compile(r'''<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+'''); tokenize = lambda text: re.findall(tokenizer_regex, text)
68
+ encoder_obj = Encoder(encoder=vocab, bpe_merges=bpe_merges, byte_encoder=byte_encoder, byte_decoder=byte_decoder, tokenize=tokenize); return encoder_obj
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ def codegen_tokenize(text, tokenizer): return tokenizer.encode(text)
71
+ def codegen_decode(tokens, tokenizer): return tokenizer.decode(tokens)